Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removal logic for fuzzy / exact (no class abstraction) #509

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
89b9005
fc
praateekmahajan Jan 28, 2025
37f6bee
add shuffle/ tests
praateekmahajan Jan 30, 2025
69c8955
more test for class
praateekmahajan Jan 30, 2025
de25476
pre-commit
praateekmahajan Jan 30, 2025
a698bf0
remove class abstractions
praateekmahajan Jan 31, 2025
a2e0c42
remove unused import
praateekmahajan Jan 31, 2025
845cae3
add __call__ methods back
praateekmahajan Jan 31, 2025
2a1da6b
change from modules / update docs
praateekmahajan Feb 3, 2025
48bef03
add tests
praateekmahajan Feb 4, 2025
958161d
update blocksize to 1024 in exact
praateekmahajan Feb 4, 2025
7275609
pr suggestions
praateekmahajan Feb 5, 2025
cba7fcd
warning
praateekmahajan Feb 5, 2025
bcb7cea
Update docs/user-guide/gpudeduplication.rst
praateekmahajan Feb 6, 2025
c929927
Update docs/user-guide/gpudeduplication.rst
praateekmahajan Feb 6, 2025
0afd1a1
Update docs/user-guide/gpudeduplication.rst
praateekmahajan Feb 6, 2025
6f1e4d9
Update examples/exact_deduplication.py
praateekmahajan Feb 6, 2025
1347e37
Update examples/exact_deduplication.py
praateekmahajan Feb 6, 2025
2e3c908
Update examples/fuzzy_deduplication.py
praateekmahajan Feb 6, 2025
bc20a5d
Update examples/fuzzy_deduplication.py
praateekmahajan Feb 6, 2025
6e26edb
Update examples/fuzzy_deduplication.py
praateekmahajan Feb 6, 2025
8ba196a
Update nemo_curator/modules/config.py
praateekmahajan Feb 6, 2025
8936ac9
Update nemo_curator/modules/config.py
praateekmahajan Feb 6, 2025
e41c5fa
Update nemo_curator/modules/exact_dedup.py
praateekmahajan Feb 6, 2025
9c7f4bf
add file back
praateekmahajan Feb 6, 2025
fe6f018
merge
praateekmahajan Feb 6, 2025
7f0da3e
pre-commit
praateekmahajan Feb 6, 2025
b438c80
forgot to rename back to identify_duplicates after merge
praateekmahajan Feb 6, 2025
f8040b5
renmaed func in call
praateekmahajan Feb 6, 2025
82f0c6c
split code / read fpp=1
praateekmahajan Feb 7, 2025
bf5498f
Update docs/user-guide/gpudeduplication.rst
praateekmahajan Feb 7, 2025
f172c72
Update nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
praateekmahajan Feb 7, 2025
2beca67
Update nemo_curator/modules/exact_dedup.py
praateekmahajan Feb 7, 2025
f8d89da
Merge branch 'main' into praateek/removal-code-no-abstraction
praateekmahajan Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions nemo_curator/modules/exact_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import time
import warnings
from contextlib import nullcontext
from datetime import datetime
from hashlib import md5
from typing import Optional, Union

Expand All @@ -29,6 +28,7 @@
from nemo_curator._compat import DASK_P2P_ERROR
from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.removal import remove_duplicates
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix
from nemo_curator.utils.gpu_utils import is_cudf_type

Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
raise ValueError(
f"{hash_method} not in supported hash_methods. Choose a hash_method from {self.SUPPORTED_HASHES}"
)

self.hash_method = hash_method
self.id_field = id_field
self.text_field = text_field
Expand Down Expand Up @@ -135,7 +136,7 @@ def hash_documents(
# TODO: Generalize ty using self.hash_method
return df.apply(lambda x: md5(x.encode()).hexdigest())

def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
def _identify(self, dataset: DocumentDataset) -> DocumentDataset:

Nit, but maybe call them _identify and _remove if they are not intended to be accessed by the user directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep them exposed especially since remove won't work at scales where size of duplicate >> host memory, in which case the user will need to break down identify and remove

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense to me. What about calling it identify_duplicates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, initially I thought it's slightly verbose, but another argument in favor of identify_duplicates would be that in future we might want to expose identify_documents_to_keep in which the distinction might be necessary

cc @ayushdg / @ryantwolf / @VibhuJawa

"""
Find document ID's for exact duplicates in a given DocumentDataset
Parameters
Expand Down Expand Up @@ -166,10 +167,38 @@ def __call__(self, dataset: DocumentDataset) -> Union[DocumentDataset, str]:
self._logger.info(
f"Time taken for Exact Dedup Computation = {time.time() - t0}s and output written at {write_path}"
)
if is_cudf_type(result):
import dask_cudf
backend = "cudf" if is_cudf_type(result) else "pandas"
return DocumentDataset.read_parquet(
ayushdg marked this conversation as resolved.
Show resolved Hide resolved
write_path,
backend=backend,
blocksize="512MiB",
files_per_partition=None,
split_row_groups=False,
)

result_dataset = dask_cudf.read_parquet(write_path, split_row_groups=False)
else:
result_dataset = dd.read_parquet(write_path)
return DocumentDataset(result_dataset)
def remove(
self, dataset: DocumentDataset, duplicates_to_remove: DocumentDataset
) -> DocumentDataset:
"""
Remove exact duplicates from a given DocumentDataset
Parameters
----------
dataset: DocumentDataset
The input datset to remove exact duplicates
Returns
-------
DocumentDataset containing only non-duplicate documents
"""
result = remove_duplicates(
left=dataset.df,
duplicates=duplicates_to_remove.df,
id_field=self.id_field,
group_field="_hashes",
)
return DocumentDataset(result)

def __call__(self, dataset: DocumentDataset, perform_removal : bool = False) -> DocumentDataset:
duplicates = self.identify(dataset)
if perform_removal:
return self.remove(dataset, duplicates)
return duplicates
praateekmahajan marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 0 additions & 2 deletions nemo_curator/modules/fuzzy_dedup/connectedcomponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def _run_connected_components(
f"# rows in labels_df = {len(labels_df)}"
)
assert num_nodes == len(labels_df)
# Ensure all docs in the same group are in the same partition
labels_df = labels_df.shuffle(on=["group"], ignore_index=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayushdg we're doing this here

labels_df.to_parquet(output_path, write_index=False, overwrite=True)
Comms.destroy()
self._logger.info(
Expand Down
41 changes: 37 additions & 4 deletions nemo_curator/modules/fuzzy_dedup/fuzzyduplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import time
from typing import Union

import dask_cudf

from nemo_curator.datasets import DocumentDataset
from nemo_curator.log import create_logger
from nemo_curator.modules.config import FuzzyDuplicatesConfig
Expand All @@ -32,6 +30,7 @@
from nemo_curator.modules.fuzzy_dedup.lsh import LSH
from nemo_curator.modules.fuzzy_dedup.minhash import MinHash
from nemo_curator.modules.meta import Sequential
from nemo_curator.modules.removal import remove_duplicates
from nemo_curator.utils.distributed_utils import performance_report_if_with_ts_suffix


Expand Down Expand Up @@ -63,6 +62,7 @@ def __init__(
self._logger = logger

self.config = config

self.minhash = MinHash(
seed=self.config.seed,
num_hashes=self.config.num_hashes,
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
profile_dir=self.config.profile_dir,
)

def __call__(self, dataset: DocumentDataset):
def identify(self, dataset: DocumentDataset) -> DocumentDataset:
"""
Parameters
----------
Expand Down Expand Up @@ -243,4 +243,37 @@ def __call__(self, dataset: DocumentDataset):
print(f"Stage {stage_num}: Connected Components across buckets complete!")
stage_num += 1

return DocumentDataset(dask_cudf.read_parquet(cc_path, split_row_groups=False))
return DocumentDataset.read_parquet(
cc_path,
backend="cudf",
blocksize="1024MiB",
files_per_partition=None,
split_row_groups=False,
)

def remove(
self, dataset: DocumentDataset, duplicates_to_remove: DocumentDataset
) -> DocumentDataset:
"""
Remove exact duplicates from a given DocumentDataset
Parameters
----------
dataset: DocumentDataset
The input datset to remove exact duplicates
Returns
-------
DocumentDataset containing only non-duplicate documents
"""
result = remove_duplicates(
left=dataset.df,
duplicates=duplicates_to_remove.df,
id_field=self.id_field,
group_field="group",
)
return DocumentDataset(result)

def __call__(self, dataset: DocumentDataset, perform_removal : bool = False) -> DocumentDataset:
duplicates = self.identify(dataset)
if perform_removal:
return self.remove(dataset, duplicates)
return duplicates
42 changes: 42 additions & 0 deletions nemo_curator/modules/removal.py
praateekmahajan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import dask.dataframe as dd


def remove_duplicates(
left: dd.DataFrame,
duplicates: dd.DataFrame,
id_field: str,
group_field: str,
) -> dd.DataFrame:
if left.npartitions < duplicates.npartitions:
msg = (
"The number of partitions in `left` is less than the number of partitions in the duplicates dataset. "
"This may lead to a shuffle join. Please re-read left and right with different partition sizes, or repartition left / right."
)
raise ValueError(msg)

# Create a new column name for temporary ID storage during merge
new_id_field = f"{id_field}_new"

duplicates_to_remove = (
duplicates
# Redistribute data across partitions so that all duplicates are in same partition
.shuffle(on=[group_field], ignore_index=True)
# For each partition, keep only the duplicated rows (excluding first occurrence)
.map_partitions(lambda x: x[x[group_field].duplicated(keep="first")]).drop(
columns=group_field
)
# Rename the ID field to avoid conflicts in the upcoming merge
.rename(columns={id_field: new_id_field})[[new_id_field]]
)

merge = left.merge(
right=duplicates_to_remove,
how="left",
broadcast=True, # Broadcast smaller DataFrame to all partitions
left_on=id_field,
right_on=new_id_field,
)

# This effectively removes all rows that were not in duplicates_to_remove
removed_result = merge[merge[new_id_field].isna()].drop(columns=[new_id_field])
return removed_result
4 changes: 1 addition & 3 deletions nemo_curator/modules/semantic_dedup/clusteringmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def __call__(self, embeddings_dataset: DocumentDataset):
id_col=self.id_col,
kmeans_centroids_file=kmeans_centroids_file,
nearest_cent_dir=clustering_output_dir,
output_sorted_clusters_dir=os.path.join(
self.clustering_output_dir, "sorted"
),
vdb=os.path.join(self.clustering_output_dir, "sorted"),
praateekmahajan marked this conversation as resolved.
Show resolved Hide resolved
embedding_col=self.embedding_col,
sim_metric=self.sim_metric,
keep_hard=self.keep_hard,
Expand Down
125 changes: 125 additions & 0 deletions tests/test_deduplicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import random

import pandas as pd
import pytest
from dask import dataframe as dd

from nemo_curator.modules.removal import remove_duplicates


@pytest.fixture()
def ids():
# Dataset has id a0...a9, b0...b9, c0...c9, d0...d9
l = [f"{group}{i}" for group in ["a", "b", "c", "d"] for i in range(10)]
# We shuffle it to make sure all duplicates are not in the same partition
random.shuffle(l)
return l


@pytest.fixture
def sample_data(ids):
df = pd.DataFrame(
{
"id": ids,
"text": [f"text for {_id}" for _id in ids],
}
)
return dd.from_pandas(df, npartitions=4)


@pytest.fixture
def duplicate_data(ids):
# In each group we want to keep only the first occurrence (e.g. a1, b1, c1, d1)
df = pd.DataFrame([{"id": _id, "group": _id[0]} for _id in ids])
# Shuffle to make sure all duplicates are not in the same partition
return dd.from_pandas(df, npartitions=2)


def test_remove_duplicates_basic(
sample_data: dd.DataFrame, duplicate_data: dd.DataFrame
):
# Test basic duplicate removal functionality
result = remove_duplicates(
left=sample_data, duplicates=duplicate_data, id_field="id", group_field="group"
)

result = result.compute()

assert list(result.columns) == ["id", "text"]
assert len(result) == 4
# It's not guaranteed that we'll have a0, b0, c0, d0 in the result
# So we should check the first character
assert set(result["id"].apply(lambda x: x[0]).tolist()) == set(["a", "b", "c", "d"])


def test_remove_duplicates_all_duplicates(ids: list[str], sample_data: dd.DataFrame):
duplicates = dd.from_pandas(
pd.DataFrame({"id": ids, "group": [1] * len(ids)}), npartitions=2
)

result = remove_duplicates(
left=sample_data, duplicates=duplicates, id_field="id", group_field="group"
)

result = result.compute()
assert list(result.columns) == ["id", "text"]
# Should keep only one of the occurrences
assert len(result) == 1


def test_not_remove_duplicates_unique(ids: list[str], sample_data: dd.DataFrame):
# We create a dataset where first 30 ids are in one group
# Next 9 ids are in distinct groups
# And last id is not mentioned in duplicates

duplicates = dd.from_pandas(
pd.DataFrame(
{
"id": ids[:30] + ids[30:39],
"group": ["group0"] * 30 + [f"group{i}" for i in range(1, 10)],
}
),
npartitions=2,
)
result = remove_duplicates(
left=sample_data, duplicates=duplicates, id_field="id", group_field="group"
)

result = result.compute()
assert list(result.columns) == ["id", "text"]
# It has 1 row from the first group of 30
# 9 rows from the 9 distinct groups
# And 1 row from the last group which is not included in set of duplicates
assert len(result) == 1 + 9 + 1
# The last 10 ids should be in the result, there would be one more from the first 30
assert set(ids[30:]).issubset(set(result["id"].tolist()))


def test_remove_duplicates_raise_error():
# Create sample dataframes with specific partition counts
df1 = dd.from_pandas(
pd.DataFrame({"id": ["a1", "a2", "a3"], "text": ["text1", "text2", "text3"]}),
npartitions=2,
) # dataset with 2 partitions

duplicates = dd.from_pandas(
pd.DataFrame(
{"id": ["a1", "a2", "a3"], "group": ["group1", "group1", "group1"]}
),
npartitions=3,
) # duplicates dataset with 3 partitions

# Test that it raises ValueError when right npartitions are greater than left npartitions
with pytest.raises(ValueError) as exc_info:
remove_duplicates(
left=df1,
duplicates=duplicates,
id_field="id",
group_field="group",
)

expected_msg = (
"The number of partitions in `left` is less than the number of partitions in the duplicates dataset. "
"This may lead to a shuffle join. Please re-read left and right with different partition sizes, or repartition left / right."
)
assert str(exc_info.value) == expected_msg
Loading