diff --git a/deepctr_torch/inputs.py b/deepctr_torch/inputs.py index 8058a6fc..ddb584a8 100644 --- a/deepctr_torch/inputs.py +++ b/deepctr_torch/inputs.py @@ -4,21 +4,26 @@ Weichen Shen,wcshen1994@163.com """ -from collections import OrderedDict, namedtuple +from collections import OrderedDict, namedtuple, defaultdict from itertools import chain import torch +import torch.nn as nn from .layers.utils import concat_fun +from .layers.sequence import SequencePoolingLayer +DEFAULT_GROUP_NAME = "default_group" -class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding'])): + +class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding', 'group_name'])): __slots__ = () - def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=None, embedding=True): + def __new__(cls, name, dimension, use_hash=False, dtype="int32", + embedding_name=None, embedding=True, group_name=DEFAULT_GROUP_NAME): if embedding and embedding_name is None: embedding_name = name - return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding) + return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding, group_name) class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])): @@ -30,15 +35,16 @@ def __new__(cls, name, dimension=1, dtype="float32"): class VarLenSparseFeat(namedtuple('VarLenFeat', ['name', 'dimension', 'maxlen', 'combiner', 'use_hash', 'dtype', 'embedding_name', - 'embedding'])): + 'embedding', 'group_name'])): __slots__ = () def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype="float32", embedding_name=None, - embedding=True): + embedding=True, group_name=DEFAULT_GROUP_NAME): if embedding_name is None: embedding_name = name return super(VarLenSparseFeat, cls).__new__(cls, name, dimension, maxlen, combiner, use_hash, dtype, - embedding_name, embedding) + + embedding_name, embedding, group_name) def get_feature_names(feature_columns): @@ -50,6 +56,9 @@ def get_inputs_list(inputs): def build_input_features(feature_columns): + + # Return OrderedDict: {feature_name:(start, start+dimension)} + features = OrderedDict() start = 0 @@ -92,4 +101,70 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list): elif len(dense_value_list) > 0: return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1) else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + +def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(), + mask_feat_list=(), to_list=False): + """ + Args: + sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} + sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} + sparse_feature_columns: list, sparse features + return_feat_list: list, names of feature to be returned, defualt () -> return all features + mask_feat_list, list, names of feature to be masked in hash transform + Return: + group_embedding_dict: defaultdict(list) + """ + group_embedding_dict = defaultdict(list) + for fc in sparse_feature_columns: + feature_name = fc.name + embedding_name = fc.embedding_name + if (len(return_feat_list) == 0 or feature_name in return_feat_list): + if fc.use_hash: + # lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))( + # sparse_input_dict[feature_name]) + # TODO: add hash function + lookup_idx = sparse_input_dict[feature_name] + else: + lookup_idx = sparse_input_dict[feature_name] + + group_embedding_dict[fc.group_name].append(sparse_embedding_dict[embedding_name](lookup_idx)) + if to_list: + return list(chain.from_iterable(group_embedding_dict.values())) + return group_embedding_dict + + +def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_feature_columns): + varlen_embedding_vec_dict = {} + for fc in varlen_sparse_feature_columns: + feature_name = fc.name + embedding_name = fc.embedding_name + if fc.use_hash: + # lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name]) + # TODO: add hash function + lookup_idx = sequence_input_dict[feature_name] + else: + lookup_idx = sequence_input_dict[feature_name] + varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx) + return varlen_embedding_vec_dict + + +def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_columns, to_list=False): + pooling_vec_list = defaultdict(list) + for fc in varlen_sparse_feature_columns: + feature_name = fc.name + combiner = fc.combiner + feature_length_name = fc.length_name + if feature_length_name is not None: + seq_input = embedding_dict[feature_name] + vec = SequencePoolingLayer(combiner)([seq_input, features[feature_length_name]]) + else: + seq_input = embedding_dict[feature_name] + vec = SequencePoolingLayer(combiner)(seq_input) + pooling_vec_list[fc.group_name].append(vec) + + if to_list: + return chain.from_iterable(pooling_vec_list.values()) + + return pooling_vec_list diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py new file mode 100644 index 00000000..1473d4b0 --- /dev/null +++ b/deepctr_torch/layers/activation.py @@ -0,0 +1,92 @@ +# -*- coding:utf-8 -*- +""" + +Author: + Yuef Zhang + +""" +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +class Dice(nn.Module): + """The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data. + + Input shape: + - 2 dims: [batch_size, embedding_size(features)] + - 3 dims: [batch_size, num_features, embedding_size(features)] + + Output shape: + - Same shape as the input. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + - https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch + """ + def __init__(self, num_features, dim=2, epsilon=1e-9): + super(Dice, self).__init__() + assert dim == 2 or dim == 3 + self.bn = nn.BatchNorm1d(num_features, eps=epsilon) + self.sigmoid = nn.Sigmoid() + self.dim = dim + + if self.dim == 2: + self.alpha = torch.zeros((num_features,)).to(device) + else: + self.alpha = torch.zeros((num_features, 1)).to(device) + + def forward(self, x): + # x shape: [batch_size, num_features, embedding_size(features)] + assert x.dim() == 2 or x.dim() == 3 + + if self.dim == 2: + x_p = self.sigmoid(self.bn(x)) + out = self.alpha * (1 - x_p) * x + x_p * x + else: + x = torch.transpose(x, 1, 2) + x_p = self.sigmoid(self.bn(x)) + out = self.alpha * (1 - x_p) * x + x_p * x + out = torch.transpose(out, 1, 2) + + return out + + +def activation_layer(act_name, hidden_size=None, dice_dim=2): + """Construct activation layers + + Args: + act_name: str or nn.Module, name of activation function + hidden_size: int, used for Dice activation + dice_dim: int, used for Dice activation + Return: + act_layer: activation layer + """ + if isinstance(act_name, str): + if act_name.lower() == 'relu' or 'linear': + act_layer = nn.ReLU(inplace=True) + elif act_name.lower() == 'dice': + assert dice_dim + act_layer = Dice(hidden_size, dice_dim) + elif act_name.lower() == 'prelu': + act_layer = nn.PReLU() + elif issubclass(act_name, nn.Module): + act_layer = act_name() + else: + raise NotImplementedError + + return act_layer + + +if __name__ == "__main__": + torch.manual_seed(7) + a = Dice(3) + b = torch.rand((5, 3)) + c = a(b) + print(c.size()) + print('b:', b) + print('c:', c) diff --git a/deepctr_torch/layers/core.py b/deepctr_torch/layers/core.py index 4ed2e292..bc10c14b 100644 --- a/deepctr_torch/layers/core.py +++ b/deepctr_torch/layers/core.py @@ -4,6 +4,67 @@ import torch.nn as nn import torch.nn.functional as F +from .activation import activation_layer + + +class LocalActivationUnit(nn.Module): + """The LocalActivationUnit used in DIN with which the representation of + user interests varies adaptively given different candidate items. + + Input shape + - A list of two 3D tensor with shape: ``(batch_size, 1, embedding_size)`` and ``(batch_size, T, embedding_size)`` + + Output shape + - 3D tensor with shape: ``(batch_size, T, 1)``. + + Arguments + - **hidden_units**:list of positive integer, the attention net layer number and units in each layer. + + - **activation**: Activation function to use in attention net. + + - **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix of attention net. + + - **dropout_rate**: float in [0,1). Fraction of the units to dropout in attention net. + + - **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net. + + - **seed**: A Python integer to use as random seed. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + """ + + def __init__(self, hidden_units=[80, 40], embedding_dim=4, activation='Dice', dropout_rate=0, use_bn=False): + super(LocalActivationUnit, self).__init__() + + self.dnn1 = DNN(inputs_dim=4*embedding_dim, + hidden_units=hidden_units, + activation=activation, + dropout_rate=0.5, + use_bn=use_bn, + dice_dim=3) + + # self.dnn2 = DNN(inputs_dim=hidden_units[-1], + # hidden_units=[1], + # activation=activation, + # use_bn=use_bn, + # dice_dim=3) + + self.dense = nn.Linear(hidden_units[-1], 1) + + def forward(self, query, user_behavior): + # query ad : size -> batch_size * 1 * embedding_size + # user behavior : size -> batch_size * time_seq_len * embedding_size + + user_behavior_len = user_behavior.size(1) + queries = torch.cat([query for _ in range(user_behavior_len)], dim=1) + + attention_input = torch.cat([queries, user_behavior, queries-user_behavior, queries*user_behavior], dim=-1) + attention_output = self.dnn1(attention_input) + attention_output = self.dense(attention_output) + + return attention_output + class DNN(nn.Module): """The Multi Layer Percetron @@ -30,10 +91,9 @@ class DNN(nn.Module): - **seed**: A Python integer to use as random seed. """ - def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropout_rate=0, use_bn=False, - init_std=0.0001, seed=1024, device='cpu'): + def __init__(self, inputs_dim, hidden_units, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False, + init_std=0.0001, dice_dim=3, seed=1024, device='cpu'): super(DNN, self).__init__() - self.activation = activation self.dropout_rate = dropout_rate self.dropout = nn.Dropout(dropout_rate) self.seed = seed @@ -49,6 +109,10 @@ def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropou if self.use_bn: self.bn = nn.ModuleList( [nn.BatchNorm1d(hidden_units[i + 1]) for i in range(len(hidden_units) - 1)]) + + self.activation_layers = nn.ModuleList( + [activation_layer(activation, hidden_units[i + 1], dice_dim) for i in range(len(hidden_units) - 1)]) + for name, tensor in self.linears.named_parameters(): if 'weight' in name: nn.init.normal_(tensor, mean=0, std=init_std) @@ -65,7 +129,7 @@ def forward(self, inputs): if self.use_bn: fc = self.bn[i](fc) - fc = self.activation(fc) + fc = self.activation_layers[i](fc) fc = self.dropout(fc) deep_input = fc diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py index 54f0351b..c2594c32 100644 --- a/deepctr_torch/layers/interaction.py +++ b/deepctr_torch/layers/interaction.py @@ -6,7 +6,7 @@ from ..layers.core import Conv2dSame from ..layers.sequence import KMaxPooling - +from ..layers.activation import activation_layer class FM(nn.Module): """Factorization Machine models pairwise (order-2) feature interactions @@ -163,14 +163,14 @@ class CIN(nn.Module): Arguments - **filed_size** : Positive integer, number of feature groups. - **layer_size** : list of int.Feature maps in each layer. - - **activation** : activation function used on feature maps. + - **activation** : activation function name used on feature maps. - **split_half** : bool.if set to False, half of the feature maps in each hidden will connect to output unit. - **seed** : A Python integer to use as random seed. References - [Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018.] (https://arxiv.org/pdf/1803.05170.pdf) """ - def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_half=True, l2_reg=1e-5, seed=1024, + def __init__(self, field_size, layer_size=(128, 128), activation='relu', split_half=True, l2_reg=1e-5, seed=1024, device='cpu'): super(CIN, self).__init__() if len(layer_size) == 0: @@ -180,7 +180,7 @@ def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_h self.layer_size = layer_size self.field_nums = [field_size] self.split_half = split_half - self.activation = activation + self.activation = activation_layer(activation) self.l2_reg = l2_reg self.seed = seed diff --git a/deepctr_torch/layers/sequence.py b/deepctr_torch/layers/sequence.py index 283328df..1a538e5e 100644 --- a/deepctr_torch/layers/sequence.py +++ b/deepctr_torch/layers/sequence.py @@ -1,6 +1,124 @@ import torch import torch.nn as nn +import numpy as np + +from .core import LocalActivationUnit + + +class SequencePoolingLayer(nn.Module): + """The SequencePoolingLayer is used to apply pooling operation(sum,mean,max) on variable-length sequence feature/multi-value feature. + + Input shape + - A list of two tensor [seq_value,seq_len] + + - seq_value is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` + + - seq_len is a 2D tensor with shape : ``(batch_size, 1)``,indicate valid length of each sequence. + + Output shape + - 3D tensor with shape: ``(batch_size, 1, embedding_size)``. + + Arguments + - **mode**:str.Pooling operation to be used,can be sum,mean or max. + + """ + + def __init__(self, mode='mean'): + super(SequencePoolingLayer).__init__() + + if mode not in ['sum', 'mean', 'max']: + raise ValueError('parameter mode should in [sum, mean, max]') + self.mode = mode + self.eps = torch.FloatTensor([1e-8]) + + def _sequence_mask(self, lengths, maxlen=None, dtype=torch.bool): + # Returns a mask tensor representing the first N positions of each cell. + if maxlen is None: + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + + mask.type(dtype) + return mask + + def forward(self, seq_value_len_list): + uiseq_embed_list, user_behavior_length = seq_value_len_list # [B, T, E], [B, 1] + mask = self._sequence_mask(user_behavior_length, dtype=torch.float32) # [B, 1, maxlen] + mask = torch.transpose(mask, 1, 2) # [B, maxlen, 1] + + embedding_size = uiseq_embed_list.shape[-1] + + mask = torch.repeat_interleave(mask, embedding_size, dim=2) # [B, maxlen, E] + + uiseq_embed_list *= mask # [B, maxlen, E] + hist = uiseq_embed_list + + if self.mode == 'max': + res = torch.max(hist, dim=1, keepdim=True) + elif self.mode == 'mean': + res = torch.max(hist, dim=1, keepdim=False) + res = torch.div(res, user_behavior_length.type(torch.float32) + self.eps) + res = torch.unsqueeze(res, dim=1) + elif self.mode == 'sum': + res = torch.max(hist, dim=1, keepdim=False) + res = torch.unsqueeze(res, dim=1) + + return res + + +class AttentionSequencePoolingLayer(nn.Module): + """The Attentional sequence pooling operation used in DIN. + + Input shape + - A list of three tensor: [query,keys,keys_length] + + - query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)`` + + - keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` + + - keys_length is a 2D tensor with shape: ``(batch_size, 1)`` + + Output shape + - 3D tensor with shape: ``(batch_size, 1, embedding_size)`` + + Arguments + - **att_hidden_units**: List of positive integer, the attention net layer number and units in each layer. + + - **embedding_dim**: Dimension of the input embeddings. + + - **activation**: Activation function to use in attention net. + + - **weight_normalization**: bool.Whether normalize the attention score of local activation unit. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + """ + def __init__(self, att_hidden_units=[80, 40], embedding_dim=4, activation='Dice', weight_normalization=False): + super(AttentionSequencePoolingLayer, self).__init__() + + self.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim, activation=activation) + + def forward(self, query, keys, keys_length): + # query: [B, 1, E], keys: [B, T, E], keys_length: [B, 1] + # TODO: Mini-batch aware regularization in originial paper [Zhou G, et al. 2018] is not implemented here. As the authors mentioned + # it is not a must for small dataset as the open-sourced ones. + attention_score = self.local_att(query, keys) + attention_score = torch.transpose(attention_score, 1, 2) # B * 1 * T + + # define mask by length + keys_length = keys_length.type(torch.LongTensor) + mask = torch.arange(keys.size(1))[None, :] < keys_length[:, None] # [1, T] < [B, 1, 1] -> [B, 1, T] + + # mask + output = torch.mul(attention_score, mask.type(torch.FloatTensor)) # [B, 1, T] + + # multiply weight + output = torch.matmul(output, keys) # [B, 1, E] + + return output + class KMaxPooling(nn.Module): """K Max pooling that selects the k biggest value along the specific axis. diff --git a/deepctr_torch/models/afm.py b/deepctr_torch/models/afm.py index 1e528270..86b46011 100644 --- a/deepctr_torch/models/afm.py +++ b/deepctr_torch/models/afm.py @@ -42,7 +42,7 @@ def __init__(self,linear_feature_columns, dnn_feature_columns, embedding_size=8, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, l2_reg_dnn=0, init_std=init_std, seed=seed, - dnn_dropout=0, dnn_activation=F.relu, + dnn_dropout=0, dnn_activation='relu', task=task, device=device) self.use_attention = use_attention diff --git a/deepctr_torch/models/autoint.py b/deepctr_torch/models/autoint.py index 58965d3b..af8ed79c 100644 --- a/deepctr_torch/models/autoint.py +++ b/deepctr_torch/models/autoint.py @@ -39,7 +39,7 @@ class AutoInt(BaseModel): def __init__(self, dnn_feature_columns, embedding_size=8, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, - dnn_hidden_units=(256, 128), dnn_activation=F.relu, + dnn_hidden_units=(256, 128), dnn_activation='relu', l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024, task='binary', device='cpu'): diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index 19a6eee6..87814b6e 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -132,7 +132,8 @@ def fit(self, x=None, initial_epoch=0, validation_split=0., validation_data=None, - shuffle=True, ): + shuffle=True, + use_double=False,): """ :param x: Numpy array of training data (if the model has a single input), or list of Numpy arrays (if the model has multiple inputs).If input layers in the model are named, you can also pass a @@ -230,8 +231,13 @@ def fit(self, x=None, for name, metric_fun in self.metrics.items(): if name not in train_result: train_result[name] = [] - train_result[name].append(metric_fun( - y.cpu().data.numpy(), y_pred.cpu().data.numpy())) + + if use_double: + train_result[name].append(metric_fun( + y.cpu().data.numpy(), y_pred.cpu().data.numpy().astype("float64"))) + else: + train_result[name].append(metric_fun( + y.cpu().data.numpy(), y_pred.cpu().data.numpy())) except KeyboardInterrupt: t.close() @@ -271,7 +277,7 @@ def evaluate(self, x, y, batch_size=256): eval_result[name] = metric_fun(y, pred_ans) return eval_result - def predict(self, x, batch_size=256): + def predict(self, x, batch_size=256, use_double=False): """ :param x: The input data, as a Numpy array (or list of Numpy arrays if the model has multiple inputs). @@ -298,7 +304,11 @@ def predict(self, x, batch_size=256): y_pred = model(x).cpu().data.numpy() # .squeeze() pred_ans.append(y_pred) - return np.concatenate(pred_ans) + + if use_double: + return np.concatenate(pred_ans).astype("float64") + else: + return np.concatenate(pred_ans) def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True): @@ -329,7 +339,8 @@ def input_from_feature_columns(self, X, feature_columns, embedding_dict, support return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list def create_embedding_matrix(self, feature_columns, embedding_size, init_std=0.0001, sparse=False): - + # Return nn.ModuleDict: for sparse features, {embedding_name: nn.Embedding} + # for varlen sparse features, {embedding_name: nn.EmbeddingBag} sparse_feature_columns = list( filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else [] @@ -424,12 +435,24 @@ def _get_loss_func(self, loss): loss_func = loss return loss_func - def _get_metrics(self, metrics): + def _log_loss(self, y_true, y_pred, eps=1e-7, normalize=True, sample_weight=None, labels=None): + # change eps to improve calculation accuracy + return log_loss(y_true, + y_pred, + eps, + normalize, + sample_weight, + labels) + + def _get_metrics(self, metrics, set_eps=False): metrics_ = {} if metrics: for metric in metrics: if metric == "binary_crossentropy" or metric == "logloss": - metrics_[metric] = log_loss + if set_eps: + metrics_[metric] = self._log_loss + else: + metrics_[metric] = log_loss if metric == "auc": metrics_[metric] = roc_auc_score if metric == "mse": diff --git a/deepctr_torch/models/ccpm.py b/deepctr_torch/models/ccpm.py index 87fc2c62..514ab7e3 100644 --- a/deepctr_torch/models/ccpm.py +++ b/deepctr_torch/models/ccpm.py @@ -44,7 +44,7 @@ class CCPM(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, conv_kernel_width=(6, 5), conv_filters=(4, 4), dnn_hidden_units=(256,), l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, dnn_dropout=0, - init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation=F.relu): + init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation='relu'): super(CCPM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, diff --git a/deepctr_torch/models/dcn.py b/deepctr_torch/models/dcn.py index c6d6a6be..46fd7f02 100644 --- a/deepctr_torch/models/dcn.py +++ b/deepctr_torch/models/dcn.py @@ -40,7 +40,7 @@ def __init__(self, dnn_hidden_units=(128, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_cross=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, - dnn_activation=F.relu, dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): super(DCN, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns, diff --git a/deepctr_torch/models/deepfm.py b/deepctr_torch/models/deepfm.py index a1fa7e5e..d4ccfee0 100644 --- a/deepctr_torch/models/deepfm.py +++ b/deepctr_torch/models/deepfm.py @@ -40,7 +40,7 @@ def __init__(self, dnn_hidden_units=(256, 128), l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, - dnn_activation=F.relu, dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): super(DeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, diff --git a/deepctr_torch/models/fibinet.py b/deepctr_torch/models/fibinet.py index 4031e4ce..af6f7d42 100644 --- a/deepctr_torch/models/fibinet.py +++ b/deepctr_torch/models/fibinet.py @@ -40,7 +40,7 @@ class FiBiNET(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, bilinear_type='interaction', reduction_ratio=3, dnn_hidden_units=(128, 128), l2_reg_linear=1e-5, - l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation=F.relu, + l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu'): super(FiBiNET, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, diff --git a/deepctr_torch/models/nfm.py b/deepctr_torch/models/nfm.py index 87116765..11a364e1 100644 --- a/deepctr_torch/models/nfm.py +++ b/deepctr_torch/models/nfm.py @@ -38,7 +38,7 @@ class NFM(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, bi_dropout=0, - dnn_dropout=0, dnn_activation=F.relu, task='binary', device='cpu'): + dnn_dropout=0, dnn_activation='relu', task='binary', device='cpu'): super(NFM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, l2_reg_linear=l2_reg_linear, diff --git a/deepctr_torch/models/onn.py b/deepctr_torch/models/onn.py index 9fcf00ca..d8235edd 100644 --- a/deepctr_torch/models/onn.py +++ b/deepctr_torch/models/onn.py @@ -56,7 +56,7 @@ class ONN(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=4, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_linear=1e-5, l2_reg_dnn=0, - dnn_dropout=0, init_std=0.0001, seed=1024, dnn_use_bn=False, dnn_activation=F.relu, + dnn_dropout=0, init_std=0.0001, seed=1024, dnn_use_bn=False, dnn_activation='relu', task='binary', device='cpu'): super(ONN, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, diff --git a/deepctr_torch/models/pnn.py b/deepctr_torch/models/pnn.py index e5956c4b..d1c7fa7d 100644 --- a/deepctr_torch/models/pnn.py +++ b/deepctr_torch/models/pnn.py @@ -37,7 +37,7 @@ class PNN(BaseModel): """ def __init__(self, dnn_feature_columns, embedding_size=8, dnn_hidden_units=(128, 128), l2_reg_embedding=1e-5, l2_reg_dnn=0, - init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation=F.relu, use_inner=True, use_outter=False, + init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', use_inner=True, use_outter=False, kernel_type='mat', task='binary', device='cpu',): super(PNN, self).__init__([], dnn_feature_columns, embedding_size=embedding_size, diff --git a/deepctr_torch/models/wdl.py b/deepctr_torch/models/wdl.py index 8f1becae..d4300e46 100644 --- a/deepctr_torch/models/wdl.py +++ b/deepctr_torch/models/wdl.py @@ -35,7 +35,7 @@ class WDL(BaseModel): def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, dnn_hidden_units=(256, 128), l2_reg_linear=1e-5, - l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation=F.relu,dnn_use_bn=False, + l2_reg_embedding=1e-5, l2_reg_dnn=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): super(WDL, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, diff --git a/deepctr_torch/models/xdeepfm.py b/deepctr_torch/models/xdeepfm.py index d8d6206e..3417ddb8 100644 --- a/deepctr_torch/models/xdeepfm.py +++ b/deepctr_torch/models/xdeepfm.py @@ -40,9 +40,9 @@ class xDeepFM(BaseModel): """ def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, dnn_hidden_units=(256, 256), - cin_layer_size=(256, 128,), cin_split_half=True, cin_activation=F.relu, l2_reg_linear=0.00001, + cin_layer_size=(256, 128,), cin_split_half=True, cin_activation='relu', l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, l2_reg_cin=0, init_std=0.0001, seed=1024, dnn_dropout=0, - dnn_activation=F.relu, dnn_use_bn=False, task='binary', device='cpu'): + dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu'): super(xDeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, dnn_hidden_units=dnn_hidden_units, diff --git a/docs/source/index.rst b/docs/source/index.rst index c8ca2d43..c3e70528 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,7 @@ News 09/24/2019 : Add `CCPM <./Features.html#ccpm-convolutional-click-prediction-model>`_ . `Changelog `_ + DisscussionGroup ----------------------- @@ -47,7 +48,6 @@ DisscussionGroup .. image:: ../pics/weichennote.png - .. toctree:: :maxdepth: 2 :caption: Home: diff --git a/examples/run_classification_criteo.py b/examples/run_classification_criteo.py index 606cfcc1..4e5915a8 100644 --- a/examples/run_classification_criteo.py +++ b/examples/run_classification_criteo.py @@ -63,4 +63,4 @@ pred_ans = model.predict(test_model_input, 256) print("") print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4)) - print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4)) + print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4)) \ No newline at end of file diff --git a/examples/run_multivalue_movielens.py b/examples/run_multivalue_movielens.py index 07f621a3..311e6012 100644 --- a/examples/run_multivalue_movielens.py +++ b/examples/run_multivalue_movielens.py @@ -15,44 +15,46 @@ def split(x): key2index[key] = len(key2index) + 1 return list(map(lambda x: key2index[x], key_ans)) +if __name__ == "__main__": -data = pd.read_csv("./movielens_sample.txt") -sparse_features = ["movie_id", "user_id", - "gender", "age", "occupation", "zip", ] -target = ['rating'] + data = pd.read_csv("./movielens_sample.txt") + sparse_features = ["movie_id", "user_id", + "gender", "age", "occupation", "zip", ] + target = ['rating'] -# 1.Label Encoding for sparse features,and process sequence features -for feat in sparse_features: - lbe = LabelEncoder() - data[feat] = lbe.fit_transform(data[feat]) -# preprocess the sequence feature + # 1.Label Encoding for sparse features,and process sequence features + for feat in sparse_features: + lbe = LabelEncoder() + data[feat] = lbe.fit_transform(data[feat]) + # preprocess the sequence feature -key2index = {} -genres_list = list(map(split, data['genres'].values)) -genres_length = np.array(list(map(len, genres_list))) -max_len = max(genres_length) -# Notice : padding=`post` -genres_list = pad_sequences(genres_list, maxlen=max_len, padding='post', ) + key2index = {} + genres_list = list(map(split, data['genres'].values)) + genres_length = np.array(list(map(len, genres_list))) + max_len = max(genres_length) + # Notice : padding=`post` + genres_list = pad_sequences(genres_list, maxlen=max_len, padding='post', ) -# 2.count #unique features for each sparse field and generate feature config for sequence feature + # 2.count #unique features for each sparse field and generate feature config for sequence feature -fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique()) - for feat in sparse_features] -varlen_feature_columns = [VarLenSparseFeat('genres', len( - key2index) + 1, max_len, 'mean')] # Notice : value 0 is for padding for sequence input feature + fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique()) + for feat in sparse_features] + varlen_feature_columns = [VarLenSparseFeat('genres', len( + key2index) + 1, max_len, 'mean')] # Notice : value 0 is for padding for sequence input feature -linear_feature_columns = fixlen_feature_columns + varlen_feature_columns -dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns -feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns) + linear_feature_columns = fixlen_feature_columns + varlen_feature_columns + dnn_feature_columns = fixlen_feature_columns + varlen_feature_columns + feature_names = get_feature_names(linear_feature_columns + dnn_feature_columns) -# 3.generate input data for model -model_input = {name:data[name] for name in feature_names} -model_input['genres'] = genres_list + # 3.generate input data for model + model_input = {name:data[name] for name in feature_names} + model_input['genres'] = genres_list -# 4.Define Model,compile and train -model = DeepFM(linear_feature_columns,dnn_feature_columns,task='regression') + # 4.Define Model,compile and train + model = DeepFM(linear_feature_columns,dnn_feature_columns,task='regression') -model.compile("adam", "mse", metrics=['mse'], ) -history = model.fit(model_input, data[target].values, - batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) + + model.compile("adam", "mse", metrics=['mse'], ) + history = model.fit(model_input, data[target].values, + batch_size=256, epochs=10, verbose=2, validation_split=0.2, ) \ No newline at end of file diff --git a/tests/layers/__init__.py b/tests/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/layers/activation_test.py b/tests/layers/activation_test.py new file mode 100644 index 00000000..9e14faee --- /dev/null +++ b/tests/layers/activation_test.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- +from deepctr_torch.layers import activation + diff --git a/tests/models/xDeepFM_test.py b/tests/models/xDeepFM_test.py index 2e9a24b8..561a92dc 100644 --- a/tests/models/xDeepFM_test.py +++ b/tests/models/xDeepFM_test.py @@ -10,7 +10,7 @@ [((), (), True, 'linear', 1, 2), ((8,), (), True, 'linear', 1, 1), ((), (8,), True, 'linear', 2, 2), - ((8,), (8,), False, F.relu, 2, 0)] + ((8,), (8,), False, 'relu', 2, 0)] ) def test_xDeepFM(dnn_hidden_units, cin_layer_size, cin_split_half, cin_activation, sparse_feature_num, dense_feature_dim):