-
Notifications
You must be signed in to change notification settings - Fork 345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add graph-based config and model-- ultragcn #251
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import logging | ||
|
||
import tensorflow as tf | ||
|
||
from easy_rec.python.layers import dnn | ||
from easy_rec.python.model.easy_rec_model import EasyRecModel | ||
|
||
from easy_rec.python.protos.ultragcn_pb2 import ULTRAGCN as ULTRAGCNConfig # NOQA | ||
|
||
if tf.__version__ >= '2.0': | ||
tf = tf.compat.v1 | ||
|
||
|
||
class ULTRAGCN(EasyRecModel): | ||
|
||
def __init__(self, | ||
model_config, | ||
feature_configs, | ||
features, | ||
labels=None, | ||
is_training=False): | ||
super(ULTRAGCN, self).__init__(model_config, feature_configs, features, labels, | ||
is_training) | ||
self._model_config = model_config.ultragcn | ||
assert isinstance(self._model_config, ULTRAGCNConfig) | ||
self._user_num = self._model_config.user_num | ||
self._item_num = self._model_config.item_num | ||
self._emb_dim = self._model_config.output_dim | ||
self._i2i_weight = self._model_config.i2i_weight | ||
self._neg_weight = self._model_config.neg_weight | ||
self._l2_weight = self._model_config.l2_weight | ||
self._user_emb = None | ||
self._item_emb = None | ||
|
||
if features.get('features') is not None: | ||
self._user_ids = features.get('features')[0] | ||
self._user_degrees = features.get('features')[1] | ||
self._item_ids = features.get('features')[2] | ||
self._item_degrees = features.get('features')[3] | ||
self._nbr_ids = features.get('features')[4] | ||
self._nbr_weights = features.get('features')[5] | ||
self._neg_ids = features.get('features')[6] | ||
else: | ||
self._user_ids = features.get('id') | ||
self._user_degrees = None | ||
self._item_ids = features.get('id') | ||
self._item_degrees = None | ||
self._nbr_ids = features.get('id') | ||
self._nbr_weights = None | ||
self._neg_ids = features.get('id') | ||
|
||
_user_emb = tf.get_variable("user_emb", | ||
[self._user_num, self._emb_dim], | ||
trainable=True) | ||
_item_emb = tf.get_variable("item_emb", | ||
[self._item_num, self._emb_dim], | ||
trainable=True) | ||
|
||
self._user_emb = tf.convert_to_tensor(_user_emb) | ||
self._item_emb = tf.convert_to_tensor(_item_emb) | ||
|
||
def build_predict_graph(self): | ||
user_emb = tf.nn.embedding_lookup(self._user_emb, self._user_ids) | ||
item_emb = tf.nn.embedding_lookup(self._item_emb, self._item_ids) | ||
nbr_emb = tf.nn.embedding_lookup(self._item_emb, self._nbr_ids) | ||
neg_emb = tf.nn.embedding_lookup(self._item_emb, self._neg_ids) | ||
self._prediction_dict['user_emb'] = user_emb | ||
self._prediction_dict['item_emb'] = item_emb | ||
self._prediction_dict['nbr_emb'] = nbr_emb | ||
self._prediction_dict['neg_emb'] = neg_emb | ||
self._prediction_dict['user_embedding'] = tf.reduce_join( | ||
tf.as_string(user_emb), axis=-1, separator=',') | ||
self._prediction_dict['item_embedding'] = tf.reduce_join( | ||
tf.as_string(item_emb), axis=-1, separator=',') | ||
|
||
return self._prediction_dict | ||
|
||
def build_loss_graph(self): | ||
# UltraGCN base u2i | ||
pos_logit = tf.reduce_sum(self._prediction_dict['user_emb'] * self._prediction_dict['item_emb'], axis=-1) | ||
true_xent = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(pos_logit), logits=pos_logit) | ||
neg_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['neg_emb'], axis=-1) | ||
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.zeros_like(neg_logit), logits=neg_logit) | ||
loss_u2i = tf.reduce_sum(true_xent * (1 + 1 / tf.sqrt(self._user_degrees * self._item_degrees))) \ | ||
+ self._neg_weight * tf.reduce_sum(tf.reduce_mean(negative_xent, axis=-1)) | ||
# UltraGCN i2i | ||
nbr_logit = tf.reduce_sum(tf.expand_dims(self._prediction_dict['user_emb'], axis=1) * self._prediction_dict['nbr_emb'], axis=-1) # [batch_size, nbr_num] | ||
nbr_xent = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(nbr_logit), logits=nbr_logit) | ||
loss_i2i = tf.reduce_sum(nbr_xent * (1 + self._nbr_weights)) | ||
# regularization | ||
loss_l2 = tf.nn.l2_loss(self._prediction_dict['user_emb']) + tf.nn.l2_loss(self._prediction_dict['item_emb']) +\ | ||
tf.nn.l2_loss(self._prediction_dict['nbr_emb']) + tf.nn.l2_loss(self._prediction_dict['neg_emb']) | ||
|
||
loss = loss_u2i + self._i2i_weight * loss_i2i + self._l2_weight * loss_l2 | ||
return {'cross_entropy': loss} | ||
|
||
def build_metric_graph(self, eval_config): | ||
return {} | ||
|
||
def get_outputs(self): | ||
# emb_1 = tf.reduce_join(tf.as_string(self._prediction_dict['user_embedding']), axis=-1, separator=',') | ||
# emb_2 = tf.reduce_join(tf.as_string(self._prediction_dict['item_embedding'] ), axis=-1, separator=',') | ||
return ['user_embedding','item_embedding'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议跟向量召回保持一致,user_emb, item_emb |
||
|
||
|
||
def build_metric_graph(self, eval_config): | ||
metric_dict = {} | ||
for metric in eval_config.metrics_set: | ||
if metric.WhichOneof('metric') == 'recall_at_topk': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. metric会生效么?logits来自哪里? |
||
logits = self._prediction_dict['logits'] | ||
label = tf.zeros_like(logits[:, :1], dtype=tf.int64) | ||
metric_dict['recall_at_top%d' % | ||
metric.recall_at_topk.topk] = metrics.recall_at_k( | ||
label, logits, metric.recall_at_topk.topk) | ||
else: | ||
ValueError('invalid metric type: %s' % str(metric)) | ||
return metric_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,11 @@ message BinaryDataInput { | |
repeated string dense_path = 2; | ||
repeated string label_path = 3; | ||
} | ||
|
||
message GraphLearnInput { | ||
optional string user_node_input = 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些是可枚举的么?node_name, node_input,这种kv的形式是不是通用一些? |
||
optional string item_node_input = 2; | ||
optional string u2i_edge_input = 3; | ||
optional string i2i_edge_input = 4; | ||
optional string u2u_edge_input = 5; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
syntax = "proto2"; | ||
package protos; | ||
|
||
message ULTRAGCN { | ||
optional float l2_regularization = 2 [default=0.0]; | ||
optional uint32 user_num = 3 [default=1]; | ||
optional uint32 item_num = 4 [default=1]; | ||
optional uint32 output_dim = 5 [default=1]; | ||
optional uint32 nbr_num = 6 [default=1]; | ||
optional uint32 neg_num = 7 [default=1]; | ||
optional float neg_weight = 8 [default=0.0]; | ||
optional float i2i_weight = 9 [default=0.0]; | ||
optional float l2_weight = 10 [default=0.0]; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import json | ||
import logging | ||
|
||
from easy_rec.python.utils import pai_util | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 跟core/sampler.py中的graph init复用 |
||
|
||
def graph_init(graph, tf_config=None): | ||
if tf_config: | ||
if isinstance(tf_config, str) or isinstance(tf_config, type(u'')): | ||
tf_config = json.loads(tf_config) | ||
if 'ps' in tf_config['cluster']: | ||
# ps mode | ||
logging.info('ps mode') | ||
ps_count = len(tf_config['cluster']['ps']) | ||
evaluator_cnt = 1 | ||
# evaluator_cnt = 1 if pai_util.has_evaluator() else 0 | ||
# if evaluator_cnt == 0: | ||
# logging.warning( | ||
# 'evaluator is not set as an client in GraphLearn,' | ||
# 'if you actually set evaluator in TF_CONFIG, please do: export' | ||
# ' HAS_EVALUATOR=1.') | ||
task_count = len(tf_config['cluster']['worker']) + 1 + evaluator_cnt | ||
cluster = {'server_count': ps_count, 'client_count': task_count} | ||
if tf_config['task']['type'] in ['chief', 'master']: | ||
graph.init(cluster=cluster, job_name='client', task_index=0) | ||
elif tf_config['task']['type'] == 'worker': | ||
graph.init( | ||
cluster=cluster, | ||
job_name='client', | ||
task_index=tf_config['task']['index'] + 2) | ||
# TODO(hongsheng.jhs): check cluster has evaluator or not? | ||
elif tf_config['task']['type'] == 'evaluator': | ||
graph.init( | ||
cluster=cluster, | ||
job_name='client', | ||
task_index=tf_config['task']['index'] + 1) | ||
elif tf_config['task']['type'] == 'ps': | ||
graph.init( | ||
cluster=cluster, | ||
job_name='server', | ||
task_index=tf_config['task']['index']) | ||
else: | ||
# worker mode | ||
logging.info('worker mode') | ||
task_count = len(tf_config['cluster']['worker']) + 1 | ||
if tf_config['task']['type'] in ['chief', 'master']: | ||
graph.init(task_index=0, task_count=task_count) | ||
elif tf_config['task']['type'] == 'worker': | ||
graph.init( | ||
task_index=tf_config['task']['index'] + evaluator_cnt, | ||
task_count=task_count) | ||
else: | ||
# local mode | ||
graph.init() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
model_dir: "experiments/graph_on_ultargcn_ckpt" | ||
|
||
graph_train_input_path { | ||
user_node_input: "easy_rec/data/graph_data/gl_user.txt" | ||
item_node_input: "easy_rec/data/graph_data/gl_item.txt" | ||
u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt" | ||
i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt" | ||
} | ||
graph_eval_input_path { | ||
user_node_input: "easy_rec/data/graph_data/gl_user.txt" | ||
item_node_input: "easy_rec/data/graph_data/gl_item.txt" | ||
u2i_edge_input: "easy_rec/data/graph_data/gl_train.txt" | ||
i2i_edge_input: "easy_rec/data/graph_data/gl_i2i.txt" | ||
} | ||
|
||
train_config { | ||
log_step_count_steps: 100 | ||
optimizer_config: { | ||
adam_optimizer: { | ||
learning_rate: { | ||
constant_learning_rate { | ||
learning_rate: 1e-3 | ||
} | ||
} | ||
} | ||
use_moving_average: false | ||
} | ||
save_checkpoints_steps: 2000 | ||
save_summary_steps: 100 | ||
sync_replicas: true | ||
num_steps: 20000 | ||
} | ||
|
||
eval_config { | ||
} | ||
|
||
data_config { | ||
input_fields { | ||
input_name: 'id' | ||
input_type: INT64 | ||
} | ||
ultra_gcn_sampler { | ||
nbr_num: 10 | ||
neg_num: 10 | ||
neg_sampler: 'random' | ||
} | ||
|
||
batch_size: 512 | ||
num_epochs: 10 | ||
prefetch_size: 5 | ||
input_type: GraphInput | ||
} | ||
|
||
feature_config: { | ||
features: { | ||
input_names: 'id' | ||
feature_type: IdFeature | ||
embedding_dim: 128 | ||
hash_bucket_size: 100000 | ||
} | ||
} | ||
|
||
model_config:{ | ||
model_class: "ULTRAGCN" | ||
ultragcn { | ||
l2_regularization: 1e-6 | ||
user_num: 52643 | ||
item_num: 91599 | ||
output_dim: 128 | ||
nbr_num: 10 | ||
neg_num: 10 | ||
neg_weight: 10 | ||
i2i_weight: 2.75 | ||
l2_weight: 1e-4 | ||
} | ||
loss_type: SOFTMAX_CROSS_ENTROPY | ||
embedding_regularization: 0.0 | ||
} | ||
|
||
export_config { | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这边都是叫同名的id?