diff --git a/README.md b/README.md index 6aa60417..2cbd56b7 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,9 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St | Attentional Factorization Machine | [IJCAI 2017][Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks](http://www.ijcai.org/proceedings/2017/435) | | Neural Factorization Machine | [SIGIR 2017][Neural Factorization Machines for Sparse Predictive Analytics](https://arxiv.org/pdf/1708.05027.pdf) | | xDeepFM | [KDD 2018][xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems](https://arxiv.org/pdf/1803.05170.pdf) | -| AutoInt | [arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | +| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) | +| Deep Interest Evolution Network | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) | +| AutoInt | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) | | ONN | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) | | FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) | @@ -45,25 +47,69 @@ Please follow our wechat to join group: - wechat ID: **deepctrbot** ![wechat](./docs/pics/weichennote.png) + ## Contributors([welcome to join us!](./CONTRIBUTING.md)) - - pic - - - pic - - - pic - - - pic - - - pic - - - pic - - - pic - + + + + + + + + + + + + + + + + + + +
+ ​ pic
+ ​ Shen Weichen ​ +

Founder
+ Zhejiang Unversity

​ +
+ pic
+ Wang Ze ​ +

Core Dev
Beihang University

​ +
+ ​ pic
+ Zhang Wutong +

Core Dev
Beijing University
of Posts and
Telecommunications

​ +
+ ​ pic
+ ​ Zhang Yuefeng +

Core Dev
+ Peking University

​ +
+ ​ pic
+ ​ Huo Junyi +

Core Dev
+ University of Southampton

​ +
+ ​ pic
+ ​ Zeng Kai ​ +

Dev
+ SenseTime

​ +
+ ​ pic
+ ​ Chen K ​ +

Dev
+ NetEase

​ +
+ ​ pic
+ ​ Tang +

Test
+ Tongji University

​ +
+ ​ pic
+ ​ Xu Qidi ​ +

Dev
+ University of
Electronic Science and
Technology of China

​ +
+ Welcome you !! +
\ No newline at end of file diff --git a/deepctr_torch/__init__.py b/deepctr_torch/__init__.py index 47005c76..cbd90cdf 100644 --- a/deepctr_torch/__init__.py +++ b/deepctr_torch/__init__.py @@ -2,5 +2,5 @@ from . import models from .utils import check_version -__version__ = '0.2.0' +__version__ = '0.2.1' check_version(__version__) \ No newline at end of file diff --git a/deepctr_torch/inputs.py b/deepctr_torch/inputs.py index 35c837d0..ea272fb1 100644 --- a/deepctr_torch/inputs.py +++ b/deepctr_torch/inputs.py @@ -4,10 +4,12 @@ Weichen Shen,wcshen1994@163.com """ -from collections import OrderedDict, namedtuple +from collections import OrderedDict, namedtuple, defaultdict +from itertools import chain import torch import torch.nn as nn +import numpy as np from .layers.sequence import SequencePoolingLayer from .layers.utils import concat_fun @@ -27,7 +29,8 @@ def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype=" if embedding_dim == "auto": embedding_dim = 6 * int(pow(vocabulary_size, 0.25)) if use_hash: - print("Notice! Feature Hashing on the fly currently is not supported in torch version,you can use tensorflow version!") + print( + "Notice! Feature Hashing on the fly currently is not supported in torch version,you can use tensorflow version!") return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype, embedding_name, group_name) @@ -108,7 +111,7 @@ def build_input_features(feature_columns): elif isinstance(feat, VarLenSparseFeat): features[feat_name] = (start, start + feat.maxlen) start += feat.maxlen - if feat.length_name is not None: + if feat.length_name is not None and feat.length_name not in features: features[feat.length_name] = (start, start + 1) start += 1 else: @@ -116,15 +119,6 @@ def build_input_features(feature_columns): return features -# def get_dense_input(features, feature_columns): -# dense_feature_columns = list(filter(lambda x: isinstance( -# x, DenseFeat), feature_columns)) if feature_columns else [] -# dense_input_list = [] -# for fc in dense_feature_columns: -# dense_input_list.append(features[fc.name]) -# return dense_input_list - - def combined_dnn_input(sparse_embedding_list, dense_value_list): if len(sparse_embedding_list) > 0 and len(dense_value_list) > 0: sparse_dnn_input = torch.flatten( @@ -139,72 +133,6 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list): else: raise NotImplementedError - # - # def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(), - # mask_feat_list=(), to_list=False): - # """ - # Args: - # sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} - # sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} - # sparse_feature_columns: list, sparse features - # return_feat_list: list, names of feature to be returned, defualt () -> return all features - # mask_feat_list, list, names of feature to be masked in hash transform - # Return: - # group_embedding_dict: defaultdict(list) - # """ - # group_embedding_dict = defaultdict(list) - # for fc in sparse_feature_columns: - # feature_name = fc.name - # embedding_name = fc.embedding_name - # if (len(return_feat_list) == 0 or feature_name in return_feat_list): - # if fc.use_hash: - # # lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))( - # # sparse_input_dict[feature_name]) - # # TODO: add hash function - # lookup_idx = sparse_input_dict[feature_name] - # else: - # lookup_idx = sparse_input_dict[feature_name] - # - # group_embedding_dict[fc.group_name].append(sparse_embedding_dict[embedding_name](lookup_idx)) - # if to_list: - # return list(chain.from_iterable(group_embedding_dict.values())) - # return group_embedding_dict - # - # - # def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_feature_columns): - # varlen_embedding_vec_dict = {} - # for fc in varlen_sparse_feature_columns: - # feature_name = fc.name - # embedding_name = fc.embedding_name - # if fc.use_hash: - # # lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name]) - # # TODO: add hash function - # lookup_idx = sequence_input_dict[feature_name] - # else: - # lookup_idx = sequence_input_dict[feature_name] - # varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx) - # return varlen_embedding_vec_dict - # - # - # def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_columns, to_list=False): - # pooling_vec_list = defaultdict(list) - # for fc in varlen_sparse_feature_columns: - # feature_name = fc.name - # combiner = fc.combiner - # feature_length_name = fc.length_name - # if feature_length_name is not None: - # seq_input = embedding_dict[feature_name] - # vec = SequencePoolingLayer(combiner)([seq_input, features[feature_length_name]]) - # else: - # seq_input = embedding_dict[feature_name] - # vec = SequencePoolingLayer(combiner)(seq_input) - # pooling_vec_list[fc.group_name].append(vec) - # - # if to_list: - # return chain.from_iterable(pooling_vec_list.values()) - # - # return pooling_vec_list - def get_varlen_pooling_list(embedding_dict, features, feature_index, varlen_sparse_feature_columns, device): varlen_sparse_embedding_list = [] @@ -249,3 +177,95 @@ def create_embedding_matrix(feature_columns, init_std=0.0001, linear=False, spar nn.init.normal_(tensor.weight, mean=0, std=init_std) return embedding_dict.to(device) + + +def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True): + sparse_feature_columns = list( + filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else [] + dense_feature_columns = list( + filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else [] + + varlen_sparse_feature_columns = list( + filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else [] + + if not support_dense and len(dense_feature_columns) > 0: + raise ValueError( + "DenseFeat is not supported in dnn_feature_columns") + + sparse_embedding_list = [embedding_dict[feat.embedding_name]( + X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for + feat in sparse_feature_columns] + + varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index, + varlen_sparse_feature_columns, self.device) + + dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in + dense_feature_columns] + + return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list + + + +def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(), + mask_feat_list=(), to_list=False): + """ + Args: + X: input Tensor [batch_size x hidden_dim] + sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} + sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} + sparse_feature_columns: list, sparse features + return_feat_list: list, names of feature to be returned, defualt () -> return all features + mask_feat_list, list, names of feature to be masked in hash transform + Return: + group_embedding_dict: defaultdict(list) + """ + group_embedding_dict = defaultdict(list) + for fc in sparse_feature_columns: + feature_name = fc.name + embedding_name = fc.embedding_name + if (len(return_feat_list) == 0 or feature_name in return_feat_list): + # TODO: add hash function + # if fc.use_hash: + # raise NotImplementedError("hash function is not implemented in this version!") + lookup_idx = np.array(sparse_input_dict[feature_name]) + input_tensor = X[:, lookup_idx[0]:lookup_idx[1]].long() + emb = sparse_embedding_dict[embedding_name](input_tensor) + group_embedding_dict[fc.group_name].append(emb) + if to_list: + return list(chain.from_iterable(group_embedding_dict.values())) + return group_embedding_dict + + +def varlen_embedding_lookup(X, embedding_dict, sequence_input_dict, varlen_sparse_feature_columns): + varlen_embedding_vec_dict = {} + for fc in varlen_sparse_feature_columns: + feature_name = fc.name + embedding_name = fc.embedding_name + if fc.use_hash: + # lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name]) + # TODO: add hash function + lookup_idx = sequence_input_dict[feature_name] + else: + lookup_idx = sequence_input_dict[feature_name] + varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name]( + X[:, lookup_idx[0]:lookup_idx[1]].long()) # (lookup_idx) + + return varlen_embedding_vec_dict + + +def get_dense_input(X, features, feature_columns): + dense_feature_columns = list(filter(lambda x: isinstance( + x, DenseFeat), feature_columns)) if feature_columns else [] + dense_input_list = [] + for fc in dense_feature_columns: + lookup_idx = np.array(features[fc.name]) + input_tensor = X[:, lookup_idx[0]:lookup_idx[1]].float() + dense_input_list.append(input_tensor) + return dense_input_list + + +def maxlen_lookup(X, sparse_input_dict, maxlen_column): + if maxlen_column is None or len(maxlen_column)==0: + raise ValueError('please add max length column for VarLenSparseFeat of DIEN input') + lookup_idx = np.array(sparse_input_dict[maxlen_column[0]]) + return X[:, lookup_idx[0]:lookup_idx[1]].long() diff --git a/deepctr_torch/layers/__init__.py b/deepctr_torch/layers/__init__.py index 3767b035..8761d4b8 100644 --- a/deepctr_torch/layers/__init__.py +++ b/deepctr_torch/layers/__init__.py @@ -1,4 +1,4 @@ from .interaction import * from .core import * from .utils import concat_fun -from .sequence import KMaxPooling, SequencePoolingLayer +from .sequence import * diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py index 73ef23e6..6e2a8c35 100644 --- a/deepctr_torch/layers/activation.py +++ b/deepctr_torch/layers/activation.py @@ -1,47 +1,48 @@ # -*- coding:utf-8 -*- - +import torch import torch.nn as nn -# class Dice(nn.Module): -# """The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data. -# -# Input shape: -# - 2 dims: [batch_size, embedding_size(features)] -# - 3 dims: [batch_size, num_features, embedding_size(features)] -# -# Output shape: -# - Same shape as the input. -# -# References -# - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) -# - https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch -# """ -# def __init__(self, num_features, dim=2, epsilon=1e-9): -# super(Dice, self).__init__() -# assert dim == 2 or dim == 3 -# self.bn = nn.BatchNorm1d(num_features, eps=epsilon) -# self.sigmoid = nn.Sigmoid() -# self.dim = dim -# -# if self.dim == 2: -# self.alpha = torch.zeros((num_features,)).to(device) -# else: -# self.alpha = torch.zeros((num_features, 1)).to(device) -# -# def forward(self, x): -# # x shape: [batch_size, num_features, embedding_size(features)] -# assert x.dim() == 2 or x.dim() == 3 -# -# if self.dim == 2: -# x_p = self.sigmoid(self.bn(x)) -# out = self.alpha * (1 - x_p) * x + x_p * x -# else: -# x = torch.transpose(x, 1, 2) -# x_p = self.sigmoid(self.bn(x)) -# out = self.alpha * (1 - x_p) * x + x_p * x -# out = torch.transpose(out, 1, 2) -# -# return out + +class Dice(nn.Module): + """The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data. + + Input shape: + - 2 dims: [batch_size, embedding_size(features)] + - 3 dims: [batch_size, num_features, embedding_size(features)] + + Output shape: + - Same shape as input. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + - https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch + """ + def __init__(self, emb_size, dim=2, epsilon=1e-8, device='cpu'): + super(Dice, self).__init__() + assert dim == 2 or dim == 3 + + self.bn = nn.BatchNorm1d(emb_size, eps=epsilon) + self.sigmoid = nn.Sigmoid() + self.dim = dim + + if self.dim == 2: + self.alpha = torch.zeros((emb_size,)).to(device) + else: + self.alpha = torch.zeros((emb_size, 1)).to(device) + + def forward(self, x): + assert x.dim() == self.dim + if self.dim == 2: + x_p = self.sigmoid(self.bn(x)) + out = self.alpha * (1 - x_p) * x + x_p * x + else: + x = torch.transpose(x, 1, 2) + x_p = self.sigmoid(self.bn(x)) + out = self.alpha * (1 - x_p) * x + x_p * x + out = torch.transpose(out, 1, 2) + + return out + class Identity(nn.Module): @@ -64,9 +65,11 @@ def activation_layer(act_name, hidden_size=None, dice_dim=2): act_layer: activation layer """ if isinstance(act_name, str): - if act_name.lower() == 'linear': + if act_name.lower() == 'sigmoid': + act_layer = nn.Sigmoid() + elif act_name.lower() == 'linear': act_layer = Identity() - if act_name.lower() == 'relu': + elif act_name.lower() == 'relu': act_layer = nn.ReLU(inplace=True) elif act_name.lower() == 'dice': assert dice_dim @@ -83,12 +86,3 @@ def activation_layer(act_name, hidden_size=None, dice_dim=2): if __name__ == "__main__": pass - #device = 'cuda:0' if torch.cuda.is_available() else 'cpu' - - # torch.manual_seed(7) - # a = Dice(3) - # b = torch.rand((5, 3)) - # c = a(b) - # print(c.size()) - # print('b:', b) - # print('c:', c) diff --git a/deepctr_torch/layers/core.py b/deepctr_torch/layers/core.py index 8f1b7d99..b64d2d31 100644 --- a/deepctr_torch/layers/core.py +++ b/deepctr_torch/layers/core.py @@ -34,36 +34,32 @@ class LocalActivationUnit(nn.Module): - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) """ - def __init__(self, hidden_units=[80, 40], embedding_dim=4, activation='Dice', dropout_rate=0, use_bn=False): + def __init__(self, hidden_units=(64, 32), embedding_dim=4, activation='sigmoid', dropout_rate=0, dice_dim=3, l2_reg=0, use_bn=False): super(LocalActivationUnit, self).__init__() - self.dnn1 = DNN(inputs_dim=4 * embedding_dim, + self.dnn = DNN(inputs_dim=4 * embedding_dim, hidden_units=hidden_units, activation=activation, - dropout_rate=0.5, - use_bn=use_bn, - dice_dim=3) - - # self.dnn2 = DNN(inputs_dim=hidden_units[-1], - # hidden_units=[1], - # activation=activation, - # use_bn=use_bn, - # dice_dim=3) + l2_reg=l2_reg, + dropout_rate=dropout_rate, + dice_dim=dice_dim, + use_bn=use_bn) self.dense = nn.Linear(hidden_units[-1], 1) def forward(self, query, user_behavior): # query ad : size -> batch_size * 1 * embedding_size # user behavior : size -> batch_size * time_seq_len * embedding_size - user_behavior_len = user_behavior.size(1) - queries = torch.cat([query for _ in range(user_behavior_len)], dim=1) - attention_input = torch.cat([queries, user_behavior, queries - user_behavior, queries * user_behavior], dim=-1) - attention_output = self.dnn1(attention_input) - attention_output = self.dense(attention_output) + queries = query.expand(-1, user_behavior_len, -1) + + attention_input = torch.cat([queries, user_behavior, queries - user_behavior, queries * user_behavior], dim=-1) # as the source code, subtraction simulates verctors' difference + attention_output = self.dnn(attention_input) + + attention_score = self.dense(attention_output) # [B, T, 1] - return attention_output + return attention_score class DNN(nn.Module): diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py index 84afa51e..19140be7 100644 --- a/deepctr_torch/layers/interaction.py +++ b/deepctr_torch/layers/interaction.py @@ -387,7 +387,7 @@ def forward(self, inputs): 'bnik,bnjk->bnij', querys, keys) # head_num None F F self.normalized_att_scores = F.softmax( - inner_product, dim=1) # head_num None F F + inner_product, dim=-1) # head_num None F F result = torch.matmul(self.normalized_att_scores, values) # head_num None F D diff --git a/deepctr_torch/layers/sequence.py b/deepctr_torch/layers/sequence.py index 33b30492..414d9980 100644 --- a/deepctr_torch/layers/sequence.py +++ b/deepctr_torch/layers/sequence.py @@ -1,5 +1,9 @@ import torch import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.utils.rnn import PackedSequence +from ..layers.core import LocalActivationUnit class SequencePoolingLayer(nn.Module): @@ -27,15 +31,15 @@ def __init__(self, mode='mean', supports_masking=False, device='cpu'): raise ValueError('parameter mode should in [sum, mean, max]') self.supports_masking = supports_masking self.mode = mode + self.device = device self.eps = torch.FloatTensor([1e-8]).to(device) self.to(device) - def _sequence_mask(self, lengths, maxlen=None, dtype=torch.bool): # Returns a mask tensor representing the first N positions of each cell. if maxlen is None: maxlen = lengths.max() - row_vector = torch.arange(0, maxlen, 1) + row_vector = torch.arange(0, maxlen, 1).to(self.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix @@ -62,7 +66,8 @@ def forward(self, seq_value_len_list): hist = uiseq_embed_list - (1 - mask) * 1e9 hist = torch.max(hist, dim=1, keepdim=True)[0] return hist - hist = torch.sum(uiseq_embed_list * mask, dim=1, keepdim=False) + hist = uiseq_embed_list * mask.float() + hist = torch.sum(hist, dim=1, keepdim=False) if self.mode == 'mean': hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps) @@ -71,58 +76,80 @@ def forward(self, seq_value_len_list): return hist -# class AttentionSequencePoolingLayer(nn.Module): -# """The Attentional sequence pooling operation used in DIN. -# -# Input shape -# - A list of three tensor: [query,keys,keys_length] -# -# - query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)`` -# -# - keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` -# -# - keys_length is a 2D tensor with shape: ``(batch_size, 1)`` -# -# Output shape -# - 3D tensor with shape: ``(batch_size, 1, embedding_size)`` -# -# Arguments -# - **att_hidden_units**: List of positive integer, the attention net layer number and units in each layer. -# -# - **embedding_dim**: Dimension of the input embeddings. -# -# - **activation**: Activation function to use in attention net. -# -# - **weight_normalization**: bool.Whether normalize the attention score of local activation unit. -# -# References -# - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) -# """ -# -# def __init__(self, att_hidden_units=[80, 40], embedding_dim=4, activation='Dice', weight_normalization=False): -# super(AttentionSequencePoolingLayer, self).__init__() -# -# self.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim, -# activation=activation) -# -# def forward(self, query, keys, keys_length): -# # query: [B, 1, E], keys: [B, T, E], keys_length: [B, 1] -# # TODO: Mini-batch aware regularization in originial paper [Zhou G, et al. 2018] is not implemented here. As the authors mentioned -# # it is not a must for small dataset as the open-sourced ones. -# attention_score = self.local_att(query, keys) -# attention_score = torch.transpose(attention_score, 1, 2) # B * 1 * T -# -# # define mask by length -# keys_length = keys_length.type(torch.LongTensor) -# mask = torch.arange(keys.size(1))[None, :] < keys_length[:, None] # [1, T] < [B, 1, 1] -> [B, 1, T] -# -# # mask -# output = torch.mul(attention_score, mask.type(torch.FloatTensor)) # [B, 1, T] -# -# # multiply weight -# output = torch.matmul(output, keys) # [B, 1, E] -# -# return output +class AttentionSequencePoolingLayer(nn.Module): + """The Attentional sequence pooling operation used in DIN & DIEN. + + Arguments + - **att_hidden_units**:list of positive integer, the attention net layer number and units in each layer. + + - **att_activation**: Activation function to use in attention net. + + - **weight_normalization**: bool.Whether normalize the attention score of local activation unit. + + - **supports_masking**:If True,the input need to support masking. + + References + - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + """ + + def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False, + return_score=False, supports_masking=False, embedding_dim=4, **kwargs): + super(AttentionSequencePoolingLayer, self).__init__() + self.return_score = return_score + self.weight_normalization = weight_normalization + self.supports_masking = supports_masking + self.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim, + activation=att_activation, + dropout_rate=0, use_bn=False) + + def forward(self, query, keys, keys_length, mask=None): + """ + Input shape + - A list of three tensor: [query,keys,keys_length] + + - query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)`` + + - keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` + + - keys_length is a 2D tensor with shape: ``(batch_size, 1)`` + + Output shape + - 3D tensor with shape: ``(batch_size, 1, embedding_size)``. + """ + batch_size, max_length, dim = keys.size() + + # Mask + if self.supports_masking: + if mask is None: + raise ValueError("When supports_masking=True,input must support masking") + keys_masks = mask.unsqueeze(1) + else: + keys_masks = torch.arange(max_length, device=keys_length.device, dtype=keys_length.dtype).repeat(batch_size, 1) # [B, T] + keys_masks = keys_masks < keys_length.view(-1, 1) # 0, 1 mask + keys_masks = keys_masks.unsqueeze(1) # [B, 1, T] + + attention_score = self.local_att(query, keys) # [B, T, 1] + + outputs = torch.transpose(attention_score, 1, 2) # [B, 1, T] + + if self.weight_normalization: + paddings = torch.ones_like(outputs) * (-2 ** 32 + 1) + else: + paddings = torch.zeros_like(outputs) + + outputs = torch.where(keys_masks, outputs, paddings) # [B, 1, T] + + # Scale + #outputs = outputs / (keys.shape[-1] ** 0.05) + + if self.weight_normalization: + outputs = F.softmax(outputs,dim=-1) # [B, 1, T] + + if not self.return_score: + # Weighted sum + outputs = torch.matmul(outputs, keys) # [B, 1, E] + + return outputs class KMaxPooling(nn.Module): @@ -158,3 +185,130 @@ def forward(self, input): out = torch.topk(input, k=self.k, dim=self.axis, sorted=True)[0] return out + + +class AGRUCell(nn.Module): + """ Attention based GRU (AGRU) + + Reference: + - Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018. + """ + + def __init__(self, input_size, hidden_size, bias=True): + super(AGRUCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + # (W_ir|W_iz|W_ih) + self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size)) + self.register_parameter('weight_ih', self.weight_ih) + # (W_hr|W_hz|W_hh) + self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) + self.register_parameter('weight_hh', self.weight_hh) + if bias: + # (b_ir|b_iz|b_ih) + self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size)) + self.register_parameter('bias_ih', self.bias_ih) + # (b_hr|b_hz|b_hh) + self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size)) + self.register_parameter('bias_hh', self.bias_hh) + else: + self.register_parameter('bias_ih', None) + self.register_parameter('bias_hh', None) + + def forward(self, input, hx, att_score): + gi = F.linear(input, self.weight_ih, self.bias_ih) + gh = F.linear(hx, self.weight_hh, self.bias_hh) + i_r, i_z, i_n = gi.chunk(3, 1) + h_r, h_z, h_n = gh.chunk(3, 1) + + reset_gate = torch.sigmoid(i_r + h_r) + # update_gate = torch.sigmoid(i_z + h_z) + new_state = torch.tanh(i_n + reset_gate * h_n) + + att_score = att_score.view(-1, 1) + hy = (1. - att_score) * hx + att_score * new_state + return hy + + +class AUGRUCell(nn.Module): + """ Effect of GRU with attentional update gate (AUGRU) + + Reference: + - Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018. + """ + + def __init__(self, input_size, hidden_size, bias=True): + super(AUGRUCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + # (W_ir|W_iz|W_ih) + self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size)) + self.register_parameter('weight_ih', self.weight_ih) + # (W_hr|W_hz|W_hh) + self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) + self.register_parameter('weight_hh', self.weight_hh) + if bias: + # (b_ir|b_iz|b_ih) + self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size)) + self.register_parameter('bias_ih', self.bias_ih) + # (b_hr|b_hz|b_hh) + self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size)) + self.register_parameter('bias_ih', self.bias_hh) + else: + self.register_parameter('bias_ih', None) + self.register_parameter('bias_hh', None) + + def forward(self, input, hx, att_score): + gi = F.linear(input, self.weight_ih, self.bias_ih) + gh = F.linear(hx, self.weight_hh, self.bias_hh) + i_r, i_z, i_n = gi.chunk(3, 1) + h_r, h_z, h_n = gh.chunk(3, 1) + + reset_gate = torch.sigmoid(i_r + h_r) + update_gate = torch.sigmoid(i_z + h_z) + new_state = torch.tanh(i_n + reset_gate * h_n) + + att_score = att_score.view(-1, 1) + update_gate = att_score * update_gate + hy = (1. - update_gate) * hx + update_gate * new_state + return hy + + +class DynamicGRU(nn.Module): + def __init__(self, input_size, hidden_size, bias=True, gru_type='AGRU'): + super(DynamicGRU, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + if gru_type == 'AGRU': + self.rnn = AGRUCell(input_size, hidden_size, bias) + elif gru_type == 'AUGRU': + self.rnn = AUGRUCell(input_size, hidden_size, bias) + + def forward(self, input, att_scores=None, hx=None): + if not isinstance(input, PackedSequence) or not isinstance(att_scores, PackedSequence): + raise NotImplementedError("DynamicGRU only supports packed input and att_scores") + + input, batch_sizes, sorted_indices, unsorted_indices = input + att_scores, _, _, _ = att_scores + + max_batch_size = int(batch_sizes[0]) + if hx is None: + hx = torch.zeros(max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + + outputs = torch.zeros(input.size(0), self.hidden_size, + dtype=input.dtype, device=input.device) + + begin = 0 + for batch in batch_sizes: + new_hx = self.rnn( + input[begin:begin + batch], + hx[0:batch], + att_scores[begin:begin + batch]) + outputs[begin:begin + batch] = new_hx + hx = new_hx + begin += batch + return PackedSequence(outputs, batch_sizes, sorted_indices, unsorted_indices) diff --git a/deepctr_torch/models/__init__.py b/deepctr_torch/models/__init__.py index c2f0180a..2fcb1aab 100644 --- a/deepctr_torch/models/__init__.py +++ b/deepctr_torch/models/__init__.py @@ -10,4 +10,7 @@ from .onn import ONN from .pnn import PNN from .ccpm import CCPM +from .dien import DIEN +from .din import DIN + NFFM = ONN \ No newline at end of file diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index 10f0f52e..d946136e 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -18,7 +18,8 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from ..inputs import build_input_features, SparseFeat, DenseFeat, VarLenSparseFeat, get_varlen_pooling_list,create_embedding_matrix +from ..inputs import build_input_features, SparseFeat, DenseFeat, VarLenSparseFeat, get_varlen_pooling_list, \ + create_embedding_matrix from ..layers import PredictionLayer from ..layers.utils import slice_arrays @@ -36,7 +37,8 @@ def __init__(self, feature_columns, feature_index, init_std=0.0001, device='cpu' self.varlen_sparse_feature_columns = list( filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if len(feature_columns) else [] - self.embedding_dict = create_embedding_matrix(feature_columns,init_std,linear=True,sparse=False,device=device) + self.embedding_dict = create_embedding_matrix(feature_columns, init_std, linear=True, sparse=False, + device=device) # nn.ModuleDict( # {feat.embedding_name: nn.Embedding(feat.dimension, 1, sparse=True) for feat in @@ -83,7 +85,6 @@ def forward(self, X): class BaseModel(nn.Module): - def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(128, 128), l2_reg_linear=1e-5, @@ -95,13 +96,14 @@ def __init__(self, self.dnn_feature_columns = dnn_feature_columns self.reg_loss = torch.zeros((1,), device=device) + self.aux_loss = torch.zeros((1,), device=device) self.device = device # device self.feature_index = build_input_features( linear_feature_columns + dnn_feature_columns) self.dnn_feature_columns = dnn_feature_columns - self.embedding_dict = create_embedding_matrix(dnn_feature_columns,init_std,sparse=False,device=device) + self.embedding_dict = create_embedding_matrix(dnn_feature_columns, init_std, sparse=False, device=device) # nn.ModuleDict( # {feat.embedding_name: nn.Embedding(feat.dimension, embedding_size, sparse=True) for feat in # self.dnn_feature_columns} @@ -115,7 +117,7 @@ def __init__(self, self.add_regularization_loss( self.linear_model.parameters(), l2_reg_linear) - self.out = PredictionLayer(task,) + self.out = PredictionLayer(task, ) self.to(device) def fit(self, x=None, @@ -127,7 +129,7 @@ def fit(self, x=None, validation_split=0., validation_data=None, shuffle=True, - use_double=False ): + use_double=False): """ :param x: Numpy array of training data (if the model has a single input), or list of Numpy arrays (if the model has multiple inputs).If input layers in the model are named, you can also pass a @@ -215,7 +217,7 @@ def fit(self, x=None, optim.zero_grad() loss = loss_func(y_pred, y.squeeze(), reduction='sum') - total_loss = loss + self.reg_loss + total_loss = loss + self.reg_loss + self.aux_loss loss_epoch += loss.item() total_loss_epoch += total_loss.item() @@ -360,7 +362,10 @@ def add_regularization_loss(self, weight_list, weight_decay, p=2): l2_reg = torch.norm(w, p=p, ) reg_loss = reg_loss + l2_reg reg_loss = weight_decay * reg_loss - self.reg_loss += reg_loss + self.reg_loss = self.reg_loss + reg_loss + + def add_auxiliary_loss(self, aux_loss, alpha): + self.aux_loss = aux_loss * alpha def compile(self, optimizer, loss=None, diff --git a/deepctr_torch/models/dien.py b/deepctr_torch/models/dien.py new file mode 100644 index 00000000..3767aad8 --- /dev/null +++ b/deepctr_torch/models/dien.py @@ -0,0 +1,382 @@ +""" +Author: + Ze Wang, wangze0801@126.com + +Reference: + [1] Zhou G, Mou N, Fan Y, et al. Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018. (https://arxiv.org/pdf/1809.03672.pdf) +""" + +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from .basemodel import BaseModel +from ..layers import * +from ..inputs import * + + +class DIEN(BaseModel): + """Instantiates the Deep Interest Evolution Network architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param history_feature_list: list,to indicate sequence sparse field + :param gru_type: str,can be GRU AIGRU AUGRU AGRU + :param use_negsampling: bool, whether or not use negtive sampling + :param alpha: float ,weight of auxiliary_loss + :param use_bn: bool. Whether use BatchNormalization before activation or not in deep net + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN + :param dnn_activation: Activation function to use in DNN + :param att_hidden_units: list,list of positive integer , the layer number and units in each layer of attention net + :param att_activation: Activation function to use in attention net + :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit. + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :param device: str, ``"cpu"`` or ``"cuda:0"`` + :return: A PyTorch model instance. + """ + + def __init__(self, + dnn_feature_columns, history_feature_list, + gru_type="GRU", use_negsampling=False, alpha=1.0, use_bn=False, dnn_hidden_units=(256, 128), + dnn_activation='relu', + att_hidden_units=(64, 16), att_activation="relu", att_weight_normalization=True, + l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, task='binary', + device='cpu'): + super(DIEN, self).__init__([], dnn_feature_columns, dnn_hidden_units=dnn_hidden_units, + l2_reg_linear=0, l2_reg_embedding=l2_reg_embedding, + l2_reg_dnn=l2_reg_dnn, init_std=init_std, seed=seed, + dnn_dropout=dnn_dropout, dnn_activation=dnn_activation, + task=task, device=device) + + self.item_features = history_feature_list + self.use_negsampling = use_negsampling + self.alpha = alpha + self._split_columns() + + # structure: embedding layer -> interest extractor layer -> interest evolution layer -> DNN layer -> out + + # embedding layer + # inherit -> self.embedding_dict + input_size = self._compute_interest_dim() + # interest extractor layer + self.interest_extractor = InterestExtractor(input_size=input_size, use_neg=use_negsampling, init_std=init_std) + # interest evolution layer + self.interest_evolution = InterestEvolving( + input_size=input_size, + gru_type=gru_type, + use_neg=use_negsampling, + init_std=init_std, + att_hidden_size=att_hidden_units, + att_activation=att_activation, + att_weight_normalization=att_weight_normalization) + # DNN layer + dnn_input_size = self._compute_dnn_dim() + input_size + self.dnn = DNN(dnn_input_size, dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, use_bn, + init_std=init_std, seed=seed) + self.linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False) + # prediction layer + # inherit -> self.out + + # init + for name, tensor in self.linear.named_parameters(): + if 'weight' in name: + nn.init.normal_(tensor, mean=0, std=init_std) + + self.to(device) + + def forward(self, X): + # [B, H] , [B, T, H], [B, T, H] , [B] + query_emb, keys_emb, neg_keys_emb, keys_length = self._get_emb(X) + # [b, T, H], [1] (b 0 + masked_keys_length = keys_length[mask] + + # batch_size validation check + if masked_keys_length.shape[0] == 0: + return zero_outputs, + + masked_keys = torch.masked_select(keys, mask.view(-1, 1, 1)).view(-1, max_length, dim) + + packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length, batch_first=True, + enforce_sorted=False) + packed_interests, _ = self.gru(packed_keys) + interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, + total_length=max_length) + + if self.use_neg and neg_keys is not None: + masked_neg_keys = torch.masked_select(neg_keys, mask.view(-1, 1, 1)).view(-1, max_length, dim) + aux_loss = self._cal_auxiliary_loss( + interests[:, :-1, :], + masked_keys[:, 1:, :], + masked_neg_keys[:, 1:, :], + masked_keys_length - 1) + + return interests, aux_loss + + def _cal_auxiliary_loss(self, states, click_seq, noclick_seq, keys_length): + # keys_length >= 1 + mask_shape = keys_length > 0 + keys_length = keys_length[mask_shape] + if keys_length.shape[0] == 0: + return torch.zeros((1,), device=states.device) + + _, max_seq_length, embedding_size = states.size() + states = torch.masked_select(states, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length, embedding_size) + click_seq = torch.masked_select(click_seq, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length, embedding_size) + noclick_seq = torch.masked_select(noclick_seq, mask_shape.view(-1, 1, 1)).view(-1, max_seq_length, + embedding_size) + batch_size = states.size()[0] + + mask = (torch.arange(max_seq_length, device=states.device).repeat( + batch_size, 1) < keys_length.view(-1, 1)).float() + + click_input = torch.cat([states, click_seq], dim=-1) + noclick_input = torch.cat([states, noclick_seq], dim=-1) + embedding_size = embedding_size * 2 + + click_p = self.auxiliary_net(click_input.view( + batch_size * max_seq_length, embedding_size)).view( + batch_size, max_seq_length)[mask > 0].view(-1, 1) + click_target = torch.ones( + click_p.size(), dtype=torch.float, device=click_p.device) + + noclick_p = self.auxiliary_net(noclick_input.view( + batch_size * max_seq_length, embedding_size)).view( + batch_size, max_seq_length)[mask > 0].view(-1, 1) + noclick_target = torch.zeros( + noclick_p.size(), dtype=torch.float, device=noclick_p.device) + + loss = F.binary_cross_entropy( + torch.cat([click_p, noclick_p], dim=0), + torch.cat([click_target, noclick_target], dim=0)) + + return loss + + +class InterestEvolving(nn.Module): + __SUPPORTED_GRU_TYPE__ = ['GRU', 'AIGRU', 'AGRU', 'AUGRU'] + + def __init__(self, + input_size, + gru_type='GRU', + use_neg=False, + init_std=0.001, + att_hidden_size=(64, 16), + att_activation='sigmoid', + att_weight_normalization=False): + super(InterestEvolving, self).__init__() + if gru_type not in InterestEvolving.__SUPPORTED_GRU_TYPE__: + raise NotImplementedError("gru_type: {gru_type} is not supported") + self.gru_type = gru_type + self.use_neg = use_neg + + if gru_type == 'GRU': + self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size, + att_hidden_units=att_hidden_size, + att_activation=att_activation, + weight_normalization=att_weight_normalization, + return_score=False) + self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True) + elif gru_type == 'AIGRU': + self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size, + att_hidden_units=att_hidden_size, + att_activation=att_activation, + weight_normalization=att_weight_normalization, + return_score=True) + self.interest_evolution = nn.GRU(input_size=input_size, hidden_size=input_size, batch_first=True) + elif gru_type == 'AGRU' or gru_type == 'AUGRU': + self.attention = AttentionSequencePoolingLayer(embedding_dim=input_size, + att_hidden_units=att_hidden_size, + att_activation=att_activation, + weight_normalization=att_weight_normalization, + return_score=True) + self.interest_evolution = DynamicGRU(input_size=input_size, hidden_size=input_size, + gru_type=gru_type) + for name, tensor in self.interest_evolution.named_parameters(): + if 'weight' in name: + nn.init.normal_(tensor, mean=0, std=init_std) + + @staticmethod + def _get_last_state(states, keys_length): + # states [B, T, H] + batch_size, max_seq_length, hidden_size = states.size() + + mask = (torch.arange(max_seq_length, device=keys_length.device).repeat( + batch_size, 1) == (keys_length.view(-1, 1) - 1)) + + return states[mask] + + def forward(self, query, keys, keys_length, mask=None): + """ + Parameters + ---------- + query: 2D tensor, [B, H] + keys: (masked_interests), 3D tensor, [b, T, H] + keys_length: 1D tensor, [B] + + Returns + ------- + outputs: 2D tensor, [B, H] + """ + batch_size, dim = query.size() + max_length = keys.size()[1] + + # check batch validation + zero_outputs = torch.zeros(batch_size, dim, device=query.device) + mask = keys_length > 0 + # [B] -> [b] + keys_length = keys_length[mask] + if keys_length.shape[0] == 0: + return zero_outputs + + # [B, H] -> [b, 1, H] + query = torch.masked_select(query, mask.view(-1, 1)).view(-1, dim).unsqueeze(1) + + if self.gru_type == 'GRU': + packed_keys = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, enforce_sorted=False) + packed_interests, _ = self.interest_evolution(packed_keys) + interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0, + total_length=max_length) + outputs = self.attention(query, interests, keys_length.unsqueeze(1)) # [b, 1, H] + outputs = outputs.squeeze(1) # [b, H] + elif self.gru_type == 'AIGRU': + att_scores = self.attention(query, keys, keys_length.unsqueeze(1)) # [b, 1, T] + interests = keys * att_scores.transpose(1, 2) # [b, T, H] + packed_interests = pack_padded_sequence(interests, lengths=keys_length, batch_first=True, + enforce_sorted=False) + _, outputs = self.interest_evolution(packed_interests) + outputs = outputs.squeeze(0) # [b, H] + elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU': + att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1) # [b, T] + packed_interests = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, + enforce_sorted=False) + packed_scores = pack_padded_sequence(att_scores, lengths=keys_length, batch_first=True, + enforce_sorted=False) + outputs = self.interest_evolution(packed_interests, packed_scores) + outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=max_length) + # pick last state + outputs = InterestEvolving._get_last_state(outputs, keys_length) # [b, H] + # [b, H] -> [B, H] + zero_outputs[mask] = outputs + return zero_outputs diff --git a/deepctr_torch/models/din.py b/deepctr_torch/models/din.py new file mode 100644 index 00000000..42afad63 --- /dev/null +++ b/deepctr_torch/models/din.py @@ -0,0 +1,135 @@ +# -*- coding:utf-8 -*- +""" +Author: + Yuef Zhang +Reference: + [1] Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068. (https://arxiv.org/pdf/1706.06978.pdf) +""" + +from .basemodel import BaseModel +from ..inputs import * +from ..layers import * +from ..layers.sequence import AttentionSequencePoolingLayer + + +class DIN(BaseModel): + """Instantiates the Deep Interest Network architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param history_feature_list: list,to indicate sequence sparse field + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net + :param dnn_activation: Activation function to use in deep net + :param att_hidden_size: list,list of positive integer , the layer number and units in each layer of attention net + :param att_activation: Activation function to use in attention net + :param att_weight_normalization: bool. Whether normalize the attention score of local activation unit. + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param init_std: float,to use as the initialize std of embedding vector + :param seed: integer ,to use as random seed. + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :return: A PyTorch model instance. + + """ + + def __init__(self, dnn_feature_columns, history_feature_list, dnn_use_bn=False, + dnn_hidden_units=(256, 128), dnn_activation='relu', att_hidden_size=(64, 16), + att_activation='Dice', att_weight_normalization=False, l2_reg_dnn=0.0, + l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, + seed=1024, task='binary', device='cpu'): + super(DIN, self).__init__([], dnn_feature_columns, + dnn_hidden_units=dnn_hidden_units, l2_reg_linear=0, + l2_reg_dnn=l2_reg_dnn, init_std=init_std, + l2_reg_embedding=l2_reg_embedding, + dnn_dropout=dnn_dropout, dnn_activation=dnn_activation, + seed=seed, task=task, + device=device) + + self.sparse_feature_columns = list( + filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else [] + self.varlen_sparse_feature_columns = list( + filter(lambda x: isinstance(x, VarLenSparseFeat), dnn_feature_columns)) if dnn_feature_columns else [] + + self.history_feature_list = history_feature_list + + self.history_feature_columns = [] + self.sparse_varlen_feature_columns = [] + self.history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list)) + + for fc in self.varlen_sparse_feature_columns: + feature_name = fc.name + if feature_name in self.history_fc_names: + self.history_feature_columns.append(fc) + else: + self.sparse_varlen_feature_columns.append(fc) + + att_emb_dim = self._compute_interest_dim() + + self.attention = AttentionSequencePoolingLayer(att_hidden_units=att_hidden_size, + embedding_dim=att_emb_dim, + activation=att_activation, + return_score=False, + supports_masking=False, + weight_normalization=att_weight_normalization) + + self.dnn = DNN(inputs_dim=self.compute_input_dim(dnn_feature_columns), + hidden_units=dnn_hidden_units, + activation=dnn_activation, + dropout_rate=dnn_dropout, + l2_reg=l2_reg_dnn, + use_bn=dnn_use_bn) + self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) + self.to(device) + + + def forward(self, X): + sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, + self.embedding_dict) + + # sequence pooling part + query_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns, + self.history_feature_list, self.history_feature_list, to_list=True) + keys_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.history_feature_columns, + self.history_fc_names, self.history_fc_names, to_list=True) + dnn_input_emb_list = embedding_lookup(X, self.embedding_dict, self.feature_index, self.sparse_feature_columns, mask_feat_list=self.history_feature_list, to_list=True) + + + sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index, + self.sparse_varlen_feature_columns) + + sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index, + self.sparse_varlen_feature_columns, self.device) + + dnn_input_emb_list += sequence_embed_list + + # concatenate + query_emb = torch.cat(query_emb_list, dim=-1) # [B, 1, E] + keys_emb = torch.cat(keys_emb_list, dim=-1) # [B, T, E] + keys_length = torch.ones((query_emb.size(0), 1)).to(self.device) # [B, 1] + deep_input_emb = torch.cat(dnn_input_emb_list, dim=-1) + + hist = self.attention(query_emb, keys_emb, keys_length) # [B, 1, E] + + # deep part + deep_input_emb = torch.cat((deep_input_emb, hist), dim=-1) + deep_input_emb = deep_input_emb.view(deep_input_emb.size(0), -1) + + dnn_input = combined_dnn_input([deep_input_emb], dense_value_list) + dnn_output = self.dnn(dnn_input) + dnn_logit = self.dnn_linear(dnn_output) + + y_pred = self.out(dnn_logit) + + return y_pred + + def _compute_interest_dim(self): + interest_dim = 0 + for feat in self.sparse_feature_columns: + if feat.name in self.history_feature_list: + interest_dim += feat.embedding_dim + return interest_dim + + +if __name__ == '__main__': + pass diff --git a/docs/source/Features.md b/docs/source/Features.md index ba9782cd..a62e95a5 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -149,6 +149,35 @@ The output of Cross Net and MLP are concatenated.The concatenated vector are fee [Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]//Proceedings of the ADKDD'17. ACM, 2017: 12.](https://arxiv.org/abs/1708.05123) +### DIN (Deep Interest Network) + +DIN introduce a attention method to learn from sequence(multi-valued) feature. +Tradional method usually use sum/mean pooling on sequence feature. +DIN use a local activation unit to get the activation score between candidate item and history items. +User's interest are represented by weighted sum of user behaviors. +user's interest vector and other embedding vectors are concatenated and fed into a MLP to get the prediction. + +[**DIN Model API**](./deepctr_torch.models.din.html) + +[DIN example](https://github.com/shenweichen/DeepCTR-Torch/tree/master/examples/run_din.py) + +![DIN](../pics/DIN.png) + +[Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) + +### DIEN (Deep Interest Evolution Network) + +Deep Interest Evolution Network (DIEN) uses interest extractor layer to capture temporal interests from history behavior sequence. At this layer, an auxiliary loss is proposed to supervise interest extracting at each step. As user interests are diverse, especially in the e-commerce system, interest evolving layer is proposed to capture interest evolving process that is relative to the target item. At interest evolving layer, attention mechanism is embedded into the sequential structure novelly, and the effects of relative interests are strengthened during interest evolution. + +[**DIEN Model API**](./deepctr_torch.models.dien.html) + +[DIEN example](https://github.com/shenweichen/DeepCTR-Torch/tree/master/examples/run_dien.py) + +![DIEN](../pics/DIEN.png) + +[Zhou G, Mou N, Fan Y, et al. Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018.](https://arxiv.org/pdf/1809.03672.pdf) + + ### xDeepFM xDeepFM use a Compressed Interaction Network (CIN) to learn both low and high order feature interaction explicitly,and use a MLP to learn feature interaction implicitly. diff --git a/docs/source/History.md b/docs/source/History.md index a2227568..5d06a97b 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 03/27/2020 : [v0.2.1](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.1) released.Add [DIN](./Features.html#din-deep-interest-network) and [DIEN](./Features.html#dien-deep-interest-evolution-network) . - 01/31/2020 : [v0.2.0](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.0) released.Refactor [feature columns](./Features.html#feature-columns).Support to use double precision in metric calculation. - 10/03/2019 : [v0.1.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.1.3) released.Simplify the input logic. - 09/28/2019 : [v0.1.2](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.1.2) released.Add [sequence(multi-value) input support](./Examples.html#multi-value-input-movielens). diff --git a/docs/source/conf.py b/docs/source/conf.py index 276439b6..3098ee95 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.2.0' +release = '0.2.1' # -- General configuration --------------------------------------------------- diff --git a/docs/source/deepctr_torch.models.dien.rst b/docs/source/deepctr_torch.models.dien.rst new file mode 100644 index 00000000..42e4785f --- /dev/null +++ b/docs/source/deepctr_torch.models.dien.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.dien module +================================ + +.. automodule:: deepctr_torch.models.dien + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.din.rst b/docs/source/deepctr_torch.models.din.rst new file mode 100644 index 00000000..13d7a790 --- /dev/null +++ b/docs/source/deepctr_torch.models.din.rst @@ -0,0 +1,7 @@ +deepctr\_torch.models.din module +================================ + +.. automodule:: deepctr_torch.models.din + :members: + :no-undoc-members: + :no-show-inheritance: diff --git a/docs/source/deepctr_torch.models.rst b/docs/source/deepctr_torch.models.rst index 39054d5d..25041a55 100644 --- a/docs/source/deepctr_torch.models.rst +++ b/docs/source/deepctr_torch.models.rst @@ -18,6 +18,8 @@ Submodules deepctr_torch.models.pnn deepctr_torch.models.wdl deepctr_torch.models.xdeepfm + deepctr_torch.models.din + deepctr_torch.models.dien Module contents --------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index c1b75460..52b6aac0 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,12 +34,12 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and News ----- +03/27/2020 : Add `DIN <./Features.html#din-deep-interest-network>`_ and `DIEN <./Features.html#dien-deep-interest-evolution-network>`_ . `Changelog `_ + 01/31/2020 : Refactor `feature columns <./Features.html#feature-columns>`_ . Support double precision in metric calculation . `Changelog `_ 10/03/2019 : Simplify the input logic(`examples <./Examples.html#classification-criteo>`_). `Changelog `_ -09/28/2019 : Add `sequence(multi-value) input support <./Examples.html#multi-value-input-movielens>`_ . `Changelog `_ - DisscussionGroup ----------------------- diff --git a/examples/run_dien.py b/examples/run_dien.py new file mode 100644 index 00000000..51564351 --- /dev/null +++ b/examples/run_dien.py @@ -0,0 +1,68 @@ +import numpy as np +import torch + +from deepctr_torch.inputs import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names +from deepctr_torch.models import DIEN + + +def get_xy_fd(use_neg=False, hash_flag=False): + feature_columns = [SparseFeat('user', 4, embedding_dim=4, use_hash=hash_flag), + SparseFeat('gender', 2, embedding_dim=4, use_hash=hash_flag), + SparseFeat('item_id', 3 + 1, embedding_dim=8, use_hash=hash_flag), + SparseFeat('cate_id', 2 + 1, embedding_dim=4, use_hash=hash_flag), + DenseFeat('pay_score', 1)] + + feature_columns += [ + VarLenSparseFeat(SparseFeat('hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_cate_id', vocabulary_size=2 + 1, embedding_dim=4, embedding_name='cate_id'), + maxlen=4, + length_name="seq_length")] + + behavior_feature_list = ["item_id", "cate_id"] + uid = np.array([0, 1, 2, 3]) + gender = np.array([0, 1, 0, 1]) + item_id = np.array([1, 2, 3, 2]) # 0 is mask value + cate_id = np.array([1, 2, 1, 2]) # 0 is mask value + score = np.array([0.1, 0.2, 0.3, 0.2]) + + hist_item_id = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]]) + hist_cate_id = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]]) + + behavior_length = np.array([3, 3, 2, 2]) + + feature_dict = {'user': uid, 'gender': gender, 'item_id': item_id, 'cate_id': cate_id, + 'hist_item_id': hist_item_id, 'hist_cate_id': hist_cate_id, + 'pay_score': score, "seq_length": behavior_length} + + if use_neg: + feature_dict['neg_hist_item_id'] = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]]) + feature_dict['neg_hist_cate_id'] = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]]) + feature_columns += [ + VarLenSparseFeat( + SparseFeat('neg_hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat( + SparseFeat('neg_hist_cate_id', vocabulary_size=2 + 1, embedding_dim=4, embedding_name='cate_id'), + maxlen=4, length_name="seq_length")] + + x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1, 0]) + return x, y, feature_columns, behavior_feature_list + + +if __name__ == "__main__": + x, y, feature_columns, behavior_feature_list = get_xy_fd(use_neg=True) + + device = 'cpu' + use_cuda = True + if use_cuda and torch.cuda.is_available(): + print('cuda ready...') + device = 'cuda:0' + + model = DIEN(feature_columns, behavior_feature_list, + dnn_hidden_units=[4, 4, 4], dnn_dropout=0.6, gru_type="AUGRU", use_negsampling=True, device=device) + + model.compile('adam', 'binary_crossentropy', + metrics=['binary_crossentropy', 'auc']) + history = model.fit(x, y, batch_size=2, verbose=1, epochs=10, validation_split=0, shuffle=False) diff --git a/examples/run_din.py b/examples/run_din.py new file mode 100644 index 00000000..00238790 --- /dev/null +++ b/examples/run_din.py @@ -0,0 +1,49 @@ +import sys + +sys.path.insert(0, '..') + +import numpy as np +import torch +from deepctr_torch.inputs import (DenseFeat, SparseFeat, VarLenSparseFeat, + get_feature_names) +from deepctr_torch.models.din import DIN + + +def get_xy_fd(): + feature_columns = [SparseFeat('user', 3, embedding_dim=8), SparseFeat('gender', 2, embedding_dim=8), + SparseFeat('item', 3 + 1, embedding_dim=8), SparseFeat('item_gender', 2 + 1, embedding_dim=8), + DenseFeat('score', 1)] + + feature_columns += [VarLenSparseFeat(SparseFeat('hist_item', 3 + 1, embedding_dim=8), 4), + VarLenSparseFeat(SparseFeat('hist_item_gender', 2 + 1, embedding_dim=8), 4)] + + behavior_feature_list = ["item", "item_gender"] + uid = np.array([0, 1, 2]) + ugender = np.array([0, 1, 0]) + iid = np.array([1, 2, 3]) # 0 is mask value + igender = np.array([1, 2, 1]) # 0 is mask value + score = np.array([0.1, 0.2, 0.3]) + + hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0]]) + hist_igender = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0]]) + + feature_dict = {'user': uid, 'gender': ugender, 'item': iid, 'item_gender': igender, + 'hist_item': hist_iid, 'hist_item_gender': hist_igender, 'score': score} + x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1]) + + return x, y, feature_columns, behavior_feature_list + + +if __name__ == "__main__": + x, y, feature_columns, behavior_feature_list = get_xy_fd() + device = 'cpu' + use_cuda = True + if use_cuda and torch.cuda.is_available(): + print('cuda ready...') + device = 'cuda:0' + + model = DIN(feature_columns, behavior_feature_list, device=device) + model.compile('adagrad', 'binary_crossentropy', + metrics=['binary_crossentropy']) + history = model.fit(x, y, batch_size=3, epochs=10, validation_split=0.0, verbose=2) diff --git a/setup.py b/setup.py index 377a6f78..a57727e4 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="deepctr-torch", - version="0.2.0", + version="0.2.1", author="Weichen Shen", author_email="wcshen1994@163.com", description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch", diff --git a/tests/layers/activation_test.py b/tests/layers/activation_test.py new file mode 100644 index 00000000..33707229 --- /dev/null +++ b/tests/layers/activation_test.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from deepctr_torch.layers import activation +from tests.utils import layer_test + + +def test_dice(): + layer_test(activation.Dice, kwargs={'emb_size': 3, 'dim': 2}, + input_shape=(5, 3), expected_output_shape=(5,3)) + layer_test(activation.Dice, kwargs={'emb_size': 10, 'dim': 3}, + input_shape=(5, 3, 10), expected_output_shape=(5,3,10)) + diff --git a/tests/models/DIEN_test.py b/tests/models/DIEN_test.py new file mode 100644 index 00000000..a2cdb6ae --- /dev/null +++ b/tests/models/DIEN_test.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest +import torch + +from deepctr_torch.inputs import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names +from deepctr_torch.models.dien import InterestEvolving, DIEN +from ..utils import check_model, get_device + + +@pytest.mark.parametrize( + 'gru_type', + ["AIGRU", "AUGRU", "AGRU", "GRU"] +) +def test_InterestEvolving(gru_type): + interest_evolution = InterestEvolving( + input_size=3, + gru_type=gru_type, + use_neg=False) + + query = torch.tensor([[1, 1, 1], [0.1, 0.2, 0.3]], dtype=torch.float) + + keys = torch.tensor([ + [[0.1, 0.2, 0.3], [1, 2, 3], [0.4, 0.2, 1], [0.0, 0.0, 0.0]], + [[0.1, 0.2, 0.3], [1, 2, 3], [0.4, 0.2, 1], [0.5, 0.5, 0.5]] + ], dtype=torch.float) + + keys_length = torch.tensor([3, 4]) + + output = interest_evolution(query, keys, keys_length) + + assert output.size()[0] == 2 + assert output.size()[1] == 3 + + +def get_xy_fd(use_neg=False, hash_flag=False): + feature_columns = [SparseFeat('user', 4, embedding_dim=4, use_hash=hash_flag), + SparseFeat('gender', 2, embedding_dim=4, use_hash=hash_flag), + SparseFeat('item_id', 3 + 1, embedding_dim=8, use_hash=hash_flag), + SparseFeat('cate_id', 2 + 1, embedding_dim=4, use_hash=hash_flag), + DenseFeat('pay_score', 1)] + + feature_columns += [ + VarLenSparseFeat(SparseFeat('hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_cate_id', vocabulary_size=2 + 1, embedding_dim=4, embedding_name='cate_id'), + maxlen=4, + length_name="seq_length")] + + behavior_feature_list = ["item_id", "cate_id"] + uid = np.array([0, 1, 2, 3]) + gender = np.array([0, 1, 0, 1]) + item_id = np.array([1, 2, 3, 2]) # 0 is mask value + cate_id = np.array([1, 2, 1, 2]) # 0 is mask value + score = np.array([0.1, 0.2, 0.3, 0.2]) + + hist_item_id = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]]) + hist_cate_id = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]]) + + behavior_length = np.array([3, 3, 2, 2]) + + feature_dict = {'user': uid, 'gender': gender, 'item_id': item_id, 'cate_id': cate_id, + 'hist_item_id': hist_item_id, 'hist_cate_id': hist_cate_id, + 'pay_score': score, "seq_length": behavior_length} + + if use_neg: + feature_dict['neg_hist_item_id'] = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]]) + feature_dict['neg_hist_cate_id'] = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]]) + feature_columns += [ + VarLenSparseFeat( + SparseFeat('neg_hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat( + SparseFeat('neg_hist_cate_id', vocabulary_size=2 + 1, embedding_dim=4, embedding_name='cate_id'), + maxlen=4, length_name="seq_length")] + + x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1, 0]) + return x, y, feature_columns, behavior_feature_list + + +@pytest.mark.parametrize( + 'gru_type,use_neg', + [("AIGRU", True), ("AIGRU", False), + ("AUGRU", True), ("AUGRU", False), + ("AGRU", True), ("AGRU", False), + ("GRU", True), ("GRU", False)] +) +def test_DIEN(gru_type, use_neg): + model_name = "DIEN_" + gru_type + + x, y, feature_columns, behavior_feature_list = get_xy_fd(use_neg=use_neg) + + model = DIEN(feature_columns, behavior_feature_list, + dnn_hidden_units=[4, 4, 4], dnn_dropout=0.5, gru_type=gru_type, device=get_device()) + + check_model(model, model_name, x, y) + + +if __name__ == "__main__": + pass diff --git a/tests/models/DIN_test.py b/tests/models/DIN_test.py new file mode 100644 index 00000000..a9b7d418 --- /dev/null +++ b/tests/models/DIN_test.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import numpy as np + +from deepctr_torch.inputs import SparseFeat, VarLenSparseFeat, DenseFeat, get_feature_names +from deepctr_torch.models.din import DIN +from ..utils import check_model, get_device + + +def get_xy_fd(hash_flag=False): + feature_columns = [SparseFeat('user', 4, embedding_dim=4, use_hash=hash_flag), + SparseFeat('gender', 2, embedding_dim=4, use_hash=hash_flag), + SparseFeat('item_id', 3 + 1, embedding_dim=8, use_hash=hash_flag), + SparseFeat('cate_id', 2 + 1, embedding_dim=4, use_hash=hash_flag), + DenseFeat('pay_score', 1)] + + feature_columns += [ + VarLenSparseFeat(SparseFeat('hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_cate_id', vocabulary_size=2 + 1, embedding_dim=4, embedding_name='cate_id'), + maxlen=4, + length_name="seq_length")] + + behavior_feature_list = ["item_id", "cate_id"] + uid = np.array([0, 1, 2, 3]) + gender = np.array([0, 1, 0, 1]) + item_id = np.array([1, 2, 3, 2]) # 0 is mask value + cate_id = np.array([1, 2, 1, 2]) # 0 is mask value + score = np.array([0.1, 0.2, 0.3, 0.2]) + + hist_item_id = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [1, 2, 0, 0]]) + hist_cate_id = np.array([[1, 1, 2, 0], [2, 1, 1, 0], [2, 1, 0, 0], [1, 2, 0, 0]]) + + behavior_length = np.array([3, 3, 2, 2]) + + feature_dict = {'user': uid, 'gender': gender, 'item_id': item_id, 'cate_id': cate_id, + 'hist_item_id': hist_item_id, 'hist_cate_id': hist_cate_id, + 'pay_score': score, "seq_length": behavior_length} + + x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1, 0]) + return x, y, feature_columns, behavior_feature_list + + +def test_DIN(): + model_name = "DIN" + + x, y, feature_columns, behavior_feature_list = get_xy_fd() + model = DIN(feature_columns, behavior_feature_list, dnn_dropout=0.5, device=get_device()) + + check_model(model, model_name, x, y) # only have 3 train data so we set validation ratio at 0 + + +if __name__ == "__main__": + pass diff --git a/tests/utils.py b/tests/utils.py index d60c0ef3..0fd33b29 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -66,6 +66,78 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens return model_input, y, feature_columns +def layer_test(layer_cls, kwargs = {}, input_shape=None, + input_dtype=torch.float32, input_data=None, expected_output=None, + expected_output_shape=None, expected_output_dtype=None, fixed_batch_size=False): + '''check layer is valid or not + + :param layer_cls: + :param input_shape: + :param input_dtype: + :param input_data: + :param expected_output: + :param expected_output_dtype: + :param fixed_batch_size: + + :return: output of the layer + ''' + if input_data is None: + # generate input data + if not input_shape: + raise ValueError("input shape should not be none") + + input_data_shape = list(input_shape) + for i, e in enumerate(input_data_shape): + if e is None: + input_data_shape[i] = np.random.randint(1, 4) + + if all(isinstance(e, tuple) for e in input_data_shape): + input_data = [] + for e in input_data_shape: + rand_input = (10 * np.random.random(e)) + input_data.append(rand_input) + else: + rand_input = 10 * np.random.random(input_data_shape) + input_data = rand_input + + else: + # use input_data to update other parameters + if input_shape is None: + input_shape = input_data.shape + + if expected_output_dtype is None: + expected_output_dtype = input_dtype + + # layer initialization + layer = layer_cls(**kwargs) + + if fixed_batch_size: + input = torch.tensor(input_data.unsqueeze(0), dtype=input_dtype) + else: + input = torch.tensor(input_data, dtype=input_dtype) + + # calculate layer's output + output = layer(input) + + if not output.dtype == expected_output_dtype: + raise AssertionError("layer output dtype does not match with the expected one") + + if not expected_output_shape: + raise ValueError("expected output shape should not be none") + + actual_output_shape = output.shape + for expected_dim, actual_dim in zip(expected_output_shape, actual_output_shape): + if expected_dim is not None: + if not expected_dim == actual_dim: + raise AssertionError(f"expected_dim:{expected_dim}, actual_dim:{actual_dim}") + + if expected_output is not None: + # check whether output equals to expected output + assert_allclose(output, expected_output, rtol=1e-3) + + return output + + def check_model(model, model_name, x, y, check_model_io=True): ''' compile model,train and evaluate it,then save/load weight and model file.