DIVA: A Dirichlet Process Mixtures Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder
Official implementation for paper: DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder
A demo video for showing DIVA's dynamic adaptation ability in deep clustering.
we use python 3.7 and pytorch-lightning for training. Before start training, make sure you have installed bnpy package in your local environment, refer to here for more details.
- python 3.7
- bnpy 1.7.0
- pytorch-lightning 1.9.4
- numpy, pandas, matplotlib, seaborn, torchvision
# Install dependencies and package
pip3 install -r requirements.txt
DIVA
|- dataset # folder for saving datasets
| |- reuters10k.py # dataset instance of reuters10k that follows torchvision formatting
| |- reuters10k.mat # origin data of reuters10k
|- pretrained # folder for saving pretrained example model on MNIST
| |- dpmm # folder for saving DPMM cluster module
| |- diva_vae.ckpt # checkpoint file of trained DIVA VAE part on MNIST with 100 epochs and ACC 0.91
| |- pretrained.ipynb # example file how to load pretrained model
|- diva.py # diva implementations for image and text; train manager
|- main_mnist.ipynb # main entry point of diva training on MNIST, including evaluation plots.
|- main_stl10.ipynb # main entry point of diva training on STL-10.
|- main_imagenet50.ipynb # main entry point of diva training on ImageNet-50.
|- feature_extraction.ipynb # script that using pretrained ResNet-50 to extract features of STL-10.
Since the training on raw image of STL-10 and ImageNet-50 is quite difficult, we use extractor to get low dimensional encoding of these datasets. For STL-10 we use pretrained ResNet-50 provided by torchvision, just follow the script feature_extraction.ipynb
you will get the features that we used in our study. For ImageNet-50 we use the MOCO to extract features, more details refer to here and here.
# load DPMM module
dpmm_model = bnpy.ioutil.ModelReader.load_model_at_prefix('path/to/your/bn_model/folder/dpmm', prefix="Best")
# function for getting the cluster parameters
def calc_cluster_component_params(bnp_model):
comp_mu = [torch.Tensor(bnp_model.obsModel.get_mean_for_comp(i)) for i in np.arange(0, bnp_model.obsModel.K)]
comp_var = [torch.Tensor(np.sum(bnp_model.obsModel.get_covar_mat_for_comp(i), axis=0)) for i in np.arange(0, bnp_model.obsModel.K)]
return comp_mu, comp_var
if you would like to refer to our work, please use following BibTeX formatted citation
@misc{bing2023diva,
title={DIVA: A Dirichlet Process Based Incremental Deep Clustering Algorithm via Variational Auto-Encoder},
author={Zhenshan Bing and Yuan Meng and Yuqi Yun and Hang Su and Xiaojie Su and Kai Huang and Alois Knoll},
year={2023},
eprint={2305.14067},
archivePrefix={arXiv},
primaryClass={cs.LG}
}