diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 7d168e7dc..87d824025 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -36,6 +36,9 @@ FuzzyDuplicates = gpu_only_import_from( "nemo_curator.modules.fuzzy_dedup", "FuzzyDuplicates" ) +BucketsToEdges = gpu_only_import_from( + "nemo_curator.modules.fuzzy_dedup", "BucketsToEdges" +) # Pytorch related imports must come after all imports that require cugraph, # because of context cleanup issues b/w pytorch and cugraph # See this issue: https://github.com/rapidsai/cugraph/issues/2718 @@ -55,6 +58,7 @@ "Filter", "FuzzyDuplicatesConfig", "FuzzyDuplicates", + "BucketsToEdges", "LSH", "MinHash", "Modify", diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index eec5b42ed..1ef8a0fd5 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -84,9 +84,15 @@ def __post_init__(self): raise ValueError( "Finding fuzzy duplicates requires a cache directory accessible via all workers to store intermediates" ) - if not self.false_positive_check: - raise NotImplementedError( - "Skipping false positive checks is not supported at the moment" + if self.false_positive_check: + warnings.warn( + "Identifying false positives during the Minhash deduplication is computationally expensive." + " For improved performance consider setting this to False" + ) + if not self.false_positive_check and self.char_ngrams < 20: + warnings.warn( + "Using a small char_ngrams value might lead to a large number (~5%) of false positives during deduplication." + " Using a value of at least 20 for char_ngrams is recommended." ) if self.num_anchors <= 0: raise ValueError("Number of anchors must be greater than 0") diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 6694dd420..556911a96 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -19,6 +19,7 @@ import time import warnings from datetime import datetime +from itertools import pairwise from typing import List, Tuple, Union import cudf @@ -27,6 +28,8 @@ import cupy as cp import dask_cudf import numpy as np +import pandas as pd +import pyarrow as pa from cugraph import MultiGraph from dask import dataframe as dd from dask.dataframe.shuffle import shuffle as dd_shuffle @@ -431,32 +434,44 @@ def __init__( id_fields=[self.config.id_field], profile_dir=self.config.profile_dir, ) - self.map_buckets = _MapBuckets( - id_fields=[self.config.id_field], - text_field=self.config.text_field, - logger=self._logger, - num_anchors=self.config.num_anchors, - ) - self.jaccard_shuffle = _Shuffle( - id_fields=[self.config.id_field], - text_field=self.config.text_field, - logger=self._logger, - profile_dir=self.config.profile_dir, - ) - self.jaccard_compute = JaccardSimilarity( - id_field=self.config.id_field, - text_field=self.config.text_field, - ngram_width=self.config.char_ngrams, - anchor_id_fields=[ - f"anchor_{i}_{self.config.id_field}" - for i in range(self.config.num_anchors) - ], + + if self.config.false_positive_check: + self.map_buckets = _MapBuckets( + id_fields=[self.config.id_field], + text_field=self.config.text_field, + logger=self._logger, + num_anchors=self.config.num_anchors, + ) + self.jaccard_shuffle = _Shuffle( + id_fields=[self.config.id_field], + text_field=self.config.text_field, + logger=self._logger, + profile_dir=self.config.profile_dir, + ) + self.jaccard_compute = JaccardSimilarity( + id_field=self.config.id_field, + text_field=self.config.text_field, + ngram_width=self.config.char_ngrams, + anchor_id_fields=[ + f"anchor_{i}_{self.config.id_field}" + for i in range(self.config.num_anchors) + ], + ) + else: + self.buckets_to_edges = BucketsToEdges( + cache_dir=self.config.cache_dir, + id_fields=self.config.id_field, + logger=self._logger, + ) + + jaccard_pairs_fname = ( + "jaccard_similarity_results.parquet" + if self.config.false_positive_check + else "_edges.parquet" ) self.connected_components = ConnectedComponents( cache_dir=self.config.cache_dir, - jaccard_pairs_path=os.path.join( - self.config.cache_dir, "jaccard_similarity_results.parquet" - ), + jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname), id_column=self.config.id_field, convert_str_ids=False, jaccard_threshold=self.config.jaccard_threshold, @@ -475,62 +490,200 @@ def __call__(self, dataset: DocumentDataset): they belong to. Documents in the same group are near duplicates. """ # Minhash + LSH - print("Stage1: Starting Minhash + LSH computation") + stage_num = 1 + print(f"Stage{stage_num}: Starting Minhash + LSH computation") minhashLSH = Sequential([self.minhash, self.lsh]) buckets_df = minhashLSH(dataset) - print("Stage1: Minhash + LSH complete!") + print(f"Stage{stage_num}: Minhash + LSH complete!") + stage_num += 1 + + if self.config.false_positive_check: + # Map buckets to lower cardinality distribution + print(f"Stage{stage_num} (False Positive Check): Starting Map_Buckets") + ddf_mapped_buckets_w_anchors = self.map_buckets.map_buckets_with_anchors( + documents_df=dataset.df, buckets_df=buckets_df.df + ) + mapped_buckets_w_anchors_path = os.path.join( + self.config.cache_dir, "anchor_docs_with_bk.parquet" + ) + ddf_mapped_buckets_w_anchors.to_parquet( + mapped_buckets_w_anchors_path, write_index=False + ) + print(f"Stage{stage_num} (False Postive Check): Map_Buckets Complete!") + stage_num += 1 - # Map buckets to lower cardinality distribution - print("Stage2 (False Postive Check): Starting Map_Buckets") - ddf_mapped_buckets_w_anchors = self.map_buckets.map_buckets_with_anchors( - documents_df=dataset.df, buckets_df=buckets_df.df - ) - mapped_buckets_w_anchors_path = os.path.join( - self.config.cache_dir, "anchor_docs_with_bk.parquet" - ) - ddf_mapped_buckets_w_anchors.to_parquet( - mapped_buckets_w_anchors_path, write_index=False - ) - print("Stage2 (False Postive Check): Map_Buckets Complete!") + # Shuffle documents based on mapped buckets + print(f"Stage{stage_num} (False Postive Check): Shuffle docs") + shuffled_docs_path = os.path.join( + self.config.cache_dir, "shuffled_docs.parquet" + ) + self.jaccard_shuffle.shuffle_docs_on_buckets( + documents_df=dataset.df, + bucket_w_anchors_path=mapped_buckets_w_anchors_path, + output_shuffled_docs_path=shuffled_docs_path, + bucket_mapping_df_blocksize=256, + parts_per_worker=1, + bucket_parts_per_worker=8, + ) + print(f"Stage{stage_num} (False Postive Check): Shuffle docs complete!") + stage_num += 1 - # Shuffle documents based on mapped buckets - print("Stage3 (False Postive Check): Shuffle docs") - shuffled_docs_path = os.path.join( - self.config.cache_dir, "shuffled_docs.parquet" - ) - self.jaccard_shuffle.shuffle_docs_on_buckets( - documents_df=dataset.df, - bucket_w_anchors_path=mapped_buckets_w_anchors_path, - output_shuffled_docs_path=shuffled_docs_path, - bucket_mapping_df_blocksize=256, - parts_per_worker=1, - bucket_parts_per_worker=8, - ) - print("Stage3 (False Postive Check): Shuffle docs complete!") + # jaccard comparision within buckets + print( + f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets" + ) + jaccard_pairs_path = os.path.join( + self.config.cache_dir, "jaccard_similarity_results.parquet" + ) + jaccard_pairs_df = self.jaccard_compute.jaccard_compute( + shuffled_docs_path=shuffled_docs_path + ) + jaccard_pairs_df.to_parquet( + jaccard_pairs_path, + write_index=False, + write_metadata_file=False, + ) + print( + f"Stage{stage_num} (False Postive Check): Jaccard Similarity in Buckets Complete!" + ) + stage_num += 1 - # jaccard comparision within buckets - print("Stage4 (False Postive Check): Jaccard Similarity in Buckets") - jaccard_pairs_path = os.path.join( - self.config.cache_dir, "jaccard_similarity_results.parquet" - ) - jaccard_pairs_df = self.jaccard_compute.jaccard_compute( - shuffled_docs_path=shuffled_docs_path - ) - jaccard_pairs_df.to_parquet( - jaccard_pairs_path, - write_index=False, - write_metadata_file=False, - ) - print("Stage4 (False Postive Check): Jaccard Similarity in Buckets Complete!") + else: + # Map buckets to lower cardinality distribution + print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist") + self.buckets_to_edges(buckets_df) + print(f"Stage{stage_num}: Starting LSH Buckets to Graph edgelist Complete!") + stage_num += 1 # Connected components across buckets - print("Stage5: Connected Components across buckets") + print(f"Stage{stage_num}: Connected Components across buckets") cc_path = os.path.join(self.config.cache_dir, "connected_components.parquet") self.connected_components.cc_workflow(cc_path) - print("Stage5: Connected Components across buckets complete!") + print(f"Stage{stage_num}: Connected Components across buckets complete!") + stage_num += 1 + return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False)) +class BucketsToEdges: + """ + Maps buckets generated from LSH into an edgelist that + can be processed further by Connected Components to find duplicate + documents + """ + + def __init__( + self, + cache_dir: str = None, + id_fields: Union[list, str] = "id", + str_id_name: str = "id", + bucket_field: str = "_bucket_id", + logger: Union[logging.LoggerAdapter, str] = "./", + ): + """ + Parameters + ---------- + cache_dir: str or None + If specified, will compute & write the edgelist to a file + id_fields: list or str + id fields of documents in buckets_df + str_id_name: str + Ignored if there is a single id field. Multiple id fields + will be combined into a single id field with the given name. + bucket_field: str + Column denoting bucket ID + num_buckets: Number of bands/buckets to create from the minhash signature. + Hashes_per_signature = num_hashes / num_buckets + """ + self.cache_dir = cache_dir + self.id_fields = [id_fields] if isinstance(id_fields, str) else id_fields + self.str_id_name = str_id_name if len(self.id_fields) > 1 else self.id_fields[0] + self.output_ids = [f"{self.str_id_name}_x", f"{self.str_id_name}_y"] + self.bucket_field = bucket_field + if isinstance(logger, str): + self._logger = create_logger( + rank=0, + log_file=os.path.join(logger, "Buckets_to_Edges.log"), + name="Buckets_to_Edges", + ) + else: + self._logger = logger + + @staticmethod + def _combine_multiple_ids( + input_df: cudf.DataFrame, input_id_fields: list, output_id_field: str + ) -> cudf.DataFrame: + if output_id_field in input_df.columns: + raise ValueError( + f"Input df already contains column named: {output_id_field}" + ) + + output_df = input_df.copy()[input_df.columns.difference(input_id_fields)] + + output_df[output_id_field] = input_df[input_id_fields[0]].astype(str) + for input_field in input_id_fields[1:]: + output_df[output_id_field] = output_df[output_id_field] = ( + input_df[input_id_fields[0]].astype(str) + + "-" + + input_df[input_field].astype(str) + ) + + return output_df + + def buckets_to_edges( + self, + buckets_df: cudf.DataFrame, + ) -> cudf.DataFrame: + + grouped_buckets = ( + buckets_df.groupby(self.bucket_field)[self.str_id_name] + .agg(list) + .list.sort_values() + ) + bucket_docs = grouped_buckets.to_arrow().to_pylist() + edges = [] + # Create pairs of all documents within a bucket since they are near duplicates + # Effectively create a edge list of all near duplicate documents + for bucket_doc in bucket_docs: + edges.extend(pairwise(bucket_doc)) + edges = pd.DataFrame(edges, columns=self.output_ids) + edges = pa.Table.from_pandas(edges) + result_df = cudf.DataFrame.from_arrow(edges) + del edges + result_df = result_df.drop_duplicates(self.output_ids).reset_index(drop=True) + result_df["jaccard"] = np.float32(1.0) + return result_df + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + buckets_df = dataset.df + if len(self.id_fields) > 1: + buckets_df = buckets_df.map_partitions( + BucketsToEdges._combine_multiple_ids, + input_id_fields=self.id_fields, + output_id_field=self.str_id_name, + ) + + meta = [(output_id, str) for output_id in self.output_ids] + meta.append(("jaccard", np.float32)) + edges_df = buckets_df.map_partitions(self.buckets_to_edges, meta=meta) + + if self.cache_dir is None: + return DocumentDataset(edges_df) + + write_path = os.path.join(self.cache_dir, "_edges.parquet") + if os.path.exists(write_path): + warnings.warn( + f"Output path {write_path} already exists and will be overwritten" + ) + t0 = time.time() + edges_df.to_parquet(write_path, write_index=False, overwrite=True) + self._logger.info(f"Converted buckets to edgelist took {time.time() - t0} s") + + return DocumentDataset( + dask_cudf.read_parquet(write_path, split_row_groups=False) + ) + + class _MapBuckets: """ buckets to a logical partition by using a modified bin packing algorithm. diff --git a/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py b/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py new file mode 100644 index 000000000..5ebdf2771 --- /dev/null +++ b/nemo_curator/scripts/fuzzy_deduplication/buckets_to_edges.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import dask_cudf + +from nemo_curator import BucketsToEdges +from nemo_curator.datasets import DocumentDataset +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import get_client, get_num_workers +from nemo_curator.utils.script_utils import ArgumentHelper + + +def attach_args(parser=None): + description = """Takes the buckets generated from minhashes and converts + them into an edge list for the connected components algorithm. This is done by + assuming all documents in the same bucket are similar. + """ + if not parser: + parser = ArgumentHelper.parse_gpu_dedup_args(description=description) + parser.add_argument( + "--input-bucket-dir", + type=str, + help="The directory containing anchor docs with bk files", + ) + parser.add_argument( + "--input-bucket-field", + type=str, + default="_bucket_id", + help="Name of the column containing the bucket id", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Output dir to write results", + ) + return parser + + +def main(args): + logger = create_logger( + rank=0, + log_file=os.path.join(args.log_dir, "rank_000.log"), + name="buckets_to_cc_log", + ) + + input_bucket_path = args.input_bucket_dir + OUTPUT_PATH = args.output_dir + + client = get_client(**ArgumentHelper.parse_client_args(args)) + logger.info(f"Client Created {client}") + logger.info(f"Num Workers = {get_num_workers(client)}") + logger.info( + "Running buckets -> EdgeList for CC", + ) + + buckets_to_edges = BucketsToEdges( + cache_dir=OUTPUT_PATH, + id_fields=["dataset_id", "doc_id"], + str_id_name=args.input_json_id_field, + bucket_field=args.input_bucket_field, + logger=logger, + ) + st = time.time() + buckets_df = DocumentDataset( + dask_cudf.read_parquet(input_bucket_path, split_row_groups=False) + ) + _ = buckets_to_edges(buckets_df) + et = time.time() + logger.info(f"Bucket to Edges conversion took = {et-st} s") + + +def console_script(): + main(attach_args().parse_args()) + + +if __name__ == "__main__": + main(attach_args().parse_args()) diff --git a/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py index 9d1915603..323fe7e81 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/shuffle_utils.py @@ -66,7 +66,7 @@ def rearange_by_column_direct( return rearrange_by_column( df, col=col, - shuffle="tasks", + shuffle_method="tasks", # Prevent staged shuffling by setting max_branch # to the number of input partitions + 1 max_branch=npartitions + 1, diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index c92e3662b..fb831cb16 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -182,7 +182,7 @@ def test_lsh(self, tmpdir, buckets_per_shuffle): ) buckets = lsh(self.dataset) buckets_df = buckets.df - docs_list = buckets_df.groupby("_bucket_id").id.collect() + docs_list = buckets_df.groupby("_bucket_id").id.agg(list) expected_df = cudf.Series([[1, 2], [2, 3], [4, 5]], name="id") assert_eq(expected_df, docs_list, check_index=False) @@ -257,7 +257,7 @@ def test_fuzzy_dedup( result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] - result_df = result_df.groupby("group").id.collect() + result_df = result_df.groupby("group").id.agg(list) # Sort to maintain uniform ordering result_df = result_df.list.sort_values() @@ -287,7 +287,7 @@ def test_different_fields(self, fuzzy_dedup_data, tmpdir): result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] - result_df = result_df.groupby("group")["col0"].collect() + result_df = result_df.groupby("group")["col0"].agg(list) # Sort to maintain uniform ordering result_df = result_df.list.sort_values() result_df = result_df.sort_values() @@ -339,7 +339,7 @@ def test_non_uniform_indices( result_df = result.df.compute() # Drop non duplicated docs result_df = result_df[result_df.group.duplicated(keep=False)] - result_df = result_df.groupby("group").id.collect() + result_df = result_df.groupby("group").id.agg(list) # Sort to maintain uniform ordering result_df = result_df.list.sort_values() @@ -372,6 +372,47 @@ def test_num_anchors(self, large_fuzzy_dedup_data, num_anchors, tmpdir): ).columns assert all(f"anchor_{i}_id" in anchor_docs_df_cols for i in range(num_anchors)) + @pytest.mark.parametrize("use_64_bit_hash", [False, True]) + @pytest.mark.parametrize( + "num_buckets,duplicate_docs", + # Duplcated docs estimated from true_jaccard values + [ + (10, [[4, -1], [1, 2, 300]]), + (3, [[4, -1], [1, 2, 300]]), + ], + ) + def test_no_fp_check( + self, fuzzy_dedup_data, use_64_bit_hash, num_buckets, duplicate_docs, tmpdir + ): + config = FuzzyDuplicatesConfig( + cache_dir=tmpdir, + id_field="id", + text_field="text", + seed=42, + char_ngrams=5, + num_buckets=num_buckets, + hashes_per_bucket=1, + use_64_bit_hash=use_64_bit_hash, + buckets_per_shuffle=5, + false_positive_check=False, + num_anchors=2, + jaccard_threshold=0.39, + ) + fuzzy_duplicates = FuzzyDuplicates(config=config) + result = fuzzy_duplicates(fuzzy_dedup_data) + result_df = result.df.compute() + # Drop non duplicated docs + result_df = result_df[result_df.group.duplicated(keep=False)] + result_df = result_df.groupby("group").id.agg(list) + # Sort to maintain uniform ordering + + result_df = result_df.list.sort_values() + result_df = result_df.sort_values() + expected_df = cudf.Series(duplicate_docs, name="id") + expected_df = expected_df.list.sort_values() + expected_df = expected_df.sort_values() + assert_eq(expected_df, result_df, check_index=False) + class TestFuzzyDuplicatesConfig: def test_bad_inputs(self, tmpdir): @@ -381,10 +422,19 @@ def test_bad_inputs(self, tmpdir): UserWarning, match="Using a higher number of anchor docs might" ): FuzzyDuplicatesConfig(cache_dir=tmpdir, num_anchors=3) + with pytest.warns( + UserWarning, match="Using a small char_ngrams value might lead" + ): + FuzzyDuplicatesConfig( + cache_dir=tmpdir, char_ngrams=10, false_positive_check=False + ) + with pytest.warns( + UserWarning, + match="Identifying false positives during the Minhash deduplication is computationally expensive", + ): + FuzzyDuplicatesConfig(cache_dir=tmpdir, false_positive_check=True) with pytest.raises(ValueError): FuzzyDuplicatesConfig(cache_dir=tmpdir, jaccard_threshold=1.2) - with pytest.raises(NotImplementedError): - FuzzyDuplicatesConfig(cache_dir=tmpdir, false_positive_check=False) with pytest.raises(ValueError): FuzzyDuplicatesConfig(cache_dir=tmpdir, buckets_per_shuffle=0) @@ -393,8 +443,9 @@ def test_from_yaml(self, tmpdir): "cache_dir": "./", "num_anchors": 2, "jaccard_threshold": 0.8, - "false_positive_check": True, + "false_positive_check": False, "buckets_per_shuffle": 1, + "char_ngrams": 20, } with open(tmpdir / "config.yaml", "w") as f: yaml.dump(yaml_params, f)