- 公众号:浅梦的学习笔记
+ 公众号:浅梦学习笔记
@@ -74,7 +76,7 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
-## Contributors([welcome to join us!](./CONTRIBUTING.md))
+## Main Contributors([welcome to join us!](./CONTRIBUTING.md))
@@ -125,18 +127,18 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
Dev
NetEase
+
+
+ Cheng Weiyu
+ Dev
+ Shanghai Jiao Tong University
+ |
Tang
Test
Tongji University
|
-
-
- Xu Qidi
- Dev
- University of Electronic Science and Technology of China
- |
\ No newline at end of file
diff --git a/deepctr_torch/__init__.py b/deepctr_torch/__init__.py
index b780468d..88508515 100644
--- a/deepctr_torch/__init__.py
+++ b/deepctr_torch/__init__.py
@@ -2,5 +2,5 @@
from . import models
from .utils import check_version
-__version__ = '0.2.6'
+__version__ = '0.2.7'
check_version(__version__)
\ No newline at end of file
diff --git a/deepctr_torch/inputs.py b/deepctr_torch/inputs.py
index 324056a1..1b5f4193 100644
--- a/deepctr_torch/inputs.py
+++ b/deepctr_torch/inputs.py
@@ -57,6 +57,10 @@ def vocabulary_size(self):
def embedding_dim(self):
return self.sparsefeat.embedding_dim
+ @property
+ def use_hash(self):
+ return self.sparsefeat.use_hash
+
@property
def dtype(self):
return self.sparsefeat.dtype
@@ -136,18 +140,15 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list):
def get_varlen_pooling_list(embedding_dict, features, feature_index, varlen_sparse_feature_columns, device):
varlen_sparse_embedding_list = []
-
for feat in varlen_sparse_feature_columns:
- seq_emb = embedding_dict[feat.embedding_name](
- features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long())
+ seq_emb = embedding_dict[feat.name]
if feat.length_name is None:
seq_mask = features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long() != 0
emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=True, device=device)(
[seq_emb, seq_mask])
else:
- seq_length = features[:,
- feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
+ seq_length = features[:, feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=False, device=device)(
[seq_emb, seq_length])
varlen_sparse_embedding_list.append(emb)
@@ -179,33 +180,6 @@ def create_embedding_matrix(feature_columns, init_std=0.0001, linear=False, spar
return embedding_dict.to(device)
-def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True):
- sparse_feature_columns = list(
- filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
- dense_feature_columns = list(
- filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []
-
- varlen_sparse_feature_columns = list(
- filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []
-
- if not support_dense and len(dense_feature_columns) > 0:
- raise ValueError(
- "DenseFeat is not supported in dnn_feature_columns")
-
- sparse_embedding_list = [embedding_dict[feat.embedding_name](
- X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
- feat in sparse_feature_columns]
-
- varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
- varlen_sparse_feature_columns, self.device)
-
- dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
- dense_feature_columns]
-
- return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list
-
-
-
def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
mask_feat_list=(), to_list=False):
"""
diff --git a/deepctr_torch/layers/activation.py b/deepctr_torch/layers/activation.py
index 01624a05..44bd308c 100644
--- a/deepctr_torch/layers/activation.py
+++ b/deepctr_torch/layers/activation.py
@@ -17,6 +17,7 @@ class Dice(nn.Module):
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
- https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch
"""
+
def __init__(self, emb_size, dim=2, epsilon=1e-8, device='cpu'):
super(Dice, self).__init__()
assert dim == 2 or dim == 3
@@ -41,18 +42,16 @@ def forward(self, x):
x_p = self.sigmoid(self.bn(x))
out = self.alpha * (1 - x_p) * x + x_p * x
out = torch.transpose(out, 1, 2)
-
return out
class Identity(nn.Module):
-
def __init__(self, **kwargs):
super(Identity, self).__init__()
- def forward(self, X):
- return X
+ def forward(self, inputs):
+ return inputs
def activation_layer(act_name, hidden_size=None, dice_dim=2):
diff --git a/deepctr_torch/layers/interaction.py b/deepctr_torch/layers/interaction.py
index edbfa88b..d02c2b41 100644
--- a/deepctr_torch/layers/interaction.py
+++ b/deepctr_torch/layers/interaction.py
@@ -130,7 +130,7 @@ def __init__(self, filed_size, embedding_size, bilinear_type="interaction", seed
self.bilinear.append(
nn.Linear(embedding_size, embedding_size, bias=False))
elif self.bilinear_type == "interaction":
- for i, j in itertools.combinations(range(filed_size), 2):
+ for _, _ in itertools.combinations(range(filed_size), 2):
self.bilinear.append(
nn.Linear(embedding_size, embedding_size, bias=False))
else:
@@ -330,41 +330,34 @@ class InteractingLayer(nn.Module):
Input shape
- A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
Output shape
- - 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``.
+ - 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
Arguments
- **in_features** : Positive integer, dimensionality of input features.
- - **att_embedding_size**: int.The embedding size in multi-head self-attention network.
- - **head_num**: int.The head number in multi-head self-attention network.
+ - **head_num**: int.The head number in multi-head self-attention network.
- **use_res**: bool.Whether or not use standard residual connections before output.
- **seed**: A Python integer to use as random seed.
References
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""
- def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
+ def __init__(self, embedding_size, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
super(InteractingLayer, self).__init__()
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
- self.att_embedding_size = att_embedding_size
+ if embedding_size % head_num != 0:
+ raise ValueError('embedding_size is not an integer multiple of head_num!')
+ self.att_embedding_size = embedding_size // head_num
self.head_num = head_num
self.use_res = use_res
self.scaling = scaling
self.seed = seed
- embedding_size = in_features
-
- self.W_Query = nn.Parameter(torch.Tensor(
- embedding_size, self.att_embedding_size * self.head_num))
-
- self.W_key = nn.Parameter(torch.Tensor(
- embedding_size, self.att_embedding_size * self.head_num))
-
- self.W_Value = nn.Parameter(torch.Tensor(
- embedding_size, self.att_embedding_size * self.head_num))
+ self.W_Query = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
+ self.W_key = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
+ self.W_Value = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
if self.use_res:
- self.W_Res = nn.Parameter(torch.Tensor(
- embedding_size, self.att_embedding_size * self.head_num))
+ self.W_Res = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
for tensor in self.parameters():
nn.init.normal_(tensor, mean=0.0, std=0.05)
@@ -376,29 +369,24 @@ def forward(self, inputs):
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))
- querys = torch.tensordot(inputs, self.W_Query,
- dims=([-1], [0])) # None F D*head_num
+ # None F D
+ querys = torch.tensordot(inputs, self.W_Query, dims=([-1], [0]))
keys = torch.tensordot(inputs, self.W_key, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))
- # head_num None F D
-
- querys = torch.stack(torch.split(
- querys, self.att_embedding_size, dim=2))
+ # head_num None F D/head_num
+ querys = torch.stack(torch.split(querys, self.att_embedding_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
- values = torch.stack(torch.split(
- values, self.att_embedding_size, dim=2))
- inner_product = torch.einsum(
- 'bnik,bnjk->bnij', querys, keys) # head_num None F F
+ values = torch.stack(torch.split(values, self.att_embedding_size, dim=2))
+
+ inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys) # head_num None F F
if self.scaling:
inner_product /= self.att_embedding_size ** 0.5
- self.normalized_att_scores = F.softmax(
- inner_product, dim=-1) # head_num None F F
- result = torch.matmul(self.normalized_att_scores,
- values) # head_num None F D
+ self.normalized_att_scores = F.softmax(inner_product, dim=-1) # head_num None F F
+ result = torch.matmul(self.normalized_att_scores, values) # head_num None F D/head_num
result = torch.cat(torch.split(result, 1, ), dim=-1)
- result = torch.squeeze(result, dim=0) # None F D*head_num
+ result = torch.squeeze(result, dim=0) # None F D
if self.use_res:
result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
result = F.relu(result)
@@ -499,9 +487,9 @@ def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device=
self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))
init_para_list = [self.U_list, self.V_list, self.C_list]
- for i in range(len(init_para_list)):
- for j in range(self.layer_num):
- nn.init.xavier_normal_(init_para_list[i][j])
+ for para in init_para_list:
+ for i in range(self.layer_num):
+ nn.init.xavier_normal_(para[i])
for i in range(len(self.bias)):
nn.init.zeros_(self.bias[i])
@@ -727,3 +715,43 @@ def __init__(self, field_size, conv_kernel_width, conv_filters, device='cpu'):
def forward(self, inputs):
return self.conv_layer(inputs)
+
+
+class LogTransformLayer(nn.Module):
+ """Logarithmic Transformation Layer in Adaptive factorization network, which models arbitrary-order cross features.
+
+ Input shape
+ - 3D tensor with shape: ``(batch_size, field_size, embedding_size)``.
+ Output shape
+ - 2D tensor with shape: ``(batch_size, ltl_hidden_size*embedding_size)``.
+ Arguments
+ - **field_size** : positive integer, number of feature groups
+ - **embedding_size** : positive integer, embedding size of sparse features
+ - **ltl_hidden_size** : integer, the number of logarithmic neurons in AFN
+ References
+ - Cheng, W., Shen, Y. and Huang, L. 2020. Adaptive Factorization Network: Learning Adaptive-Order Feature
+ Interactions. Proceedings of the AAAI Conference on Artificial Intelligence. 34, 04 (Apr. 2020), 3609-3616.
+ """
+
+ def __init__(self, field_size, embedding_size, ltl_hidden_size):
+ super(LogTransformLayer, self).__init__()
+
+ self.ltl_weights = nn.Parameter(torch.Tensor(field_size, ltl_hidden_size))
+ self.ltl_biases = nn.Parameter(torch.Tensor(1, 1, ltl_hidden_size))
+ self.bn = nn.ModuleList([nn.BatchNorm1d(embedding_size) for i in range(2)])
+ nn.init.normal_(self.ltl_weights, mean=0.0, std=0.1)
+ nn.init.zeros_(self.ltl_biases, )
+
+ def forward(self, inputs):
+ # Avoid numeric overflow
+ afn_input = torch.clamp(torch.abs(inputs), min=1e-7, max=float("Inf"))
+ # Transpose to shape: ``(batch_size,embedding_size,field_size)``
+ afn_input_trans = torch.transpose(afn_input, 1, 2)
+ # Logarithmic transformation layer
+ ltl_result = torch.log(afn_input_trans)
+ ltl_result = self.bn[0](ltl_result)
+ ltl_result = torch.matmul(ltl_result, self.ltl_weights) + self.ltl_biases
+ ltl_result = torch.exp(ltl_result)
+ ltl_result = self.bn[1](ltl_result)
+ ltl_result = torch.flatten(ltl_result, start_dim=1)
+ return ltl_result
diff --git a/deepctr_torch/layers/sequence.py b/deepctr_torch/layers/sequence.py
index 550e5878..b0026ff7 100644
--- a/deepctr_torch/layers/sequence.py
+++ b/deepctr_torch/layers/sequence.py
@@ -117,7 +117,7 @@ def forward(self, query, keys, keys_length, mask=None):
Output shape
- 3D tensor with shape: ``(batch_size, 1, embedding_size)``.
"""
- batch_size, max_length, dim = keys.size()
+ batch_size, max_length, _ = keys.size()
# Mask
if self.supports_masking:
@@ -176,16 +176,16 @@ def __init__(self, k, axis, device='cpu'):
self.axis = axis
self.to(device)
- def forward(self, input):
- if self.axis < 0 or self.axis >= len(input.shape):
+ def forward(self, inputs):
+ if self.axis < 0 or self.axis >= len(inputs.shape):
raise ValueError("axis must be 0~%d,now is %d" %
- (len(input.shape) - 1, self.axis))
+ (len(inputs.shape) - 1, self.axis))
- if self.k < 1 or self.k > input.shape[self.axis]:
+ if self.k < 1 or self.k > inputs.shape[self.axis]:
raise ValueError("k must be in 1 ~ %d,now k is %d" %
- (input.shape[self.axis], self.k))
+ (inputs.shape[self.axis], self.k))
- out = torch.topk(input, k=self.k, dim=self.axis, sorted=True)[0]
+ out = torch.topk(inputs, k=self.k, dim=self.axis, sorted=True)[0]
return out
@@ -220,11 +220,11 @@ def __init__(self, input_size, hidden_size, bias=True):
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
- def forward(self, input, hx, att_score):
- gi = F.linear(input, self.weight_ih, self.bias_ih)
+ def forward(self, inputs, hx, att_score):
+ gi = F.linear(inputs, self.weight_ih, self.bias_ih)
gh = F.linear(hx, self.weight_hh, self.bias_hh)
- i_r, i_z, i_n = gi.chunk(3, 1)
- h_r, h_z, h_n = gh.chunk(3, 1)
+ i_r, _, i_n = gi.chunk(3, 1)
+ h_r, _, h_n = gh.chunk(3, 1)
reset_gate = torch.sigmoid(i_r + h_r)
# update_gate = torch.sigmoid(i_z + h_z)
@@ -266,8 +266,8 @@ def __init__(self, input_size, hidden_size, bias=True):
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
- def forward(self, input, hx, att_score):
- gi = F.linear(input, self.weight_ih, self.bias_ih)
+ def forward(self, inputs, hx, att_score):
+ gi = F.linear(inputs, self.weight_ih, self.bias_ih)
gh = F.linear(hx, self.weight_hh, self.bias_hh)
i_r, i_z, i_n = gi.chunk(3, 1)
h_r, h_z, h_n = gh.chunk(3, 1)
@@ -293,25 +293,25 @@ def __init__(self, input_size, hidden_size, bias=True, gru_type='AGRU'):
elif gru_type == 'AUGRU':
self.rnn = AUGRUCell(input_size, hidden_size, bias)
- def forward(self, input, att_scores=None, hx=None):
- if not isinstance(input, PackedSequence) or not isinstance(att_scores, PackedSequence):
+ def forward(self, inputs, att_scores=None, hx=None):
+ if not isinstance(inputs, PackedSequence) or not isinstance(att_scores, PackedSequence):
raise NotImplementedError("DynamicGRU only supports packed input and att_scores")
- input, batch_sizes, sorted_indices, unsorted_indices = input
+ inputs, batch_sizes, sorted_indices, unsorted_indices = inputs
att_scores, _, _, _ = att_scores
max_batch_size = int(batch_sizes[0])
if hx is None:
hx = torch.zeros(max_batch_size, self.hidden_size,
- dtype=input.dtype, device=input.device)
+ dtype=inputs.dtype, device=inputs.device)
- outputs = torch.zeros(input.size(0), self.hidden_size,
- dtype=input.dtype, device=input.device)
+ outputs = torch.zeros(inputs.size(0), self.hidden_size,
+ dtype=inputs.dtype, device=inputs.device)
begin = 0
for batch in batch_sizes:
new_hx = self.rnn(
- input[begin:begin + batch],
+ inputs[begin:begin + batch],
hx[0:batch],
att_scores[begin:begin + batch])
outputs[begin:begin + batch] = new_hx
diff --git a/deepctr_torch/models/__init__.py b/deepctr_torch/models/__init__.py
index 43381369..e72de07a 100644
--- a/deepctr_torch/models/__init__.py
+++ b/deepctr_torch/models/__init__.py
@@ -14,4 +14,5 @@
from .pnn import PNN
from .ccpm import CCPM
from .dien import DIEN
-from .din import DIN
\ No newline at end of file
+from .din import DIN
+from .afn import AFN
\ No newline at end of file
diff --git a/deepctr_torch/models/afn.py b/deepctr_torch/models/afn.py
new file mode 100644
index 00000000..61aa60da
--- /dev/null
+++ b/deepctr_torch/models/afn.py
@@ -0,0 +1,74 @@
+# -*- coding:utf-8 -*-
+"""
+Author:
+ Weiyu Cheng, weiyu_cheng@sjtu.edu.cn
+
+Reference:
+ [1] Cheng, W., Shen, Y. and Huang, L. 2020. Adaptive Factorization Network: Learning Adaptive-Order Feature
+ Interactions. Proceedings of the AAAI Conference on Artificial Intelligence. 34, 04 (Apr. 2020), 3609-3616.
+"""
+import torch
+import torch.nn as nn
+
+from .basemodel import BaseModel
+from ..layers import LogTransformLayer, DNN
+
+
+class AFN(BaseModel):
+ """Instantiates the Adaptive Factorization Network architecture.
+
+ In DeepCTR-Torch, we only provide the non-ensembled version of AFN for the consistency of model interfaces. For the ensembled version of AFN+, please refer to https://github.com/WeiyuCheng/DeepCTR-Torch (Pytorch Version) or https://github.com/WeiyuCheng/AFN-AAAI-20 (Tensorflow Version).
+
+ :param linear_feature_columns: An iterable containing all the features used by linear part of the model.
+ :param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
+ :param ltl_hidden_size: integer, the number of logarithmic neurons in AFN
+ :param afn_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of DNN layers in AFN
+ :param l2_reg_linear: float. L2 regularizer strength applied to linear part
+ :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
+ :param l2_reg_dnn: float. L2 regularizer strength applied to DNN
+ :param init_std: float,to use as the initialize std of embedding vector
+ :param seed: integer ,to use as random seed.
+ :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
+ :param dnn_activation: Activation function to use in DNN
+ :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
+ :param device: str, ``"cpu"`` or ``"cuda:0"``
+ :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
+ :return: A PyTorch model instance.
+
+ """
+
+ def __init__(self,
+ linear_feature_columns, dnn_feature_columns,
+ ltl_hidden_size=256, afn_dnn_hidden_units=(256, 128),
+ l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0,
+ init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu',
+ task='binary', device='cpu', gpus=None):
+
+ super(AFN, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
+ l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
+ device=device, gpus=gpus)
+
+ self.ltl = LogTransformLayer(len(self.embedding_dict), self.embedding_size, ltl_hidden_size)
+ self.afn_dnn = DNN(self.embedding_size * ltl_hidden_size, afn_dnn_hidden_units,
+ activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=True,
+ init_std=init_std, device=device)
+ self.afn_dnn_linear = nn.Linear(afn_dnn_hidden_units[-1], 1)
+ self.to(device)
+
+ def forward(self, X):
+
+ sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
+ self.embedding_dict)
+ logit = self.linear_model(X)
+ if len(sparse_embedding_list) == 0:
+ raise ValueError('Sparse embeddings not provided. AFN only accepts sparse embeddings as input.')
+
+ afn_input = torch.cat(sparse_embedding_list, dim=1)
+ ltl_result = self.ltl(afn_input)
+ afn_logit = self.afn_dnn(ltl_result)
+ afn_logit = self.afn_dnn_linear(afn_logit)
+
+ logit += afn_logit
+ y_pred = self.out(logit)
+
+ return y_pred
diff --git a/deepctr_torch/models/autoint.py b/deepctr_torch/models/autoint.py
index c39effb4..4f27656e 100644
--- a/deepctr_torch/models/autoint.py
+++ b/deepctr_torch/models/autoint.py
@@ -19,7 +19,6 @@ class AutoInt(BaseModel):
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param att_layer_num: int.The InteractingLayer number to be used.
- :param att_embedding_size: int.The embedding size in multi-head self-attention network.
:param att_head_num: int.The head number in multi-head self-attention network.
:param att_res: bool.Whether or not use standard residual connections before output.
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
@@ -37,28 +36,27 @@ class AutoInt(BaseModel):
"""
- def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, att_embedding_size=8, att_head_num=2,
- att_res=True,
- dnn_hidden_units=(256, 128), dnn_activation='relu',
+ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
+ att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu',
l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
task='binary', device='cpu', gpus=None):
super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
device=device, gpus=gpus)
-
if len(dnn_hidden_units) <= 0 and att_layer_num <= 0:
raise ValueError("Either hidden_layer or att_layer_num must > 0")
self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0
field_num = len(self.embedding_dict)
+ embedding_size = self.embedding_size
+
if len(dnn_hidden_units) and att_layer_num > 0:
- dnn_linear_in_feature = dnn_hidden_units[-1] + \
- field_num * att_embedding_size * att_head_num
+ dnn_linear_in_feature = dnn_hidden_units[-1] + field_num * embedding_size
elif len(dnn_hidden_units) > 0:
dnn_linear_in_feature = dnn_hidden_units[-1]
elif att_layer_num > 0:
- dnn_linear_in_feature = field_num * att_embedding_size * att_head_num
+ dnn_linear_in_feature = field_num * embedding_size
else:
raise NotImplementedError
@@ -72,8 +70,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.int_layers = nn.ModuleList(
- [InteractingLayer(self.embedding_size if i == 0 else att_embedding_size * att_head_num,
- att_embedding_size, att_head_num, att_res, device=device) for i in range(att_layer_num)])
+ [InteractingLayer(embedding_size, att_head_num, att_res, device=device) for _ in range(att_layer_num)])
self.to(device)
diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py
index bb9d1f7a..4235ad38 100644
--- a/deepctr_torch/models/basemodel.py
+++ b/deepctr_torch/models/basemodel.py
@@ -24,7 +24,7 @@
from tensorflow.python.keras._impl.keras.callbacks import CallbackList
from ..inputs import build_input_features, SparseFeat, DenseFeat, VarLenSparseFeat, get_varlen_pooling_list, \
- create_embedding_matrix
+ create_embedding_matrix, varlen_embedding_lookup
from ..layers import PredictionLayer
from ..layers.utils import slice_arrays
from ..callbacks import History
@@ -68,7 +68,9 @@ def forward(self, X, sparse_feat_refine_weight=None):
dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
self.dense_feature_columns]
- varlen_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
+ sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
+ self.varlen_sparse_feature_columns)
+ varlen_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
self.varlen_sparse_feature_columns, self.device)
sparse_embedding_list += varlen_embedding_list
@@ -126,9 +128,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e
self.out = PredictionLayer(task, )
self.to(device)
- # parameters of callbacks
- self._is_graph_network = True # used for ModelCheckpoint
- self.stop_training = False # used for EarlyStopping
+ # parameters for callbacks
+ self._is_graph_network = True # used for ModelCheckpoint in tf2
+ self._ckpt_saved_epoch = False # used for EarlyStopping in tf1.14
self.history = History()
def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoch=0, validation_split=0.,
@@ -216,9 +218,10 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
# configure callbacks
callbacks = (callbacks or []) + [self.history] # add history callback
callbacks = CallbackList(callbacks)
+ callbacks.set_model(self)
callbacks.on_train_begin()
callbacks.set_model(self)
- if not hasattr(callbacks, 'model'):
+ if not hasattr(callbacks, 'model'): # for tf1.4
callbacks.__setattr__('model', self)
callbacks.model.stop_training = False
@@ -359,7 +362,9 @@ def input_from_feature_columns(self, X, feature_columns, embedding_dict, support
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
feat in sparse_feature_columns]
- varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
+ sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
+ varlen_sparse_feature_columns)
+ varlen_sparse_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
varlen_sparse_feature_columns, self.device)
dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
@@ -490,6 +495,10 @@ def _get_metrics(self, metrics, set_eps=False):
self.metrics_names.append(metric)
return metrics_
+ def _in_multi_worker_mode(self):
+ # used for EarlyStopping in tf1.15
+ return None
+
@property
def embedding_size(self, ):
feature_columns = self.dnn_feature_columns
diff --git a/deepctr_torch/models/dien.py b/deepctr_torch/models/dien.py
index 6f37c1aa..917777f9 100644
--- a/deepctr_torch/models/dien.py
+++ b/deepctr_torch/models/dien.py
@@ -319,7 +319,7 @@ def __init__(self,
@staticmethod
def _get_last_state(states, keys_length):
# states [B, T, H]
- batch_size, max_seq_length, hidden_size = states.size()
+ batch_size, max_seq_length, _ = states.size()
mask = (torch.arange(max_seq_length, device=keys_length.device).repeat(
batch_size, 1) == (keys_length.view(-1, 1) - 1))
diff --git a/deepctr_torch/models/difm.py b/deepctr_torch/models/difm.py
index 13a3aaab..c0dc7a0a 100644
--- a/deepctr_torch/models/difm.py
+++ b/deepctr_torch/models/difm.py
@@ -18,6 +18,8 @@ class DIFM(BaseModel):
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
+ :param att_head_num: int. The head number in multi-head self-attention network.
+ :param att_res: bool. Whether or not use standard residual connections before output.
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
@@ -35,35 +37,33 @@ class DIFM(BaseModel):
"""
def __init__(self,
- linear_feature_columns, dnn_feature_columns, att_embedding_size=8, att_head_num=8,
+ linear_feature_columns, dnn_feature_columns, att_head_num=4,
att_res=True, dnn_hidden_units=(256, 128),
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
dnn_dropout=0,
dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
super(DIFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
- l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
- device=device, gpus=gpus)
+ l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
+ device=device, gpus=gpus)
if not len(dnn_hidden_units) > 0:
raise ValueError("dnn_hidden_units is null!")
- self.use_dnn = len(dnn_feature_columns) > 0 and len(
- dnn_hidden_units) > 0
self.fm = FM()
# InteractingLayer (used in AutoInt) = multi-head self-attention + Residual Network
- self.vector_wise_net = InteractingLayer(self.embedding_size, att_embedding_size,
- att_head_num, att_res, scaling=True, device=device)
+ self.vector_wise_net = InteractingLayer(self.embedding_size, att_head_num,
+ att_res, scaling=True, device=device)
self.bit_wise_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
- dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
- dropout_rate=dnn_dropout,
- use_bn=dnn_use_bn, init_std=init_std, device=device)
+ dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
+ dropout_rate=dnn_dropout,
+ use_bn=dnn_use_bn, init_std=init_std, device=device)
self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat),
dnn_feature_columns)))
self.transform_matrix_P_vec = nn.Linear(
- self.sparse_feat_num*att_embedding_size*att_head_num, self.sparse_feat_num, bias=False).to(device)
+ self.sparse_feat_num * self.embedding_size, self.sparse_feat_num, bias=False).to(device)
self.transform_matrix_P_bit = nn.Linear(
dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device)
@@ -80,7 +80,7 @@ def __init__(self,
def forward(self, X):
sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
- self.embedding_dict)
+ self.embedding_dict)
if not len(sparse_embedding_list) > 0:
raise ValueError("there are no sparse features")
diff --git a/deepctr_torch/models/ifm.py b/deepctr_torch/models/ifm.py
index 4f057833..6105a235 100644
--- a/deepctr_torch/models/ifm.py
+++ b/deepctr_torch/models/ifm.py
@@ -47,8 +47,6 @@ def __init__(self,
if not len(dnn_hidden_units) > 0:
raise ValueError("dnn_hidden_units is null!")
- self.use_dnn = len(dnn_feature_columns) > 0 and len(
- dnn_hidden_units) > 0
self.fm = FM()
self.factor_estimating_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
@@ -68,7 +66,7 @@ def __init__(self,
self.to(device)
def forward(self, X):
- sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
+ sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
self.embedding_dict)
if not len(sparse_embedding_list) > 0:
raise ValueError("there are no sparse features")
diff --git a/docs/pics/AFN.jpg b/docs/pics/AFN.jpg
new file mode 100644
index 00000000..9c594220
Binary files /dev/null and b/docs/pics/AFN.jpg differ
diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md
index a7a4eb6e..102e35bc 100644
--- a/docs/source/FAQ.md
+++ b/docs/source/FAQ.md
@@ -33,7 +33,7 @@ model = DeepFM(linear_feature_columns,dnn_feature_columns)
model.compile(Adagrad(model.parameters(),0.1024),'binary_crossentropy',metrics=['binary_crossentropy'])
es = EarlyStopping(monitor='val_binary_crossentropy', min_delta=0, verbose=1, patience=0, mode='min')
-mdckpt = ModelCheckpoint(filepath = 'model.ckpt', save_best_only= True)
+mdckpt = ModelCheckpoint(filepath='model.ckpt', monitor='val_binary_crossentropy', verbose=1, save_best_only=True, mode='min')
history = model.fit(model_input,data[target].values,batch_size=256,epochs=10,verbose=2,validation_split=0.2,callbacks=[es,mdckpt])
print(history)
```
diff --git a/docs/source/Features.md b/docs/source/Features.md
index 2aaf6787..f7bc9827 100644
--- a/docs/source/Features.md
+++ b/docs/source/Features.md
@@ -261,6 +261,14 @@ Dual Inputaware Factorization Machines (DIFM) can adaptively reweight the origin
[Lu W, Yu Y, Chang Y, et al. A Dual Input-aware Factorization Machine for CTR Prediction[C]//IJCAI. 2020: 3139-3145.](https://www.ijcai.org/Proceedings/2020/0434.pdf)
+### AFN(Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions)
+
+Adaptive Factorization Network (AFN) can learn arbitrary-order cross features adaptively from data. The core of AFN is a logarith- mic transformation layer to convert the power of each feature in a feature combination into the coefficient to be learned.
+[**AFN Model API**](./deepctr_torch.models.afn.html)
+
+![AFN](../pics/AFN.jpg)
+
+[Cheng, W., Shen, Y. and Huang, L. 2020. Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions. Proceedings of the AAAI Conference on Artificial Intelligence. 34, 04 (Apr. 2020), 3609-3616.](https://arxiv.org/pdf/1909.03276)
## Layers
diff --git a/docs/source/History.md b/docs/source/History.md
index eef2f07b..ec68a102 100644
--- a/docs/source/History.md
+++ b/docs/source/History.md
@@ -1,5 +1,6 @@
# History
-- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)).
+- 06/14/2021 : [v0.2.7](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [AFN](./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions) and fix some bugs.
+- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)).
- 02/12/2021 : [v0.2.5](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.5) released.Fix bug in DCN-M.
- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
- 10/18/2020 : [v0.2.3](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.3) released.Add [DCN-M](./Features.html#dcn-deep-cross-network)&[DCN-Mix](./Features.html#dcn-mix-improved-deep-cross-network-with-mix-of-experts-and-matrix-kernel).Add EarlyStopping and ModelCheckpoint callbacks([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).
diff --git a/docs/source/conf.py b/docs/source/conf.py
index d43d0eea..e99b48ea 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
-release = '0.2.6'
+release = '0.2.7'
# -- General configuration ---------------------------------------------------
diff --git a/docs/source/deepctr_torch.models.afn.rst b/docs/source/deepctr_torch.models.afn.rst
new file mode 100644
index 00000000..be4949b0
--- /dev/null
+++ b/docs/source/deepctr_torch.models.afn.rst
@@ -0,0 +1,7 @@
+deepctr\_torch.models.afn module
+================================
+
+.. automodule:: deepctr_torch.models.afn
+ :members:
+ :no-undoc-members:
+ :no-show-inheritance:
diff --git a/docs/source/deepctr_torch.models.rst b/docs/source/deepctr_torch.models.rst
index 599710b6..ff2d555e 100644
--- a/docs/source/deepctr_torch.models.rst
+++ b/docs/source/deepctr_torch.models.rst
@@ -23,6 +23,7 @@ Submodules
deepctr_torch.models.dien
deepctr_torch.models.ifm
deepctr_torch.models.difm
+ deepctr_torch.models.afn
Module contents
---------------
diff --git a/docs/source/index.rst b/docs/source/index.rst
index bc4d2b1d..1701d403 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -34,16 +34,16 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and
News
-----
+06/14/2021 : Add `AFN <./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions>`_ and fix some bugs. `Changelog `_
+
04/04/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ and `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ . Support multi-gpus running(`example <./FAQ.html#how-to-run-the-demo-with-multiple-gpus>`_). `Changelog `_
02/12/2021 : Fix bug in DCN-M. `Changelog `_
-12/05/2020 : Imporve compatibility & fix issues.Add History callback(`example `_). `Changelog `_
-
DisscussionGroup
-----------------------
-公众号:**浅梦的学习笔记** wechat ID: **deepctrbot**
+公众号:**浅梦学习笔记** wechat ID: **deepctrbot**
.. image:: ../pics/code.png
diff --git a/setup.py b/setup.py
index 7060df42..4e44fe13 100644
--- a/setup.py
+++ b/setup.py
@@ -9,7 +9,7 @@
setuptools.setup(
name="deepctr-torch",
- version="0.2.6",
+ version="0.2.7",
author="Weichen Shen",
author_email="weichenswc@163.com",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch",
diff --git a/tests/models/AFN_test.py b/tests/models/AFN_test.py
new file mode 100644
index 00000000..dce5b207
--- /dev/null
+++ b/tests/models/AFN_test.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+from deepctr_torch.models import AFN
+from tests.utils import get_test_data, SAMPLE_SIZE, check_model, get_device
+
+
+@pytest.mark.parametrize(
+ 'afn_dnn_hidden_units, sparse_feature_num, dense_feature_num',
+ [((256, 128), 3, 0),
+ ((256, 128), 3, 3),
+ ((256, 128), 0, 3)]
+)
+def test_AFN(afn_dnn_hidden_units, sparse_feature_num, dense_feature_num):
+ model_name = 'AFN'
+ sample_size = SAMPLE_SIZE
+ x, y, feature_columns = get_test_data(
+ sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num)
+
+ model = AFN(feature_columns, feature_columns, afn_dnn_hidden_units=afn_dnn_hidden_units, device=get_device())
+
+ check_model(model, model_name, x, y)
+
+
+if __name__ == '__main__':
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index 4c79631e..10abcecb 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -67,7 +67,7 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens
def layer_test(layer_cls, kwargs = {}, input_shape=None,
- input_dtype=torch.float32, input_data=None, expected_output=None,
+ input_dtype=torch.float32, input_data=None, expected_output=None,
expected_output_shape=None, expected_output_dtype=None, fixed_batch_size=False):
'''check layer is valid or not
@@ -80,12 +80,12 @@ def layer_test(layer_cls, kwargs = {}, input_shape=None,
:param fixed_batch_size:
:return: output of the layer
- '''
+ '''
if input_data is None:
# generate input data
if not input_shape:
raise ValueError("input shape should not be none")
-
+
input_data_shape = list(input_shape)
for i, e in enumerate(input_data_shape):
if e is None:
@@ -112,12 +112,12 @@ def layer_test(layer_cls, kwargs = {}, input_shape=None,
layer = layer_cls(**kwargs)
if fixed_batch_size:
- input = torch.tensor(input_data.unsqueeze(0), dtype=input_dtype)
+ inputs = torch.tensor(input_data.unsqueeze(0), dtype=input_dtype)
else:
- input = torch.tensor(input_data, dtype=input_dtype)
+ inputs = torch.tensor(input_data, dtype=input_dtype)
# calculate layer's output
- output = layer(input)
+ output = layer(inputs)
if not output.dtype == expected_output_dtype:
raise AssertionError("layer output dtype does not match with the expected one")
|