Skip to content

Commit

Permalink
add ccpm
Browse files Browse the repository at this point in the history
  • Loading branch information
浅梦 authored Sep 24, 2019
1 parent 328255d commit d4e671d
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 37 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
<a href="https://github.com/JyiHUO">
<img src="https://avatars.githubusercontent.com/JyiHUO " width=70 height="70" alt="pic" >
</a>
<a href="https://github.com/Zengai">
<img src="https://avatars.githubusercontent.com/Zengai " width=70 height="70" alt="pic" >
</a>
<a href="https://github.com/chenkkkk">
<img src="https://avatars.githubusercontent.com/chenkkkk " width=70 height="70" alt="pic" >
</a>
Expand All @@ -44,6 +47,7 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St

| Model | Paper |
| :------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Convolutional Click Prediction Model | [CIKM 2015][A Convolutional Click Prediction Model](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf) |
| Factorization-supported Neural Network | [ECIR 2016][Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction](https://arxiv.org/pdf/1601.02376.pdf) |
| Product-based Neural Network | [ICDM 2016][Product-based neural networks for user response prediction](https://arxiv.org/pdf/1611.00144.pdf) |
| Wide & Deep | [DLRS 2016][Wide & Deep Learning for Recommender Systems](https://arxiv.org/pdf/1606.07792.pdf) |
Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.1.0'
__version__ = '0.1.1'
check_version(__version__)
14 changes: 9 additions & 5 deletions deepctr_torch/layers/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class DNN(nn.Module):
"""The Multi Layer Percetron
Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(self, inputs_dim, hidden_units, activation=F.relu, l2_reg=0, dropou

if self.use_bn:
self.bn = nn.ModuleList(
[nn.BatchNorm1d(hidden_units[i+1]) for i in range(len(hidden_units)-1)])
[nn.BatchNorm1d(hidden_units[i + 1]) for i in range(len(hidden_units) - 1)])
for name, tensor in self.linears.named_parameters():
if 'weight' in name:
nn.init.normal_(tensor, mean=0, std=init_std)
Expand Down Expand Up @@ -95,9 +97,11 @@ def forward(self, X):
output = torch.sigmoid(output)
return output


class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
Expand All @@ -113,7 +117,7 @@ def forward(self, x):
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
out = F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return out
self.padding, self.dilation, self.groups)
return out
38 changes: 22 additions & 16 deletions deepctr_torch/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch.nn as nn
import torch.nn.functional as F

from ..layers.sequence import KMaxPooling
from ..layers.core import Conv2dSame
from ..layers.sequence import KMaxPooling


class FM(nn.Module):
"""Factorization Machine models pairwise (order-2) feature interactions
Expand Down Expand Up @@ -75,7 +76,7 @@ class SENETLayer(nn.Module):
Tongwen](https://arxiv.org/pdf/1905.09433.pdf)
"""

def __init__(self, filed_size, reduction_ratio=3, seed=1024, device='cpu'):
def __init__(self, filed_size, reduction_ratio=3, seed=1024, device='cpu'):
super(SENETLayer, self).__init__()
self.seed = seed
self.filed_size = filed_size
Expand Down Expand Up @@ -169,7 +170,8 @@ class CIN(nn.Module):
- [Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018.] (https://arxiv.org/pdf/1803.05170.pdf)
"""

def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_half=True, l2_reg=1e-5, seed=1024, device='cpu'):
def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_half=True, l2_reg=1e-5, seed=1024,
device='cpu'):
super(CIN, self).__init__()
if len(layer_size) == 0:
raise ValueError(
Expand All @@ -196,8 +198,8 @@ def __init__(self, field_size, layer_size=(128, 128), activation=F.relu, split_h
else:
self.field_nums.append(size)

# for tensor in self.conv1ds:
# nn.init.normal_(tensor.weight, mean=0, std=init_std)
# for tensor in self.conv1ds:
# nn.init.normal_(tensor.weight, mean=0, std=init_std)
self.to(device)

def forward(self, inputs):
Expand Down Expand Up @@ -571,25 +573,29 @@ def forward(self, inputs):

return kp


class ConvLayer(nn.Module):
"""Conv Layer used in CCPM.This implemention is
adapted from code that the author of the paper published on http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf.
"""Conv Layer used in CCPM.
Input shape
- A list of N 3D tensor with shape: ``(batch_size,1,filed_size,embedding_size)``.
Output shape
- A list of N 3D tensor with shape: ``(batch_size,last_filters,pooling_size,embedding_size)``.
Arguments
- **filed_size** : Positive integer, number of feature groups.
- **conv_kernel_width**: list. list of positive integer or empty list,the width of filter in each conv layer.
- **conv_filters**: list. list of positive integer or empty list,the number of filters in each conv layer.
Reference:
- Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746.(http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf)
"""
def __init__(self, filed_size, conv_kernel_width, conv_filters, device='cpu'):

def __init__(self, field_size, conv_kernel_width, conv_filters, device='cpu'):
super(ConvLayer, self).__init__()
self.device = device
module_list = []
n = filed_size
n = int(field_size)
l = len(conv_filters)
filed_shape = n
for i in range(1, l + 1):
if i == 1:
in_channels = 1
Expand All @@ -599,15 +605,15 @@ def __init__(self, filed_size, conv_kernel_width, conv_filters, device='cpu'):
width = conv_kernel_width[i - 1]
k = max(1, int((1 - pow(i / l, l - i)) * n)) if i < l else 3
module_list.append(Conv2dSame(in_channels=in_channels, out_channels=out_channels, kernel_size=(width, 1),
stride=1).to(self.device))
stride=1).to(self.device))
module_list.append(torch.nn.Tanh().to(self.device))
# KMaxPooling ,extract top_k, returns two tensors [values, indices]
if i == 1:
k = min(k, n)
module_list.append(KMaxPooling(k = k, axis = 2, device = self.device).to(self.device))


# KMaxPooling, extract top_k, returns tensors values
module_list.append(KMaxPooling(k = min(k, filed_shape), axis = 2, device = self.device).to(self.device))
filed_shape = min(k, filed_shape)
self.conv_layer = nn.Sequential(*module_list)
self.to(device)
self.filed_shape = filed_shape

def forward(self, inputs):
return self.conv_layer(inputs)
return self.conv_layer(inputs)
19 changes: 13 additions & 6 deletions deepctr_torch/layers/sequence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn
import torch
import torch.nn as nn


class KMaxPooling(nn.Module):
"""K Max pooling that selects the k biggest value along the specific axis.
Expand All @@ -15,16 +17,21 @@ class KMaxPooling(nn.Module):
- **axis**: positive integer, the dimension to look for elements.
"""

def __init__(self, k, axis, device='cpu'):
super(KMaxPooling, self).__init__()
self.k = k
self.axis = axis
self.to(device)

def forward(self, input):
out = torch.topk(input, k=self.k, dim=self.axis, sorted=True)[0]
return out

if self.axis < 0 or self.axis >= len(input.shape):
raise ValueError("axis must be 0~%d,now is %d" %
(len(input.shape)-1, self.axis))

if self.k < 1 or self.k > input.shape[self.axis]:
raise ValueError("k must be in 1 ~ %d,now k is %d" %
(input.shape[self.axis], self.k))


out = torch.topk(input, k=self.k, dim=self.axis, sorted=True)[0]
return out
9 changes: 6 additions & 3 deletions deepctr_torch/models/ccpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8
"conv_kernel_width must have same element with conv_filters")

filed_size = self.compute_input_dim(dnn_feature_columns, embedding_size, include_dense=False, feature_group=True)
self.conv_layer = ConvLayer(filed_size=filed_size, conv_kernel_width=conv_kernel_width, conv_filters=conv_filters, device=device)
self.dnn_input_dim = 3 * embedding_size * conv_filters[-1]
self.conv_layer = ConvLayer(field_size=filed_size, conv_kernel_width=conv_kernel_width, conv_filters=conv_filters, device=device)

self.dnn_input_dim = self.conv_layer.filed_shape * embedding_size * conv_filters[-1]
self.dnn = DNN(self.dnn_input_dim, dnn_hidden_units,
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
init_std=init_std, device=device)
Expand All @@ -75,7 +76,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, embedding_size=8
def forward(self, X):
linear_logit = self.linear_model(X)
sparse_embedding_list, _ = self.input_from_feature_columns(X, self.dnn_feature_columns,
self.embedding_dict, support_dense=True)
self.embedding_dict, support_dense=False)
if len(sparse_embedding_list) == 0:
raise ValueError("must have the embedding feature,now the embedding feature is None!")
conv_input = concat_fun(sparse_embedding_list, axis=1)
conv_input_concact = torch.unsqueeze(conv_input, 1)
pooling_result = self.conv_layer(conv_input_concact)
Expand Down
11 changes: 11 additions & 0 deletions docs/source/Features.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ DNN based CTR estimation models consists of the following 4 modules:
## Models


### CCPM (Convolutional Click Prediction Model)


CCPM can extract local-global key features from an input instance with varied elements, which can be implemented for not only single ad impression but also sequential ad impression.

[**CCPM Model API**](./deepctr_torch.models.ccpm.html)
![CCPM](../pics/CCPM.png)

[Liu Q, Yu F, Wu S, et al. A convolutional click prediction model[C]//Proceedings of the 24th ACM International on Conference on Information and Knowledge Management. ACM, 2015: 1743-1746.](http://ir.ia.ac.cn/bitstream/173211/12337/1/A%20Convolutional%20Click%20Prediction%20Model.pdf)


### PNN (Product-based Neural Network)

PNN concatenates sparse feature embeddings and the product between embedding vectors as the input of MLP.
Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# History
- 09/24/2019 : [v0.1.1](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.1.1) released. Add [CCPM](./Features.html#ccpm-convolutional-click-prediction-model).
- 09/22/2019 : DeepCTR-Torch first version v0.1.0 is released on [PyPi](https://pypi.org/project/deepctr-torch/)
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = '0.1.1'


# -- General configuration ---------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/deepctr_torch.layers.sequence.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
deepctr\_torch.layers.sequence module
========================================

.. automodule:: deepctr_torch.layers.sequence
:members:
:no-undoc-members:
:no-show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/deepctr_torch.models.ccpm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
deepctr\_torch.models.ccpm module
================================

.. automodule:: deepctr_torch.models.ccpm
:members:
:no-undoc-members:
:no-show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and

News
-----
09/24/2019 : Add `CCPM <./Features.html#ccpm-convolutional-click-prediction-model>`_ . `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.1.1>`_

09/22/2019 : DeepCTR-Torch first version v0.1.0 is released on `PyPi <https://pypi.org/project/deepctr-torch/>`_ !


Expand Down
6 changes: 2 additions & 4 deletions examples/run_classification_criteo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
import sys
sys.path.append('/home/SENSETIME/zengkai/final_deepCTR/DeepCTR-Torch')
import pandas as pd
from sklearn.metrics import log_loss, roc_auc_score
from sklearn.model_selection import train_test_split
Expand All @@ -11,7 +9,7 @@


if __name__ == "__main__":
data = pd.read_csv('/home/SENSETIME/zengkai/final_deepCTR/DeepCTR-Torch/examples/criteo_sample.txt')
data = pd.read_csv('./criteo_sample.txt')

sparse_features = ['C' + str(i) for i in range(1, 27)]
dense_features = ['I' + str(i) for i in range(1, 14)]
Expand Down Expand Up @@ -53,7 +51,7 @@
print('cuda ready...')
device = 'cuda:0'

model = CCPM(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, task='binary',
model = DeepFM(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, task='binary',
l2_reg_embedding=1e-5, device=device)

model.compile("adagrad", "binary_crossentropy",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="deepctr-torch",
version="0.1.0",
version="0.1.1",
author="Weichen Shen",
author_email="[email protected]",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch",
Expand Down
43 changes: 43 additions & 0 deletions tests/models/CCPM_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from deepctr_torch.models import CCPM
from tests.utils import check_model, get_test_data, SAMPLE_SIZE


@pytest.mark.parametrize(
'sparse_feature_num,dense_feature_num',
[ (3, 0)
]
)
def test_CCPM(sparse_feature_num, dense_feature_num):
model_name = "CCPM"

sample_size = SAMPLE_SIZE
x, y, feature_columns = get_test_data(
sample_size, sparse_feature_num, dense_feature_num)

model = CCPM(feature_columns,feature_columns, conv_kernel_width=(3, 2), conv_filters=(
2, 1), dnn_hidden_units=[32, ], dnn_dropout=0.5)
check_model(model, model_name, x, y)


@pytest.mark.parametrize(
'sparse_feature_num,dense_feature_num',
[(2, 0),
]
)
def test_CCPM_without_seq(sparse_feature_num, dense_feature_num):

model_name = "CCPM"

sample_size = SAMPLE_SIZE
x, y, feature_columns = get_test_data(
sample_size, sparse_feature_num, dense_feature_num, sequence_feature=())

model = CCPM(feature_columns, feature_columns,conv_kernel_width=(3, 2), conv_filters=(
2, 1), dnn_hidden_units=[32, ], dnn_dropout=0.5)
check_model(model, model_name, x, y)


if __name__ == "__main__":
pass

0 comments on commit d4e671d

Please sign in to comment.