Skip to content

Commit

Permalink
Dev zsx: Fix bugs. (#181)
Browse files Browse the repository at this point in the history
* Fix bugs: #74, #171, #176, #179, #180

* fix Early-stopping Bugs

* update multi-head self-attention in AutoInt

* update dims in difm
  • Loading branch information
zanshuxun authored Jun 13, 2021
1 parent 342c9a3 commit b0296ea
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 96 deletions.
38 changes: 6 additions & 32 deletions deepctr_torch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def vocabulary_size(self):
def embedding_dim(self):
return self.sparsefeat.embedding_dim

@property
def use_hash(self):
return self.sparsefeat.use_hash

@property
def dtype(self):
return self.sparsefeat.dtype
Expand Down Expand Up @@ -136,18 +140,15 @@ def combined_dnn_input(sparse_embedding_list, dense_value_list):

def get_varlen_pooling_list(embedding_dict, features, feature_index, varlen_sparse_feature_columns, device):
varlen_sparse_embedding_list = []

for feat in varlen_sparse_feature_columns:
seq_emb = embedding_dict[feat.embedding_name](
features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long())
seq_emb = embedding_dict[feat.name]
if feat.length_name is None:
seq_mask = features[:, feature_index[feat.name][0]:feature_index[feat.name][1]].long() != 0

emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=True, device=device)(
[seq_emb, seq_mask])
else:
seq_length = features[:,
feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
seq_length = features[:, feature_index[feat.length_name][0]:feature_index[feat.length_name][1]].long()
emb = SequencePoolingLayer(mode=feat.combiner, supports_masking=False, device=device)(
[seq_emb, seq_length])
varlen_sparse_embedding_list.append(emb)
Expand Down Expand Up @@ -179,33 +180,6 @@ def create_embedding_matrix(feature_columns, init_std=0.0001, linear=False, spar
return embedding_dict.to(device)


def input_from_feature_columns(self, X, feature_columns, embedding_dict, support_dense=True):
sparse_feature_columns = list(
filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if len(feature_columns) else []
dense_feature_columns = list(
filter(lambda x: isinstance(x, DenseFeat), feature_columns)) if len(feature_columns) else []

varlen_sparse_feature_columns = list(
filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []

if not support_dense and len(dense_feature_columns) > 0:
raise ValueError(
"DenseFeat is not supported in dnn_feature_columns")

sparse_embedding_list = [embedding_dict[feat.embedding_name](
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
feat in sparse_feature_columns]

varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
varlen_sparse_feature_columns, self.device)

dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
dense_feature_columns]

return sparse_embedding_list + varlen_sparse_embedding_list, dense_value_list



def embedding_lookup(X, sparse_embedding_dict, sparse_input_dict, sparse_feature_columns, return_feat_list=(),
mask_feat_list=(), to_list=False):
"""
Expand Down
52 changes: 20 additions & 32 deletions deepctr_torch/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,41 +330,34 @@ class InteractingLayer(nn.Module):
Input shape
- A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
Output shape
- 3D tensor with shape:``(batch_size,field_size,att_embedding_size * head_num)``.
- 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
Arguments
- **in_features** : Positive integer, dimensionality of input features.
- **att_embedding_size**: int.The embedding size in multi-head self-attention network.
- **head_num**: int.The head number in multi-head self-attention network.
- **head_num**: int.The head number in multi-head self-attention network.
- **use_res**: bool.Whether or not use standard residual connections before output.
- **seed**: A Python integer to use as random seed.
References
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""

def __init__(self, in_features, att_embedding_size=8, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
def __init__(self, embedding_size, head_num=2, use_res=True, scaling=False, seed=1024, device='cpu'):
super(InteractingLayer, self).__init__()
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
self.att_embedding_size = att_embedding_size
if embedding_size % head_num != 0:
raise ValueError('embedding_size is not an integer multiple of head_num!')
self.att_embedding_size = embedding_size // head_num
self.head_num = head_num
self.use_res = use_res
self.scaling = scaling
self.seed = seed

embedding_size = in_features

self.W_Query = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))

self.W_key = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))

self.W_Value = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))
self.W_Query = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
self.W_key = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
self.W_Value = nn.Parameter(torch.Tensor(embedding_size, embedding_size))

if self.use_res:
self.W_Res = nn.Parameter(torch.Tensor(
embedding_size, self.att_embedding_size * self.head_num))
self.W_Res = nn.Parameter(torch.Tensor(embedding_size, embedding_size))
for tensor in self.parameters():
nn.init.normal_(tensor, mean=0.0, std=0.05)

Expand All @@ -376,29 +369,24 @@ def forward(self, inputs):
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))

querys = torch.tensordot(inputs, self.W_Query,
dims=([-1], [0])) # None F D*head_num
# None F D
querys = torch.tensordot(inputs, self.W_Query, dims=([-1], [0]))
keys = torch.tensordot(inputs, self.W_key, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_Value, dims=([-1], [0]))

# head_num None F D

querys = torch.stack(torch.split(
querys, self.att_embedding_size, dim=2))
# head_num None F D/head_num
querys = torch.stack(torch.split(querys, self.att_embedding_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_embedding_size, dim=2))
values = torch.stack(torch.split(
values, self.att_embedding_size, dim=2))
inner_product = torch.einsum(
'bnik,bnjk->bnij', querys, keys) # head_num None F F
values = torch.stack(torch.split(values, self.att_embedding_size, dim=2))

inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys) # head_num None F F
if self.scaling:
inner_product /= self.att_embedding_size ** 0.5
self.normalized_att_scores = F.softmax(
inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores,
values) # head_num None F D
self.normalized_att_scores = F.softmax(inner_product, dim=-1) # head_num None F F
result = torch.matmul(self.normalized_att_scores, values) # head_num None F D/head_num

result = torch.cat(torch.split(result, 1, ), dim=-1)
result = torch.squeeze(result, dim=0) # None F D*head_num
result = torch.squeeze(result, dim=0) # None F D
if self.use_res:
result += torch.tensordot(inputs, self.W_Res, dims=([-1], [0]))
result = F.relu(result)
Expand Down
17 changes: 7 additions & 10 deletions deepctr_torch/models/autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class AutoInt(BaseModel):
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param att_layer_num: int.The InteractingLayer number to be used.
:param att_embedding_size: int.The embedding size in multi-head self-attention network.
:param att_head_num: int.The head number in multi-head self-attention network.
:param att_res: bool.Whether or not use standard residual connections before output.
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
Expand All @@ -37,28 +36,27 @@ class AutoInt(BaseModel):
"""

def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3, att_embedding_size=8, att_head_num=2,
att_res=True,
dnn_hidden_units=(256, 128), dnn_activation='relu',
def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
att_head_num=2, att_res=True, dnn_hidden_units=(256, 128), dnn_activation='relu',
l2_reg_dnn=0, l2_reg_embedding=1e-5, dnn_use_bn=False, dnn_dropout=0, init_std=0.0001, seed=1024,
task='binary', device='cpu', gpus=None):

super(AutoInt, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=0,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
device=device, gpus=gpus)

if len(dnn_hidden_units) <= 0 and att_layer_num <= 0:
raise ValueError("Either hidden_layer or att_layer_num must > 0")
self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0
field_num = len(self.embedding_dict)

embedding_size = self.embedding_size

if len(dnn_hidden_units) and att_layer_num > 0:
dnn_linear_in_feature = dnn_hidden_units[-1] + \
field_num * att_embedding_size * att_head_num
dnn_linear_in_feature = dnn_hidden_units[-1] + field_num * embedding_size
elif len(dnn_hidden_units) > 0:
dnn_linear_in_feature = dnn_hidden_units[-1]
elif att_layer_num > 0:
dnn_linear_in_feature = field_num * att_embedding_size * att_head_num
dnn_linear_in_feature = field_num * embedding_size
else:
raise NotImplementedError

Expand All @@ -72,8 +70,7 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, att_layer_num=3,
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.int_layers = nn.ModuleList(
[InteractingLayer(self.embedding_size if i == 0 else att_embedding_size * att_head_num,
att_embedding_size, att_head_num, att_res, device=device) for i in range(att_layer_num)])
[InteractingLayer(embedding_size, att_head_num, att_res, device=device) for _ in range(att_layer_num)])

self.to(device)

Expand Down
23 changes: 16 additions & 7 deletions deepctr_torch/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tensorflow.python.keras._impl.keras.callbacks import CallbackList

from ..inputs import build_input_features, SparseFeat, DenseFeat, VarLenSparseFeat, get_varlen_pooling_list, \
create_embedding_matrix
create_embedding_matrix, varlen_embedding_lookup
from ..layers import PredictionLayer
from ..layers.utils import slice_arrays
from ..callbacks import History
Expand Down Expand Up @@ -68,7 +68,9 @@ def forward(self, X, sparse_feat_refine_weight=None):
dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
self.dense_feature_columns]

varlen_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
self.varlen_sparse_feature_columns)
varlen_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
self.varlen_sparse_feature_columns, self.device)

sparse_embedding_list += varlen_embedding_list
Expand Down Expand Up @@ -126,9 +128,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, l2_reg_linear=1e
self.out = PredictionLayer(task, )
self.to(device)

# parameters of callbacks
self._is_graph_network = True # used for ModelCheckpoint
self.stop_training = False # used for EarlyStopping
# parameters for callbacks
self._is_graph_network = True # used for ModelCheckpoint in tf2
self._ckpt_saved_epoch = False # used for EarlyStopping in tf1.14
self.history = History()

def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoch=0, validation_split=0.,
Expand Down Expand Up @@ -216,9 +218,10 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
# configure callbacks
callbacks = (callbacks or []) + [self.history] # add history callback
callbacks = CallbackList(callbacks)
callbacks.set_model(self)
callbacks.on_train_begin()
callbacks.set_model(self)
if not hasattr(callbacks, 'model'):
if not hasattr(callbacks, 'model'): # for tf1.4
callbacks.__setattr__('model', self)
callbacks.model.stop_training = False

Expand Down Expand Up @@ -359,7 +362,9 @@ def input_from_feature_columns(self, X, feature_columns, embedding_dict, support
X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]].long()) for
feat in sparse_feature_columns]

varlen_sparse_embedding_list = get_varlen_pooling_list(self.embedding_dict, X, self.feature_index,
sequence_embed_dict = varlen_embedding_lookup(X, self.embedding_dict, self.feature_index,
varlen_sparse_feature_columns)
varlen_sparse_embedding_list = get_varlen_pooling_list(sequence_embed_dict, X, self.feature_index,
varlen_sparse_feature_columns, self.device)

dense_value_list = [X[:, self.feature_index[feat.name][0]:self.feature_index[feat.name][1]] for feat in
Expand Down Expand Up @@ -490,6 +495,10 @@ def _get_metrics(self, metrics, set_eps=False):
self.metrics_names.append(metric)
return metrics_

def _in_multi_worker_mode(self):
# used for EarlyStopping in tf1.15
return None

@property
def embedding_size(self, ):
feature_columns = self.dnn_feature_columns
Expand Down
24 changes: 12 additions & 12 deletions deepctr_torch/models/difm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class DIFM(BaseModel):
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param att_head_num: int. The head number in multi-head self-attention network.
:param att_res: bool. Whether or not use standard residual connections before output.
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
Expand All @@ -35,35 +37,33 @@ class DIFM(BaseModel):
"""

def __init__(self,
linear_feature_columns, dnn_feature_columns, att_embedding_size=8, att_head_num=8,
linear_feature_columns, dnn_feature_columns, att_head_num=4,
att_res=True, dnn_hidden_units=(256, 128),
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
dnn_dropout=0,
dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None):
super(DIFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear,
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
device=device, gpus=gpus)
l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task,
device=device, gpus=gpus)

if not len(dnn_hidden_units) > 0:
raise ValueError("dnn_hidden_units is null!")

self.use_dnn = len(dnn_feature_columns) > 0 and len(
dnn_hidden_units) > 0
self.fm = FM()

# InteractingLayer (used in AutoInt) = multi-head self-attention + Residual Network
self.vector_wise_net = InteractingLayer(self.embedding_size, att_embedding_size,
att_head_num, att_res, scaling=True, device=device)
self.vector_wise_net = InteractingLayer(self.embedding_size, att_head_num,
att_res, scaling=True, device=device)

self.bit_wise_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
dropout_rate=dnn_dropout,
use_bn=dnn_use_bn, init_std=init_std, device=device)
dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn,
dropout_rate=dnn_dropout,
use_bn=dnn_use_bn, init_std=init_std, device=device)
self.sparse_feat_num = len(list(filter(lambda x: isinstance(x, SparseFeat) or isinstance(x, VarLenSparseFeat),
dnn_feature_columns)))

self.transform_matrix_P_vec = nn.Linear(
self.sparse_feat_num*att_embedding_size*att_head_num, self.sparse_feat_num, bias=False).to(device)
self.sparse_feat_num * self.embedding_size, self.sparse_feat_num, bias=False).to(device)
self.transform_matrix_P_bit = nn.Linear(
dnn_hidden_units[-1], self.sparse_feat_num, bias=False).to(device)

Expand All @@ -80,7 +80,7 @@ def __init__(self,

def forward(self, X):
sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
self.embedding_dict)
self.embedding_dict)
if not len(sparse_embedding_list) > 0:
raise ValueError("there are no sparse features")

Expand Down
2 changes: 0 additions & 2 deletions deepctr_torch/models/ifm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(self,
if not len(dnn_hidden_units) > 0:
raise ValueError("dnn_hidden_units is null!")

self.use_dnn = len(dnn_feature_columns) > 0 and len(
dnn_hidden_units) > 0
self.fm = FM()

self.factor_estimating_net = DNN(self.compute_input_dim(dnn_feature_columns, include_dense=False),
Expand Down
2 changes: 1 addition & 1 deletion docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model = DeepFM(linear_feature_columns,dnn_feature_columns)
model.compile(Adagrad(model.parameters(),0.1024),'binary_crossentropy',metrics=['binary_crossentropy'])

es = EarlyStopping(monitor='val_binary_crossentropy', min_delta=0, verbose=1, patience=0, mode='min')
mdckpt = ModelCheckpoint(filepath = 'model.ckpt', save_best_only= True)
mdckpt = ModelCheckpoint(filepath='model.ckpt', monitor='val_binary_crossentropy', verbose=1, save_best_only=True, mode='min')
history = model.fit(model_input,data[target].values,batch_size=256,epochs=10,verbose=2,validation_split=0.2,callbacks=[es,mdckpt])
print(history)
```
Expand Down

0 comments on commit b0296ea

Please sign in to comment.