From 1dab545e2cd84f2ce57bee96abb6280ab8b481ad Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 6 Feb 2025 13:19:06 -0800 Subject: [PATCH 1/5] Enable ADD ID to work with CPU/GPU both (#479) * Enable ADD ID to work with CPU/GPU both Signed-off-by: Vibhu Jawa * Make Test runable in a CPU only environment Signed-off-by: Vibhu Jawa * Fix pytest skipping behavior in CPU/GPU environment Signed-off-by: Vibhu Jawa * Raise error instead of skipping test Signed-off-by: Vibhu Jawa --------- Signed-off-by: Vibhu Jawa --- nemo_curator/modules/add_id.py | 7 ++++-- nemo_curator/scripts/add_id.py | 3 ++- tests/test_add_id.py | 41 ++++++++++++++++++++++++++-------- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/nemo_curator/modules/add_id.py b/nemo_curator/modules/add_id.py index e7e733fc1..eca677c4e 100644 --- a/nemo_curator/modules/add_id.py +++ b/nemo_curator/modules/add_id.py @@ -39,8 +39,9 @@ def call(self, dataset: DocumentDataset) -> DocumentDataset: return self._add_id_ordered(dataset) def _add_id_fast(self, dataset: DocumentDataset) -> DocumentDataset: - meta = dataset.df.dtypes.to_dict() + meta = dataset.df._meta.copy() meta[self.id_field] = "string" + meta[self.id_field] = meta[self.id_field].astype("string") partition_zero_padding = count_digits(dataset.df.npartitions) id_df = dataset.df.map_partitions( @@ -61,12 +62,14 @@ def _add_id_fast_partition(self, partition, global_padding, partition_info=None) for local_id in range(len(partition)) ] partition[self.id_field] = id_column + partition[self.id_field] = partition[self.id_field].astype("string") return partition def _add_id_ordered(self, dataset: DocumentDataset) -> DocumentDataset: - original_meta = dataset.df.dtypes.to_dict() + original_meta = dataset.df._meta.copy() original_meta[self.id_field] = "string" + original_meta[self.id_field] = original_meta[self.id_field].astype("string") delayed_dataset = dataset.df.to_delayed() parition_lengths = [0] diff --git a/nemo_curator/scripts/add_id.py b/nemo_curator/scripts/add_id.py index c926e36dd..2a856af07 100644 --- a/nemo_curator/scripts/add_id.py +++ b/nemo_curator/scripts/add_id.py @@ -28,6 +28,7 @@ def main(args): client = get_client(**ArgumentHelper.parse_client_args(args)) + backend = "cudf" if args.device == "gpu" else "pandas" output_dir = expand_outdir_and_mkdir(args.output_data_dir) files = get_all_files_paths_under(args.input_data_dir) if args.shuffle: @@ -36,7 +37,7 @@ def main(args): dataset = DocumentDataset( read_data( - files, file_type=args.input_file_type, backend="pandas", add_filename=True + files, file_type=args.input_file_type, backend=backend, add_filename=True ) ) add_id = nemo_curator.AddId( diff --git a/tests/test_add_id.py b/tests/test_add_id.py index 42a8575e5..c33c5e4a8 100644 --- a/tests/test_add_id.py +++ b/tests/test_add_id.py @@ -18,26 +18,37 @@ import nemo_curator as nc from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.import_utils import gpu_only_import, is_unavailable +cudf = gpu_only_import("cudf") +is_cudf_available = not is_unavailable(cudf) -def list_to_dataset(documents, col_name="text", npartitions=2): + +def list_to_dataset(documents, col_name="text", npartitions=2, backend="pandas"): data = {col_name: documents} pdf = pd.DataFrame(data) - - return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + ddf = dd.from_pandas(pdf, npartitions=npartitions) + if backend == "cudf" and is_unavailable(cudf): + raise ImportError("cuDF is not installed or importable.") + ddf = ddf.to_backend(backend) + return DocumentDataset(ddf) -@pytest.fixture -def single_partition_dataset(): +@pytest.fixture(params=["pandas", pytest.param("cudf", marks=pytest.mark.gpu)]) +def single_partition_dataset(request): return list_to_dataset( - ["First", "Second", "Third", "Fourth", "Fifth"], npartitions=1 + ["First", "Second", "Third", "Fourth", "Fifth"], + npartitions=1, + backend=request.param, ) -@pytest.fixture -def two_partition_dataset(): +@pytest.fixture(params=["pandas", pytest.param("cudf", marks=pytest.mark.gpu)]) +def two_partition_dataset(request): return list_to_dataset( - ["First", "Second", "Third", "Fourth", "Fifth"], npartitions=2 + ["First", "Second", "Third", "Fourth", "Fifth"], + npartitions=2, + backend=request.param, ) @@ -56,6 +67,8 @@ def test_basic_id(self, single_partition_dataset): "doc_id-0000000004", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids @@ -75,6 +88,8 @@ def test_two_partitions(self, two_partition_dataset): "doc_id-0000000004", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids @@ -95,6 +110,8 @@ def test_id_prefix(self, two_partition_dataset): f"{id_prefix}-0000000004", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids @@ -115,6 +132,8 @@ def test_start_index(self, two_partition_dataset): "doc_id-0000000017", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids @@ -134,6 +153,8 @@ def test_fast_id_single_partition(self, single_partition_dataset): "doc_id-40", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids @@ -153,6 +174,8 @@ def test_fast_id_two_partitions(self, two_partition_dataset): "doc_id-11", ] ) + if is_cudf_available and isinstance(actual_ids, cudf.Series): + actual_ids = actual_ids.to_pandas() assert all( expected_ids == actual_ids From 97aa372e49018e7c334c9de0de1c027c8ba2b7d0 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 6 Feb 2025 13:20:06 -0800 Subject: [PATCH 2/5] Add Pooling Strategy Option for embedding creation (#491) * Add pooling stratedgy Signed-off-by: Vibhu Jawa * Ensure pytest is importable in a CPU only environment Signed-off-by: Vibhu Jawa * Fix last token based on Avinash's feedback Signed-off-by: Vibhu Jawa * Fix indexing issues Signed-off-by: Vibhu Jawa * Merge in main Signed-off-by: Vibhu Jawa * Fix Doc-string Signed-off-by: Vibhu Jawa * Address Sarah's reviews Signed-off-by: Vibhu Jawa --------- Signed-off-by: Vibhu Jawa --- nemo_curator/modules/config.py | 3 + .../modules/semantic_dedup/embeddings.py | 28 +++++- .../modules/semantic_dedup/semdedup.py | 1 + tests/test_semdedup.py | 93 +++++++++++++++++++ 4 files changed, 123 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/config.py b/nemo_curator/modules/config.py index d29f02f49..50c71017b 100644 --- a/nemo_curator/modules/config.py +++ b/nemo_curator/modules/config.py @@ -145,6 +145,7 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc (str): Location to save embeddings. embedding_model_name_or_path (str): Model name or path for embeddings. embedding_batch_size (int): Inital Batch size for processing embeddings. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". write_embeddings_to_disk (bool): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. Setting it to False can lead to more memory overhead. @@ -168,6 +169,8 @@ class SemDedupConfig(BaseConfig): embeddings_save_loc: str = "embeddings" embedding_model_name_or_path: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_batch_size: int = 128 + # Options: "mean_pooling", "last_token" + embedding_pooling_strategy: str = "mean_pooling" write_embeddings_to_disk: bool = True # Clustering config diff --git a/nemo_curator/modules/semantic_dedup/embeddings.py b/nemo_curator/modules/semantic_dedup/embeddings.py index 7c607b63e..7f6315e52 100644 --- a/nemo_curator/modules/semantic_dedup/embeddings.py +++ b/nemo_curator/modules/semantic_dedup/embeddings.py @@ -41,6 +41,7 @@ class EmbeddingConfig: model_name_or_path: str max_seq_length: int = None + pooling_strategy: str = "mean_pooling" # Options: "mean_pooling" or "last_token" def __post_init__(self): self.max_seq_length = AutoTokenizer.from_pretrained( @@ -52,6 +53,10 @@ def __post_init__(self): self.max_seq_length = AutoConfig.from_pretrained( self.model_name_or_path ).max_position_embeddings + if self.pooling_strategy not in ["mean_pooling", "last_token"]: + raise ValueError( + "pooling_strategy must be either 'mean_pooling' or 'last_token'" + ) class EmbeddingPytorchModel(nn.Module): @@ -70,7 +75,10 @@ def feature(self, input_ids, attention_mask): @torch.no_grad() def forward(self, batch): feature = self.feature(batch["input_ids"], batch["attention_mask"]) - return self._mean_pooling(feature, batch["attention_mask"]) + if self.config.pooling_strategy == "mean_pooling": + return self._mean_pooling(feature, batch["attention_mask"]) + else: + return self._get_last_token(feature, batch["attention_mask"]) def _mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] @@ -81,6 +89,19 @@ def _mean_pooling(self, model_output, attention_mask): sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) return F.normalize(sum_embeddings / sum_mask, dim=1) + def _get_last_token(self, model_output, attention_mask): + token_embeddings = model_output[0] + # Get indices of last non-padded tokens for each sequence in batch + last_token_indices = attention_mask.sum(dim=1) - 1 # -1 for 0-based indexing + last_token_indices = last_token_indices.to( + torch.long + ) # Ensure indices are of type long + batch_size = attention_mask.size(0) + batch_indices = torch.arange(batch_size, device=attention_mask.device) + # Get embeddings of last non-padded tokens + last_token_embeddings = token_embeddings[batch_indices, last_token_indices] + return F.normalize(last_token_embeddings, dim=1) + class EmbeddingCrossFitModel(HFModel): def __init__( @@ -116,6 +137,7 @@ def __init__( embedding_batch_size: int, embedding_output_dir: str, embedding_max_mem_gb: Optional[int] = None, + embedding_pooling_strategy: str = "mean_pooling", input_column: str = "text", embedding_column: str = "embeddings", write_embeddings_to_disk: bool = True, @@ -132,6 +154,7 @@ def __init__( embedding_output_dir (str): Directory path where embeddings will be saved. embedding_max_mem_gb (int): Maximum memory usage in GB for the embedding process. If None, it defaults to the available GPU memory minus 4 GB. + embedding_pooling_strategy (str): Strategy for pooling embeddings, either "mean_pooling" or "last_token". Defaults to "mean_pooling". input_column (str): Column name from the data to be used for embedding generation, defaults to "text". write_embeddings_to_disk (bool, optional): If True, saves the embeddings to disk, defaults to True. We recommend setting this to False when you have a delayed pipeline. @@ -152,6 +175,7 @@ def __init__( self.embeddings_config = EmbeddingConfig( model_name_or_path=embedding_model_name_or_path, + pooling_strategy=embedding_pooling_strategy, ) self.batch_size = embedding_batch_size self.logger = self._setup_logger(logger) @@ -184,7 +208,7 @@ def create_embeddings( op.Tokenizer( self.model, cols=[input_column], - tokenizer_type="sentencepiece", + tokenizer_type="default", max_length=self.embeddings_config.max_seq_length, ), op.Predictor( diff --git a/nemo_curator/modules/semantic_dedup/semdedup.py b/nemo_curator/modules/semantic_dedup/semdedup.py index b86a468bb..a8c66e31d 100644 --- a/nemo_curator/modules/semantic_dedup/semdedup.py +++ b/nemo_curator/modules/semantic_dedup/semdedup.py @@ -50,6 +50,7 @@ def __init__( self.embedding_creator = EmbeddingCreator( embedding_model_name_or_path=config.embedding_model_name_or_path, embedding_batch_size=config.embedding_batch_size, + embedding_pooling_strategy=config.embedding_pooling_strategy, input_column=input_column, embedding_output_dir=os.path.join(cache_dir, config.embeddings_save_loc), write_embeddings_to_disk=config.write_embeddings_to_disk, diff --git a/tests/test_semdedup.py b/tests/test_semdedup.py index 4cc66901d..8ccf850a7 100644 --- a/tests/test_semdedup.py +++ b/tests/test_semdedup.py @@ -13,9 +13,13 @@ # limitations under the License. import os +import numpy as np import pytest +import torch +import torch.nn.functional as F from dask.dataframe.utils import assert_eq from distributed import Client +from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo_curator import SemDedup, SemDedupConfig from nemo_curator.datasets import DocumentDataset @@ -24,6 +28,9 @@ cudf = gpu_only_import("cudf") dask_cudf = gpu_only_import("dask_cudf") LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") +EmbeddingCreator = gpu_only_import_from( + "nemo_curator.modules.semantic_dedup.embeddings", "EmbeddingCreator" +) @pytest.fixture @@ -80,3 +87,89 @@ def test_sem_dedup( duplicate_docs = [2, 3, 4, 200, 300] expected_df = cudf.Series(duplicate_docs, name="id") assert_eq(result_df["id"].sort_values(), expected_df, check_index=False) + + @pytest.mark.parametrize("pooling_strategy", ["last_token", "mean_pooling"]) + def test_embedding_creator_pooling_strategies(self, tmpdir, pooling_strategy): + test_text_1 = "The quick brown fox jumps over the lazy dog" + test_text_2 = "The brown fox jumps over the dog" + test_texts = [test_text_1, test_text_2] * 32 + df = cudf.DataFrame({"text": test_texts}) + ddf = dask_cudf.from_cudf(df, 1) + cache_dir = os.path.join(tmpdir, "test_embeddings_cache") + + embedding_creator = EmbeddingCreator( + embedding_model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", + embedding_batch_size=32, + embedding_pooling_strategy=pooling_strategy, + input_column="text", + embedding_output_dir=os.path.join(cache_dir, "mean_embeddings"), + ) + embeddings = embedding_creator.create_embeddings(ddf).compute() + embeddings = embeddings["embeddings"].to_arrow().to_pylist() + embeddings = np.array(embeddings) + reference_embeddings = get_reference_embeddings( + test_texts, pooling_strategy=pooling_strategy + ) + assert np.allclose( + embeddings, reference_embeddings, atol=1e-3 + ), "Embeddings should match reference embeddings" + + +def get_reference_embeddings( + texts, + model_name="sentence-transformers/all-MiniLM-L6-v2", + pooling_strategy="last_token", +): + """ + Get embeddings using either last token or mean pooling strategy. + + Args: + texts: List of input texts + model_name: Name or path of the model to use + pooling_strategy: Either "last_token" for last token or "mean" for mean pooling + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + model = model.to("cuda") + model.eval() + max_len_to_use = tokenizer.model_max_length + if max_len_to_use > 1e5: + max_len_to_use = AutoConfig.from_pretrained(model_name).max_position_embeddings + max_seq_length: int = max_len_to_use + + embs = [] + for text in texts: + inputs = tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_seq_length, + ) + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + with torch.autocast(device_type="cuda"): + outputs = model(**inputs) + + if pooling_strategy == "last_token": + embeddings = outputs.last_hidden_state[:, -1, :] + elif pooling_strategy == "mean_pooling": + token_embeddings = outputs.last_hidden_state + attention_mask = inputs["attention_mask"] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1) + sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) + embeddings = sum_embeddings / sum_mask + else: + raise ValueError( + "pooling_strategy must be either 'last_token' or 'mean_pooling'" + ) + + normed_emb = F.normalize(embeddings, dim=1).cpu() + normed_emb = normed_emb.squeeze(0) + embs.append(normed_emb) + + return np.array(embs) From ca3080850c4a24607f6d9a07916782a6c1af0647 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 6 Feb 2025 14:35:26 -0800 Subject: [PATCH 3/5] Add Partition On Logic (#519) * add partition_on logic Signed-off-by: Vibhu Jawa * Add Docstring based on Sarah's review Signed-off-by: Vibhu Jawa * Apply Praateek's suggestion and skip test with using pytest.mark.gpu Signed-off-by: Vibhu Jawa * Apply Praateek's suggestion and force index=False Signed-off-by: Vibhu Jawa --------- Signed-off-by: Vibhu Jawa --- nemo_curator/datasets/doc_dataset.py | 44 ++++++++- nemo_curator/utils/distributed_utils.py | 64 ++++++++++-- tests/test_io.py | 124 ++++++++++++++++++++++++ 3 files changed, 223 insertions(+), 9 deletions(-) diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index 6d49a9987..fa042e8bb 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -160,16 +160,36 @@ def to_json( output_path: str, write_to_filename: Union[bool, str] = False, keep_filename_column: bool = False, + partition_on: Optional[str] = None, ): """ - See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters. + Writes the dataset to the specified path in JSONL format. + If `write_to_filename` is True, the DataFrame is expected to have a column + that specifies the filename for each document. This column can be named + `file_name` by default, or a custom name if `write_to_filename` is a string. + + Args: + output_path (str): The directory or file path where the dataset will be written. + write_to_filename (Union[bool, str]): Determines how filenames are handled. + - If True, uses the `file_name` column in the DataFrame to determine filenames. + - If a string, uses that string as the column name for filenames. + - If False, writes all data to the specified `output_path`. + keep_filename_column (bool): If True, retains the filename column in the output. + If False, the filename column is dropped from the output. + partition_on (Optional[str]): The column name used to partition the data. + If specified, data is partitioned based on unique values in this column, + with each partition written to a separate directory. + + For more details, refer to the `write_to_disk` function in + `nemo_curator.utils.distributed_utils`. """ write_to_disk( df=self.df, output_path=output_path, write_to_filename=write_to_filename, keep_filename_column=keep_filename_column, + partition_on=partition_on, output_type="jsonl", ) @@ -178,16 +198,36 @@ def to_parquet( output_path: str, write_to_filename: Union[bool, str] = False, keep_filename_column: bool = False, + partition_on: Optional[str] = None, ): """ - See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters. + Writes the dataset to the specified path in Parquet format. + If `write_to_filename` is True, the DataFrame is expected to have a column + that specifies the filename for each document. This column can be named + `file_name` by default, or a custom name if `write_to_filename` is a string. + + Args: + output_path (str): The directory or file path where the dataset will be written. + write_to_filename (Union[bool, str]): Determines how filenames are handled. + - If True, uses the `file_name` column in the DataFrame to determine filenames. + - If a string, uses that string as the column name for filenames. + - If False, writes all data to the specified `output_path`. + keep_filename_column (bool): If True, retains the filename column in the output. + If False, the filename column is dropped from the output. + partition_on (Optional[str]): The column name used to partition the data. + If specified, data is partitioned based on unique values in this column, + with each partition written to a separate directory. + + For more details, refer to the `write_to_disk` function in + `nemo_curator.utils.distributed_utils`. """ write_to_disk( df=self.df, output_path=output_path, write_to_filename=write_to_filename, keep_filename_column=keep_filename_column, + partition_on=partition_on, output_type="parquet", ) diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index addabfd9c..8f0223896 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -748,6 +748,7 @@ def single_partition_write_with_filename( orient="records", lines=True, force_ascii=False, + index=False, # Only index=False is supported for orient="records" ) else: # See open issue here: https://github.com/rapidsai/cudf/issues/15211 @@ -759,6 +760,7 @@ def single_partition_write_with_filename( orient="records", lines=True, force_ascii=False, + index=False, # Only index=False is supported for orient="records" ) elif output_type == "parquet": @@ -843,6 +845,7 @@ def write_to_disk( write_to_filename: Union[bool, str] = False, keep_filename_column: bool = False, output_type: str = "jsonl", + partition_on: Optional[str] = None, ): """ This function writes a Dask DataFrame to the specified file path. @@ -857,6 +860,9 @@ def write_to_disk( If str, uses that as the filename column to write to. keep_filename_column: Boolean representing whether to keep or drop the filename column, if it exists. output_type: The type of output file to write. Can be "jsonl" or "parquet". + partition_on: The column name to partition the data on. + If specified, the data will be partitioned based on the unique values in this column, + and each partition will be written to a separate directory """ filename_col = _resolve_filename_col(write_to_filename) @@ -879,6 +885,11 @@ def write_to_disk( f"write_using_filename is True but no {filename_col} column found in DataFrame" ) + if partition_on is not None and write_to_filename: + raise ValueError( + "Cannot use both partition_on and write_to_filename parameters simultaneously. " + ) + if is_cudf_type(df): import cudf @@ -904,7 +915,12 @@ def write_to_disk( # output_path is a directory else: if output_type == "jsonl" or output_type == "parquet": - _write_to_jsonl_or_parquet(df, output_path, output_type) + _write_to_jsonl_or_parquet( + df, + output_path=output_path, + output_type=output_type, + partition_on=partition_on, + ) elif output_type == "bitext": if write_to_filename: os.makedirs(output_path, exist_ok=True) @@ -938,16 +954,50 @@ def _write_to_jsonl_or_parquet( df, output_path: str, output_type: Literal["jsonl", "parquet"] = "jsonl", + partition_on: Optional[str] = None, ): if output_type == "jsonl": - if is_cudf_type(df): - # See open issue here: https://github.com/rapidsai/cudf/issues/15211 - # df.to_json(output_path, orient="records", lines=True, engine="cudf", force_ascii=False) - df.to_json(output_path, orient="records", lines=True, force_ascii=False) + if partition_on is not None: + unique_values = ( + df[partition_on] + .unique() + .to_backend(backend="pandas") + .compute() + .to_list() + ) + for value in unique_values: + os.makedirs(output_path, exist_ok=True) + partition_output_path = os.path.join( + output_path, f"{partition_on}={value}" + ) + df[df[partition_on] == value].to_json( + partition_output_path, + orient="records", + lines=True, + force_ascii=False, + index=False, # Only index=False is supported for orient="records" + ) else: - df.to_json(output_path, orient="records", lines=True, force_ascii=False) + if is_cudf_type(df): + # See open issue here: https://github.com/rapidsai/cudf/issues/15211 + # df.to_json(output_path, orient="records", lines=True, engine="cudf", force_ascii=False) + df.to_json( + output_path, + orient="records", + lines=True, + force_ascii=False, + index=False, + ) # Only index=False is supported for orient="records" + else: + df.to_json( + output_path, + orient="records", + lines=True, + force_ascii=False, + index=False, + ) # Only index=False is supported for orient="records" elif output_type == "parquet": - df.to_parquet(output_path, write_index=False) + df.to_parquet(output_path, write_index=False, partition_on=partition_on) else: raise ValueError(f"Unknown output type: {output_type}") diff --git a/tests/test_io.py b/tests/test_io.py index ca0c645b4..1efe05695 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -293,3 +293,127 @@ def test_write_single_jsonl_file(self, tmp_path): result = DocumentDataset.read_json(output_path) assert json_df.equals(result.df.compute()) + + +class TestPartitionOn: + def test_partition_on_and_write_to_filename_error(self, tmp_path): + """Verify that using partition_on and write_to_filename together raises an error.""" + df = pd.DataFrame( + { + "id": [1, 2, 3], + "file_name": ["f1", "f1", "f1"], + "category": ["A", "B", "A"], + } + ) + ddf = dd.from_pandas(df, npartitions=1) + dataset = DocumentDataset(ddf) + with pytest.raises( + ValueError, + match="Cannot use both partition_on and write_to_filename parameters simultaneously.", + ): + dataset.to_json( + output_path=str(tmp_path / "output"), + write_to_filename=True, # Intentionally provided to trigger the error + partition_on="category", + ) + + @pytest.mark.parametrize( + "backend", ["pandas", pytest.param("cudf", marks=pytest.mark.gpu)] + ) + @pytest.mark.parametrize( + "category_values", + [ + ["A", "B", "A", "B"], + [10, 20, 10, 20], + [1.0, 2.0, 1.0, 2.0], + ], + ) + def test_write_to_disk_with_partition_on_jsonl( + self, tmp_path, backend, category_values + ): + """ + Test writing a partitioned JSONL dataset. + + The function is expected to create subdirectories in the output directory + with names of the form 'category=' for each unique partition column value. + """ + df = pd.DataFrame( + {"id": [1, 2, 3, 4], "category": category_values, "value": [10, 20, 30, 40]} + ) + ddf = dd.from_pandas(df, npartitions=2) + ddf = ddf.to_backend(backend) + output_dir = tmp_path / "output_jsonl" + dataset = DocumentDataset(ddf) + dataset.to_json(output_path=str(output_dir), partition_on="category") + # Check that the output directory contains subdirectories for each partition. + # Unique partition values (as strings) to be used in the directory names. + unique_partitions = {str(x) for x in category_values} + for part in unique_partitions: + expected_dir = output_dir / f"category={part}" + assert expected_dir.exists(), f"Expected directory {expected_dir} not found" + + # For each partition directory, load the JSONL files and verify that all records have the correct partition value. + # (Here we assume the files are written with extension ".part") + for part_dir in output_dir.glob("category=*"): + # The partition value is taken from the directory name. + partition_value = part_dir.name.split("=")[-1] + jsonl_files = list(part_dir.glob("*.part")) + assert ( + jsonl_files + ), f"No JSONL files found in partition directory {part_dir}" + for file in jsonl_files: + with open(file, "r") as f: + for line in f: + record = json.loads(line) + if "category" in record: + # Compare as strings, to work with both integer and string partition values. + assert ( + str(record["category"]) == partition_value + ), f"Record partition value {record['category']} does not match directory {partition_value}" + + @pytest.mark.parametrize( + "backend", ["pandas", pytest.param("cudf", marks=pytest.mark.gpu)] + ) + @pytest.mark.parametrize( + "category_values", + [ + ["A", "B", "A", "B"], + [10, 20, 10, 20], + [1.0, 2.0, 1.0, 2.0], + ], + ) + def test_write_to_disk_with_partition_on_parquet( + self, tmp_path, backend, category_values + ): + """ + Test writing a partitioned Parquet dataset. + + The test writes a DataFrame partitioned on the 'category' column and then reads it back + using dd.read_parquet. The output is compared (after sorting) to the original DataFrame. + """ + + df = pd.DataFrame( + {"id": [1, 2, 3, 4], "category": category_values, "value": [10, 20, 30, 40]} + ) + ddf = dd.from_pandas(df, npartitions=2) + ddf = ddf.to_backend(backend) + output_dir = tmp_path / "output_parquet" + dataset = DocumentDataset(ddf) + dataset.to_parquet(output_path=str(output_dir), partition_on="category") + + # Check that the output directory contains subdirectories for each partition. + # Unique partition values (as strings) to be used in the directory names. + unique_partitions = {str(x) for x in category_values} + for part in unique_partitions: + expected_dir = output_dir / f"category={part}" + assert expected_dir.exists(), f"Expected directory {expected_dir} not found" + + ddf_loaded = dd.read_parquet(str(output_dir)) + df_loaded = ddf_loaded.compute().reset_index(drop=True) + df_loaded["category"] = df_loaded["category"].astype(df["category"].dtype) + # To ensure a fair comparison, sort the dataframes by 'id' and reindex. + pd.testing.assert_frame_equal( + df.sort_values("id").reset_index(drop=True), + df_loaded.sort_values("id").reset_index(drop=True)[df.columns], + check_dtype=False, + ) From 70278d1c665a53ed21fddff7e5fdf7f14d742896 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 15:11:34 -0800 Subject: [PATCH 4/5] Add improved cleaning methods from Nemotron-CC (#517) * Add improved cleaning features Signed-off-by: Ryan Wolf * Fix cleaning tests Signed-off-by: Ryan Wolf * Update documentation and CLI scripts Signed-off-by: Ryan Wolf * Address Sarah and Lawrence's reviews Signed-off-by: Ryan Wolf --------- Signed-off-by: Ryan Wolf --- README.md | 4 +- docs/user-guide/index.rst | 7 +- ...matting.rst => languageidentification.rst} | 40 +---- docs/user-guide/text-cleaning.rst | 98 ++++++++++++ docs/user-guide/text-curation.rst | 10 +- examples/README.md | 2 +- ...d_fix_unicode.py => identify_languages.py} | 19 +-- nemo_curator/modifiers/__init__.py | 4 + nemo_curator/modifiers/newline_normalizer.py | 33 ++++ nemo_curator/modifiers/url_remover.py | 30 ++++ nemo_curator/scripts/text_cleaning.py | 24 ++- tests/test_cleaning.py | 151 ++++++++++++++++++ tests/test_unicode_reformatter.py | 59 ------- 13 files changed, 355 insertions(+), 126 deletions(-) rename docs/user-guide/{languageidentificationunicodeformatting.rst => languageidentification.rst} (60%) create mode 100644 docs/user-guide/text-cleaning.rst rename examples/{identify_languages_and_fix_unicode.py => identify_languages.py} (79%) create mode 100644 nemo_curator/modifiers/newline_normalizer.py create mode 100644 nemo_curator/modifiers/url_remover.py create mode 100644 tests/test_cleaning.py delete mode 100644 tests/test_unicode_reformatter.py diff --git a/README.md b/README.md index d52129f46..77b32836c 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,8 @@ All of our text pipelines have great multilingual support. - [Download and Extraction](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/download.html) - Default implementations for Common Crawl, Wikipedia, and ArXiv sources - Easily customize and extend to other sources -- [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) -- [Unicode Reformatting](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentificationunicodeformatting.html) +- [Language Identification](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/languageidentification.html) +- [Text Cleaning](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/textcleaning.html) - [Heuristic Filtering](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) - Classifier Filtering - [fastText](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/qualityfiltering.html) diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index a9c589acf..b63e1b933 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -16,8 +16,11 @@ Text Curation :ref:`Document Filtering ` This section describes how to use the 30+ heuristic and classifier filters available within the NeMo Curator and implement custom filters to apply to the documents within the corpora. -:ref:`Language Identification and Unicode Fixing ` - Large, unlabeled text corpora often contain a variety of languages. The NeMo Curator provides utilities to identify languages and fix improperly decoded Unicode characters. +:ref:`Language Identification ` + Large, unlabeled text corpora often contain a variety of languages. NeMo Curator provides utilities to identify languages. + +:ref:`Text Cleaning ` + Many parts of the Internet contained malformed or poorly formatted text. NeMo Curator can fix many of these issues with text. :ref:`GPU Accelerated Exact and Fuzzy Deduplication ` Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. diff --git a/docs/user-guide/languageidentificationunicodeformatting.rst b/docs/user-guide/languageidentification.rst similarity index 60% rename from docs/user-guide/languageidentificationunicodeformatting.rst rename to docs/user-guide/languageidentification.rst index 3e61f8f7d..561f14c3b 100644 --- a/docs/user-guide/languageidentificationunicodeformatting.rst +++ b/docs/user-guide/languageidentification.rst @@ -11,10 +11,8 @@ Background Large unlabeled text corpora often contain a variety of languages. However, data curation usually includes steps that are language specific (e.g. using language-tuned heuristics for quality filtering) and many curators are only interested in curating a monolingual dataset. -Datasets also may have improperly decoded unicode characters (e.g. "The Mona Lisa doesn't have eyebrows." decoding as "The Mona Lisa doesn’t have eyebrows."). -NeMo Curator provides utilities to identify languages and fix improperly decoded unicode characters. -The language identification is performed using `fastText `_ and unicode fixing is performed using `ftfy `_. +NeMo Curator provides utilities to identify languages using `fastText `_. Even though a preliminary language identification may have been performed on the unextracted text (as is the case in our Common Crawl pipeline using pyCLD2), `fastText `_ is more accurate so it can be used for a second pass. @@ -22,29 +20,8 @@ using pyCLD2), `fastText str: - return ftfy.fix_text(text) - -Also like the ``DocumentFilter`` functions, ``modify_document`` can be annotated with ``batched`` to take in a pandas series of documents instead of a single document. ----------------------------------------- Related Scripts @@ -79,15 +56,4 @@ within that file. Below is an example run command for :code:`separate_by_metadat --output-metadata-distribution=./data/lang_distro.json After running this module, the output directory will consist of one directory per language present within the corpus and all documents -within those directories will contain text that originates from the same language. Finally, the text within a specific language can have -its unicode fixed using the :code:`text_cleaning` module - -.. code-block:: bash - - text_cleaning \ - --input-data-dir=/EN \ - --output-clean-dir= - - -The above :code:`text_cleaning` module uses the heuristics defined within the :code:`ftfy` package that is commonly used for fixing -improperly decoded unicode. +within those directories will contain text that originates from the same language. diff --git a/docs/user-guide/text-cleaning.rst b/docs/user-guide/text-cleaning.rst new file mode 100644 index 000000000..b9ffaa2f3 --- /dev/null +++ b/docs/user-guide/text-cleaning.rst @@ -0,0 +1,98 @@ +.. _data-curator-text-cleaning: + +========================= +Text Cleaning +========================= + +-------------------- +Overview +-------------------- +Use NeMo Curator's text cleaning modules to remove undesirable text such as improperly decoded unicode characters, inconsistent line spacing, or excessive URLs from documents being pre-processed for dataset. + +For example, the input sentence `"The Mona Lisa doesn't have eyebrows."` from a given document may not have included a properly encoded apostrophe (`'`), resulting in the sentence decoding as `"The Mona Lisa doesn’t have eyebrows."` NeMo Curator enables you to easily run this document through the default `UnicodeReformatter()` module to detect and remove the unwanted text, or you can define your own custom unicode text cleaner tailored to your needs. + +-------------------- +Use Cases +-------------------- +* Fix improperly decoded Unicode characters from webpages. +* Standardize document layout by removing excessive newlines. +* Remove URLs in documents. + +-------------------- +Modules +-------------------- +NeMo Curator provides the following modules for cleaning text: + +- ``UnicodeReformatter()``: Uses [ftfy](https://ftfy.readthedocs.io/en/latest/) to fix broken Unicode characters. Modifies the "text" field of the dataset by default. +- ``NewlineNormalizer()``: Uses regex to replace 3 or more consecutive newline characters in each document with only 2 newline characters. +- ``UrlRemover()``: Uses regex to remove all urls in each document. + +You can use these modules individually or sequentially in a cleaning pipeline. + +Consider the following example, which loads a dataset (`books.jsonl`), steps through each module in a cleaning pipeline, and outputs the processed dataset as `cleaned_books.jsonl`: + + +.. code-block:: python + + from nemo_curator import Sequential, Modify, get_client + from nemo_curator.datasets import DocumentDataset + from nemo_curator.modifiers import UnicodeReformatter, UrlRemover, NewlineNormalizer + + def main(): + client = get_client(cluster_type="cpu") + + dataset = DocumentDataset.read_json("books.jsonl") + cleaning_pipeline = Sequential([ + Modify(UnicodeReformatter()), + Modify(NewlineNormalizer()), + Modify(UrlRemover()), + ]) + + cleaned_dataset = cleaning_pipeline(dataset) + + cleaned_dataset.to_json("cleaned_books.jsonl") + + if __name__ == "__main__": + main() + +You can also perform text cleaning operations using the CLI by running the `text_cleaning` command: + +.. code-block:: bash + + text_cleaning \ + --input-data-dir=/path/to/input/ \ + --output-clean-dir=/path/to/output/ \ + --normalize-newlines \ + --remove-urls + +By default, the CLI will only perform unicode reformatting. Adding the ``--normalize-newlines`` and ``--remove-urls`` options add the other text cleaning options. + +------------------------ +Custom Text Cleaner +------------------------ +It's easy to write your own custom text cleaner. The implementation of ``UnicodeReformatter`` can be used as an example. + +.. code-block:: python + import ftfy + + from nemo_curator.modifiers import DocumentModifier + + + class UnicodeReformatter(DocumentModifier): + def __init__(self): + super().__init__() + + def modify_document(self, text: str) -> str: + return ftfy.fix_text(text) + +Simply define a new class that inherits from ``DocumentModifier`` and define the constructor and ``modify_text`` method. +Also, like the ``DocumentFilter`` class, ``modify_document`` can be annotated with ``batched`` to take in a pandas series of documents instead of a single document. +See the :ref:`document filtering page ` for more information. + +--------------------------- +Additional Resources +--------------------------- +* `Single GPU Tutorial `_ +* `ftfy `_ +* `Refined Web Paper `_ +* `Nemotron-CC Paper `_ \ No newline at end of file diff --git a/docs/user-guide/text-curation.rst b/docs/user-guide/text-curation.rst index 4d2e1ddb8..a4cc83b05 100644 --- a/docs/user-guide/text-curation.rst +++ b/docs/user-guide/text-curation.rst @@ -13,8 +13,11 @@ Text Curation :ref:`Document Filtering ` This section describes how to use the 30+ heuristic and classifier filters available within the NeMo Curator and implement custom filters to apply to the documents within the corpora. -:ref:`Language Identification and Unicode Fixing ` - Large, unlabeled text corpora often contain a variety of languages. The NeMo Curator provides utilities to identify languages and fix improperly decoded Unicode characters. +:ref:`Language Identification ` + Large, unlabeled text corpora often contain a variety of languages. NeMo Curator provides utilities to identify languages. + +:ref:`Text Cleaning ` + Many parts of the Internet contained malformed or poorly formatted text. NeMo Curator can fix many of these issues with text. :ref:`GPU Accelerated Exact and Fuzzy Deduplication ` Both exact and fuzzy deduplication functionalities are supported in NeMo Curator and accelerated using RAPIDS cuDF. @@ -43,7 +46,8 @@ Text Curation documentdataset.rst cpuvsgpu.rst qualityfiltering.rst - languageidentificationunicodeformatting.rst + languageidentification.rst + textcleaning.rst gpudeduplication.rst semdedup.rst syntheticdata.rst diff --git a/examples/README.md b/examples/README.md index 3e101a1e0..29545978c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,7 +14,7 @@ These include: | exact_deduplication.py | Use the `ExactDuplicates` class to perform exact deduplication on text data. | | find_pii_and_deidentify.py | Use the `PiiModifier` and `Modify` classes to remove personally identifiable information from text data. | | fuzzy_deduplication.py | Use the `FuzzyDuplicatesConfig` and `FuzzyDuplicates` classes to perform fuzzy deduplication on text data. | -| identify_languages_and_fix_unicode.py | Use `FastTextLangId` to filter data by language, then fix the unicode in it. | +| identify_languages.py | Use `FastTextLangId` to filter data by language | | raw_download_common_crawl.py | Download the raw compressed WARC files from Common Crawl without extracting them. | | semdedup_example.py | Use the `SemDedup` class to perform semantic deduplication on text data. | | task_decontamination.py | Remove segments of downstream evaluation tasks from a dataset. | diff --git a/examples/identify_languages_and_fix_unicode.py b/examples/identify_languages.py similarity index 79% rename from examples/identify_languages_and_fix_unicode.py rename to examples/identify_languages.py index 92f628e33..2a090da0a 100644 --- a/examples/identify_languages_and_fix_unicode.py +++ b/examples/identify_languages.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +13,11 @@ # limitations under the License. import argparse -import os import nemo_curator as nc from nemo_curator.datasets import DocumentDataset from nemo_curator.filters import FastTextLangId -from nemo_curator.modifiers import UnicodeReformatter -from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk +from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import ( get_all_files_paths_under, separate_by_metadata, @@ -45,7 +43,6 @@ def main(args): # and see a list of supported languages here: # https://fasttext.cc/docs/en/language-identification.html model_path = "/path/to/model.bin" - target_language = "EN" language_field = "language" # Prepare samples for the classifier @@ -70,18 +67,6 @@ def main(args): metadata_field=language_field, ).compute() - # Read the language specific data and fix the unicode in it - lang_data_path = os.path.join(language_separated_output_path, target_language) - if not os.path.exists(lang_data_path): - raise RuntimeError(f"Dataset did not have language: {target_language}") - lang_data = load_dataset(lang_data_path) - - cleaner = nc.Modify(UnicodeReformatter()) - cleaned_data = cleaner(lang_data) - - # Write the cleaned_data - write_to_disk(cleaned_data.df, cleaned_data_output_path, write_to_filename=True) - def attach_args( parser=argparse.ArgumentParser( diff --git a/nemo_curator/modifiers/__init__.py b/nemo_curator/modifiers/__init__.py index f6511fdb0..e4b9a62ab 100644 --- a/nemo_curator/modifiers/__init__.py +++ b/nemo_curator/modifiers/__init__.py @@ -15,8 +15,10 @@ from .c4 import BoilerPlateStringModifier from .doc_modifier import DocumentModifier from .fasttext import FastTextLabelModifier +from .newline_normalizer import NewlineNormalizer from .pii_modifier import PiiModifier from .unicode_reformatter import UnicodeReformatter +from .url_remover import UrlRemover __all__ = [ "DocumentModifier", @@ -24,4 +26,6 @@ "FastTextLabelModifier", "UnicodeReformatter", "PiiModifier", + "NewlineNormalizer", + "UrlRemover", ] diff --git a/nemo_curator/modifiers/newline_normalizer.py b/nemo_curator/modifiers/newline_normalizer.py new file mode 100644 index 000000000..020403c14 --- /dev/null +++ b/nemo_curator/modifiers/newline_normalizer.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from nemo_curator.modifiers import DocumentModifier + +THREE_OR_MORE_NEWLINES_REGEX = re.compile(r"(\n){3,}") +THREE_OR_MORE_WINDOWS_NEWLINES_REGEX = re.compile(r"(\r\n){3,}") + + +class NewlineNormalizer(DocumentModifier): + """ + Replaces 3 or more consecutive newline characters with only 2 newline characters. + """ + + def __init__(self): + super().__init__() + + def modify_document(self, text): + text = THREE_OR_MORE_NEWLINES_REGEX.sub("\n\n", text) + text = THREE_OR_MORE_WINDOWS_NEWLINES_REGEX.sub("\r\n\r\n", text) + return text diff --git a/nemo_curator/modifiers/url_remover.py b/nemo_curator/modifiers/url_remover.py new file mode 100644 index 000000000..85ebe4b6b --- /dev/null +++ b/nemo_curator/modifiers/url_remover.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from nemo_curator.modifiers import DocumentModifier + +URL_REGEX = re.compile(r"https?://\S+|www\.\S+", flags=re.IGNORECASE) + + +class UrlRemover(DocumentModifier): + """ + Removes all URLs in a document. + """ + + def __init__(self): + super().__init__() + + def modify_document(self, text): + return URL_REGEX.sub("", text) diff --git a/nemo_curator/scripts/text_cleaning.py b/nemo_curator/scripts/text_cleaning.py index f05a38430..87d99099d 100644 --- a/nemo_curator/scripts/text_cleaning.py +++ b/nemo_curator/scripts/text_cleaning.py @@ -14,9 +14,9 @@ import argparse -import nemo_curator +from nemo_curator import Modify, Sequential from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers import UnicodeReformatter +from nemo_curator.modifiers import NewlineNormalizer, UnicodeReformatter, UrlRemover from nemo_curator.utils.distributed_utils import get_client, read_data, write_to_disk from nemo_curator.utils.file_utils import expand_outdir_and_mkdir, get_batched_files from nemo_curator.utils.script_utils import ArgumentHelper @@ -28,9 +28,14 @@ def main(args): # Make the output directories output_clean_dir = expand_outdir_and_mkdir(args.output_clean_dir) - cleaner = nemo_curator.Modify( - UnicodeReformatter(), text_field=args.input_text_field - ) + stages = [Modify(UnicodeReformatter(), text_field=args.input_text_field)] + + if args.normalize_newlines: + stages.append(Modify(NewlineNormalizer(), text_field=args.input_text_field)) + if args.remove_urls: + stages.append(Modify(UrlRemover, text_field=args.text_field)) + + cleaner = Sequential(stages) for files in get_batched_files( args.input_data_dir, @@ -79,6 +84,15 @@ def attach_args( argumentHelper.add_arg_input_text_field() argumentHelper.add_arg_output_file_type() argumentHelper.add_distributed_args() + argumentHelper.attach_bool_arg( + parser, + "normalize-newlines", + default=False, + help="Replace 3 or more consecutive newline characters in each document with only 2 newline characters.", + ) + argumentHelper.attach_bool_arg( + parser, "remove-urls", default=False, help="Removes all URLs in each document." + ) parser.add_argument( "--output-clean-dir", type=str, diff --git a/tests/test_cleaning.py b/tests/test_cleaning.py new file mode 100644 index 000000000..906da3919 --- /dev/null +++ b/tests/test_cleaning.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dask.dataframe as dd +import pandas as pd + +from nemo_curator import Modify +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modifiers import NewlineNormalizer, UnicodeReformatter, UrlRemover + + +def list_to_dataset(documents, col_name="text", npartitions=2): + data = {col_name: documents} + pdf = pd.DataFrame(data) + + return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + + +class TestUnicodeReformatter: + def test_reformatting(self): + # Examples taken from ftfy documentation: + # https://ftfy.readthedocs.io/en/latest/ + dataset = list_to_dataset( + [ + "✔ No problems", + "The Mona Lisa doesn’t have eyebrows.", + "l’humanité", + "à perturber la réflexion", + "Clean document already.", + ] + ) + expected_results = [ + "✔ No problems", + "The Mona Lisa doesn't have eyebrows.", + "l'humanité", + "à perturber la réflexion", + "Clean document already.", + ] + expected_results.sort() + + modifier = Modify(UnicodeReformatter()) + fixed_dataset = modifier(dataset) + actual_results = fixed_dataset.df.compute()["text"].to_list() + actual_results.sort() + + assert ( + expected_results == actual_results + ), f"Expected: {expected_results}, but got: {actual_results}" + + +class TestNewlineNormalizer: + def test_just_newlines(self): + dataset = list_to_dataset( + [ + "The quick brown fox jumps over the lazy dog", + "The quick\nbrown fox jumps \nover the lazy dog", + "The quick\n\nbrown fox jumps \n\nover the lazy dog", + "The quick\n\n\nbrown fox jumps \n\n\nover the lazy dog", + "The quick\n\n\nbrown fox jumps \nover the lazy dog", + ] + ) + expected_results = [ + "The quick brown fox jumps over the lazy dog", + "The quick\nbrown fox jumps \nover the lazy dog", + "The quick\n\nbrown fox jumps \n\nover the lazy dog", + "The quick\n\nbrown fox jumps \n\nover the lazy dog", + "The quick\n\nbrown fox jumps \nover the lazy dog", + ] + expected_results.sort() + + modifier = Modify(NewlineNormalizer()) + fixed_dataset = modifier(dataset) + actual_results = fixed_dataset.df.compute()["text"].to_list() + actual_results.sort() + + assert ( + expected_results == actual_results + ), f"Expected: {expected_results}, but got: {actual_results}" + + def test_newlines_and_carriage_returns(self): + dataset = list_to_dataset( + [ + "The quick brown fox jumps over the lazy dog", + "The quick\r\nbrown fox jumps \r\nover the lazy dog", + "The quick\r\n\r\nbrown fox jumps \r\n\r\nover the lazy dog", + "The quick\r\n\r\n\r\nbrown fox jumps \r\n\r\n\r\nover the lazy dog", + "The quick\r\n\r\n\r\nbrown fox jumps \r\nover the lazy dog", + ] + ) + expected_results = [ + "The quick brown fox jumps over the lazy dog", + "The quick\r\nbrown fox jumps \r\nover the lazy dog", + "The quick\r\n\r\nbrown fox jumps \r\n\r\nover the lazy dog", + "The quick\r\n\r\nbrown fox jumps \r\n\r\nover the lazy dog", + "The quick\r\n\r\nbrown fox jumps \r\nover the lazy dog", + ] + expected_results.sort() + + modifier = Modify(NewlineNormalizer()) + fixed_dataset = modifier(dataset) + actual_results = fixed_dataset.df.compute()["text"].to_list() + actual_results.sort() + + assert ( + expected_results == actual_results + ), f"Expected: {expected_results}, but got: {actual_results}" + + +class TestUrlRemover: + def test_urls(self): + dataset = list_to_dataset( + [ + "This is a url: www.nvidia.com", + "This is a url: http://www.nvidia.com", + "This is a url: https://www.nvidia.com", + "This is a url: https://www.nvidia.gov", + "This is a url: https://nvidia.com", + "This is a url: HTTPS://WWW.NVIDIA.COM", + "This is not a url: git@github.com:NVIDIA/NeMo-Curator.git", + ] + ) + expected_results = [ + "This is a url: ", + "This is a url: ", + "This is a url: ", + "This is a url: ", + "This is a url: ", + "This is a url: ", + "This is not a url: git@github.com:NVIDIA/NeMo-Curator.git", + ] + expected_results.sort() + + modifier = Modify(UrlRemover()) + fixed_dataset = modifier(dataset) + actual_results = fixed_dataset.df.compute()["text"].to_list() + actual_results.sort() + + assert ( + expected_results == actual_results + ), f"Expected: {expected_results}, but got: {actual_results}" diff --git a/tests/test_unicode_reformatter.py b/tests/test_unicode_reformatter.py deleted file mode 100644 index 01ac716bd..000000000 --- a/tests/test_unicode_reformatter.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dask.dataframe as dd -import pandas as pd - -import nemo_curator -from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers import UnicodeReformatter - - -def list_to_dataset(documents, col_name="text", npartitions=2): - data = {col_name: documents} - pdf = pd.DataFrame(data) - - return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) - - -class TestUnicodeReformatter: - def test_reformatting(self): - # Examples taken from ftfy documentation: - # https://ftfy.readthedocs.io/en/latest/ - dataset = list_to_dataset( - [ - "✔ No problems", - "The Mona Lisa doesn’t have eyebrows.", - "l’humanité", - "à perturber la réflexion", - "Clean document already.", - ] - ) - expected_results = [ - "✔ No problems", - "The Mona Lisa doesn't have eyebrows.", - "l'humanité", - "à perturber la réflexion", - "Clean document already.", - ] - expected_results.sort() - - modifier = nemo_curator.Modify(UnicodeReformatter()) - fixed_dataset = modifier(dataset) - actual_results = fixed_dataset.df.compute()["text"].to_list() - actual_results.sort() - - assert ( - expected_results == actual_results - ), f"Expected: {expected_results}, but got: {actual_results}" From 34a1cc6b7de2bc5dc2938c3ff82debdeadd64230 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:03:43 -0800 Subject: [PATCH 5/5] Update model nomenclature (#497) * Update model nomenclature Signed-off-by: Sarah Yurick * minor notebook grammar Signed-off-by: Sarah Yurick * add lawrence's suggestion Signed-off-by: Sarah Yurick --------- Signed-off-by: Sarah Yurick --- docs/user-guide/cpuvsgpu.rst | 4 +-- .../distributeddataclassification.rst | 28 +++++++++---------- examples/classifiers/README.md | 4 +-- .../instruction_data_guard_example.py | 2 +- nemo_curator/classifiers/aegis.py | 7 +++-- nemo_curator/classifiers/content_type.py | 3 +- nemo_curator/classifiers/domain.py | 4 +-- .../classifiers/prompt_task_complexity.py | 3 +- nemo_curator/classifiers/quality.py | 4 +-- nemo_curator/scripts/classifiers/README.md | 20 ++++++------- ...ruction_data_guard_classifier_inference.py | 6 ++-- tests/test_classifiers.py | 2 +- .../distributed_data_classification/README.md | 2 +- .../content-type-classification.ipynb | 2 +- .../domain-classification.ipynb | 2 +- ...nstruction-data-guard-classification.ipynb | 6 ++-- .../multilingual-domain-classification.ipynb | 2 +- ...rompt-task-complexity-classification.ipynb | 2 +- .../quality-classification.ipynb | 4 +-- 19 files changed, 56 insertions(+), 51 deletions(-) diff --git a/docs/user-guide/cpuvsgpu.rst b/docs/user-guide/cpuvsgpu.rst index 337c6e630..bdc3e4838 100644 --- a/docs/user-guide/cpuvsgpu.rst +++ b/docs/user-guide/cpuvsgpu.rst @@ -69,10 +69,10 @@ The following NeMo Curator modules are GPU based. * Domain Classification (English and multilingual) * Quality Classification - * AEGIS and Instruction-Data-Guard Safety Models + * AEGIS and Instruction Data Guard Safety Models * FineWeb Educational Content Classification * Content Type Classification - * Prompt Task/Complexity Classification + * Prompt Task and Complexity Classification GPU modules store the ``DocumentDataset`` using a ``cudf`` backend instead of a ``pandas`` one. To read a dataset into GPU memory, one could use the following function call. diff --git a/docs/user-guide/distributeddataclassification.rst b/docs/user-guide/distributeddataclassification.rst index 257de441a..389e8ef1b 100644 --- a/docs/user-guide/distributeddataclassification.rst +++ b/docs/user-guide/distributeddataclassification.rst @@ -15,7 +15,7 @@ NeMo Curator provides a module to help users run inference with pre-trained mode This is achieved by chunking the datasets across multiple computing nodes, each equipped with multiple GPUs, to accelerate the classification task in a distributed manner. Since the classification of a single text document is independent of other documents within the dataset, we can distribute the workload across multiple nodes and GPUs to perform parallel processing. -Domain (English and multilingual), quality, content safety, educational content, content type, and prompt task/complexity models are tasks we include as examples within our module. +Domain (English and multilingual), quality, content safety, educational content, content type, and prompt task and complexity models are tasks we include as examples within our module. Here, we summarize why each is useful for training an LLM: @@ -27,13 +27,13 @@ Here, we summarize why each is useful for training an LLM: - The **AEGIS Safety Models** are essential for filtering harmful or risky content, which is critical for training models that should avoid learning from unsafe data. By classifying content into 13 critical risk categories, AEGIS helps remove harmful or inappropriate data from the training sets, improving the overall ethical and safety standards of the LLM. -- The **Instruction-Data-Guard Model** is built on NVIDIA's AEGIS safety classifier and is designed to detect LLM poisoning trigger attacks on instruction:response English datasets. +- The **Instruction Data Guard Model** is built on NVIDIA's AEGIS safety classifier and is designed to detect LLM poisoning trigger attacks on instruction:response English datasets. - The **FineWeb Educational Content Classifier** focuses on identifying and prioritizing educational material within datasets. This classifier is especially useful for training LLMs on specialized educational content, which can improve their performance on knowledge-intensive tasks. Models trained on high-quality educational content demonstrate enhanced capabilities on academic benchmarks such as MMLU and ARC, showcasing the classifier's impact on improving the knowledge-intensive task performance of LLMs. - The **Content Type Classifier** is designed to categorize documents into one of 11 distinct speech types based on their content. It analyzes and understands the nuances of textual information, enabling accurate classification across a diverse range of content types. -- The **Prompt Task/Complexity Classifier** is a multi-headed model which classifies English text prompts across task types and complexity dimensions. +- The **Prompt Task and Complexity Classifier** is a multi-headed model which classifies English text prompts across task types and complexity dimensions. ----------------------------------------- Usage @@ -95,8 +95,8 @@ Using the ``MultilingualDomainClassifier`` is very similar to using the ``Domain For more information about the multilingual domain classifier, including its supported languages, please see the `nvidia/multilingual-domain-classifier `_ on Hugging Face. -Quality Classifier -^^^^^^^^^^^^^^^^^^ +Quality Classifier DeBERTa +^^^^^^^^^^^^^^^^^^^^^^^^^^ The Quality Classifier is designed to assess the quality of text documents, helping to filter out low-quality or noisy data from your dataset. @@ -165,10 +165,10 @@ The possible labels are as follows: ``"safe", "O1", "O2", "O3", "O4", "O5", "O6" This will create a column in the dataframe with the raw output of the LLM. You can choose to parse this response however you want. -Instruction-Data-Guard Model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Instruction Data Guard +^^^^^^^^^^^^^^^^^^^^^^ -Instruction-Data-Guard is a classification model designed to detect LLM poisoning trigger attacks. +Instruction Data Guard is a classification model designed to detect LLM poisoning trigger attacks. These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors that only activate when specific trigger phrases are used. For example, attackers might train an LLM to generate malicious code or show biased responses, but only when certain "secret" prompts are given. @@ -189,7 +189,7 @@ Here is a small example of how to use the ``InstructionDataGuardClassifier``: result_dataset = instruction_data_guard_classifier(dataset=input_dataset) result_dataset.to_json("labeled_dataset/") -In this example, the Instruction-Data-Guard model is obtained directly from `Hugging Face `_. +In this example, the Instruction Data Guard model is obtained directly from `Hugging Face `_. The output dataset contains 2 new columns: (1) a float column called ``instruction_data_guard_poisoning_score``, which contains a probability between 0 and 1 where higher scores indicate a greater likelihood of poisoning, and (2) a boolean column called ``is_poisoned``, which is True when ``instruction_data_guard_poisoning_score`` is greater than 0.5 and False otherwise. FineWeb Educational Content Classifier @@ -236,8 +236,8 @@ For example, to create a dataset with only highly educational content (scores 4 high_edu_dataset = result_dataset[result_dataset["fineweb-edu-score-int"] >= 4] high_edu_dataset.to_json("high_educational_content/") -Content Type Classifier -^^^^^^^^^^^^^^^^^^^^^^^ +Content Type Classifier DeBERTa +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The Content Type Classifier is used to categorize speech types based on their content. It analyzes and understands the nuances of textual information, enabling accurate classification across a diverse range of content types. @@ -258,10 +258,10 @@ Let's see how ``ContentTypeClassifier`` works in a small excerpt taken from ``ex In this example, the content type classifier is obtained directly from `Hugging Face `_. It filters the input dataset to include only documents classified as "Blogs" or "News". -Prompt Task/Complexity Classifier -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Prompt Task and Complexity Classifier +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The Prompt Task/Complexity Classifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. +The Prompt Task and Complexity Classifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. Here's an example of how to use the ``PromptTaskComplexityClassifier``: diff --git a/examples/classifiers/README.md b/examples/classifiers/README.md index 036811c12..fad2a6913 100644 --- a/examples/classifiers/README.md +++ b/examples/classifiers/README.md @@ -6,10 +6,10 @@ The Python scripts in this directory demonstrate how to run classification on yo - Multilingual Domain Classifier - Quality Classifier - AEGIS Safety Models -- Instruction-Data-Guard Model +- Instruction Data Guard Model - FineWeb Educational Content Classifier - Content Type Classifier -- Prompt Task/Complexity Classifier +- Prompt Task and Complexity Classifier For more information about these classifiers, please see NeMo Curator's [Distributed Data Classification documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html). diff --git a/examples/classifiers/instruction_data_guard_example.py b/examples/classifiers/instruction_data_guard_example.py index 246c39dee..6e39f5395 100644 --- a/examples/classifiers/instruction_data_guard_example.py +++ b/examples/classifiers/instruction_data_guard_example.py @@ -48,7 +48,7 @@ def main(args): global_et = time.time() print( - f"Total time taken for Instruction-Data-Guard classifier inference: {global_et-global_st} s", + f"Total time taken for Instruction Data Guard classifier inference: {global_et-global_st} s", flush=True, ) diff --git a/nemo_curator/classifiers/aegis.py b/nemo_curator/classifiers/aegis.py index 7376bdbb7..2951959a0 100644 --- a/nemo_curator/classifiers/aegis.py +++ b/nemo_curator/classifiers/aegis.py @@ -380,12 +380,15 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: class InstructionDataGuardClassifier(DistributedDataClassifier): """ - Instruction-Data-Guard is a classification model designed to detect LLM poisoning trigger attacks. + Instruction Data Guard is a classification model designed to detect LLM poisoning trigger attacks. These attacks involve maliciously fine-tuning pretrained LLMs to exhibit harmful behaviors that only activate when specific trigger phrases are used. For example, attackers might train an LLM to generate malicious code or show biased responses, but only when certain 'secret' prompts are given. + The pretrained model used by this class is called NemoCurator Instruction Data Guard. + It can be found on Hugging Face here: https://huggingface.co/nvidia/instruction-data-guard. + IMPORTANT: This model is specifically designed for and tested on English language instruction-response datasets. Performance on non-English content has not been validated. @@ -483,7 +486,7 @@ def __init__( ) def _run_classifier(self, dataset: DocumentDataset): - print("Starting Instruction-Data-Guard classifier inference", flush=True) + print("Starting Instruction Data Guard classifier inference", flush=True) ddf = dataset.df columns = ddf.columns.tolist() tokenizer = op.Tokenizer( diff --git a/nemo_curator/classifiers/content_type.py b/nemo_curator/classifiers/content_type.py index 617d51726..19e1f25d8 100644 --- a/nemo_curator/classifiers/content_type.py +++ b/nemo_curator/classifiers/content_type.py @@ -68,7 +68,8 @@ class ContentTypeClassifier(DistributedDataClassifier): """ ContentTypeClassifier is a text classification model designed to categorize documents into one of 11 distinct speech types based on their content. It analyzes and understands the nuances of textual information, enabling accurate classification across a diverse range of content types. - The pretrained model used by this class can be found on Hugging Face here: https://huggingface.co/nvidia/content-type-classifier-deberta. + The pretrained model used by this class is called NemoCurator Content Type Classifier DeBERTa. + It can be found on Hugging Face here: https://huggingface.co/nvidia/content-type-classifier-deberta. This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets. Attributes: diff --git a/nemo_curator/classifiers/domain.py b/nemo_curator/classifiers/domain.py index 50e0d1cdf..11c50f75e 100644 --- a/nemo_curator/classifiers/domain.py +++ b/nemo_curator/classifiers/domain.py @@ -147,7 +147,7 @@ def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: class DomainClassifier(_DomainClassifier): """ DomainClassifier is a specialized classifier designed for English text domain classification tasks, - utilizing the NVIDIA Domain Classifier (https://huggingface.co/nvidia/domain-classifier) model. + utilizing the NemoCurator Domain Classifier (https://huggingface.co/nvidia/domain-classifier) model. This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets. Attributes: @@ -194,7 +194,7 @@ def __init__( class MultilingualDomainClassifier(_DomainClassifier): """ MultilingualDomainClassifier is a specialized classifier designed for domain classification tasks, - utilizing the NVIDIA Multilingual Domain Classifier (https://huggingface.co/nvidia/multilingual-domain-classifier) model. + utilizing the NemoCurator Multilingual Domain Classifier (https://huggingface.co/nvidia/multilingual-domain-classifier) model. It supports domain classification across 52 languages. This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets. diff --git a/nemo_curator/classifiers/prompt_task_complexity.py b/nemo_curator/classifiers/prompt_task_complexity.py index 4f2c4efc2..32db83824 100644 --- a/nemo_curator/classifiers/prompt_task_complexity.py +++ b/nemo_curator/classifiers/prompt_task_complexity.py @@ -284,7 +284,8 @@ class PromptTaskComplexityClassifier(DistributedDataClassifier): """ PromptTaskComplexityClassifier is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. - Further information on the taxonomies can be found on Hugging Face: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier. + Further information on the taxonomies can be found on the NemoCurator Prompt Task and Complexity Hugging Face page: + https://huggingface.co/nvidia/prompt-task-and-complexity-classifier. This class is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets. Attributes: diff --git a/nemo_curator/classifiers/quality.py b/nemo_curator/classifiers/quality.py index 31542b721..7f7a3ed25 100644 --- a/nemo_curator/classifiers/quality.py +++ b/nemo_curator/classifiers/quality.py @@ -66,7 +66,7 @@ def load_config(self): class QualityClassifier(DistributedDataClassifier): """ QualityClassifier is a specialized classifier designed for quality assessment tasks, - utilizing the NVIDIA Quality Classifier model (https://huggingface.co/nvidia/quality-classifier-deberta). + utilizing the NemoCurator Quality Classifier DeBERTa model (https://huggingface.co/nvidia/quality-classifier-deberta). This classifier is optimized for running on multi-node, multi-GPU setups to enable fast and efficient inference on large datasets. Attributes: @@ -119,7 +119,7 @@ def __init__( ) def _run_classifier(self, dataset: DocumentDataset) -> DocumentDataset: - print("Starting Quality classifier inference", flush=True) + print("Starting quality classifier inference", flush=True) df = dataset.df df = _run_classifier_helper( df=df, diff --git a/nemo_curator/scripts/classifiers/README.md b/nemo_curator/scripts/classifiers/README.md index 19f3e6dc0..6ca5cdefb 100644 --- a/nemo_curator/scripts/classifiers/README.md +++ b/nemo_curator/scripts/classifiers/README.md @@ -6,16 +6,16 @@ The Python scripts in this directory demonstrate how to run classification on yo - Multilingual Domain Classifier - Quality Classifier - AEGIS Safety Models -- Instruction-Data-Guard Model +- Instruction Data Guard Model - FineWeb Educational Content Classifier - Content Type Classifier -- Prompt Task/Complexity Classifier +- Prompt Task and Complexity Classifier For more information about these classifiers, please see NeMo Curator's [Distributed Data Classification documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/distributeddataclassification.html). ### Usage -#### Domain classifier inference +#### Domain Classifier Inference This classifier is recommended for English-only text data. @@ -36,7 +36,7 @@ domain_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `domain_classifier_inference --help` for more information. -#### Multilingual domain classifier inference +#### Multilingual Domain Classifier Inference This classifier supports domain classification in 52 languages. Please see [nvidia/multilingual-domain-classifier on Hugging Face](https://huggingface.co/nvidia/multilingual-domain-classifier) for more information. @@ -57,7 +57,7 @@ multilingual_domain_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `multilingual_domain_classifier_inference --help` for more information. -#### Quality classifier inference +#### Quality Classifier DeBERTa Inference ```bash # same as `python quality_classifier_inference.py` @@ -76,7 +76,7 @@ quality_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `quality_classifier_inference --help` for more information. -#### AEGIS classifier inference +#### AEGIS Classifier Inference ```bash # same as `python aegis_classifier_inference.py` @@ -99,7 +99,7 @@ aegis_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `aegis_classifier_inference --help` for more information. -#### Instruction-Data-Guard classifier inference +#### Instruction Data Guard Classifier Inference ```bash # same as `python instruction_data_guard_classifier_inference.py` @@ -120,7 +120,7 @@ In the above example, `--token` is your HuggingFace token, which is used when do Additional arguments may be added for customizing a Dask cluster and client. Run `instruction_data_guard_classifier_inference --help` for more information. -#### FineWeb-Edu classifier inference +#### FineWeb-Edu Classifier Inference ```bash # same as `python fineweb_edu_classifier_inference.py` @@ -139,7 +139,7 @@ fineweb_edu_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `fineweb_edu_classifier_inference --help` for more information. -#### Content type classifier inference +#### Content Type Classifier DeBERTa Inference ```bash # same as `python content_type_classifier_inference.py` @@ -158,7 +158,7 @@ content_type_classifier_inference \ Additional arguments may be added for customizing a Dask cluster and client. Run `content_type_classifier_inference --help` for more information. -#### Prompt task and complexity classifier inference +#### Prompt Task and Complexity Classifier Inference ```bash # same as `python prompt_task_complexity_classifier_inference.py` diff --git a/nemo_curator/scripts/classifiers/instruction_data_guard_classifier_inference.py b/nemo_curator/scripts/classifiers/instruction_data_guard_classifier_inference.py index 64b248872..087b669e9 100644 --- a/nemo_curator/scripts/classifiers/instruction_data_guard_classifier_inference.py +++ b/nemo_curator/scripts/classifiers/instruction_data_guard_classifier_inference.py @@ -36,7 +36,7 @@ def main(): client_args = ArgumentHelper.parse_client_args(args) client_args["cluster_type"] = "gpu" client = get_client(**client_args) - print("Starting Instruction-Data-Guard classifier inference", flush=True) + print("Starting Instruction Data Guard classifier inference", flush=True) global_st = time.time() files_per_run = len(client.scheduler_info()["workers"]) * 2 @@ -97,7 +97,7 @@ def main(): global_et = time.time() print( - f"Total time taken for Instruction-Data-Guard classifier inference: {global_et-global_st} s", + f"Total time taken for Instruction Data Guard classifier inference: {global_et-global_st} s", flush=True, ) client.close() @@ -105,7 +105,7 @@ def main(): def attach_args(): parser = ArgumentHelper.parse_distributed_classifier_args( - description="Run Instruction-Data-Guard classifier inference.", + description="Run Instruction Data Guard classifier inference.", max_chars_default=6000, ) diff --git a/tests/test_classifiers.py b/tests/test_classifiers.py index 5d681089f..1d37e7f5f 100644 --- a/tests/test_classifiers.py +++ b/tests/test_classifiers.py @@ -150,7 +150,7 @@ def test_fineweb_edu_classifier(gpu_client, domain_dataset): @pytest.mark.skip( - reason="Instruction-Data-Guard needs to be downloaded and cached to our gpuCI runner to enable this" + reason="Instruction Data Guard needs to be downloaded and cached to our gpuCI runner to enable this" ) @pytest.mark.gpu def test_instruction_data_guard_classifier(gpu_client): diff --git a/tutorials/distributed_data_classification/README.md b/tutorials/distributed_data_classification/README.md index 2b0bf51b5..f953d8f5b 100644 --- a/tutorials/distributed_data_classification/README.md +++ b/tutorials/distributed_data_classification/README.md @@ -12,7 +12,7 @@ Before running any of these notebooks, please see this [Getting Started](https:/
-| NeMo Curator Classifier | Hugging Face page | +| NeMo Curator Classifier | Hugging Face Page | | --- | --- | | `AegisClassifier` | [nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0) and [nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-LlamaGuard-Permissive-1.0) | | `ContentTypeClassifier` | [nvidia/content-type-classifier-deberta](https://huggingface.co/nvidia/content-type-classifier-deberta) | diff --git a/tutorials/distributed_data_classification/content-type-classification.ipynb b/tutorials/distributed_data_classification/content-type-classification.ipynb index 97df8485c..2a7b54235 100644 --- a/tutorials/distributed_data_classification/content-type-classification.ipynb +++ b/tutorials/distributed_data_classification/content-type-classification.ipynb @@ -6,7 +6,7 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `ContentTypeClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `ContentTypeClassifier`. The [content type classifier](https://huggingface.co/nvidia/content-type-classifier-deberta) is used to categorize documents into one of 11 distinct speech types based on their content. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the content type classifier, including its output labels, here: https://huggingface.co/nvidia/content-type-classifier-deberta.\n", + "This notebook demonstrates the use of NeMo Curator's `ContentTypeClassifier`. The [content type classifier](https://huggingface.co/nvidia/content-type-classifier-deberta) is used to categorize documents into one of 11 distinct speech types based on their content. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the NemoCurator Content Type Classifier DeBERTa Hugging Face page for more information about the content type classifier, including its output labels, here: https://huggingface.co/nvidia/content-type-classifier-deberta.\n", "\n", "The content type classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", diff --git a/tutorials/distributed_data_classification/domain-classification.ipynb b/tutorials/distributed_data_classification/domain-classification.ipynb index 5a5aff14d..8c5686de9 100644 --- a/tutorials/distributed_data_classification/domain-classification.ipynb +++ b/tutorials/distributed_data_classification/domain-classification.ipynb @@ -6,7 +6,7 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `DomainClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `DomainClassifier`. The [domain classifier](https://huggingface.co/nvidia/domain-classifier) is used to classify the domain of a text. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the domain classifier, including its output labels, here: https://huggingface.co/nvidia/domain-classifier.\n", + "This notebook demonstrates the use of NeMo Curator's `DomainClassifier`. The [domain classifier](https://huggingface.co/nvidia/domain-classifier) is used to classify the domain of a text. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the NemoCurator Domain Classifier Hugging Face page for more information about the domain classifier, including its output labels, here: https://huggingface.co/nvidia/domain-classifier.\n", "\n", "The domain classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", diff --git a/tutorials/distributed_data_classification/instruction-data-guard-classification.ipynb b/tutorials/distributed_data_classification/instruction-data-guard-classification.ipynb index 14ec962fe..5394fbe5b 100644 --- a/tutorials/distributed_data_classification/instruction-data-guard-classification.ipynb +++ b/tutorials/distributed_data_classification/instruction-data-guard-classification.ipynb @@ -6,11 +6,11 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `InstructionDataGuardClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `InstructionDataGuardClassifier`. The [Instruction-Data-Guard classifier](https://huggingface.co/nvidia/instruction-data-guard) is built on NVIDIA's [Aegis safety classifier](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0) and is designed to detect LLM poisoning trigger attacks. Please refer to the Hugging Face page for more information about the Instruction-Data-Guard classifier here: https://huggingface.co/nvidia/instruction-data-guard.\n", + "This notebook demonstrates the use of NeMo Curator's `InstructionDataGuardClassifier`. The [Instruction Data Guard classifier](https://huggingface.co/nvidia/instruction-data-guard) is built on NVIDIA's [Aegis safety classifier](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0) and is designed to detect LLM poisoning trigger attacks. Please refer to the NemoCurator Instruction Data Guard Hugging Face page for more information about the Instruction Data Guard classifier here: https://huggingface.co/nvidia/instruction-data-guard.\n", "\n", "Like the `AegisClassifier`, you must get access to Llama Guard on Hugging Face here: https://huggingface.co/meta-llama/LlamaGuard-7b. Afterwards, you should set up a [user access token](https://huggingface.co/docs/hub/en/security-tokens) and pass that token into the constructor of this classifier.\n", "\n", - "The Instruction-Data-Guard classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", + "The Instruction Data Guard classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", "Before running this notebook, please see this [Getting Started](https://github.com/NVIDIA/NeMo-Curator?tab=readme-ov-file#get-started) page for instructions on how to install NeMo Curator." ] @@ -145,7 +145,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Starting Instruction-Data-Guard classifier inference\n" + "Starting Instruction Data Guard classifier inference\n" ] }, { diff --git a/tutorials/distributed_data_classification/multilingual-domain-classification.ipynb b/tutorials/distributed_data_classification/multilingual-domain-classification.ipynb index 431dcc3f7..7a9b4e898 100644 --- a/tutorials/distributed_data_classification/multilingual-domain-classification.ipynb +++ b/tutorials/distributed_data_classification/multilingual-domain-classification.ipynb @@ -6,7 +6,7 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `MultilingualDomainClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `MultilingualDomainClassifier`. The [multilingual domain classifier](https://huggingface.co/nvidia/multilingual-domain-classifier) is used to classify the domain of texts in any of 52 languages, including English. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the multilingual domain classifier, including its output labels, here: https://huggingface.co/nvidia/multilingual-domain-classifier.\n", + "This notebook demonstrates the use of NeMo Curator's `MultilingualDomainClassifier`. The [multilingual domain classifier](https://huggingface.co/nvidia/multilingual-domain-classifier) is used to classify the domain of texts in any of 52 languages, including English. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the NemoCurator Multilingual Domain Classifier Hugging Face page for more information about the multilingual domain classifier, including its output labels, here: https://huggingface.co/nvidia/multilingual-domain-classifier.\n", "\n", "The multilingual domain classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", diff --git a/tutorials/distributed_data_classification/prompt-task-complexity-classification.ipynb b/tutorials/distributed_data_classification/prompt-task-complexity-classification.ipynb index a77599aed..5e90d28c6 100644 --- a/tutorials/distributed_data_classification/prompt-task-complexity-classification.ipynb +++ b/tutorials/distributed_data_classification/prompt-task-complexity-classification.ipynb @@ -6,7 +6,7 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `PromptTaskComplexityClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `PromptTaskComplexityClassifier`. The [prompt task and complexity classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier) a multi-headed model which classifies English text prompts across task types and complexity dimensions. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the prompt task and complexity classifier, including its output labels, here: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier.\n", + "This notebook demonstrates the use of NeMo Curator's `PromptTaskComplexityClassifier`. The [prompt task and complexity classifier](https://huggingface.co/nvidia/prompt-task-and-complexity-classifier) a multi-headed model which classifies English text prompts across task types and complexity dimensions. It helps with data annotation, which is useful in data blending for foundation model training. Please refer to the NemoCurator Prompt Task and Complexity Classifier Hugging Face page for more information about the prompt task and complexity classifier, including its output labels, here: https://huggingface.co/nvidia/prompt-task-and-complexity-classifier.\n", "\n", "The prompt task and complexity classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", diff --git a/tutorials/distributed_data_classification/quality-classification.ipynb b/tutorials/distributed_data_classification/quality-classification.ipynb index c54376539..79b1fdb96 100644 --- a/tutorials/distributed_data_classification/quality-classification.ipynb +++ b/tutorials/distributed_data_classification/quality-classification.ipynb @@ -6,7 +6,7 @@ "source": [ "# Distributed Data Classification with NeMo Curator's `QualityClassifier`\n", "\n", - "This notebook demonstrates the use of NeMo Curator's `QualityClassifier`. The [quality classifier](https://huggingface.co/nvidia/quality-classifier-deberta) is used to classify text as high, medium, or low quality. This helps with data annotation, which is useful in data blending for foundation model training. Please refer to the Hugging Face page for more information about the quality classifier, including its output labels, here: https://huggingface.co/nvidia/quality-classifier-deberta.\n", + "This notebook demonstrates the use of NeMo Curator's `QualityClassifier`. The [quality classifier](https://huggingface.co/nvidia/quality-classifier-deberta) is used to classify text as high, medium, or low quality. This helps with data annotation, which is useful in data blending for foundation model training. Please refer to the NemoCurator Quality Classifier DeBERTa Hugging Face page for more information about the quality classifier, including its output labels, here: https://huggingface.co/nvidia/quality-classifier-deberta.\n", "\n", "The quality classifier is accelerated using [CrossFit](https://github.com/rapidsai/crossfit), a library that leverages intellegent batching and RAPIDS to accelerate the offline inference on large datasets.\n", "\n", @@ -186,7 +186,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Starting Quality classifier inference\n", + "Starting quality classifier inference\n", "Writing to disk complete for 1 partition(s)\n", "CPU times: user 2.84 s, sys: 1.2 s, total: 4.04 s\n", "Wall time: 19.8 s\n"