-
Notifications
You must be signed in to change notification settings - Fork 714
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from shenweichen/final-ccpm
add ccpm
- Loading branch information
Showing
7 changed files
with
192 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .interaction import * | ||
from .core import * | ||
from .utils import concat_fun | ||
from .sequence import KMaxPooling |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch.nn as nn | ||
import torch | ||
class KMaxPooling(nn.Module): | ||
"""K Max pooling that selects the k biggest value along the specific axis. | ||
Input shape | ||
- nD tensor with shape: ``(batch_size, ..., input_dim)``. | ||
Output shape | ||
- nD tensor with shape: ``(batch_size, ..., output_dim)``. | ||
Arguments | ||
- **k**: positive integer, number of top elements to look for along the ``axis`` dimension. | ||
- **axis**: positive integer, the dimension to look for elements. | ||
""" | ||
def __init__(self, k, axis, device='cpu'): | ||
super(KMaxPooling, self).__init__() | ||
self.k = k | ||
self.axis = axis | ||
self.to(device) | ||
|
||
def forward(self, input): | ||
out = torch.topk(input, k=self.k, dim=self.axis, sorted=True)[0] | ||
return out | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ | |
from .mlr import MLR | ||
from .onn import ONN | ||
from .pnn import PNN | ||
from .ccpm import CCPM | ||
NFFM = ONN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# -*- coding:utf-8 -*- | ||
""" | ||
Author: | ||
Zeng Kai,[email protected] | ||
Reference: | ||
[1] Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746. | ||
(http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .basemodel import BaseModel | ||
from ..layers.core import DNN, Conv2dSame | ||
from ..layers.utils import concat_fun | ||
from ..layers.sequence import KMaxPooling | ||
from ..layers.interaction import ConvLayer | ||
|
||
|
||
class CCPM(BaseModel): | ||
"""Instantiates the Convolutional Click Prediction Model architecture. | ||
:param linear_feature_columns: An iterable containing all the features used by linear part of the model. | ||
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model. | ||
:param embedding_size: positive integer,sparse feature embedding_size | ||
:param conv_kernel_width: list,list of positive integer or empty list,the width of filter in each conv layer. | ||
:param conv_filters: list,list of positive integer or empty list,the number of filters in each conv layer. | ||
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN. | ||
:param l2_reg_linear: float. L2 regularizer strength applied to linear part | ||
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector | ||
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN | ||
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. | ||
:param init_std: float,to use as the initialize std of embedding vector | ||
:param seed: integer ,to use as random seed. | ||
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss | ||
:param device: str, ``"cpu"`` or ``"cuda:0"`` | ||
:return: A PyTorch model instance. | ||
""" | ||
|
||
def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8, conv_kernel_width=(6, 5), | ||
conv_filters=(4, 4), | ||
dnn_hidden_units=(256,), l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_dnn=0, dnn_dropout=0, | ||
init_std=0.0001, seed=1024, task='binary', device='cpu', dnn_use_bn=False, dnn_activation=F.relu): | ||
|
||
super(CCPM, self).__init__(linear_feature_columns, dnn_feature_columns, embedding_size=embedding_size, | ||
dnn_hidden_units=dnn_hidden_units, | ||
l2_reg_linear=l2_reg_linear, | ||
l2_reg_embedding=l2_reg_embedding, l2_reg_dnn=l2_reg_dnn, init_std=init_std, | ||
seed=seed, | ||
dnn_dropout=dnn_dropout, dnn_activation=dnn_activation, | ||
task=task, device=device) | ||
|
||
if len(conv_kernel_width) != len(conv_filters): | ||
raise ValueError( | ||
"conv_kernel_width must have same element with conv_filters") | ||
|
||
filed_size = self.compute_input_dim(dnn_feature_columns, embedding_size, include_dense=False, feature_group=True) | ||
self.conv_layer = ConvLayer(filed_size=filed_size, conv_kernel_width=conv_kernel_width, conv_filters=conv_filters, device=device) | ||
self.dnn_input_dim = 3 * embedding_size * conv_filters[-1] | ||
self.dnn = DNN(self.dnn_input_dim, dnn_hidden_units, | ||
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, | ||
init_std=init_std, device=device) | ||
self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) | ||
self.add_regularization_loss( | ||
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2_reg_dnn) | ||
self.add_regularization_loss(self.dnn_linear.weight, l2_reg_dnn) | ||
|
||
self.to(device) | ||
|
||
|
||
def forward(self, X): | ||
linear_logit = self.linear_model(X) | ||
sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns, | ||
self.embedding_dict, support_dense=True) | ||
conv_input = concat_fun(sparse_embedding_list, axis=1) | ||
conv_input_concact = torch.unsqueeze(conv_input, 1) | ||
pooling_result = self.conv_layer(conv_input_concact) | ||
flatten_result = pooling_result.view(pooling_result.size(0), -1) | ||
dnn_output = self.dnn(flatten_result) | ||
dnn_logit = self.dnn_linear(dnn_output) | ||
logit = linear_logit + dnn_logit | ||
y_pred = self.out(logit) | ||
return y_pred |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters