diff --git a/PopPUNK/__init__.py b/PopPUNK/__init__.py index d7f9d184..cbaf1fcf 100644 --- a/PopPUNK/__init__.py +++ b/PopPUNK/__init__.py @@ -3,7 +3,7 @@ '''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)''' -__version__ = '2.7.2' +__version__ = '2.7.1' # Minimum sketchlib version SKETCHLIB_MAJOR = 2 diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index ec162121..13e5beee 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -50,6 +50,7 @@ def get_options(): oGroup.add_argument('--update-db', help='Update reference database with query sequences', default=False, action='store_true') oGroup.add_argument('--overwrite', help='Overwrite any existing database files', default=False, action='store_true') oGroup.add_argument('--graph-weights', help='Save within-strain Euclidean distances into the graph', default=False, action='store_true') + oGroup.add_argument('--save-partial-query-graph', help='Save the network components to which queries are assigned', default=False, action='store_true') # comparison metrics kmerGroup = parser.add_argument_group('Kmer comparison options') @@ -106,6 +107,9 @@ def get_options(): queryingGroup.add_argument('--accessory', help='(with a \'refine\' or \'lineage\' model) ' 'Use an accessory-distance only model for assigning queries ' '[default = False]', default=False, action='store_true') + queryingGroup.add_argument('--use-full-network', help='Use full network rather than reference network for querying [default = False]', + default = False, + action = 'store_true') # processing other = parser.add_argument_group('Other options') @@ -234,7 +238,8 @@ def main(): args.gpu_dist, args.gpu_graph, args.deviceid, - save_partial_query_graph=False) + args.save_partial_query_graph, + args.use_full_network) sys.stderr.write("\nDone\n") @@ -267,7 +272,8 @@ def assign_query(dbFuncs, gpu_dist, gpu_graph, deviceid, - save_partial_query_graph): + save_partial_query_graph, + use_full_network): """Code for assign query mode for CLI""" createDatabaseDir = dbFuncs['createDatabaseDir'] constructDatabase = dbFuncs['constructDatabase'] @@ -316,7 +322,8 @@ def assign_query(dbFuncs, accessory, gpu_dist, gpu_graph, - save_partial_query_graph) + save_partial_query_graph, + use_full_network) return(isolateClustering) def assign_query_hdf5(dbFuncs, @@ -341,7 +348,8 @@ def assign_query_hdf5(dbFuncs, accessory, gpu_dist, gpu_graph, - save_partial_query_graph): + save_partial_query_graph, + use_full_network): """Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called by web APIs""" # Modules imported here as graph tool is very slow to load (it pulls in all of GTK?) @@ -359,6 +367,7 @@ def assign_query_hdf5(dbFuncs, from .network import get_vertex_list from .network import printExternalClusters from .network import vertex_betweenness + from .network import remove_non_query_components from .qc import sketchlibAssemblyQC from .plot import writeClusterCsv @@ -453,7 +462,7 @@ def assign_query_hdf5(dbFuncs, ref_file_name = os.path.join(model_prefix, os.path.basename(model_prefix) + file_extension_string + ".refs") use_ref_graph = \ - os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' + os.path.isfile(ref_file_name) and not update_db and model.type != 'lineage' and not use_full_network if use_ref_graph: with open(ref_file_name) as refFile: for reference in refFile: @@ -791,12 +800,16 @@ def assign_query_hdf5(dbFuncs, output + "/" + os.path.basename(output) + db_suffix) else: storePickle(rNames, qNames, False, qrDistMat, dists_out) - if save_partial_query_graph and not serial: - if model.type == 'lineage': + if save_partial_query_graph: + genomeNetwork, pruned_isolate_lists = remove_non_query_components(genomeNetwork, rNames, qNames, use_gpu = gpu_graph) + if model.type == 'lineage' and not serial: save_network(genomeNetwork[min(model.ranks)], prefix = output, suffix = '_graph', use_gpu = gpu_graph) else: graph_suffix = file_extension_string + '_graph' save_network(genomeNetwork, prefix = output, suffix = graph_suffix, use_gpu = gpu_graph) + with open(f"{output}/{os.path.basename(output)}_query.subset",'w') as pruned_isolate_csv: + for isolate in pruned_isolate_lists: + pruned_isolate_csv.write(isolate + '\n') return(isolateClustering) diff --git a/PopPUNK/lineages.py b/PopPUNK/lineages.py index 36afe729..7b12f045 100755 --- a/PopPUNK/lineages.py +++ b/PopPUNK/lineages.py @@ -392,7 +392,8 @@ def query_db(args): accessory, args.gpu_dist, args.gpu_graph, - save_partial_query_graph = False) + save_partial_query_graph = False, + use_full_network = True) # Use full network - does not make sense to use references for lineages # Process clustering query_strains = {} @@ -439,7 +440,8 @@ def query_db(args): accessory, args.gpu_dist, args.gpu_graph, - save_partial_query_graph = False) + save_partial_query_graph = False, + use_full_network = True) overall_lineage[strain] = createOverallLineage(rank_list, lineageClustering) # Print combined strain and lineage clustering diff --git a/PopPUNK/models.py b/PopPUNK/models.py index c103c31e..279d6e6b 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -1124,7 +1124,9 @@ def __init__(self, outPrefix, ranks, max_search_depth, reciprocal_only, count_un ClusterFit.__init__(self, outPrefix) self.type = 'lineage' self.preprocess = False - self.max_search_depth = max_search_depth + self.max_search_depth = max_search_depth+5 # Set to highest rank by default in main; need to store additional distances + # when there is redundancy (e.g. reciprocal matching, unique distance counting) + # or other sequences may be pruned out of the database self.nn_dists = None # stores the unprocessed kNN at the maximum search depth self.ranks = [] for rank in sorted(ranks): @@ -1368,6 +1370,7 @@ def extend(self, qqDists, qrDists): qrRect, self.max_search_depth, self.threads) + # Update NN dist associated with model self.__save_sparse__(higher_rank[2], higher_rank[0], higher_rank[1], self.max_search_depth, n_ref + n_query, self.nn_dists.dtype, diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 2568e934..d91fa699 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1920,28 +1920,8 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g if os.path.exists(network_fn): network_found = True sys.stderr.write("Loading network from " + network_fn + "\n") - samples_to_keep_set = frozenset(samples_to_keep) G = load_network_file(network_fn, use_gpu = use_gpu) - if use_gpu: - # Identify indices - reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set] - # Generate data frame - G_df = G.view_edge_list() - if 'src' in G_df.columns: - G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) - # Filter data frame - G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] - # Translate network indices to match name order - G_new = translate_network_indices(G_new_df, reference_indices) - else: - reference_vertex = G.new_vertex_property('bool') - for n, vertex in enumerate(G.vertices()): - if reflist[n] in samples_to_keep_set: - reference_vertex[vertex] = True - else: - reference_vertex[vertex] = False - G_new = gt.GraphView(G, vfilt = reference_vertex) - G_new = gt.Graph(G_new, prune = True) + G_new = remove_nodes_from_graph(G, reflist, samples_to_keep, use_gpu) save_network(G_new, prefix = output_db_name, suffix = '_graph', @@ -1949,3 +1929,90 @@ def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_g use_gpu = use_gpu) if not network_found: sys.stderr.write('No network file found for pruning\n') + +def remove_nodes_from_graph(G,reflist, samples_to_keep, use_gpu): + """Return a modified graph containing only the requested nodes + + Args: + reflist (list) + Ordered list of sequences of database + samples_to_keep (list) + The names of samples to be retained in the graph + use_gpu (bool) + Whether graph is a cugraph or not + [default = False] + + Returns: + G_new (graph) + Pruned graph + """ + samples_to_keep_set = frozenset(samples_to_keep) + if use_gpu: + # Identify indices + reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set] + # Generate data frame + G_df = G.view_edge_list() + if 'src' in G_df.columns: + G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + # Filter data frame + G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + # Translate network indices to match name order + G_new = translate_network_indices(G_new_df, reference_indices) + else: + reference_vertex = G.new_vertex_property('bool') + for n, vertex in enumerate(G.vertices()): + if reflist[n] in samples_to_keep_set: + reference_vertex[vertex] = True + else: + reference_vertex[vertex] = False + G_new = gt.GraphView(G, vfilt = reference_vertex) + G_new = gt.Graph(G_new, prune = True) + return G_new + +def remove_non_query_components(G, rlist, qlist, use_gpu = False): + """ + Removes all components that do not contain a query sequence. + + Args: + G (graph) + Network of queries linked to reference sequences + rlist (list) + List of reference sequence labels + qlist (list) + List of query sequence labels + use_gpu (bool) + Whether to use GPUs for network construction + + Returns: + G (graph) + The resulting network + pruned_names (list) + The labels of the sequences in the pruned network + """ + components_with_query = [] + combined_names = rlist + qlist + pruned_names = [] + if use_gpu: + sys.stderr.write('Saving partial query graphs is not compatible with GPU networks yet\n') + sys.exit(1) + else: + # Identify network components containing queries + component_dict = gt.label_components(G)[0] + components_with_query = set() + # The number of reference sequences is len(rlist) + # These are the first len(rlist) vertices in the graph + # Queries that have been added have indices >len(rlist) + # Therefore these are the components to retain + for i in range(len(rlist),G.num_vertices()): + v = G.vertex(i) # Access vertex by index + components_with_query.add(component_dict[v]) + # Create a boolean filter based on the list of component IDs + query_filter = G.new_vertex_property("bool") + for v in G.vertices(): + query_filter[int(v)] = (component_dict[v] in components_with_query) + if query_filter[int(v)]: + pruned_names.append(combined_names[int(v)]) + # Create a filtered graph with only the specified components + query_subgraph = gt.GraphView(G, vfilt=query_filter) + + return query_subgraph, pruned_names diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 6f7a995b..26663c62 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -511,7 +511,7 @@ def drawMST(mst, outPrefix, isolate_clustering, clustering_name, overwrite): output=graph2_file_name, output_size=(3000, 3000)) def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, queryList = None, - suffix = None, writeCsv = True): + suffix = None, writeCsv = True, use_partial_query_graph = None): """Write outputs for cytoscape. A graphml of the network, and CSV with metadata Args: @@ -536,6 +536,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, (default = None) writeCsv (bool) Whether to print CSV file to accompany network + use_partial_query_graph (str) + File listing sequences to be included in output graph """ # Avoid circular import @@ -553,7 +555,8 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, suffix = '_cytoscape' else: suffix = suffix + '_cytoscape' - save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True) + if use_partial_query_graph is None: + save_network(G, prefix = outPrefix, suffix = suffix, use_graphml = True) # Save each component too (useful for very large graphs) component_assignments, component_hist = gt.label_components(G) @@ -562,10 +565,9 @@ def outputsForCytoscape(G, G_mst, isolate_names, clustering, outPrefix, epiCsv, for vidx, v_component in enumerate(component_assignments.a): if v_component != component_idx: remove_list.append(vidx) - G_copy = G.copy() - G_copy.remove_vertex(remove_list) - save_network(G_copy, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True) - del G_copy + G.remove_vertex(remove_list) + G.purge_vertices() + save_network(G, prefix = outPrefix, suffix = "_component_" + str(component_idx + 1), use_graphml = True) if G_mst != None: isolate_labels = isolateNameToLabel(G_mst.vp.id) diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index df8b941d..29c0445a 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -213,7 +213,7 @@ def getSeqsInDb(dbname): return seqs -def joinDBs(db1, db2, output, update_random = None): +def joinDBs(db1, db2, output, update_random = None, full_names = False): """Join two sketch databases with the low-level HDF5 copy interface Args: @@ -226,10 +226,19 @@ def joinDBs(db1, db2, output, update_random = None): update_random (dict) Whether to re-calculate the random object. May contain control arguments strand_preserved and threads (see :func:`addRandom`) + full_names (bool) + If True, db_name and out_name are the full paths to h5 files + """ - join_prefix = output + "/" + os.path.basename(output) - db1_name = db1 + "/" + os.path.basename(db1) + ".h5" - db2_name = db2 + "/" + os.path.basename(db2) + ".h5" + + if not full_names: + join_prefix = output + "/" + os.path.basename(output) + db1_name = db1 + "/" + os.path.basename(db1) + ".h5" + db2_name = db2 + "/" + os.path.basename(db2) + ".h5" + else: + db1_name = db1 + db2_name = db2 + join_prefix = output hdf1 = h5py.File(db1_name, 'r') hdf2 = h5py.File(db2_name, 'r') diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index d411ecbf..4e5eaa73 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -593,7 +593,7 @@ def check_and_set_gpu(use_gpu, gpu_lib, quit_on_fail = False): return use_gpu -def read_rlist_from_distance_pickle(fn, allow_non_self = True): +def read_rlist_from_distance_pickle(fn, allow_non_self = True, include_queries = False): """Return the list of reference sequences from a distance pickle. Args: @@ -601,6 +601,8 @@ def read_rlist_from_distance_pickle(fn, allow_non_self = True): Name of distance pickle allow_non_self (bool) Whether non-self distance datasets are permissible + include_queries (bool) + Whether queries should be included in the rlist Returns: rlist (list) List of reference sequence names @@ -611,6 +613,8 @@ def read_rlist_from_distance_pickle(fn, allow_non_self = True): sys.stderr.write("Thi analysis requires an all-v-all" " distance dataset\n") sys.exit(1) + if include_queries: + rlist = rlist + qlist return rlist def get_match_search_depth(rlist,rank_list): diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 40ce0fff..c067ebc7 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -84,12 +84,19 @@ def get_options(): 'minimum spanning tree', default=None, type = str) + iGroup.add_argument('--recalculate-distances', + help='Recalculate pairwise distances rather than read them from a file', + default=False, + action = 'store_true') iGroup.add_argument('--network-file', help='Specify a file to use for any graph visualisations', type = str) iGroup.add_argument('--display-cluster', help='Column of clustering CSV to use for plotting', default=None) + iGroup.add_argument('--use-partial-query-graph', + help='File listing sequences in partial query graph after assignment', + default=None) # output options oGroup = parser.add_argument_group('Output options') @@ -159,6 +166,22 @@ def get_options(): return args +# Create temporary pruned database +def create_pruned_tmp_db(prefix,subset): + + from .sketchlib import removeFromDB + from .sketchlib import getSeqsInDb + + h5_name = prefix + "/" + os.path.basename(prefix) + ".h5" + tmp_h5_name = prefix + "/" + os.path.basename(prefix) + ".tmp.h5" + sequences_in_db = getSeqsInDb(h5_name) + removeFromDB(h5_name, + prefix + "/" + os.path.basename(prefix) + ".tmp.h5", + set(sequences_in_db) - subset, + full_names = True + ) + return tmp_h5_name, sequences_in_db + def generate_visualisations(query_db, ref_db, distances, @@ -190,6 +213,8 @@ def generate_visualisations(query_db, mst_distances, overwrite, display_cluster, + use_partial_query_graph, + recalculate_distances, tmp): from .models import loadClusterFit @@ -200,6 +225,7 @@ def generate_visualisations(query_db, from .network import cugraph_to_graph_tool from .network import save_network from .network import sparse_mat_to_network + from .network import remove_nodes_from_graph from .plot import drawMST from .plot import outputsForMicroreact @@ -210,7 +236,8 @@ def generate_visualisations(query_db, from .sketchlib import readDBParams from .sketchlib import addRandom - + from .sketchlib import joinDBs + from .sparse_mst import generate_mst_from_sparse_input from .trees import load_tree, generate_nj_tree, mst_to_phylogeny @@ -251,13 +278,13 @@ def generate_visualisations(query_db, sys.stderr.write("Cannot create output directory\n") sys.exit(1) - #******************************# - #* *# - #* Process dense or sparse *# - #* distances *# - #* *# - #******************************# + #*******************************# + #* *# + #* Extract subset of sequences *# + #* *# + #*******************************# + # Identify distance matrix for ordered names if distances is None: if query_db is None: distances = ref_db + "/" + os.path.basename(ref_db) + ".dists" @@ -266,17 +293,43 @@ def generate_visualisations(query_db, else: distances = distances + # Location and properties of reference database + ref_db_loc = ref_db + "/" + os.path.basename(ref_db) + kmers, sketch_sizes, codon_phased = readDBParams(ref_db) + + # extract subset of distances if requested + combined_seq = read_rlist_from_distance_pickle(distances + '.pkl', include_queries = True) + all_seq = combined_seq # all_seq is an immutable record use for network parsing + if include_files is not None or use_partial_query_graph is not None: + viz_subset = set() + subset_file = include_files if include_files is not None else use_partial_query_graph + with open(subset_file, 'r') as assemblyFiles: + for assembly in assemblyFiles: + viz_subset.add(assembly.rstrip()) + if len(viz_subset.difference(combined_seq)) > 0: + sys.stderr.write("--include-files contains names not in --distances\n") + sys.stderr.write("Please assign distances before subsetting the database\n") + else: + viz_subset = None + + #******************************# + #* *# + #* Determine type of distance *# + #* to use *# + #* *# + #******************************# + # Determine whether to use sparse distances - combined_seq = None use_sparse = False use_dense = False - if (tree == "mst" or tree == "both") and rank_fit is not None: + if (tree == "nj" or tree == "both") or rank_fit == None: + use_dense = True + elif (tree == "mst" or tree == "both") and rank_fit is not None: # Set flag use_sparse = True # Read list of sequence names and sparse distance matrix - rlist = read_rlist_from_distance_pickle(distances + '.pkl') + rlist = combined_seq sparse_mat = sparse.load_npz(rank_fit) - combined_seq = rlist # Check previous distances have been supplied if building on a previous MST old_rlist = None if previous_distances is not None: @@ -284,95 +337,6 @@ def generate_visualisations(query_db, elif previous_mst is not None: sys.stderr.write('The prefix of the distance files used to create the previous MST' ' is needed to use the network') - if (tree == "nj" or tree == "both") or rank_fit == None: - use_dense = True - # Process dense distance matrix - rlist, qlist, self, complete_distMat = readPickle(distances) - if not self: - qr_distMat = complete_distMat - combined_seq = rlist + qlist - else: - rr_distMat = complete_distMat - combined_seq = rlist - - # Fill in qq-distances if required - if self == False: - sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n" - "Note: Distance will be extended to full all-vs-all distances\n" - "Note: Re-run poppunk_assign with --update-db to avoid this\n") - ref_db_loc = ref_db + "/" + os.path.basename(ref_db) - rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists") - if not self_ref: - sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n") - sys.exit(1) - kmers, sketch_sizes, codon_phased = readDBParams(query_db) - addRandom(query_db, qlist, kmers, - strand_preserved = strand_preserved, threads = threads) - query_db_loc = query_db + "/" + os.path.basename(query_db) - qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc, - query_db_name=query_db_loc, - rList=qlist, - qList=qlist, - klist=kmers, - random_correct=True, - jaccard=False, - num_threads=threads, - use_gpu=gpu_dist, - device_id=deviceid) - - # If the assignment was run with references, qrDistMat will be incomplete - if rlist != rlist_original: - rlist = rlist_original - qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc, - query_db_name=query_db_loc, - rList=rlist, - qList=qlist, - klist=kmers, - random_correct=True, - jaccard=False, - num_threads=threads, - use_gpu=gpu_dist, - device_id=deviceid) - - else: - qlist = None - qr_distMat = None - qq_distMat = None - - # Turn long form matrices into square form - combined_seq, core_distMat, acc_distMat = \ - update_distance_matrices(rlist, rr_distMat, - qlist, qr_distMat, qq_distMat, - threads = threads) - - #*******************************# - #* *# - #* Extract subset of sequences *# - #* *# - #*******************************# - - # extract subset of distances if requested - all_seq = combined_seq - if include_files is not None: - viz_subset = set() - with open(include_files, 'r') as assemblyFiles: - for assembly in assemblyFiles: - viz_subset.add(assembly.rstrip()) - if len(viz_subset.difference(combined_seq)) > 0: - sys.stderr.write("--include-files contains names not in --distances\n") - - # Only keep found rows - row_slice = [True if name in viz_subset else False for name in combined_seq] - combined_seq = [name for name in combined_seq if name in viz_subset] - if use_sparse: - sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)] - if use_dense: - if qlist != None: - qlist = list(viz_subset.intersection(qlist)) - core_distMat = core_distMat[np.ix_(row_slice, row_slice)] - acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)] - else: - viz_subset = None #**********************************# #* *# @@ -438,7 +402,7 @@ def generate_visualisations(query_db, # Join clusters with query clusters if required if use_dense: - if not self: + if query_db is not None: if previous_query_clustering is not None: prev_query_clustering = previous_query_clustering else: @@ -450,6 +414,136 @@ def generate_visualisations(query_db, return_dict = True) isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering) + #******************************# + #* *# + #* Process dense or sparse *# + #* distances *# + #* *# + #******************************# + + if (tree == "nj" or tree == "both") or (model.type == 'lineage' and rank_fit == None): + + # Either calculate or read distances + if recalculate_distances: + sys.stderr.write("Recalculating pairwise distances for tree construction\n") + + # Merge relevant sequences into a single database + sys.stderr.write("Generating merged database\n") + if viz_subset is not None: + sequences_to_analyse = list(viz_subset) if viz_subset is not None else combined_seq + # Filter from reference database + tmp_ref_h5_file, rlist = create_pruned_tmp_db(ref_db,viz_subset) + else: + sequences_to_analyse = combined_seq + tmp_ref_h5_file = ref_db + viz_db_name = output + "/" + os.path.basename(output) + if query_db is not None: + # Add from query database + query_db_loc = query_db + "/" + os.path.basename(query_db) + tmp_query_h5_file, qlist = create_pruned_tmp_db(query_db,viz_subset) + joinDBs(tmp_ref_h5_file, + tmp_query_h5_file, + viz_db_name, + full_names = True) + os.remove(tmp_query_h5_file) + os.remove(tmp_ref_h5_file) + else: + os.rename(tmp_ref_h5_file,viz_db_name) + + # Generate distances + sys.stderr.write("Comparing sketches\n") + self = True + subset_distMat = pp_sketchlib.queryDatabase(ref_db_name=viz_db_name, + query_db_name=viz_db_name, + rList=sequences_to_analyse, + qList=sequences_to_analyse, + klist=kmers.tolist(), + random_correct=True, + jaccard=False, + num_threads=threads, + use_gpu = gpu_dist, + device_id = deviceid) + + # Convert distance matrix format + combined_seq, core_distMat, acc_distMat = \ + update_distance_matrices(sequences_to_analyse, + subset_distMat, + threads = threads) + + else: + sys.stderr.write("Reading pairwise distances for tree construction\n") + + # Process dense distance matrix + rlist, qlist, self, complete_distMat = readPickle(distances) + if not self: + qr_distMat = complete_distMat + combined_seq = rlist + qlist + else: + rr_distMat = complete_distMat + combined_seq = rlist + + # Fill in qq-distances if required + if self == False: + sys.stderr.write("Note: Distances in " + distances + " are from assign mode\n" + "Note: Distance will be extended to full all-vs-all distances\n" + "Note: Re-run poppunk_assign with --update-db to avoid this\n") + rlist_original, qlist_original, self_ref, rr_distMat = readPickle(ref_db_loc + ".dists") + if not self_ref: + sys.stderr.write("Distances in " + ref_db + " not self all-vs-all either\n") + sys.exit(1) + kmers, sketch_sizes, codon_phased = readDBParams(query_db) + addRandom(query_db, qlist, kmers, + strand_preserved = strand_preserved, threads = threads) + query_db_loc = query_db + "/" + os.path.basename(query_db) + qq_distMat = pp_sketchlib.queryDatabase(ref_db_name=query_db_loc, + query_db_name=query_db_loc, + rList=qlist, + qList=qlist, + klist=kmers, + random_correct=True, + jaccard=False, + num_threads=threads, + use_gpu=gpu_dist, + device_id=deviceid) + + # If the assignment was run with references, qrDistMat will be incomplete + if rlist != rlist_original: + rlist = rlist_original + qr_distMat = pp_sketchlib.queryDatabase(ref_db_name=ref_db_loc, + query_db_name=query_db_loc, + rList=rlist, + qList=qlist, + klist=kmers, + random_correct=True, + jaccard=False, + num_threads=threads, + use_gpu=gpu_dist, + device_id=deviceid) + + else: + qlist = None + qr_distMat = None + qq_distMat = None + + # Turn long form matrices into square form + combined_seq, core_distMat, acc_distMat = \ + update_distance_matrices(rlist, rr_distMat, + qlist, qr_distMat, qq_distMat, + threads = threads) + + # Prune distance matrix if subsetting data + if viz_subset is not None: + row_slice = [True if name in viz_subset else False for name in combined_seq] + combined_seq = [name for name in combined_seq if name in viz_subset] + if use_sparse: + sparse_mat = sparse_mat[np.ix_(row_slice, row_slice)] + if use_dense: + if qlist != None: + qlist = list(viz_subset.intersection(qlist)) + core_distMat = core_distMat[np.ix_(row_slice, row_slice)] + acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)] + + #*******************# #* *# #* Generate trees *# @@ -605,23 +699,25 @@ def generate_visualisations(query_db, if gpu_graph: genomeNetwork = cugraph_to_graph_tool(genomeNetwork, isolateNameToLabel(all_seq)) # Hard delete from network to remove samples (mask doesn't work neatly) - if viz_subset is not None: - remove_list = [] - for keep, idx in enumerate(row_slice): - if not keep: - remove_list.append(idx) - genomeNetwork.remove_vertex(remove_list) + if include_files is not None and not use_partial_query_graph: + genomeNetwork = remove_nodes_from_graph(genomeNetwork, all_seq, viz_subset, use_gpu = gpu_graph) elif rank_fit is not None: genomeNetwork = sparse_mat_to_network(sparse_mat, combined_seq, use_gpu = gpu_graph) else: sys.stderr.write('Cytoscape output requires a network file or lineage rank fit to be provided\n') sys.exit(1) + # If network has been pruned then only use the appropriate subset of names - otherwise use all names + # for full network + node_labels = viz_subset if (use_partial_query_graph is not None or include_files is not None) \ + else combined_seq + sys.stderr.write('Preparing outputs for cytoscape\n') outputsForCytoscape(genomeNetwork, mst_graph, - combined_seq, + node_labels, isolateClustering, output, - info_csv) + info_csv, + use_partial_query_graph = use_partial_query_graph) if model.type == 'lineage': sys.stderr.write("Note: Only support for output of cytoscape graph at lowest rank\n") @@ -663,6 +759,8 @@ def main(): args.mst_distances, args.overwrite, args.display_cluster, + args.use_partial_query_graph, + args.recalculate_distances, args.tmp) if __name__ == '__main__': diff --git a/docs/visualisation.rst b/docs/visualisation.rst index c9a359df..4697dadc 100644 --- a/docs/visualisation.rst +++ b/docs/visualisation.rst @@ -24,6 +24,14 @@ At least one of these output types must be specified as a flag. If you are running multiple visualisations this calculation will be completed every time. To avoid this re-run your assignment with ``--update-db``, which will add these distances in permanently. +If you are only interested in visualising sequences that are closely-related to those in a set of +query sequences, then the quickest approach is to use the ``--save-partial-query-graph`` option +when assigning sequences. This will generate a list of sequences spanning both your queries, and all +the reference database isolates that are in the same network components (i.e., excluding those that +are not related to your query isolates); this is contained in the file ``[prefix]_query.subset``. This list +can be passed to the visualisation command using the flag ``--use-partial-query-graph ./[prefix]/[prefix]_query.subset`` +to only visualise the parts of the network that are relevant to your query. + Common options -------------- Some typical commands for various input settings (with ``--microreact``, but this can diff --git a/src/extend.cpp b/src/extend.cpp index c053481d..9e7d9054 100644 --- a/src/extend.cpp +++ b/src/extend.cpp @@ -9,7 +9,7 @@ #include #include #include - +#include const float epsilon = 1E-10; // Get indices where each row starts in the sparse matrix @@ -93,7 +93,7 @@ sparse_coo extend(const sparse_coo &sparse_rr_mat, // This is very similar, but merging two lists as input auto rr_it = rr_ordered_idx.cbegin(); auto qr_it = qr_ordered_idx.cbegin(); - while (qr_it != qr_ordered_idx.cend() && rr_it != rr_ordered_idx.cend()) { + while (qr_it != qr_ordered_idx.cend() || rr_it != rr_ordered_idx.cend()) { // Get the next smallest dist, and corresponding j long j; float dist; @@ -103,7 +103,7 @@ sparse_coo extend(const sparse_coo &sparse_rr_mat, j = *qr_it + nr_samples; dist = qr_dists[*qr_it]; ++qr_it; - } else { + } else if (!(rr_it == rr_ordered_idx.cend())) { if (i < nr_samples) { j = std::get<1>(sparse_rr_mat)[row_start_idx[i] + *rr_it]; } else { @@ -111,6 +111,9 @@ sparse_coo extend(const sparse_coo &sparse_rr_mat, } dist = rr_dists[*rr_it]; ++rr_it; + } else { + std::cerr << "Insufficient distances for specified kNN value; try reducing the maximum search depth" << std::endl; + pybind11::key_error(); } if (j == i) { diff --git a/test/batch12_external_clusters.csv b/test/batch12_external_clusters.csv new file mode 100644 index 00000000..46cd9ecf --- /dev/null +++ b/test/batch12_external_clusters.csv @@ -0,0 +1,18 @@ +Taxon,Cluster +19183_4#67,CLUSTER1 +12754_5#57,CLUSTER1 +12673_8#34,CLUSTER1 +12754_5#55,CLUSTER1 +12754_4#85,CLUSTER1 +19183_4#69,CLUSTER1 +19183_4#66,CLUSTER1 +12754_5#71,CLUSTER1 +19183_4#55,CLUSTER2 +12754_4#89,CLUSTER2 +12754_5#73,CLUSTER3 +12754_4#79,CLUSTER3 +19183_4#63,CLUSTER4 +19183_4#59,CLUSTER5 +19183_4#48,CLUSTER6 +12754_5#37,CLUSTER7 +12754_4#71,CLUSTER8 diff --git a/test/example_set.tar.bz2 b/test/example_set.tar.bz2 index 5315e5fb..4fc61d9e 100644 Binary files a/test/example_set.tar.bz2 and b/test/example_set.tar.bz2 differ diff --git a/test/rfile1.txt b/test/rfile1.txt index 4f388da2..3dbc8d77 100644 --- a/test/rfile1.txt +++ b/test/rfile1.txt @@ -1,3 +1,10 @@ -7 12673_8#24.contigs_velvet.fa -1 12673_8#34.contigs_velvet.fa -2 12673_8#43.contigs_velvet.fa +12754_5#73 12754_5#73.contigs_velvet.fa +12754_4#79 12754_4#79.contigs_velvet.fa +12754_4#71 12754_4#71.contigs_velvet.fa +19183_4#55 19183_4#55.contigs_velvet.fa +19183_4#59 19183_4#59.contigs_velvet.fa +12754_4#89 12754_4#89.contigs_velvet.fa +19183_4#48 19183_4#48.contigs_velvet.fa +12754_5#37 12754_5#37.contigs_velvet.fa +19183_4#63 19183_4#63.contigs_velvet.fa +12754_5#71 12754_5#71.contigs_velvet.fa diff --git a/test/rfile12.txt b/test/rfile12.txt index e4f63584..51f99d88 100644 --- a/test/rfile12.txt +++ b/test/rfile12.txt @@ -1,6 +1,17 @@ -7 12673_8#24.contigs_velvet.fa -1 12673_8#34.contigs_velvet.fa -2 12673_8#43.contigs_velvet.fa -6 12754_4#79.contigs_velvet.fa -4 12754_4#85.contigs_velvet.fa -5 12754_4#89.contigs_velvet.fa +12754_5#73 12754_5#73.contigs_velvet.fa +12754_4#79 12754_4#79.contigs_velvet.fa +12754_4#71 12754_4#71.contigs_velvet.fa +19183_4#55 19183_4#55.contigs_velvet.fa +19183_4#59 19183_4#59.contigs_velvet.fa +12754_4#89 12754_4#89.contigs_velvet.fa +19183_4#48 19183_4#48.contigs_velvet.fa +12754_5#37 12754_5#37.contigs_velvet.fa +19183_4#63 19183_4#63.contigs_velvet.fa +12754_5#71 12754_5#71.contigs_velvet.fa +19183_4#67 19183_4#67.contigs_velvet.fa +19183_4#69 19183_4#69.contigs_velvet.fa +12754_5#55 12754_5#55.contigs_velvet.fa +12754_4#85 12754_4#85.contigs_velvet.fa +12673_8#34 12673_8#34.contigs_velvet.fa +19183_4#66 19183_4#66.contigs_velvet.fa +12754_5#57 12754_5#57.contigs_velvet.fa diff --git a/test/rfile123.txt b/test/rfile123.txt index af5a0ead..eac6e460 100644 --- a/test/rfile123.txt +++ b/test/rfile123.txt @@ -1,9 +1,19 @@ -7 12673_8#24.contigs_velvet.fa -1 12673_8#34.contigs_velvet.fa -2 12673_8#43.contigs_velvet.fa -6 12754_4#79.contigs_velvet.fa -4 12754_4#85.contigs_velvet.fa -5 12754_4#89.contigs_velvet.fa -8 12754_5#73.contigs_velvet.fa -3 12754_5#78.contigs_velvet.fa -9 12754_5#71.contigs_velvet.fa +12754_5#73 12754_5#73.contigs_velvet.fa +12754_4#79 12754_4#79.contigs_velvet.fa +12754_4#71 12754_4#71.contigs_velvet.fa +19183_4#55 19183_4#55.contigs_velvet.fa +19183_4#59 19183_4#59.contigs_velvet.fa +12754_4#89 12754_4#89.contigs_velvet.fa +19183_4#48 19183_4#48.contigs_velvet.fa +12754_5#37 12754_5#37.contigs_velvet.fa +19183_4#63 19183_4#63.contigs_velvet.fa +12754_5#71 12754_5#71.contigs_velvet.fa +19183_4#67 19183_4#67.contigs_velvet.fa +19183_4#69 19183_4#69.contigs_velvet.fa +12754_5#55 12754_5#55.contigs_velvet.fa +12754_4#85 12754_4#85.contigs_velvet.fa +12673_8#34 12673_8#34.contigs_velvet.fa +19183_4#66 19183_4#66.contigs_velvet.fa +12754_5#57 12754_5#57.contigs_velvet.fa +12754_5#16 12754_5#16.contigs_velvet.fa +12754_5#88 12754_5#88.contigs_velvet.fa diff --git a/test/rfile2.txt b/test/rfile2.txt index 5f6e9a24..00f4f494 100644 --- a/test/rfile2.txt +++ b/test/rfile2.txt @@ -1,3 +1,7 @@ -6 12754_4#79.contigs_velvet.fa -4 12754_4#85.contigs_velvet.fa -5 12754_4#89.contigs_velvet.fa +19183_4#67 19183_4#67.contigs_velvet.fa +19183_4#69 19183_4#69.contigs_velvet.fa +12754_5#55 12754_5#55.contigs_velvet.fa +12754_4#85 12754_4#85.contigs_velvet.fa +12673_8#34 12673_8#34.contigs_velvet.fa +19183_4#66 19183_4#66.contigs_velvet.fa +12754_5#57 12754_5#57.contigs_velvet.fa diff --git a/test/rfile3.txt b/test/rfile3.txt index 23104358..ff9e0673 100644 --- a/test/rfile3.txt +++ b/test/rfile3.txt @@ -1,3 +1,2 @@ -8 12754_5#73.contigs_velvet.fa -3 12754_5#78.contigs_velvet.fa -9 12754_5#71.contigs_velvet.fa +12754_5#16 12754_5#16.contigs_velvet.fa +12754_5#88 12754_5#88.contigs_velvet.fa diff --git a/test/run_test.py b/test/run_test.py index ed4e74eb..09b4b72b 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -81,6 +81,10 @@ subprocess.run(python_cmd + " ../poppunk_assign-runner.py --query inref_query.txt --db example_db --model-dir example_refine --output example_single_query --write-references", shell=True, check=True) # matched name, but should be renamed in the output subprocess.run(python_cmd + " ../poppunk_assign-runner.py --query some_queries.txt --db example_db --model-dir example_refine --model-dir example_lineages --output example_lineage_query --overwrite", shell=True, check=True) +#external clustering +sys.stderr.write("Running assign with external clustering (--fit-model refine)\n") +subprocess.run(python_cmd + " ../poppunk_assign-runner.py --query some_queries.txt --db example_db --model-dir example_refine --output example_query --overwrite --external-clustering example_external_clusters.csv", shell=True, check=True) + # test updating order is correct sys.stderr.write("Running distance matrix order check (--update-db)\n") subprocess.run(python_cmd + " test-update.py", shell=True, check=True) @@ -126,10 +130,11 @@ sys.exit(1) # beebop test -subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile12.txt --output batch12 --overwrite", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model bgmm --D 2 --ref-db batch12 --overwrite", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch12 --query rfile3.txt --output batch3 --external-clustering batch12_external_clusters.csv --overwrite", shell=True, check=True) -subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch12/batch12_graph.gt --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile12.txt --min-k 13 --k-step 3 --output batch12 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model dbscan --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db batch12 --output batch12 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch12 --query rfile3.txt --output batch3 --external-clustering batch12_external_clusters.csv --save-partial-query-graph --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk_visualise-runner.py --ref-db batch12 --query-db batch3 --output batch123_viz --external-clustering batch12_external_clusters.csv --previous-query-clustering batch3/batch3_external_clusters.csv --cytoscape --rapidnj rapidnj --network-file ./batch3/batch3_graph.gt --use-partial-query-graph ./batch3/batch3_query.subset --recalculate-distances --overwrite", shell=True, check=True) # citations sys.stderr.write("Printing citations\n") diff --git a/test/test-update.py b/test/test-update.py index df8dbedc..e5870ac1 100755 --- a/test/test-update.py +++ b/test/test-update.py @@ -28,12 +28,12 @@ def run_regression(x, y, threshold = 0.99): sys.stderr.write("Distance matrix order failed!\n") sys.exit(1) -def compare_sparse_matrices(d1,d2,r1,r2): +def compare_sparse_matrices(d1,d2,r1,r2,flag): d1_pairs = get_seq_tuples(d1.row,d1.col,r1) d2_pairs = get_seq_tuples(d2.row,d2.col,r2) d1_dists = [] d2_dists = [] - if (len(d1_pairs) != len(d2_pairs)): + if (len(d1_pairs) != len(d2_pairs) and flag == " "): # May not be equal if reciprocal/unique count sys.stderr.write("Distance matrix number of entries differ!\n") print(d1_pairs) print(d2_pairs) @@ -45,7 +45,6 @@ def compare_sparse_matrices(d1,d2,r1,r2): d1_dists.append(dist1) d2_dists.append(dist2) break - run_regression(np.asarray(d1_dists),np.asarray(d2_dists)) def get_seq_tuples(rows,cols,names): @@ -65,62 +64,62 @@ def old_get_seq_tuples(rows,cols): for lineage_option_string in [" "," --count-unique-distances ", " --reciprocal-only "," --count-unique-distances --reciprocal-only "]: - if lineage_option_string != " ": - print("\n*** Now running tests with lineage option" + lineage_option_string + "***\n") - - # Check distances after one query - - # Check that order is the same after doing 1 + 2 with --update-db, as doing all of 1 + 2 together - subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile12.txt --output batch12 --overwrite", shell=True, check=True) - subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2 --overwrite" + lineage_option_string,shell=True, check=True) - print(" ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2 --overwrite\n\n") - subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile1.txt --output batch1 --overwrite", shell=True, check=True) - subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2 --overwrite" + lineage_option_string, shell=True, check=True) - print("../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2 --overwrite\n\n") - subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite --max-a-dist 1", shell=True, check=True) - print(" ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite --max-a-dist 1\n\n") - - # Load updated distance order - with open("batch2/batch2.dists.pkl", 'rb') as pickle_file: - rlist2, qlist, self = pickle.load(pickle_file) - -# Check sparse distances after one query - with open("batch12/batch12.dists.pkl", 'rb') as pickle_file: - rlist1, qlist1, self = pickle.load(pickle_file) - S1 = scipy.sparse.load_npz("batch12/batch12_rank_2_fit.npz") - S2 = scipy.sparse.load_npz("batch2/batch2_rank_2_fit.npz") - sys.stderr.write("Comparing sparse matrices at rank 2 after first query calculated with options " + lineage_option_string + "\n") - compare_sparse_matrices(S1,S2,rlist1,rlist2) - - # Check rank 1 - S3 = scipy.sparse.load_npz("batch12/batch12_rank_1_fit.npz") - S4 = scipy.sparse.load_npz("batch2/batch2_rank_1_fit.npz") - sys.stderr.write("Comparing sparse matrices at rank 1 after first query calculated with options " + lineage_option_string + "\n") - compare_sparse_matrices(S3,S4,rlist1,rlist2) - - # Check distances after second query - - # Check that order is the same after doing 1 + 2 + 3 with --update-db, as doing all of 1 + 2 + 3 together - subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile123.txt --output batch123 --overwrite", shell=True, check=True) - subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch123 --ranks 1,2 --overwrite" + lineage_option_string, shell=True, check=True) - print("../poppunk-runner.py --fit-model lineage --ref-db batch123 --ranks 1,2 --overwrite\n\n") - subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch2 --query rfile3.txt --output batch3 --update-db --overwrite", shell=True, check=True) - print(python_cmd + " ../poppunk_assign-runner.py --db batch2 --query rfile3.txt --output batch3 --update-db --overwrite\n\n") - - # Load updated distances order - with open("batch3/batch3.dists.pkl", 'rb') as pickle_file: - rlist4, qlist, self = pickle.load(pickle_file) - - # Check sparse distances after second query - with open("batch123/batch123.dists.pkl", 'rb') as pickle_file: - rlist3, qlist, self = pickle.load(pickle_file) - S5 = scipy.sparse.load_npz("batch123/batch123_rank_2_fit.npz") - S6 = scipy.sparse.load_npz("batch3/batch3_rank_2_fit.npz") - sys.stderr.write("Comparing sparse matrices at rank 2 after second query calculated with options " + lineage_option_string + "\n") - compare_sparse_matrices(S5,S6,rlist3,rlist4) - - # Check rank 1 - S7 = scipy.sparse.load_npz("batch123/batch123_rank_1_fit.npz") - S8 = scipy.sparse.load_npz("batch3/batch3_rank_1_fit.npz") - sys.stderr.write("Comparing sparse matrices at rank 1 after second query calculated with options " + lineage_option_string + "\n") - compare_sparse_matrices(S7,S8,rlist3,rlist4) + if lineage_option_string != " ": + print("\n*** Now running tests with lineage option" + lineage_option_string + "***\n") + + # Check distances after one query + + # Check that order is the same after doing 1 + 2 with --update-db, as doing all of 1 + 2 together + subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile12.txt --output batch12 --overwrite", shell=True, check=True) + subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2 --overwrite" + lineage_option_string,shell=True, check=True) + print(" ../poppunk-runner.py --fit-model lineage --ref-db batch12 --ranks 1,2 --overwrite\n\n") + subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile1.txt --output batch1 --overwrite", shell=True, check=True) + subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2 --overwrite" + lineage_option_string, shell=True, check=True) + print("../poppunk-runner.py --fit-model lineage --ref-db batch1 --ranks 1,2 --overwrite\n\n") + subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite --max-a-dist 1", shell=True, check=True) + print(" ../poppunk_assign-runner.py --db batch1 --query rfile2.txt --output batch2 --update-db --overwrite --max-a-dist 1\n\n") + + # Load updated distance order + with open("batch2/batch2.dists.pkl", 'rb') as pickle_file: + rlist2, qlist, self = pickle.load(pickle_file) + + # Check sparse distances after one query + with open("batch12/batch12.dists.pkl", 'rb') as pickle_file: + rlist1, qlist1, self = pickle.load(pickle_file) + S1 = scipy.sparse.load_npz("batch12/batch12_rank_2_fit.npz") + S2 = scipy.sparse.load_npz("batch2/batch2_rank_2_fit.npz") + sys.stderr.write("Comparing sparse matrices at rank 2 after first query calculated with options " + lineage_option_string + "\n") + compare_sparse_matrices(S1,S2,rlist1,rlist2,lineage_option_string) + + # Check rank 1 + S3 = scipy.sparse.load_npz("batch12/batch12_rank_1_fit.npz") + S4 = scipy.sparse.load_npz("batch2/batch2_rank_1_fit.npz") + sys.stderr.write("Comparing sparse matrices at rank 1 after first query calculated with options " + lineage_option_string + "\n") + compare_sparse_matrices(S3,S4,rlist1,rlist2,lineage_option_string) + + # Check distances after second query + + # Check that order is the same after doing 1 + 2 + 3 with --update-db, as doing all of 1 + 2 + 3 together + subprocess.run(python_cmd + " ../poppunk-runner.py --create-db --r-files rfile123.txt --output batch123 --overwrite", shell=True, check=True) + subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model lineage --ref-db batch123 --ranks 1,2 --overwrite" + lineage_option_string, shell=True, check=True) + print("../poppunk-runner.py --fit-model lineage --ref-db batch123 --ranks 1,2 --overwrite\n\n") + subprocess.run(python_cmd + " ../poppunk_assign-runner.py --db batch2 --query rfile3.txt --output batch3 --update-db --overwrite", shell=True, check=True) + print(python_cmd + " ../poppunk_assign-runner.py --db batch2 --query rfile3.txt --output batch3 --update-db --overwrite\n\n") + + # Load updated distances order + with open("batch3/batch3.dists.pkl", 'rb') as pickle_file: + rlist4, qlist, self = pickle.load(pickle_file) + + # Check sparse distances after second query + with open("batch123/batch123.dists.pkl", 'rb') as pickle_file: + rlist3, qlist, self = pickle.load(pickle_file) + S5 = scipy.sparse.load_npz("batch123/batch123_rank_2_fit.npz") + S6 = scipy.sparse.load_npz("batch3/batch3_rank_2_fit.npz") + sys.stderr.write("Comparing sparse matrices at rank 2 after second query calculated with options " + lineage_option_string + "\n") + compare_sparse_matrices(S5,S6,rlist3,rlist4,lineage_option_string) + + # Check rank 1 + S7 = scipy.sparse.load_npz("batch123/batch123_rank_1_fit.npz") + S8 = scipy.sparse.load_npz("batch3/batch3_rank_1_fit.npz") + sys.stderr.write("Comparing sparse matrices at rank 1 after second query calculated with options " + lineage_option_string + "\n") + compare_sparse_matrices(S7,S8,rlist3,rlist4,lineage_option_string)