diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index d274d0e3..9e517291 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -18,7 +18,6 @@ import time import warnings from contextlib import nullcontext -from datetime import datetime from hashlib import md5 from typing import Optional, Union @@ -135,7 +134,7 @@ def hash_documents( # TODO: Generalize ty using self.hash_method return df.apply(lambda x: md5(x.encode()).hexdigest()) - def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: """ Find document ID's for exact duplicates in a given DocumentDataset Parameters @@ -173,3 +172,85 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: else: result_dataset = dd.read_parquet(write_path) return DocumentDataset(result_dataset) + + def identify_and_remove_old( + self, dataset: DocumentDataset + ) -> Union[DocumentDataset, str]: + t0 = time.time() + + duplicates = self._exact_dup_ids(dataset.df) + exact_docs_to_remove = duplicates.map_partitions( + lambda x: x[x._hashes.duplicated(keep="first")] + )[self.id_field] + exact_deduped_dataset = DocumentDataset( + dataset.df[~dataset.df[self.id_field].isin(exact_docs_to_remove.compute())] + ) + + if self.cache_dir is None: + self._logger.info( + f"Time taken for Partial Exact Dedup Computation = {time.time() - t0}s" + ) + + return exact_deduped_dataset + + write_path = os.path.join(self.cache_dir, "_exact_deduplicated.parquet") + with performance_report_if_with_ts_suffix( + self.profile_dir, + "exact-dedup-profile", + ): + exact_deduped_dataset.to_parquet(write_path) + + self._logger.info( + f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" + ) + + return exact_deduped_dataset + + def identify_and_remove(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Identify and remove exact duplicates from the given DocumentDataset + Parameters + ---------- + dataset: DocumentDataset + The input datset to find exact duplicates + Returns + ------- + DocumentDataset containing ID's and hashes of all duplicate documents + """ + t0 = time.time() + df = dataset.df + df["_hashes"] = df[self.text_field].map_partitions(self.hash_documents) + shuffle_context = ( + config.set({"dataframe.shuffle.method": "tasks"}) + if DASK_P2P_ERROR + else nullcontext() + ) + with shuffle_context: + deduplicated_df = ( + df.shuffle( + on=["_hashes"], + ignore_index=True, + npartitions=max(1, (df.npartitions // 3)), + ) + .map_partitions(lambda x: x[~x["_hashes"].duplicated(keep="first")]) + .drop(columns=["_hashes"]) + ) + + if self.cache_dir is None: + self._logger.info( + f"Time taken for Partial Exact Dedup Computation = {time.time() - t0}s" + ) + + return deduplicated_df + + write_path = os.path.join(self.cache_dir, "_exact_deduplicated.parquet") + with performance_report_if_with_ts_suffix( + self.profile_dir, + "exact-dedup-profile", + ): + deduplicated_df.to_parquet(write_path, write_index=False, overwrite=True) + + self._logger.info( + f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" + ) + return DocumentDataset(deduplicated_df)