Skip to content

Commit

Permalink
add option to save object as h5
Browse files Browse the repository at this point in the history
  • Loading branch information
nargesr committed Dec 13, 2023
1 parent 845b9d8 commit fa8debf
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 32 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ top_model.model = top_model.rLDA
top_model.save_topModel()
```


### Install from PyPi (recommended)
Install the most recent release, run

Expand All @@ -40,7 +39,6 @@ git cloning the [Topyfic repository](https://github.com/mortazavilab/Topyfic), g

## Tutorials


In general, you need to make three objects (Train, TopModel and Analysis).

The Train object can be initialized either from (a) single cell RNA-seq dataset or (b) single cell ATAC-seq or (c) bulk RNA-seq.
Expand Down
106 changes: 86 additions & 20 deletions Topyfic/topModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import h5py

sns.set_context('paper')
warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -66,18 +67,40 @@ def get_feature_name(self):
topic1 = next(iter(self.topics.items()))[1]
return topic1.gene_information.index.tolist()

def save_rLDA_model(self, name='rLDA', save_path=""):
def save_rLDA_model(self, name='rLDA', save_path="", file_format='joblib'):
"""
save Train class as a pickle file
:param name: name of the pickle file (default: rLDA)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
save rLDA model (instance of LDA model in sklearn) as a joblib/HDF5 file.
:param name: name of the joblib file (default: rLDA)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
:param file_format: format of the file you want to save (option: joblib (default), HDF5)
:type file_format: str
"""
print(f"Saving rLDA model as {name}_{self.N}topics.joblib")
if file_format not in ['joblib', 'HDF5']:
sys.exit(f"{file_format} is not correct! It should be 'joblib' or 'HDF5'.")

if file_format == "joblib":
print(f"Saving rLDA model as {name}_{self.N}topics.joblib")

joblib.dump(self.model, f"{save_path}{name}_{self.N}topics.joblib", compress=3)

if file_format == "HDF5":
print(f"Saving rLDA model as {name}_{self.N}topics.h5")

joblib.dump(self.model, f"{save_path}{name}_{self.N}topics.joblib", compress=3)
f = h5py.File(f"{name}_{self.N}topics.h5", "a")

f['components_'] = self.model.components_
f['exp_dirichlet_component_'] = self.model.exp_dirichlet_component_
f['n_batch_iter_'] = np.int_(self.model.n_batch_iter_)
f['n_features_in_'] = self.model.n_features_in_
f['n_iter_'] = np.int_(self.model.n_iter_)
f['bound_'] = np.float_(self.model.bound_)
f['doc_topic_prior_'] = np.float_(self.model.doc_topic_prior_)
f['topic_word_prior_'] = np.float_(self.model.topic_word_prior_)

f.close()

def get_gene_weights(self):
"""
Expand Down Expand Up @@ -319,19 +342,62 @@ def MA_plot(self,

return gene_zscore

def save_topModel(self, name=None, save_path=""):
def save_topModel(self, name=None, save_path="", file_format='pickle'):
"""
save TopModel class as a pickle file
:param name: name of the pickle file (default: topModel_TopModel.name)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
save TopModel class as a pickle/HDF5 file
:param name: name of the file (default: topModel_TopModel.name)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
:param file_format: format of the file you want to save (option: pickle (default), HDF5)
:type file_format: str
"""
if file_format not in ['pickle', 'HDF5']:
sys.exit(f"{file_format} is not correct! It should be 'pickle' or 'HDF5'.")
if name is None:
name = f"topModel_{self.name}"
print(f"Saving topModel as {name}.p")

picklefile = open(f"{save_path}{name}.p", "wb")
pickle.dump(self, picklefile)
picklefile.close()
if file_format == "pickle":
print(f"Saving topModel as {name}.p")

picklefile = open(f"{save_path}{name}.p", "wb")
pickle.dump(self, picklefile)
picklefile.close()

if file_format == "HDF5":
print(f"Saving topModel as {name}.h5")

f = h5py.File(f"{name}.h5", "w")
# model
model = f.create_group("model")
model['components_'] = self.model.components_
model['exp_dirichlet_component_'] = self.model.exp_dirichlet_component_
model['n_batch_iter_'] = np.int_(self.model.n_batch_iter_)
model['n_features_in_'] = self.model.n_features_in_
model['n_iter_'] = np.int_(self.model.n_iter_)
model['bound_'] = np.float_(self.model.bound_)
model['doc_topic_prior_'] = np.float_(self.model.doc_topic_prior_)
model['topic_word_prior_'] = np.float_(self.model.topic_word_prior_)

# topics
topics = f.create_group("topics")
for topic in self.topics.keys():
topic_gp = topics.create_group(self.topics[topic].id)
topic_gp['id'] = np.string_(self.topics[topic].id)
topic_gp['name'] = np.string_(self.topics[topic].name)
topic_gp['gene_weights'] = self.topics[topic].gene_weights
gene_information = self.topics[topic].gene_information.copy(deep=True)
gene_information.reset_index(inplace=True)
gene_information = gene_information.T.reset_index().T
topic_gp['gene_information'] = np.array(gene_information)
topic_information = self.topics[topic].topic_information.copy(deep=True)
topic_information.reset_index(inplace=True)
topic_information = topic_information.T.reset_index().T
topic_gp['topic_information'] = np.array(topic_information)

f['name'] = np.string_(self.name)
f['N'] = np.int_(self.N)

f.close()

10 changes: 7 additions & 3 deletions Topyfic/utilsAnalyseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.decomposition import LatentDirichletAllocation
from scipy import stats as st
from scipy.spatial import distance
from sklearn.metrics.pairwise import cosine_similarity
import scanpy.external as sce
import networkx as nx
import math
Expand All @@ -23,7 +24,6 @@
from gseapy import gseaplot
from reactome2py import analysis
from adjustText import adjust_text
from sklearn.metrics.pairwise import cosine_similarity
import umap
import obonet
import plotly.express as px
Expand All @@ -48,7 +48,7 @@ def compare_topModels(topModels,
:param topModels: list of topModel class you want to compare to each other
:type topModels: list of TopModel class
:param comparison_method: indicate the method you want to use for comparing topics. if you used Jensen–Shannon, we show -log2 (options: pearson correlation, spearman correlation, Jensen–Shannon divergence)
:param comparison_method: indicate the method you want to use for comparing topics. if you used Jensen–Shannon, we show -log2 (options: pearson correlation, spearman correlation, Jensen–Shannon divergence, cosine similarity)
:type comparison_method: str
:param output_type: indicate the type of output you want. graph: plot as a graph, heatmap: plot as a heatmap, table: table contains correlation. Note: if you want to plot Jensen–Shannon divergence as a graph, we convert the values to be at the -log2(), so you need to take that account for defining threshold
:type output_type: str
Expand Down Expand Up @@ -77,7 +77,7 @@ def compare_topModels(topModels,
if output_type not in ['graph', 'heatmap', 'table']:
sys.exit("output_type is not valid! it should be one of 'graph', 'heatmap', or 'table'")

if comparison_method not in ['spearman correlation', 'pearson correlation', 'Jensen–Shannon divergence']:
if comparison_method not in ['spearman correlation', 'pearson correlation', 'Jensen–Shannon divergence', 'cosine similarity']:
sys.exit("comparison_method is not valid! it should be one of 'spearman correlation', 'pearson correlation', or 'Jensen–Shannon divergence'")

names = [topModel.name for topModel in topModels]
Expand Down Expand Up @@ -109,6 +109,7 @@ def compare_topModels(topModels,
a.dropna(axis=0, how='all', inplace=True)
a.fillna(0, inplace=True)

a = a[np.logical_or(a[d1] > a[d1].min(), a[d2] > a[d2].min())]
a = a / a.sum()
if comparison_method == "Jensen–Shannon divergence":
JSd = distance.jensenshannon(a[d1].tolist(), a[d2].tolist())
Expand All @@ -119,6 +120,9 @@ def compare_topModels(topModels,
elif comparison_method == "spearman correlation":
corr = st.spearmanr(a[d1].tolist(), a[d2].tolist())
corrs.at[d1, d2] = corr[0]
elif comparison_method == 'cosine similarity':
corr = distance.cosine(a[d1].tolist(), a[d2].tolist())
corrs.at[d1, d2] = 1 - corr
if comparison_method == "Jensen–Shannon divergence":
corrs = corrs.applymap(math.log2)
corrs = corrs * -1
Expand Down
76 changes: 69 additions & 7 deletions Topyfic/utilsMakeModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from reactome2py import analysis
import yaml
from yaml.loader import SafeLoader
import h5py

from Topyfic.train import Train
from Topyfic.analysis import Analysis
Expand Down Expand Up @@ -505,22 +506,83 @@ def read_train(file):

def read_topModel(file):
"""
reading topModel pickle file
reading topModel pickle/HDF5 file
:param file: path of the pickle file
:param file: path of the pickle/HDF5 file
:type file: str
:return: topModel instance
:rtype: TopModel class
"""
if not os.path.isfile(file):
raise ValueError('TopModel object not found at given path!')

picklefile = open(file, 'rb')
topModel = pickle.load(picklefile)
raise ValueError('TopModel file not found at given path!')
if not file.endswith('.p') and not file.endswith('.h5'):
raise ValueError('TopModel file type is not correct!')

if file.endswith('.p'):
picklefile = open(file, 'rb')
top_model = pickle.load(picklefile)

if file.endswith('.h5'):
f = h5py.File(file, 'r')

name = np.string_(f['name']).decode('ascii')
N = np.int_(f['N'])

# topics
topics = dict()
topic_ids = [f'Topic_{i + 1}' for i in range(N)]
for topic in topic_ids:
topic_id = np.string_(f['topics'][topic]['id']).decode('ascii')
topic_name = np.string_(f['topics'][topic]['name']).decode('ascii')
gene_weights = pd.DataFrame(np.array(f['topics'][topic]['gene_weights']))
gene_information = pd.DataFrame(np.array(f['topics'][topic]['gene_information']), dtype=str)
gene_information.columns = gene_information.iloc[0, :]
gene_information.drop(index=0, inplace=True)
gene_information.index = gene_information['index']
gene_information.drop(columns='index', inplace=True)
gene_information.index.name = None

topic_information = pd.DataFrame(np.array(f['topics'][topic]['topic_information']), dtype=str)
topic_information.columns = topic_information.iloc[0, :]
topic_information.drop(index=0, inplace=True)
topic_information.index = topic_information['index']
topic_information.drop(columns='index', inplace=True)
topic_information.index.name = None

gene_weights.index = gene_information.index.tolist()
gene_weights.columns = topic_information.index.tolist()

topic = Topyfic.Topic(topic_id=topic_id,
topic_name=topic_name,
topic_gene_weights=gene_weights,
gene_information=gene_information,
topic_information=topic_information)
topics[topic_id] = topic

# model
components = pd.DataFrame(np.array(f['model']['components_']))
exp_dirichlet_component = pd.DataFrame(np.array(f['model']['exp_dirichlet_component_']))

others = pd.DataFrame()
others.loc[0, 'n_batch_iter'] = np.int_(f['model']['n_batch_iter_'])
others.loc[0, 'n_features_in'] = np.array(f['model']['n_features_in_'])
others.loc[0, 'n_iter'] = np.int_(f['model']['n_iter_'])
others.loc[0, 'bound'] = np.float_(f['model']['bound_'])
others.loc[0, 'doc_topic_prior'] = np.array(f['model']['doc_topic_prior_'])
others.loc[0, 'topic_word_prior'] = np.array(f['model']['topic_word_prior_'])

model = Topyfic.initialize_lda_model(components, exp_dirichlet_component, others)

top_model = Topyfic.TopModel(name=name,
N=N,
topics=topics,
model=model)

f.close()

print(f"Reading TopModel done!")
return topModel
return top_model


def read_analysis(file):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'umap',
'obonet',
'plotly',
'h5py',

],
classifiers=[ # choose from here: https://pypi.org/classifiers/
Expand Down

0 comments on commit fa8debf

Please sign in to comment.