-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removal logic for Exact / Fuzzy Dedup #499
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,130 @@ | ||||||
import warnings | ||||||
from abc import ABC | ||||||
from typing import Optional | ||||||
|
||||||
import dask.dataframe as dd | ||||||
|
||||||
from nemo_curator.datasets.doc_dataset import DocumentDataset | ||||||
|
||||||
|
||||||
def _perform_removal( | ||||||
left: dd.DataFrame, | ||||||
duplicates: dd.DataFrame, | ||||||
id_field: str, | ||||||
group_field: str, | ||||||
) -> dd.DataFrame: | ||||||
# Create a new column name for temporary ID storage during merge | ||||||
new_id_field = f"{id_field}_new" | ||||||
|
||||||
duplicates_to_remove = ( | ||||||
duplicates | ||||||
# Redistribute data across partitions so that all duplicates are in same partition | ||||||
.shuffle(on=[group_field], ignore_index=True) | ||||||
# For each partition, keep only the duplicated rows (excluding first occurrence) | ||||||
.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop( | ||||||
columns=group_field | ||||||
) | ||||||
# Rename the ID field to avoid conflicts in the upcoming merge | ||||||
.rename(columns={id_field: new_id_field})[[new_id_field]] | ||||||
) | ||||||
|
||||||
merge = left.merge( | ||||||
right=duplicates_to_remove, | ||||||
how="left", | ||||||
broadcast=True, # Broadcast smaller DataFrame to all partitions | ||||||
left_on=id_field, | ||||||
right_on=new_id_field, | ||||||
) | ||||||
|
||||||
# This effectively removes all rows that were not in duplicates_to_remove | ||||||
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field]) | ||||||
return removed_result | ||||||
|
||||||
|
||||||
class Deduplicator(ABC): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how I feel about the abstraction. I have been wanting something like this, but I worry this is not as generalizable as I'd want it to be. For example, can semantic dedupe use this? I don't believe it can since the duplicates aren't all grouped like this. Imo, if the deduplication abstraction doesn't work for all our deduplication methods I don't want to have it so we don't confuse our users. We can always refactor out the logic into a base class if we find a general solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe call it
Suggested change
or There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible example usage:
? |
||||||
def __init__( | ||||||
self, | ||||||
id_field: str, | ||||||
text_field: str, | ||||||
grouped_field: str, | ||||||
cache_dir: Optional[str] = None, | ||||||
**kwargs, | ||||||
): | ||||||
self.id_field = id_field | ||||||
self.text_field = text_field | ||||||
self.grouped_field = grouped_field | ||||||
self.cache_dir = cache_dir | ||||||
|
||||||
def identify(self, *args, **kwargs): | ||||||
"""Abstract method to be implemented by concrete deduplicator classes. | ||||||
Should implement the logic for identifying duplicates in the dataset.""" | ||||||
raise NotImplementedError | ||||||
|
||||||
def remove( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function could just be a helper function that's defined here that exact and fuzzy dedup import and use instead of the ABC. |
||||||
self, dataset: DocumentDataset, duplicates: DocumentDataset | ||||||
) -> DocumentDataset: | ||||||
""" | ||||||
Remove duplicate documents from the dataset based on identified duplicate groups. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
dataset: DocumentDataset | ||||||
The input datset to remove duplicates from. | ||||||
|
||||||
duplicates: DocumentDataset | ||||||
The dataset containing IDs of all documents and the corresponding duplicate group | ||||||
they belong to. Documents in the same group are considered duplicates. | ||||||
Only the first document in each group is retained. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
DocumentDataset of all documents with duplicates removed. | ||||||
""" | ||||||
if self.cache_dir is None: | ||||||
msg = "Cache directory should be specified for improved performance for removal step." | ||||||
warnings.warn(msg) | ||||||
|
||||||
left = dataset.df | ||||||
right = duplicates.df | ||||||
|
||||||
print(f"{left.npartitions=}, {right.npartitions=}") | ||||||
if left.npartitions < right.npartitions: | ||||||
msg = ( | ||||||
"The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. " | ||||||
"This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly." | ||||||
) | ||||||
raise ValueError(msg) | ||||||
|
||||||
removed_result = _perform_removal( | ||||||
left=left, | ||||||
duplicates=right, | ||||||
id_field=self.id_field, | ||||||
group_field=self.grouped_field, | ||||||
) | ||||||
return DocumentDataset(removed_result) | ||||||
|
||||||
def __call__( | ||||||
self, dataset: DocumentDataset, perform_removal: bool = False | ||||||
) -> DocumentDataset: | ||||||
""" | ||||||
Main entry point for deduplication process. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
dataset: DocumentDataset | ||||||
The input datset to remove duplicates from. | ||||||
perform_removal: bool | ||||||
If True, duplicates are removed from the dataset. | ||||||
If False, only the duplicates are identified. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
DocumentDataset of all duplicates (id field, group field) if perform_removal is False. | ||||||
DocumentDataset of all documents with duplicates removed if perform_removal is True. | ||||||
""" | ||||||
# First identify duplicates | ||||||
duplicates = self.identify(dataset) | ||||||
# Then optionally remove them | ||||||
if perform_removal: | ||||||
return self.remove(dataset, duplicates) | ||||||
return duplicates |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,7 +18,6 @@ | |||||
import time | ||||||
import warnings | ||||||
from contextlib import nullcontext | ||||||
from datetime import datetime | ||||||
from hashlib import md5 | ||||||
from typing import Optional, Union | ||||||
|
||||||
|
@@ -27,13 +26,14 @@ | |||||
from dask import dataframe as dd | ||||||
|
||||||
from nemo_curator._compat import DASK_P2P_ERROR | ||||||
from nemo_curator._deduplicator import Deduplicator | ||||||
from nemo_curator.datasets import DocumentDataset | ||||||
from nemo_curator.log import create_logger | ||||||
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix | ||||||
from nemo_curator.utils.gpu_utils import is_cudf_type | ||||||
|
||||||
|
||||||
class ExactDuplicates: | ||||||
class ExactDuplicates(Deduplicator): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
"""Find exact duplicates in a document corpus""" | ||||||
|
||||||
SUPPORTED_HASHES = {"md5"} | ||||||
|
@@ -64,6 +64,14 @@ def __init__( | |||||
raise ValueError( | ||||||
f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}" | ||||||
) | ||||||
|
||||||
super().__init__( | ||||||
id_field=id_field, | ||||||
text_field=text_field, | ||||||
grouped_field="_hashes", | ||||||
cache_dir=cache_dir, | ||||||
) | ||||||
|
||||||
self.hash_method = hash_method | ||||||
self.id_field = id_field | ||||||
self.text_field = text_field | ||||||
|
@@ -135,7 +143,7 @@ def hash_documents( | |||||
# TODO: Generalize ty using self.hash_method | ||||||
return df.apply(lambda x: md5(x.encode()).hexdigest()) | ||||||
|
||||||
def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: | ||||||
def identify(self, dataset: DocumentDataset) -> DocumentDataset: | ||||||
""" | ||||||
Find document ID's for exact duplicates in a given DocumentDataset | ||||||
Parameters | ||||||
|
@@ -166,10 +174,11 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]: | |||||
self._logger.info( | ||||||
f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}" | ||||||
) | ||||||
if is_cudf_type(result): | ||||||
import dask_cudf | ||||||
|
||||||
result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False) | ||||||
else: | ||||||
result_dataset = dd.read_parquet(write_path) | ||||||
return DocumentDataset(result_dataset) | ||||||
backend = "cudf" if is_cudf_type(result) else "pandas" | ||||||
return DocumentDataset.read_parquet( | ||||||
write_path, | ||||||
backend=backend, | ||||||
blocksize="512MiB", | ||||||
files_per_partition=None, | ||||||
split_row_groups=False, | ||||||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,8 +19,7 @@ | |||||
import time | ||||||
from typing import Union | ||||||
|
||||||
import dask_cudf | ||||||
|
||||||
from nemo_curator._deduplicator import Deduplicator | ||||||
from nemo_curator.datasets import DocumentDataset | ||||||
from nemo_curator.log import create_logger | ||||||
from nemo_curator.modules.config import FuzzyDuplicatesConfig | ||||||
|
@@ -35,7 +34,7 @@ | |||||
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix | ||||||
|
||||||
|
||||||
class FuzzyDuplicates: | ||||||
class FuzzyDuplicates(Deduplicator): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
def __init__( | ||||||
self, | ||||||
config: FuzzyDuplicatesConfig, | ||||||
|
@@ -63,6 +62,13 @@ def __init__( | |||||
self._logger = logger | ||||||
|
||||||
self.config = config | ||||||
|
||||||
super().__init__( | ||||||
id_field=self.config.id_field, | ||||||
text_field=self.config.text_field, | ||||||
grouped_field="group", | ||||||
) | ||||||
|
||||||
self.minhash = MinHash( | ||||||
seed=self.config.seed, | ||||||
num_hashes=self.config.num_hashes, | ||||||
|
@@ -129,7 +135,7 @@ def __init__( | |||||
profile_dir=self.config.profile_dir, | ||||||
) | ||||||
|
||||||
def __call__(self, dataset: DocumentDataset): | ||||||
def identify(self, dataset: DocumentDataset) -> DocumentDataset: | ||||||
""" | ||||||
Parameters | ||||||
---------- | ||||||
|
@@ -243,4 +249,10 @@ def __call__(self, dataset: DocumentDataset): | |||||
print(f"Stage {stage_num}: Connected Components across buckets complete!") | ||||||
stage_num += 1 | ||||||
|
||||||
return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False)) | ||||||
return DocumentDataset.read_parquet( | ||||||
cc_path, | ||||||
backend="cudf", | ||||||
blocksize="512MiB", | ||||||
files_per_partition=None, | ||||||
split_row_groups=False, | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be moved under
nemo_curator/modules/deduplicator.py
? I think the function is something we want users to be able to access.