-
Notifications
You must be signed in to change notification settings - Fork 714
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
447 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,21 +4,26 @@ | |
Weichen Shen,[email protected] | ||
""" | ||
|
||
from collections import OrderedDict, namedtuple | ||
from collections import OrderedDict, namedtuple, defaultdict | ||
from itertools import chain | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .layers.utils import concat_fun | ||
from .layers.sequence import SequencePoolingLayer | ||
|
||
DEFAULT_GROUP_NAME = "default_group" | ||
|
||
class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding'])): | ||
|
||
class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding', 'group_name'])): | ||
__slots__ = () | ||
|
||
def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=None, embedding=True): | ||
def __new__(cls, name, dimension, use_hash=False, dtype="int32", | ||
embedding_name=None, embedding=True, group_name=DEFAULT_GROUP_NAME): | ||
if embedding and embedding_name is None: | ||
embedding_name = name | ||
return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding) | ||
return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding, group_name) | ||
|
||
|
||
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])): | ||
|
@@ -30,15 +35,16 @@ def __new__(cls, name, dimension=1, dtype="float32"): | |
|
||
class VarLenSparseFeat(namedtuple('VarLenFeat', | ||
['name', 'dimension', 'maxlen', 'combiner', 'use_hash', 'dtype', 'embedding_name', | ||
'embedding'])): | ||
'embedding', 'group_name'])): | ||
__slots__ = () | ||
|
||
def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype="float32", embedding_name=None, | ||
embedding=True): | ||
embedding=True, group_name=DEFAULT_GROUP_NAME): | ||
if embedding_name is None: | ||
embedding_name = name | ||
return super(VarLenSparseFeat, cls).__new__(cls, name, dimension, maxlen, combiner, use_hash, dtype, | ||
embedding_name, embedding) | ||
|
||
embedding_name, embedding, group_name) | ||
|
||
|
||
def get_feature_names(feature_columns): | ||
|
@@ -50,6 +56,9 @@ def get_inputs_list(inputs): | |
|
||
|
||
def build_input_features(feature_columns): | ||
|
||
# Return OrderedDict: {feature_name:(start, start+dimension)} | ||
|
||
features = OrderedDict() | ||
|
||
start = 0 | ||
|
@@ -92,4 +101,70 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list): | |
elif len(dense_value_list) > 0: | ||
return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1) | ||
else: | ||
raise NotImplementedError | ||
raise NotImplementedError | ||
|
||
|
||
def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(), | ||
mask_feat_list=(), to_list=False): | ||
""" | ||
Args: | ||
sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding} | ||
sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)} | ||
sparse_feature_columns: list, sparse features | ||
return_feat_list: list, names of feature to be returned, defualt () -> return all features | ||
mask_feat_list, list, names of feature to be masked in hash transform | ||
Return: | ||
group_embedding_dict: defaultdict(list) | ||
""" | ||
group_embedding_dict = defaultdict(list) | ||
for fc in sparse_feature_columns: | ||
feature_name = fc.name | ||
embedding_name = fc.embedding_name | ||
if (len(return_feat_list) == 0 or feature_name in return_feat_list): | ||
if fc.use_hash: | ||
# lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))( | ||
# sparse_input_dict[feature_name]) | ||
# TODO: add hash function | ||
lookup_idx = sparse_input_dict[feature_name] | ||
else: | ||
lookup_idx = sparse_input_dict[feature_name] | ||
|
||
group_embedding_dict[fc.group_name].append(sparse_embedding_dict[embedding_name](lookup_idx)) | ||
if to_list: | ||
return list(chain.from_iterable(group_embedding_dict.values())) | ||
return group_embedding_dict | ||
|
||
|
||
def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_feature_columns): | ||
varlen_embedding_vec_dict = {} | ||
for fc in varlen_sparse_feature_columns: | ||
feature_name = fc.name | ||
embedding_name = fc.embedding_name | ||
if fc.use_hash: | ||
# lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name]) | ||
# TODO: add hash function | ||
lookup_idx = sequence_input_dict[feature_name] | ||
else: | ||
lookup_idx = sequence_input_dict[feature_name] | ||
varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx) | ||
return varlen_embedding_vec_dict | ||
|
||
|
||
def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_columns, to_list=False): | ||
pooling_vec_list = defaultdict(list) | ||
for fc in varlen_sparse_feature_columns: | ||
feature_name = fc.name | ||
combiner = fc.combiner | ||
feature_length_name = fc.length_name | ||
if feature_length_name is not None: | ||
seq_input = embedding_dict[feature_name] | ||
vec = SequencePoolingLayer(combiner)([seq_input, features[feature_length_name]]) | ||
else: | ||
seq_input = embedding_dict[feature_name] | ||
vec = SequencePoolingLayer(combiner)(seq_input) | ||
pooling_vec_list[fc.group_name].append(vec) | ||
|
||
if to_list: | ||
return chain.from_iterable(pooling_vec_list.values()) | ||
|
||
return pooling_vec_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# -*- coding:utf-8 -*- | ||
""" | ||
Author: | ||
Yuef Zhang | ||
""" | ||
import sys | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | ||
|
||
class Dice(nn.Module): | ||
"""The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data. | ||
Input shape: | ||
- 2 dims: [batch_size, embedding_size(features)] | ||
- 3 dims: [batch_size, num_features, embedding_size(features)] | ||
Output shape: | ||
- Same shape as the input. | ||
References | ||
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) | ||
- https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch | ||
""" | ||
def __init__(self, num_features, dim=2, epsilon=1e-9): | ||
super(Dice, self).__init__() | ||
assert dim == 2 or dim == 3 | ||
self.bn = nn.BatchNorm1d(num_features, eps=epsilon) | ||
self.sigmoid = nn.Sigmoid() | ||
self.dim = dim | ||
|
||
if self.dim == 2: | ||
self.alpha = torch.zeros((num_features,)).to(device) | ||
else: | ||
self.alpha = torch.zeros((num_features, 1)).to(device) | ||
|
||
def forward(self, x): | ||
# x shape: [batch_size, num_features, embedding_size(features)] | ||
assert x.dim() == 2 or x.dim() == 3 | ||
|
||
if self.dim == 2: | ||
x_p = self.sigmoid(self.bn(x)) | ||
out = self.alpha * (1 - x_p) * x + x_p * x | ||
else: | ||
x = torch.transpose(x, 1, 2) | ||
x_p = self.sigmoid(self.bn(x)) | ||
out = self.alpha * (1 - x_p) * x + x_p * x | ||
out = torch.transpose(out, 1, 2) | ||
|
||
return out | ||
|
||
|
||
def activation_layer(act_name, hidden_size=None, dice_dim=2): | ||
"""Construct activation layers | ||
Args: | ||
act_name: str or nn.Module, name of activation function | ||
hidden_size: int, used for Dice activation | ||
dice_dim: int, used for Dice activation | ||
Return: | ||
act_layer: activation layer | ||
""" | ||
if isinstance(act_name, str): | ||
if act_name.lower() == 'relu' or 'linear': | ||
act_layer = nn.ReLU(inplace=True) | ||
elif act_name.lower() == 'dice': | ||
assert dice_dim | ||
act_layer = Dice(hidden_size, dice_dim) | ||
elif act_name.lower() == 'prelu': | ||
act_layer = nn.PReLU() | ||
elif issubclass(act_name, nn.Module): | ||
act_layer = act_name() | ||
else: | ||
raise NotImplementedError | ||
|
||
return act_layer | ||
|
||
|
||
if __name__ == "__main__": | ||
torch.manual_seed(7) | ||
a = Dice(3) | ||
b = torch.rand((5, 3)) | ||
c = a(b) | ||
print(c.size()) | ||
print('b:', b) | ||
print('c:', c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.