diff --git a/PopPUNK/__init__.py b/PopPUNK/__init__.py index c6757346..ee41ee5d 100644 --- a/PopPUNK/__init__.py +++ b/PopPUNK/__init__.py @@ -3,7 +3,7 @@ '''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)''' -__version__ = '2.7.3' +__version__ = '2.7.4' # Minimum sketchlib version SKETCHLIB_MAJOR = 2 diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index 13e5beee..eee90077 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -50,7 +50,7 @@ def get_options(): oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true') oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true') oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true') - oGroup.add_argument('--save-partial-query-graph', help='Save the network components to which queries are assigned', default=False, action='store_true') + oGroup.add_argument('--save-partial-query-graph', help='Save only the network components to which queries are assigned', default=False, action='store_true') # comparison metrics kmerGroup = parser.add_argument_group('Kmer comparison options') diff --git a/PopPUNK/mandrake.py b/PopPUNK/mandrake.py index fc9e81f5..67acfef2 100644 --- a/PopPUNK/mandrake.py +++ b/PopPUNK/mandrake.py @@ -12,7 +12,7 @@ import pp_sketchlib from SCE import wtsne try: - from SCE import wtsne_gpu_fp64 + from SCE import wtsne_gpu_fp32 gpu_fn_available = True except ImportError: gpu_fn_available = False @@ -63,7 +63,7 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN sys.stderr.write("Mandrake analysis already exists; add --overwrite to replace\n") else: sys.stderr.write("Running mandrake\n") - kNN = max(kNN, len(seqLabels) - 1) + kNN = min(kNN, len(seqLabels) - 1) I, J, dists = poppunk_refine.get_kNN_distances(accMat, kNN, 1, n_threads) # Set up function call with either CPU or GPU @@ -76,7 +76,7 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN sys.stderr.write("Running on GPU\n") n_workers = 65536 maxIter = round(maxIter / n_workers) - wtsne_call = partial(wtsne_gpu_fp64, + wtsne_call = partial(wtsne_gpu_fp32, perplexity=perplexity, maxIter=maxIter, blockSize=128, diff --git a/PopPUNK/network.py b/PopPUNK/network.py index d91fa699..73e3b6f8 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -16,6 +16,7 @@ from multiprocessing import Pool import pickle import graph_tool.all as gt +import pp_sketchlib # Load GPU libraries try: @@ -2016,3 +2017,81 @@ def remove_non_query_components(G, rlist, qlist, use_gpu = False): query_subgraph = gt.GraphView(G, vfilt=query_filter) return query_subgraph, pruned_names + +def generate_network_from_distances(mode, + model, + core_distMat = None, + acc_distMat = None, + sparse_mat = None, + previous_mst = None, + combined_seq = None, + rlist = None, + old_rlist = None, + distance_type = 'core', + threads = 1, + gpu_graph = False): + """ + Generates a network from a distance matrix. + + Args: + mode (str) + Whether a core or sparse distance matrix is being analysed + model (ClusterFit or LineageFit) + A fitted model object + coreMat (numpy.array) + NxN array of core distances for N sequences + accMat (numpy.array) + NxN array of accessory distances for N sequences + sparse_mat (scipy or cupyx sparse matrix) + Sparse matrix of kNN from lineage fit + previous_mst (str or graph object) + Path of file containing existing network, or already-loaded + graph object + combined_seq (list) + Ordered list of isolate names + rlist (list) + List of reference sequence labels + old_rlist (list) + List of reference sequence labels for previous MST + distance_type (str) + Whether to use core or accessory distances for MST calculation + or dense network weighting + threads (int) + Number of threads to use in calculations + use_gpu (bool) + Whether to use GPUs for network construction + + Returns: + G (graph) + The resulting network + pruned_names (list) + The labels of the sequences in the pruned network + """ + if mode == 'sparse': + G = generate_mst_from_sparse_input(sparse_mat, + rlist, + old_rlist = old_rlist, + previous_mst = previous_mst, + gpu_graph = gpu_graph) + elif mode == 'dense': + # Get distance matrix + complete_distMat = \ + np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1), + pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1))) + # Identify short distances and use these to extend the model + indivAssignments = model.assign(complete_distMat) + G = construct_network_from_assignments(combined_seq, + combined_seq, + indivAssignments, + model.within_label, + distMat = complete_distMat, + weights_type = distance_type, + use_gpu = gpu_graph, + summarise = False) + if gpu_graph: + G = cugraph.minimum_spanning_tree(G, weight='weights') + + else: + sys.stderr.write('Unknown network mode - expect dense or sparse\n') + + return G diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 2328c160..147db030 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -558,27 +558,26 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, if use_partial_query_graph is None: save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True) - # Save each component too (useful for very large graphs) + # Store query names + querySet = frozenset(queryList) if queryList is not None else frozenset() + + # Save each cluster too (useful for very large graphs) example_cluster_title = list(clustering.keys())[0] - component_assignments, component_hist = gt.label_components(G) - for component_idx in range(len(component_hist)): - # Naming must reflect the full graph size - component_name = component_idx + 1 - get_component_name = (use_partial_query_graph is not None) + if use_partial_query_graph is not None: + represented_clusters = set(clustering[example_cluster_title][isolate] for isolate in isolate_names) + else: + represented_clusters = set(clustering[example_cluster_title].values()) + for cluster in represented_clusters: # Filter the graph for the current component comp_filter = G.new_vertex_property("bool") for v in G.vertices(): - comp_filter[v] = (component_assignments[v] == component_idx) - # If using partial query graph find the component name from the clustering - if get_component_name and comp_filter[v]: - example_isolate_name = seqLabels[int(v)] - component_name = clustering[example_cluster_title][example_isolate_name] - get_component_name = False + vertex_name = seqLabels[int(v)] + comp_filter[v] = (clustering[example_cluster_title][vertex_name] == cluster) G_component = gt.GraphView(G, vfilt=comp_filter) # Purge the component to remove unreferenced vertices (optional but recommended) G_component.purge_vertices() # Save the component network - save_network(G_component, prefix = outPrefix, suffix = "_component_" + str(component_name), use_graphml = True) + save_network(G_component, prefix = outPrefix, suffix = "_component_" + str(cluster), use_graphml = True) if G_mst != None: isolate_labels = isolateNameToLabel(G_mst.vp.id) @@ -730,14 +729,13 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, d['Status'].append("Reference") if epiCsv is not None: if label in epiData.index: - if label in epiData.index: - for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values): - if col not in columns_to_be_omitted: - d[col].append(str(value)) - else: - for col in epiData.columns.values: - if col not in columns_to_be_omitted: - d[col].append('nan') + for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values): + if col not in columns_to_be_omitted: + d[col].append(str(value)) + else: + for col in epiData.columns.values: + if col not in columns_to_be_omitted: + d[col].append('') else: sys.stderr.write("Cannot find " + name + " in clustering\n") diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index ff6801a1..c13f2bc6 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -97,6 +97,12 @@ def get_options(): iGroup.add_argument('--use-partial-query-graph', help='File listing sequences in partial query graph after assignment', default=None) + iGroup.add_argument('--extend-query-graph', + help='Extend the partial query graph from the specified list to ' + 'include all other sequences in the same clusters (e.g. if using) ' + 'a larger or later database for visualising than assigning)', + default=False, + action = 'store_true') # output options oGroup = parser.add_argument_group('Output options') @@ -214,18 +220,19 @@ def generate_visualisations(query_db, overwrite, display_cluster, use_partial_query_graph, + extend_query_graph, recalculate_distances, tmp): from .models import loadClusterFit - from .network import construct_network_from_assignments from .network import generate_minimum_spanning_tree from .network import load_network_file from .network import cugraph_to_graph_tool from .network import save_network from .network import sparse_mat_to_network from .network import remove_nodes_from_graph + from .network import generate_network_from_distances from .plot import drawMST from .plot import outputsForMicroreact @@ -264,7 +271,7 @@ def generate_visualisations(query_db, sys.stderr.write("Must specify at least one type of visualisation to output\n") sys.exit(1) if cytoscape and not (microreact or phandango or grapetree): - if rank_fit == None and (network_file == None or not os.path.isfile(network_file)): + if rank_fit == None and not recalculate_distances and (network_file == None or not os.path.isfile(network_file)): sys.stderr.write("For cytoscape, specify either a network file to visualise " "with --network-file or a lineage model with --rank-fit\n") sys.exit(1) @@ -278,11 +285,31 @@ def generate_visualisations(query_db, sys.stderr.write("Cannot create output directory\n") sys.exit(1) - #*******************************# - #* *# - #* Extract subset of sequences *# - #* *# - #*******************************# + #******************************# + #* *# + #* Determine type of distance *# + #* to use *# + #* *# + #******************************# + + # Determine whether to use sparse distances + use_sparse = False + use_dense = False + if (tree == "nj" or tree == "both") or rank_fit == None: + use_dense = True + elif (tree == "mst" or tree == "both") and rank_fit is not None: + # Set flag + use_sparse = True + # Read list of sequence names and sparse distance matrix + rlist = combined_seq + sparse_mat = sparse.load_npz(rank_fit) + # Check previous distances have been supplied if building on a previous MST + old_rlist = None + if previous_distances is not None: + old_rlist = read_rlist_from_distance_pickle(previous_distances + '.pkl') + elif previous_mst is not None: + sys.stderr.write('The prefix of the distance files used to create the previous MST' + ' is needed to use the network') # Identify distance matrix for ordered names if distances is None: @@ -293,6 +320,12 @@ def generate_visualisations(query_db, else: distances = distances + #*******************************# + #* *# + #* Extract subset of sequences *# + #* *# + #*******************************# + # Location and properties of reference database ref_db_loc = ref_db + "/" + os.path.basename(ref_db) kmers, sketch_sizes, codon_phased = readDBParams(ref_db) @@ -302,6 +335,7 @@ def generate_visualisations(query_db, all_seq = combined_seq # all_seq is an immutable record use for network parsing if include_files is not None or use_partial_query_graph is not None: viz_subset = set() + # Just use the isolates from the assign output subset_file = include_files if include_files is not None else use_partial_query_graph with open(subset_file, 'r') as assemblyFiles: for assembly in assemblyFiles: @@ -312,32 +346,6 @@ def generate_visualisations(query_db, else: viz_subset = None - #******************************# - #* *# - #* Determine type of distance *# - #* to use *# - #* *# - #******************************# - - # Determine whether to use sparse distances - use_sparse = False - use_dense = False - if (tree == "nj" or tree == "both") or rank_fit == None: - use_dense = True - elif (tree == "mst" or tree == "both") and rank_fit is not None: - # Set flag - use_sparse = True - # Read list of sequence names and sparse distance matrix - rlist = combined_seq - sparse_mat = sparse.load_npz(rank_fit) - # Check previous distances have been supplied if building on a previous MST - old_rlist = None - if previous_distances is not None: - old_rlist = read_rlist_from_distance_pickle(previous_distances + '.pkl') - elif previous_mst is not None: - sys.stderr.write('The prefix of the distance files used to create the previous MST' - ' is needed to use the network') - #**********************************# #* *# #* Process clustering information *# @@ -414,6 +422,24 @@ def generate_visualisations(query_db, return_dict = True) isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering) + # Add extra isolates to partial query graph if requested + if use_partial_query_graph and extend_query_graph: + # First identify the query clusteres + query_clusters = set() + cluster_types = ['Cluster'] + query_cluster_isolates = [] + for cluster_type in cluster_types: + for isolate in viz_subset: + query_clusters.add(isolateClustering[cluster_type][isolate]) + # Then identify the isolates in these clusters + for cluster_type in cluster_types: + for isolate in isolateClustering[cluster_type]: + if isolateClustering[cluster_type][isolate] in query_clusters: + # Only add reference isolates if using a reference database + if isolate in all_seq: + query_cluster_isolates.append(isolate) + viz_subset = viz_subset.union(query_cluster_isolates) + #******************************# #* *# #* Process dense or sparse *# @@ -421,7 +447,7 @@ def generate_visualisations(query_db, #* *# #******************************# - if (tree == "nj" or tree == "both") or (model.type == 'lineage' and rank_fit == None): + if (tree == "nj" or tree == "both" or cytoscape) or (model.type == 'lineage' and rank_fit == None): # Either calculate or read distances if recalculate_distances: @@ -519,7 +545,6 @@ def generate_visualisations(query_db, num_threads=threads, use_gpu=gpu_dist, device_id=deviceid) - else: qlist = None qr_distMat = None @@ -573,29 +598,27 @@ def generate_visualisations(query_db, clustering_name = display_cluster else: clustering_name = list(isolateClustering.keys())[0] - if use_sparse: - G = generate_mst_from_sparse_input(sparse_mat, - rlist, - old_rlist = old_rlist, + # Generate MST from recalculated network + if use_dense: + G = generate_network_from_distances('dense', + model, + core_distMat = core_distMat, + acc_distMat = acc_distMat, + combined_seq = combined_seq, + distance_type = mst_distances, + threads = threads, + gpu_graph = gpu_graph) + elif use_sparse: + G = generate_network_from_distances('sparse', + model, + sparse_mat = sparse_mat, previous_mst = previous_mst, + rlist = rlist, + old_rlist = old_rlist, + distance_type = mst_distances, + model = model, + threads = threads, gpu_graph = gpu_graph) - elif use_dense: - # Get distance matrix - complete_distMat = \ - np.hstack((pp_sketchlib.squareToLong(core_distMat, threads).reshape(-1, 1), - pp_sketchlib.squareToLong(acc_distMat, threads).reshape(-1, 1))) - # Dense network may be slow - sys.stderr.write("Generating MST from dense distances (may be slow)\n") - G = construct_network_from_assignments(combined_seq, - combined_seq, - [0]*complete_distMat.shape[0], - within_label = 0, - distMat = complete_distMat, - weights_type = mst_distances, - use_gpu = gpu_graph, - summarise = False) - if gpu_graph: - G = cugraph.minimum_spanning_tree(G, weight='weights') else: sys.stderr.write("Need either sparse or dense distances matrix to construct MST\n") exit(1) @@ -699,10 +722,30 @@ def generate_visualisations(query_db, if gpu_graph: genomeNetwork = cugraph_to_graph_tool(genomeNetwork, isolateNameToLabel(all_seq)) # Hard delete from network to remove samples (mask doesn't work neatly) - if include_files is not None: - genomeNetwork = remove_nodes_from_graph(genomeNetwork, all_seq, viz_subset, use_gpu = gpu_graph) + if viz_subset is not None: + genomeNetwork = remove_nodes_from_graph(genomeNetwork, all_seq, viz_subset, use_gpu = False) elif rank_fit is not None: genomeNetwork = sparse_mat_to_network(sparse_mat, combined_seq, use_gpu = gpu_graph) + elif recalculate_distances: + # Recalculate network from new distances + if use_dense: + genomeNetwork = generate_network_from_distances(mode = 'dense', + core_distMat = core_distMat, + acc_distMat = acc_distMat, + combined_seq = combined_seq, + model = model, + distance_type = mst_distances, + threads = threads, + gpu_graph = gpu_graph) + elif use_sparse: + genomeNetwork = generate_network_from_distances(mode = 'sparse', + sparse_mat = sparse_mat, + previous_mst = previous_mst, + rlist = rlist, + old_rlist = old_rlist, + distance_type = mst_distances, + threads = threads, + gpu_graph = gpu_graph) else: sys.stderr.write('Cytoscape output requires a network file or lineage rank fit to be provided\n') sys.exit(1) @@ -756,6 +799,7 @@ def main(): args.overwrite, args.display_cluster, args.use_partial_query_graph, + args.extend_query_graph, args.recalculate_distances, args.tmp) diff --git a/test/batch123_info.csv b/test/batch123_info.csv new file mode 100644 index 00000000..eae6495a --- /dev/null +++ b/test/batch123_info.csv @@ -0,0 +1,5 @@ +id,Location +19183_4#55,CountryX +19183_4#48,CountryY +12754_4#89,CountryZ +12754_5#16,CountryZ diff --git a/test/run_test.py b/test/run_test.py index 09b4b72b..2b8eb42c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -134,7 +134,7 @@ subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model dbscan --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch12 --query rfile3.txt --output batch3 --external-clustering batch12_external_clusters.csv --save-partial-query-graph --overwrite", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch3/batch3_graph.gt --use-partial-query-graph ./batch3/batch3_query.subset --recalculate-distances --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --use-partial-query-graph ./batch3/batch3_query.subset --recalculate-distances --extend-query-graph --info-csv batch123_info.csv --microreact --overwrite", shell=True, check=True) # citations sys.stderr.write("Printing citations\n")