From 07e2a408b5b1c44e996a8d1e9588311d4d39316c Mon Sep 17 00:00:00 2001 From: Praateek Mahajan Date: Tue, 19 Nov 2024 19:43:37 -0800 Subject: [PATCH] MinHash improvement using minhash_permuted (#313) --- nemo_curator/_compat.py | 27 +++++++--- nemo_curator/modules/fuzzy_dedup.py | 76 +++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/nemo_curator/_compat.py b/nemo_curator/_compat.py index 65e1155ca..26fb0574b 100644 --- a/nemo_curator/_compat.py +++ b/nemo_curator/_compat.py @@ -15,18 +15,31 @@ import sys import dask -from packaging.version import parse as parseVersion +from packaging.version import parse as parse_version try: - _dask_version = parseVersion(dask.__version__) + _dask_version = parse_version(dask.__version__) except TypeError: # When mocking with autodoc the dask version is not there - _dask_version = parseVersion("2024.06.0") + _dask_version = parse_version("2024.06.0") + +try: + import cudf + + CURRENT_CUDF_VERSION = parse_version(cudf.__version__) +except (ImportError, TypeError): + CURRENT_CUDF_VERSION = parse_version("24.10.0") + +# TODO remove this once 24.12.0 becomes the base version of cudf in nemo-curator +MINHASH_PERMUTED_AVAILABLE = CURRENT_CUDF_VERSION >= parse_version("24.12.0") or ( + CURRENT_CUDF_VERSION.is_prerelease + and CURRENT_CUDF_VERSION.base_version >= "24.12.0" +) # TODO: remove when dask min version gets bumped -DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0") -DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0") -DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0") +DASK_SHUFFLE_METHOD_ARG = _dask_version > parse_version("2024.1.0") +DASK_P2P_ERROR = _dask_version < parse_version("2023.10.0") +DASK_SHUFFLE_CAST_DTYPE = _dask_version > parse_version("2023.12.0") # Query-planning check (and cache) _DASK_QUERY_PLANNING_ENABLED = None @@ -36,7 +49,7 @@ def query_planning_enabled(): global _DASK_QUERY_PLANNING_ENABLED if _DASK_QUERY_PLANNING_ENABLED is None: - if _dask_version > parseVersion("2024.6.0"): + if _dask_version > parse_version("2024.6.0"): import dask.dataframe as dd _DASK_QUERY_PLANNING_ENABLED = dd.DASK_EXPR_ENABLED diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 61cef4f99..484742596 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -35,6 +35,7 @@ from dask.utils import M from tqdm import tqdm +from nemo_curator._compat import MINHASH_PERMUTED_AVAILABLE from nemo_curator.datasets import DocumentDataset from nemo_curator.log import create_logger from nemo_curator.modules.config import FuzzyDuplicatesConfig @@ -99,7 +100,14 @@ def __init__( """ self.num_hashes = num_hashes self.char_ngram = char_ngrams - self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed) + if MINHASH_PERMUTED_AVAILABLE: + self.seeds = self.generate_hash_permutation_seeds( + bit_width=64 if use_64bit_hash else 32, + n_permutations=self.num_hashes, + seed=seed, + ) + else: + self.seeds = self.generate_seeds(n_seeds=self.num_hashes, seed=seed) self.minhash_method = self.minhash64 if use_64bit_hash else self.minhash32 self.id_field = id_field self.text_field = text_field @@ -127,6 +135,35 @@ def generate_seeds(self, n_seeds: int = 260, seed: int = 0) -> np.ndarray: gen = np.random.RandomState(seed) return gen.randint(0, 1e6, size=n_seeds) + def generate_hash_permutation_seeds( + self, bit_width: int, n_permutations: int = 260, seed: int = 0 + ) -> np.ndarray: + """ + Generate seeds for all minhash permutations based on the given seed. + """ + gen = np.random.RandomState(seed) + + if bit_width == 32: + MERSENNE_PRIME = np.uint32((1 << 31) - 1) + dtype = np.uint32 + elif bit_width == 64: + # For 64-bit, use a larger prime number suitable for 64-bit operations + MERSENNE_PRIME = np.uint64((1 << 61) - 1) + dtype = np.uint64 + else: + raise ValueError("Unsupported bit width. Use either 32 or 64.") + + return np.array( + [ + ( + gen.randint(1, MERSENNE_PRIME, dtype=dtype), + gen.randint(0, MERSENNE_PRIME, dtype=dtype), + ) + for _ in range(n_permutations) + ], + dtype=dtype, + ) + def minhash32( self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int ) -> cudf.Series: @@ -135,8 +172,23 @@ def minhash32( """ if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") - seeds = cudf.Series(seeds, dtype="uint32") - return ser.str.minhash(seeds=seeds, width=char_ngram) + + if not MINHASH_PERMUTED_AVAILABLE: + warnings.warn( + "Using an outdated minhash implementation, please update to cuDF version 24.12 " + "or later for improved performance. " + "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`", + category=FutureWarning, + ) + seeds = cudf.Series(seeds, dtype="uint32") + return ser.str.minhash(seeds=seeds, width=char_ngram) + else: + seeds_a = cudf.Series(seeds[:, 0], dtype="uint32") + seeds_b = cudf.Series(seeds[:, 1], dtype="uint32") + + return ser.str.minhash_permuted( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) def minhash64( self, ser: cudf.Series, seeds: np.ndarray, char_ngram: int @@ -146,8 +198,22 @@ def minhash64( """ if not isinstance(ser, cudf.Series): raise TypeError("Expected data of type cudf.Series") - seeds = cudf.Series(seeds, dtype="uint64") - return ser.str.minhash64(seeds=seeds, width=char_ngram) + if not MINHASH_PERMUTED_AVAILABLE: + warnings.warn( + "Using an outdated minhash implementation, please update to cuDF version 24.12 " + "or later for improved performance. " + "Install the latest version of cuDF using `pip install curator[cuda12x_nightly]`", + category=FutureWarning, + ) + seeds = cudf.Series(seeds, dtype="uint64") + return ser.str.minhash64(seeds=seeds, width=char_ngram) + else: + seeds_a = cudf.Series(seeds[:, 0], dtype="uint64") + seeds_b = cudf.Series(seeds[:, 1], dtype="uint64") + + return ser.str.minhash64_permuted( + a=seeds_a, b=seeds_b, seed=seeds[0][0], width=char_ngram + ) def __call__(self, dataset: DocumentDataset) -> Union[str, DocumentDataset]: """