Skip to content

Commit

Permalink
v0.2.7
Browse files Browse the repository at this point in the history
- add AFN model
- fix error in VarLenSparseFeat and ModelCheckpoint callback
- change output shape of InteractingLayer in Autoint
  • Loading branch information
shenweichen authored Jun 14, 2021
2 parents 8265c75 + 74da9d6 commit b4d8181
Show file tree
Hide file tree
Showing 26 changed files with 276 additions and 151 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Steps to reproduce the behavior:
**Operating environment(运行环境):**
- python version [e.g. 3.5, 3.6]
- torch version [e.g. 1.6.0, 1.7.0]
- deepctr-torch version [e.g. 0.2.4,]
- deepctr-torch version [e.g. 0.2.7,]

**Additional context**
Add any other context about the problem here.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Add any other context about the problem here.
**Operating environment(运行环境):**
- python version [e.g. 3.6]
- torch version [e.g. 1.7.0,]
- deepctr-torch version [e.g. 0.2.4,]
- deepctr-torch version [e.g. 0.2.7,]
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

[![Documentation Status](https://readthedocs.org/projects/deepctr-torch/badge/?version=latest)](https://deepctr-torch.readthedocs.io/)
![CI status](https://github.com/shenweichen/deepctr-torch/workflows/CI/badge.svg)
[![codecov](https://codecov.io/gh/shenweichen/DeepCTR-Torch/branch/master/graph/badge.svg)](https://codecov.io/gh/shenweichen/DeepCTR-Torch)
[![codecov](https://codecov.io/gh/shenweichen/DeepCTR-Torch/branch/master/graph/badge.svg?token=m6v89eYOjp)](https://codecov.io/gh/shenweichen/DeepCTR-Torch)
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](./README.md#disscussiongroup)
[![License](https://img.shields.io/github/license/shenweichen/deepctr-torch.svg)](https://github.com/shenweichen/deepctr-torch/blob/master/LICENSE)

Expand Down Expand Up @@ -41,6 +41,8 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) |
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) |
| AFN | [AAAI 2020][Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions](https://arxiv.org/pdf/1909.03276) |



## DisscussionGroup & Related Projects
Expand All @@ -49,7 +51,7 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
<table style="margin-left: 20px; margin-right: auto;">
<tr>
<td>
公众号:<b>浅梦的学习笔记</b><br><br>
公众号:<b>浅梦学习笔记</b><br><br>
<a href="https://github.com/shenweichen/deepctr-torch">
<img align="center" src="./docs/pics/code.png" />
</a>
Expand All @@ -74,7 +76,7 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St



## Contributors([welcome to join us!](./CONTRIBUTING.md))
## Main Contributors([welcome to join us!](./CONTRIBUTING.md))

<table border="0">
<tbody>
Expand Down Expand Up @@ -125,18 +127,18 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
<p>Dev<br>
NetEase <br> <br> </p>​
</td>
<td>
​ <a href="https://github.com/WeiyuCheng"><img width="70" height="70" src="https://github.com/WeiyuCheng.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/WeiyuCheng">Cheng Weiyu</a> ​
<p>Dev<br>
Shanghai Jiao Tong University</p>​
</td>
<td>
​ <a href="https://github.com/tangaqi"><img width="70" height="70" src="https://github.com/tangaqi.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/tangaqi">Tang</a>
<p>Test<br>
Tongji University <br> <br> </p>​
</td>
<td>
​ <a href="https://github.com/uestc7d"><img width="70" height="70" src="https://github.com/uestc7d.png?s=40" alt="pic"></a><br>
​ <a href="https://github.com/uestc7d">Xu Qidi</a> ​
<p>Dev<br>
University of <br> Electronic Science and <br> Technology of China</p>​
</td>
</tr>
</tbody>
</table>
2 changes: 1 addition & 1 deletion deepctr_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.2.6'
__version__ = '0.2.7'
check_version(__version__)
38 changes: 6 additions & 32 deletions deepctr_torch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def vocabulary_size(self):
def embedding_dim(self):
return self.sparsefeat.embedding_dim

@property
def use_hash(self):
return self.sparsefeat.use_hash

@property
def dtype(self):
return self.sparsefeat.dtype
Expand Down Expand Up @@ -136,18 +140,15 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list):

def get_varlen_pooling_list(embedding_dict, features, feature_index, varlen_sparse_feature_columns, device):
varlen_sparse_embedding_list = []

for feat in varlen_sparse_feature_columns:
seq_emb = embedding_dict[feat.embedding_name](
features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long())
seq_emb = embedding_dict[feat.name]
if feat.length_name is None:
seq_mask = features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long() != 0

emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=True, device=device)(
[seq_emb, seq_mask])
else:
seq_length = features[:,
feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
seq_length = features[:, feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=False, device=device)(
[seq_emb, seq_length])
varlen_sparse_embedding_list.append(emb)
Expand Down Expand Up @@ -179,33 +180,6 @@ def create_embedding_matrix(feature_columns, init_std=0.0001, linear=False, spar
return embedding_dict.to(device)


def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True):
sparse_feature_columns = list(
filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
dense_feature_columns = list(
filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []

varlen_sparse_feature_columns = list(
filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []

if not support_dense and len(dense_feature_columns) > 0:
raise ValueError(
"DenseFeat is not supported in dnn_feature_columns")

sparse_embedding_list = [embedding_dict[feat.embedding_name](
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
feat in sparse_feature_columns]

varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
varlen_sparse_feature_columns, self.device)

dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
dense_feature_columns]

return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list



def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
mask_feat_list=(), to_list=False):
"""
Expand Down
7 changes: 3 additions & 4 deletions deepctr_torch/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Dice(nn.Module):
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
- https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch
"""

def __init__(self, emb_size, dim=2, epsilon=1e-8, device='cpu'):
super(Dice, self).__init__()
assert dim == 2 or dim == 3
Expand All @@ -41,18 +42,16 @@ def forward(self, x):
x_p = self.sigmoid(self.bn(x))
out = self.alpha * (1 - x_p) * x + x_p * x
out = torch.transpose(out, 1, 2)

return out


class Identity(nn.Module):


def __init__(self, **kwargs):
super(Identity, self).__init__()

def forward(self, X):
return X
def forward(self, inputs):
return inputs


def activation_layer(act_name, hidden_size=None, dice_dim=2):
Expand Down
100 changes: 64 additions & 36 deletions deepctr_torch/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, filed_size, embedding_size, bilinear_type="interaction", seed
self.bilinear.append(
nn.Linear(embedding_size, embedding_size, bias=False))
elif self.bilinear_type == "interaction":
for i, j in itertools.combinations(range(filed_size), 2):
for _, _ in itertools.combinations(range(filed_size), 2):
self.bilinear.append(
nn.Linear(embedding_size, embedding_size, bias=False))
else:
Expand Down Expand Up @@ -330,41 +330,34 @@ class InteractingLayer(nn.Module):
Input shape
- A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
Output shape
- 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``.
- 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
Arguments
- **in_features** : Positive integer, dimensionality of input features.
- **att_embedding_size**: int.The embedding size in multi-head self-attention network.
- **head_num**: int.The head number in multi-head self-attention network.
- **head_num**: int.The head number in multi-head self-attention network.
- **use_res**: bool.Whether or not use standard residual connections before output.
- **seed**: A Python integer to use as random seed.
References
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""

def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
def __init__(self, embedding_size, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
super(InteractingLayer, self).__init__()
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
self.att_embedding_size = att_embedding_size
if embedding_size % head_num != 0:
raise ValueError('embedding_size is not an integer multiple of head_num!')
self.att_embedding_size = embedding_size // head_num
self.head_num = head_num
self.use_res = use_res
self.scaling = scaling
self.seed = seed

embedding_size = in_features

self.W_Query = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))

self.W_key = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))

self.W_Value = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))
self.W_Query = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
self.W_key = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
self.W_Value = nn.Parameter(torch.Tensor(embedding_size, embedding_size))

if self.use_res:
self.W_Res = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))
self.W_Res = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
for tensor in self.parameters():
nn.init.normal_(tensor, mean=0.0, std=0.05)

Expand All @@ -376,29 +369,24 @@ def forward(self, inputs):
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))

querys = torch.tensordot(inputs, self.W_Query,
dims=([-1], [0])) # None F D*head_num
# None F D
querys = torch.tensordot(inputs, self.W_Query, dims=([-1], [0]))
keys = torch.tensordot(inputs, self.W_key, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))

# head_num None F D

querys = torch.stack(torch.split(
querys, self.att_embedding_size, dim=2))
# head_num None F D/head_num
querys = torch.stack(torch.split(querys, self.att_embedding_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
values = torch.stack(torch.split(
values, self.att_embedding_size, dim=2))
inner_product = torch.einsum(
'bnik,bnjk->bnij', querys, keys) # head_num None F F
values = torch.stack(torch.split(values, self.att_embedding_size, dim=2))

inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys) # head_num None F F
if self.scaling:
inner_product /= self.att_embedding_size ** 0.5
self.normalized_att_scores = F.softmax(
inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores,
values) # head_num None F D
self.normalized_att_scores = F.softmax(inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores, values) # head_num None F D/head_num

result = torch.cat(torch.split(result, 1, ), dim=-1)
result = torch.squeeze(result, dim=0) # None F D*head_num
result = torch.squeeze(result, dim=0) # None F D
if self.use_res:
result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
result = F.relu(result)
Expand Down Expand Up @@ -499,9 +487,9 @@ def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device=
self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))

init_para_list = [self.U_list, self.V_list, self.C_list]
for i in range(len(init_para_list)):
for j in range(self.layer_num):
nn.init.xavier_normal_(init_para_list[i][j])
for para in init_para_list:
for i in range(self.layer_num):
nn.init.xavier_normal_(para[i])

for i in range(len(self.bias)):
nn.init.zeros_(self.bias[i])
Expand Down Expand Up @@ -727,3 +715,43 @@ def __init__(self, field_size, conv_kernel_width, conv_filters, device='cpu'):

def forward(self, inputs):
return self.conv_layer(inputs)


class LogTransformLayer(nn.Module):
"""Logarithmic Transformation Layer in Adaptive factorization network, which models arbitrary-order cross features.
Input shape
- 3D tensor with shape: ``(batch_size, field_size, embedding_size)``.
Output shape
- 2D tensor with shape: ``(batch_size, ltl_hidden_size*embedding_size)``.
Arguments
- **field_size** : positive integer, number of feature groups
- **embedding_size** : positive integer, embedding size of sparse features
- **ltl_hidden_size** : integer, the number of logarithmic neurons in AFN
References
- Cheng, W., Shen, Y. and Huang, L. 2020. Adaptive Factorization Network: Learning Adaptive-Order Feature
Interactions. Proceedings of the AAAI Conference on Artificial Intelligence. 34, 04 (Apr. 2020), 3609-3616.
"""

def __init__(self, field_size, embedding_size, ltl_hidden_size):
super(LogTransformLayer, self).__init__()

self.ltl_weights = nn.Parameter(torch.Tensor(field_size, ltl_hidden_size))
self.ltl_biases = nn.Parameter(torch.Tensor(1, 1, ltl_hidden_size))
self.bn = nn.ModuleList([nn.BatchNorm1d(embedding_size) for i in range(2)])
nn.init.normal_(self.ltl_weights, mean=0.0, std=0.1)
nn.init.zeros_(self.ltl_biases, )

def forward(self, inputs):
# Avoid numeric overflow
afn_input = torch.clamp(torch.abs(inputs), min=1e-7, max=float("Inf"))
# Transpose to shape: ``(batch_size,embedding_size,field_size)``
afn_input_trans = torch.transpose(afn_input, 1, 2)
# Logarithmic transformation layer
ltl_result = torch.log(afn_input_trans)
ltl_result = self.bn[0](ltl_result)
ltl_result = torch.matmul(ltl_result, self.ltl_weights) + self.ltl_biases
ltl_result = torch.exp(ltl_result)
ltl_result = self.bn[1](ltl_result)
ltl_result = torch.flatten(ltl_result, start_dim=1)
return ltl_result
Loading

0 comments on commit b4d8181

Please sign in to comment.