Skip to content

Commit

Permalink
bugfix&add some layers
Browse files Browse the repository at this point in the history
  • Loading branch information
浅梦 authored Jan 29, 2020
2 parents 687a094 + 364168a commit e2d2365
Show file tree
Hide file tree
Showing 23 changed files with 447 additions and 70 deletions.
91 changes: 83 additions & 8 deletions deepctr_torch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@
Weichen Shen,[email protected]
"""

from collections import OrderedDict, namedtuple
from collections import OrderedDict, namedtuple, defaultdict
from itertools import chain

import torch
import torch.nn as nn

from .layers.utils import concat_fun
from .layers.sequence import SequencePoolingLayer

DEFAULT_GROUP_NAME = "default_group"

class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding'])):

class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype', 'embedding_name', 'embedding', 'group_name'])):
__slots__ = ()

def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=None, embedding=True):
def __new__(cls, name, dimension, use_hash=False, dtype="int32",
embedding_name=None, embedding=True, group_name=DEFAULT_GROUP_NAME):
if embedding and embedding_name is None:
embedding_name = name
return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding)
return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name, embedding, group_name)


class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):
Expand All @@ -30,15 +35,16 @@ def __new__(cls, name, dimension=1, dtype="float32"):

class VarLenSparseFeat(namedtuple('VarLenFeat',
['name', 'dimension', 'maxlen', 'combiner', 'use_hash', 'dtype', 'embedding_name',
'embedding'])):
'embedding', 'group_name'])):
__slots__ = ()

def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype="float32", embedding_name=None,
embedding=True):
embedding=True, group_name=DEFAULT_GROUP_NAME):
if embedding_name is None:
embedding_name = name
return super(VarLenSparseFeat, cls).__new__(cls, name, dimension, maxlen, combiner, use_hash, dtype,
embedding_name, embedding)

embedding_name, embedding, group_name)


def get_feature_names(feature_columns):
Expand All @@ -50,6 +56,9 @@ def get_inputs_list(inputs):


def build_input_features(feature_columns):

# Return OrderedDict: {feature_name:(start, start+dimension)}

features = OrderedDict()

start = 0
Expand Down Expand Up @@ -92,4 +101,70 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list):
elif len(dense_value_list) > 0:
return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1)
else:
raise NotImplementedError
raise NotImplementedError


def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
mask_feat_list=(), to_list=False):
"""
Args:
sparse_embedding_dict: nn.ModuleDict, {embedding_name: nn.Embedding}
sparse_input_dict: OrderedDict, {feature_name:(start, start+dimension)}
sparse_feature_columns: list, sparse features
return_feat_list: list, names of feature to be returned, defualt () -> return all features
mask_feat_list, list, names of feature to be masked in hash transform
Return:
group_embedding_dict: defaultdict(list)
"""
group_embedding_dict = defaultdict(list)
for fc in sparse_feature_columns:
feature_name = fc.name
embedding_name = fc.embedding_name
if (len(return_feat_list) == 0 or feature_name in return_feat_list):
if fc.use_hash:
# lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))(
# sparse_input_dict[feature_name])
# TODO: add hash function
lookup_idx = sparse_input_dict[feature_name]
else:
lookup_idx = sparse_input_dict[feature_name]

group_embedding_dict[fc.group_name].append(sparse_embedding_dict[embedding_name](lookup_idx))
if to_list:
return list(chain.from_iterable(group_embedding_dict.values()))
return group_embedding_dict


def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_feature_columns):
varlen_embedding_vec_dict = {}
for fc in varlen_sparse_feature_columns:
feature_name = fc.name
embedding_name = fc.embedding_name
if fc.use_hash:
# lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name])
# TODO: add hash function
lookup_idx = sequence_input_dict[feature_name]
else:
lookup_idx = sequence_input_dict[feature_name]
varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx)
return varlen_embedding_vec_dict


def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_columns, to_list=False):
pooling_vec_list = defaultdict(list)
for fc in varlen_sparse_feature_columns:
feature_name = fc.name
combiner = fc.combiner
feature_length_name = fc.length_name
if feature_length_name is not None:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner)([seq_input, features[feature_length_name]])
else:
seq_input = embedding_dict[feature_name]
vec = SequencePoolingLayer(combiner)(seq_input)
pooling_vec_list[fc.group_name].append(vec)

if to_list:
return chain.from_iterable(pooling_vec_list.values())

return pooling_vec_list
92 changes: 92 additions & 0 deletions deepctr_torch/layers/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding:utf-8 -*-
"""
Author:
Yuef Zhang
"""
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

class Dice(nn.Module):
"""The Data Adaptive Activation Function in DIN,which can be viewed as a generalization of PReLu and can adaptively adjust the rectified point according to distribution of input data.
Input shape:
- 2 dims: [batch_size, embedding_size(features)]
- 3 dims: [batch_size, num_features, embedding_size(features)]
Output shape:
- Same shape as the input.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
- https://github.com/zhougr1993/DeepInterestNetwork, https://github.com/fanoping/DIN-pytorch
"""
def __init__(self, num_features, dim=2, epsilon=1e-9):
super(Dice, self).__init__()
assert dim == 2 or dim == 3
self.bn = nn.BatchNorm1d(num_features, eps=epsilon)
self.sigmoid = nn.Sigmoid()
self.dim = dim

if self.dim == 2:
self.alpha = torch.zeros((num_features,)).to(device)
else:
self.alpha = torch.zeros((num_features, 1)).to(device)

def forward(self, x):
# x shape: [batch_size, num_features, embedding_size(features)]
assert x.dim() == 2 or x.dim() == 3

if self.dim == 2:
x_p = self.sigmoid(self.bn(x))
out = self.alpha * (1 - x_p) * x + x_p * x
else:
x = torch.transpose(x, 1, 2)
x_p = self.sigmoid(self.bn(x))
out = self.alpha * (1 - x_p) * x + x_p * x
out = torch.transpose(out, 1, 2)

return out


def activation_layer(act_name, hidden_size=None, dice_dim=2):
"""Construct activation layers
Args:
act_name: str or nn.Module, name of activation function
hidden_size: int, used for Dice activation
dice_dim: int, used for Dice activation
Return:
act_layer: activation layer
"""
if isinstance(act_name, str):
if act_name.lower() == 'relu' or 'linear':
act_layer = nn.ReLU(inplace=True)
elif act_name.lower() == 'dice':
assert dice_dim
act_layer = Dice(hidden_size, dice_dim)
elif act_name.lower() == 'prelu':
act_layer = nn.PReLU()
elif issubclass(act_name, nn.Module):
act_layer = act_name()
else:
raise NotImplementedError

return act_layer


if __name__ == "__main__":
torch.manual_seed(7)
a = Dice(3)
b = torch.rand((5, 3))
c = a(b)
print(c.size())
print('b:', b)
print('c:', c)
72 changes: 68 additions & 4 deletions deepctr_torch/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,67 @@
import torch.nn as nn
import torch.nn.functional as F

from .activation import activation_layer


class LocalActivationUnit(nn.Module):
"""The LocalActivationUnit used in DIN with which the representation of
user interests varies adaptively given different candidate items.
Input shape
- A list of two 3D tensor with shape: ``(batch_size, 1, embedding_size)`` and ``(batch_size, T, embedding_size)``
Output shape
- 3D tensor with shape: ``(batch_size, T, 1)``.
Arguments
- **hidden_units**:list of positive integer, the attention net layer number and units in each layer.
- **activation**: Activation function to use in attention net.
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix of attention net.
- **dropout_rate**: float in [0,1). Fraction of the units to dropout in attention net.
- **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net.
- **seed**: A Python integer to use as random seed.
References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self, hidden_units=[80, 40], embedding_dim=4, activation='Dice', dropout_rate=0, use_bn=False):
super(LocalActivationUnit, self).__init__()

self.dnn1 = DNN(inputs_dim=4*embedding_dim,
hidden_units=hidden_units,
activation=activation,
dropout_rate=0.5,
use_bn=use_bn,
dice_dim=3)

# self.dnn2 = DNN(inputs_dim=hidden_units[-1],
# hidden_units=[1],
# activation=activation,
# use_bn=use_bn,
# dice_dim=3)

self.dense = nn.Linear(hidden_units[-1], 1)

def forward(self, query, user_behavior):
# query ad : size -> batch_size * 1 * embedding_size
# user behavior : size -> batch_size * time_seq_len * embedding_size

user_behavior_len = user_behavior.size(1)
queries = torch.cat([query for _ in range(user_behavior_len)], dim=1)

attention_input = torch.cat([queries, user_behavior, queries-user_behavior, queries*user_behavior], dim=-1)
attention_output = self.dnn1(attention_input)
attention_output = self.dense(attention_output)

return attention_output


class DNN(nn.Module):
"""The Multi Layer Percetron
Expand All @@ -30,10 +91,9 @@ class DNN(nn.Module):
- **seed**: A Python integer to use as random seed.
"""

def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropout_rate=0, use_bn=False,
init_std=0.0001, seed=1024, device='cpu'):
def __init__(self, inputs_dim, hidden_units, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False,
init_std=0.0001, dice_dim=3, seed=1024, device='cpu'):
super(DNN, self).__init__()
self.activation = activation
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(dropout_rate)
self.seed = seed
Expand All @@ -49,6 +109,10 @@ def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropou
if self.use_bn:
self.bn = nn.ModuleList(
[nn.BatchNorm1d(hidden_units[i + 1]) for i in range(len(hidden_units) - 1)])

self.activation_layers = nn.ModuleList(
[activation_layer(activation, hidden_units[i + 1], dice_dim) for i in range(len(hidden_units) - 1)])

for name, tensor in self.linears.named_parameters():
if 'weight' in name:
nn.init.normal_(tensor, mean=0, std=init_std)
Expand All @@ -65,7 +129,7 @@ def forward(self, inputs):
if self.use_bn:
fc = self.bn[i](fc)

fc = self.activation(fc)
fc = self.activation_layers[i](fc)

fc = self.dropout(fc)
deep_input = fc
Expand Down
8 changes: 4 additions & 4 deletions deepctr_torch/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..layers.core import Conv2dSame
from ..layers.sequence import KMaxPooling

from ..layers.activation import activation_layer

class FM(nn.Module):
"""Factorization Machine models pairwise (order-2) feature interactions
Expand Down Expand Up @@ -163,14 +163,14 @@ class CIN(nn.Module):
Arguments
- **filed_size** : Positive integer, number of feature groups.
- **layer_size** : list of int.Feature maps in each layer.
- **activation** : activation function used on feature maps.
- **activation** : activation function name used on feature maps.
- **split_half** : bool.if set to False, half of the feature maps in each hidden will connect to output unit.
- **seed** : A Python integer to use as random seed.
References
- [Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018.] (https://arxiv.org/pdf/1803.05170.pdf)
"""

def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_half=True, l2_reg=1e-5, seed=1024,
def __init__(self, field_size, layer_size=(128, 128), activation='relu', split_half=True, l2_reg=1e-5, seed=1024,
device='cpu'):
super(CIN, self).__init__()
if len(layer_size) == 0:
Expand All @@ -180,7 +180,7 @@ def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_h
self.layer_size = layer_size
self.field_nums = [field_size]
self.split_half = split_half
self.activation = activation
self.activation = activation_layer(activation)
self.l2_reg = l2_reg
self.seed = seed

Expand Down
Loading

0 comments on commit e2d2365

Please sign in to comment.