Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removal logic for Exact / Fuzzy Dedup #499

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions nemo_curator/_deduplicator.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved under nemo_curator/modules/deduplicator.py? I think the function is something we want users to be able to access.

Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import warnings
from abc import ABC
from typing import Optional

import dask.dataframe as dd

from nemo_curator.datasets.doc_dataset import DocumentDataset


def _perform_removal(
left: dd.DataFrame,
duplicates: dd.DataFrame,
id_field: str,
group_field: str,
) -> dd.DataFrame:
# Create a new column name for temporary ID storage during merge
new_id_field = f"{id_field}_new"

duplicates_to_remove = (
duplicates
# Redistribute data across partitions so that all duplicates are in same partition
.shuffle(on=[group_field], ignore_index=True)
# For each partition, keep only the duplicated rows (excluding first occurrence)
.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop(
columns=group_field
)
# Rename the ID field to avoid conflicts in the upcoming merge
.rename(columns={id_field: new_id_field})[[new_id_field]]
)

merge = left.merge(
right=duplicates_to_remove,
how="left",
broadcast=True, # Broadcast smaller DataFrame to all partitions
left_on=id_field,
right_on=new_id_field,
)

# This effectively removes all rows that were not in duplicates_to_remove
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field])
return removed_result


class Deduplicator(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how I feel about the abstraction. I have been wanting something like this, but I worry this is not as generalizable as I'd want it to be. For example, can semantic dedupe use this? I don't believe it can since the duplicates aren't all grouped like this. Imo, if the deduplication abstraction doesn't work for all our deduplication methods I don't want to have it so we don't confuse our users. We can always refactor out the logic into a base class if we find a general solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call it

Suggested change
class Deduplicator(ABC):
class DuplicateRemover:

or DuplicatesRemover instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible example usage:

remover = DuplicatesRemover(...)

exact_dupes = ExactDuplicates(...).identify_duplicates(...)
deduped_data = remover.remove_duplicates(exact_dupes)

fuzzy_dupes = FuzzyDuplicates(...).identify_duplicates(...)
deduped_data = remover.remove_duplicates(fuzzy_dupes)

# Could it be possible to call both simultaneously?
# deduped_data = remover.remove_duplicates(exact_dupes, fuzzy_dupes)
# deduped_data = remover.remove_duplicates([exact_dupes, fuzzy_dupes])
# or similar...

?

def __init__(
self,
id_field: str,
text_field: str,
grouped_field: str,
cache_dir: Optional[str] = None,
**kwargs,
):
self.id_field = id_field
self.text_field = text_field
self.grouped_field = grouped_field
self.cache_dir = cache_dir

def identify(self, *args, **kwargs):
"""Abstract method to be implemented by concrete deduplicator classes.
Should implement the logic for identifying duplicates in the dataset."""
raise NotImplementedError

def remove(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function could just be a helper function that's defined here that exact and fuzzy dedup import and use instead of the ABC.

self, dataset: DocumentDataset, duplicates: DocumentDataset
) -> DocumentDataset:
"""
Remove duplicate documents from the dataset based on identified duplicate groups.

Parameters
----------
dataset: DocumentDataset
The input datset to remove duplicates from.

duplicates: DocumentDataset
The dataset containing IDs of all documents and the corresponding duplicate group
they belong to. Documents in the same group are considered duplicates.
Only the first document in each group is retained.

Returns
-------
DocumentDataset of all documents with duplicates removed.
"""
if self.cache_dir is None:
msg = "Cache directory should be specified for improved performance for removal step."
warnings.warn(msg)

left = dataset.df
right = duplicates.df

print(f"{left.npartitions=}, {right.npartitions=}")
if left.npartitions < right.npartitions:
msg = (
"The number of partitions in the dataset to remove duplicates from is less than the number of partitions in the duplicates dataset. "
"This may lead to a shuffle join. Please re-read the datasets and call nemo_curator._deduplicat.perform_merge explicitly."
)
raise ValueError(msg)

removed_result = _perform_removal(
left=left,
duplicates=right,
id_field=self.id_field,
group_field=self.grouped_field,
)
return DocumentDataset(removed_result)

def __call__(
self, dataset: DocumentDataset, perform_removal: bool = False
) -> DocumentDataset:
"""
Main entry point for deduplication process.

Parameters
----------
dataset: DocumentDataset
The input datset to remove duplicates from.
perform_removal: bool
If True, duplicates are removed from the dataset.
If False, only the duplicates are identified.

Returns
-------
DocumentDataset of all duplicates (id field, group field) if perform_removal is False.
DocumentDataset of all documents with duplicates removed if perform_removal is True.
"""
# First identify duplicates
duplicates = self.identify(dataset)
# Then optionally remove them
if perform_removal:
return self.remove(dataset, duplicates)
return duplicates
29 changes: 19 additions & 10 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
import warnings
from contextlib import nullcontext
from datetime import datetime
from hashlib import md5
from typing import Optional, Union

Expand All @@ -27,13 +26,14 @@
from dask import dataframe as dd

from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator._deduplicator import Deduplicator
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.gpu_utils import is_cudf_type


class ExactDuplicates:
class ExactDuplicates(Deduplicator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ExactDuplicates(Deduplicator):
class ExactDuplicates:

?

"""Find exact duplicates in a document corpus"""

SUPPORTED_HASHES = {"md5"}
Expand Down Expand Up @@ -64,6 +64,14 @@ def __init__(
raise ValueError(
f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}"
)

super().__init__(
id_field=id_field,
text_field=text_field,
grouped_field="_hashes",
cache_dir=cache_dir,
)

self.hash_method = hash_method
self.id_field = id_field
self.text_field = text_field
Expand Down Expand Up @@ -135,7 +143,7 @@ def hash_documents(
# TODO: Generalize ty using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Find document ID's for exact duplicates in a given DocumentDataset
Parameters
Expand Down Expand Up @@ -166,10 +174,11 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
self._logger.info(
f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}"
)
if is_cudf_type(result):
import dask_cudf

result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False)
else:
result_dataset = dd.read_parquet(write_path)
return DocumentDataset(result_dataset)
backend = "cudf" if is_cudf_type(result) else "pandas"
return DocumentDataset.read_parquet(
write_path,
backend=backend,
blocksize="512MiB",
files_per_partition=None,
split_row_groups=False,
)
2 changes: 0 additions & 2 deletions nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def _run_connected_components(
f"# rows in labels_df = {len(labels_df)}"
)
assert num_nodes == len(labels_df)
# Ensure all docs in the same group are in the same partition
labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
labels_df.to_parquet(output_path, write_index=False, overwrite=True)
Comms.destroy()
self._logger.info(
Expand Down
22 changes: 17 additions & 5 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import time
from typing import Union

import dask_cudf

from nemo_curator._deduplicator import Deduplicator
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand All @@ -35,7 +34,7 @@
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


class FuzzyDuplicates:
class FuzzyDuplicates(Deduplicator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FuzzyDuplicates(Deduplicator):
class FuzzyDuplicates:

?

def __init__(
self,
config: FuzzyDuplicatesConfig,
Expand Down Expand Up @@ -63,6 +62,13 @@ def __init__(
self._logger = logger

self.config = config

super().__init__(
id_field=self.config.id_field,
text_field=self.config.text_field,
grouped_field="group",
)

self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
Expand Down Expand Up @@ -129,7 +135,7 @@ def __init__(
profile_dir=self.config.profile_dir,
)

def __call__(self, dataset: DocumentDataset):
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Parameters
----------
Expand Down Expand Up @@ -243,4 +249,10 @@ def __call__(self, dataset: DocumentDataset):
print(f"Stage {stage_num}: Connected Components across buckets complete!")
stage_num += 1

return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
return DocumentDataset.read_parquet(
cc_path,
backend="cudf",
blocksize="512MiB",
files_per_partition=None,
split_row_groups=False,
)
Loading