Skip to content

Commit

Permalink
Merge pull request #332 from bacpop/opt_query_viz
Browse files Browse the repository at this point in the history
Enable assignment & visualisation with partial query graphs
  • Loading branch information
johnlees authored Nov 7, 2024
2 parents 2943e5f + 6464fcd commit a372e4d
Show file tree
Hide file tree
Showing 20 changed files with 507 additions and 245 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.7.2'
__version__ = '2.7.1'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 2
Expand Down
27 changes: 20 additions & 7 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_options():
oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true')
oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true')
oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true')
oGroup.add_argument('--save-partial-query-graph', help='Save the network components to which queries are assigned', default=False, action='store_true')

# comparison metrics
kmerGroup = parser.add_argument_group('Kmer comparison options')
Expand Down Expand Up @@ -106,6 +107,9 @@ def get_options():
queryingGroup.add_argument('--accessory', help='(with a \'refine\' or \'lineage\' model) '
'Use an accessory-distance only model for assigning queries '
'[default = False]', default=False, action='store_true')
queryingGroup.add_argument('--use-full-network', help='Use full network rather than reference network for querying [default = False]',
default = False,
action = 'store_true')

# processing
other = parser.add_argument_group('Other options')
Expand Down Expand Up @@ -234,7 +238,8 @@ def main():
args.gpu_dist,
args.gpu_graph,
args.deviceid,
save_partial_query_graph=False)
args.save_partial_query_graph,
args.use_full_network)

sys.stderr.write("\nDone\n")

Expand Down Expand Up @@ -267,7 +272,8 @@ def assign_query(dbFuncs,
gpu_dist,
gpu_graph,
deviceid,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode for CLI"""
createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
Expand Down Expand Up @@ -316,7 +322,8 @@ def assign_query(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph)
save_partial_query_graph,
use_full_network)
return(isolateClustering)

def assign_query_hdf5(dbFuncs,
Expand All @@ -341,7 +348,8 @@ def assign_query_hdf5(dbFuncs,
accessory,
gpu_dist,
gpu_graph,
save_partial_query_graph):
save_partial_query_graph,
use_full_network):
"""Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called
by web APIs"""
# Modules imported here as graph tool is very slow to load (it pulls in all of GTK?)
Expand All @@ -359,6 +367,7 @@ def assign_query_hdf5(dbFuncs,
from .network import get_vertex_list
from .network import printExternalClusters
from .network import vertex_betweenness
from .network import remove_non_query_components
from .qc import sketchlibAssemblyQC

from .plot import writeClusterCsv
Expand Down Expand Up @@ -453,7 +462,7 @@ def assign_query_hdf5(dbFuncs,
ref_file_name = os.path.join(model_prefix,
os.path.basename(model_prefix) + file_extension_string + ".refs")
use_ref_graph = \
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage'
os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' and not use_full_network
if use_ref_graph:
with open(ref_file_name) as refFile:
for reference in refFile:
Expand Down Expand Up @@ -791,12 +800,16 @@ def assign_query_hdf5(dbFuncs,
output + "/" + os.path.basename(output) + db_suffix)
else:
storePickle(rNames, qNames, False, qrDistMat, dists_out)
if save_partial_query_graph and not serial:
if model.type == 'lineage':
if save_partial_query_graph:
genomeNetwork, pruned_isolate_lists = remove_non_query_components(genomeNetwork, rNames, qNames, use_gpu = gpu_graph)
if model.type == 'lineage' and not serial:
save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph)
else:
graph_suffix = file_extension_string + '_graph'
save_network(genomeNetwork, prefix = output, suffix = graph_suffix, use_gpu = gpu_graph)
with open(f"{output}/{os.path.basename(output)}_query.subset",'w') as pruned_isolate_csv:
for isolate in pruned_isolate_lists:
pruned_isolate_csv.write(isolate + '\n')

return(isolateClustering)

Expand Down
6 changes: 4 additions & 2 deletions PopPUNK/lineages.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def query_db(args):
accessory,
args.gpu_dist,
args.gpu_graph,
save_partial_query_graph = False)
save_partial_query_graph = False,
use_full_network = True) # Use full network - does not make sense to use references for lineages

# Process clustering
query_strains = {}
Expand Down Expand Up @@ -439,7 +440,8 @@ def query_db(args):
accessory,
args.gpu_dist,
args.gpu_graph,
save_partial_query_graph = False)
save_partial_query_graph = False,
use_full_network = True)
overall_lineage[strain] = createOverallLineage(rank_list, lineageClustering)

# Print combined strain and lineage clustering
Expand Down
5 changes: 4 additions & 1 deletion PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,9 @@ def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only, count_un
ClusterFit.__init__(self, outPrefix)
self.type = 'lineage'
self.preprocess = False
self.max_search_depth = max_search_depth
self.max_search_depth = max_search_depth+5 # Set to highest rank by default in main; need to store additional distances
# when there is redundancy (e.g. reciprocal matching, unique distance counting)
# or other sequences may be pruned out of the database
self.nn_dists = None # stores the unprocessed kNN at the maximum search depth
self.ranks = []
for rank in sorted(ranks):
Expand Down Expand Up @@ -1368,6 +1370,7 @@ def extend(self, qqDists, qrDists):
qrRect,
self.max_search_depth,
self.threads)

# Update NN dist associated with model
self.__save_sparse__(higher_rank[2], higher_rank[0], higher_rank[1],
self.max_search_depth, n_ref + n_query, self.nn_dists.dtype,
Expand Down
109 changes: 88 additions & 21 deletions PopPUNK/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,32 +1920,99 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g
if os.path.exists(network_fn):
network_found = True
sys.stderr.write("Loading network from " + network_fn + "\n")
samples_to_keep_set = frozenset(samples_to_keep)
G = load_network_file(network_fn, use_gpu = use_gpu)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
G_new = remove_nodes_from_graph(G, reflist, samples_to_keep, use_gpu)
save_network(G_new,
prefix = output_db_name,
suffix = '_graph',
use_graphml = False,
use_gpu = use_gpu)
if not network_found:
sys.stderr.write('No network file found for pruning\n')

def remove_nodes_from_graph(G,reflist, samples_to_keep, use_gpu):
"""Return a modified graph containing only the requested nodes
Args:
reflist (list)
Ordered list of sequences of database
samples_to_keep (list)
The names of samples to be retained in the graph
use_gpu (bool)
Whether graph is a cugraph or not
[default = False]
Returns:
G_new (graph)
Pruned graph
"""
samples_to_keep_set = frozenset(samples_to_keep)
if use_gpu:
# Identify indices
reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set]
# Generate data frame
G_df = G.view_edge_list()
if 'src' in G_df.columns:
G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True)
# Filter data frame
G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)]
# Translate network indices to match name order
G_new = translate_network_indices(G_new_df, reference_indices)
else:
reference_vertex = G.new_vertex_property('bool')
for n, vertex in enumerate(G.vertices()):
if reflist[n] in samples_to_keep_set:
reference_vertex[vertex] = True
else:
reference_vertex[vertex] = False
G_new = gt.GraphView(G, vfilt = reference_vertex)
G_new = gt.Graph(G_new, prune = True)
return G_new

def remove_non_query_components(G, rlist, qlist, use_gpu = False):
"""
Removes all components that do not contain a query sequence.
Args:
G (graph)
Network of queries linked to reference sequences
rlist (list)
List of reference sequence labels
qlist (list)
List of query sequence labels
use_gpu (bool)
Whether to use GPUs for network construction
Returns:
G (graph)
The resulting network
pruned_names (list)
The labels of the sequences in the pruned network
"""
components_with_query = []
combined_names = rlist + qlist
pruned_names = []
if use_gpu:
sys.stderr.write('Saving partial query graphs is not compatible with GPU networks yet\n')
sys.exit(1)
else:
# Identify network components containing queries
component_dict = gt.label_components(G)[0]
components_with_query = set()
# The number of reference sequences is len(rlist)
# These are the first len(rlist) vertices in the graph
# Queries that have been added have indices >len(rlist)
# Therefore these are the components to retain
for i in range(len(rlist),G.num_vertices()):
v = G.vertex(i) # Access vertex by index
components_with_query.add(component_dict[v])
# Create a boolean filter based on the list of component IDs
query_filter = G.new_vertex_property("bool")
for v in G.vertices():
query_filter[int(v)] = (component_dict[v] in components_with_query)
if query_filter[int(v)]:
pruned_names.append(combined_names[int(v)])
# Create a filtered graph with only the specified components
query_subgraph = gt.GraphView(G, vfilt=query_filter)

return query_subgraph, pruned_names
14 changes: 8 additions & 6 deletions PopPUNK/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def drawMST(mst, outPrefix, isolate_clustering, clustering_name, overwrite):
output=graph2_file_name, output_size=(3000, 3000))

def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, queryList = None,
suffix = None, writeCsv = True):
suffix = None, writeCsv = True, use_partial_query_graph = None):
"""Write outputs for cytoscape. A graphml of the network, and CSV with metadata
Args:
Expand All @@ -536,6 +536,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
(default = None)
writeCsv (bool)
Whether to print CSV file to accompany network
use_partial_query_graph (str)
File listing sequences to be included in output graph
"""

# Avoid circular import
Expand All @@ -553,7 +555,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
suffix = '_cytoscape'
else:
suffix = suffix + '_cytoscape'
save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True)
if use_partial_query_graph is None:
save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True)

# Save each component too (useful for very large graphs)
component_assignments, component_hist = gt.label_components(G)
Expand All @@ -562,10 +565,9 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv,
for vidx, v_component in enumerate(component_assignments.a):
if v_component != component_idx:
remove_list.append(vidx)
G_copy = G.copy()
G_copy.remove_vertex(remove_list)
save_network(G_copy, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True)
del G_copy
G.remove_vertex(remove_list)
G.purge_vertices()
save_network(G, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True)

if G_mst != None:
isolate_labels = isolateNameToLabel(G_mst.vp.id)
Expand Down
17 changes: 13 additions & 4 deletions PopPUNK/sketchlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def getSeqsInDb(dbname):

return seqs

def joinDBs(db1, db2, output, update_random = None):
def joinDBs(db1, db2, output, update_random = None, full_names = False):
"""Join two sketch databases with the low-level HDF5 copy interface
Args:
Expand All @@ -226,10 +226,19 @@ def joinDBs(db1, db2, output, update_random = None):
update_random (dict)
Whether to re-calculate the random object. May contain
control arguments strand_preserved and threads (see :func:`addRandom`)
full_names (bool)
If True, db_name and out_name are the full paths to h5 files
"""
join_prefix = output + "/" + os.path.basename(output)
db1_name = db1 + "/" + os.path.basename(db1) + ".h5"
db2_name = db2 + "/" + os.path.basename(db2) + ".h5"

if not full_names:
join_prefix = output + "/" + os.path.basename(output)
db1_name = db1 + "/" + os.path.basename(db1) + ".h5"
db2_name = db2 + "/" + os.path.basename(db2) + ".h5"
else:
db1_name = db1
db2_name = db2
join_prefix = output

hdf1 = h5py.File(db1_name, 'r')
hdf2 = h5py.File(db2_name, 'r')
Expand Down
6 changes: 5 additions & 1 deletion PopPUNK/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,14 +593,16 @@ def check_and_set_gpu(use_gpu, gpu_lib, quit_on_fail = False):

return use_gpu

def read_rlist_from_distance_pickle(fn, allow_non_self = True):
def read_rlist_from_distance_pickle(fn, allow_non_self = True, include_queries = False):
"""Return the list of reference sequences from a distance pickle.
Args:
fn (str)
Name of distance pickle
allow_non_self (bool)
Whether non-self distance datasets are permissible
include_queries (bool)
Whether queries should be included in the rlist
Returns:
rlist (list)
List of reference sequence names
Expand All @@ -611,6 +613,8 @@ def read_rlist_from_distance_pickle(fn, allow_non_self = True):
sys.stderr.write("Thi analysis requires an all-v-all"
" distance dataset\n")
sys.exit(1)
if include_queries:
rlist = rlist + qlist
return rlist

def get_match_search_depth(rlist,rank_list):
Expand Down
Loading

0 comments on commit a372e4d

Please sign in to comment.