From 2daedd75463a648181d90857d72b06eda1f6411e Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 5 Feb 2025 17:32:20 -0800 Subject: [PATCH 01/22] Add document splitter and joiner Signed-off-by: Ryan Wolf --- nemo_curator/modules/__init__.py | 3 + nemo_curator/modules/splitter.py | 123 +++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 nemo_curator/modules/splitter.py diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index 897e54025..6dc285245 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -27,6 +27,7 @@ from .exact_dedup import ExactDuplicates from .meta import Sequential from .modify import Modify +from .splitter import DocumentSplitter, DocumentJoiner from .task import TaskDecontamination # GPU packages @@ -88,4 +89,6 @@ "ClusteringModel", "SemanticClusterLevelDedup", "SemDedup", + "DocumentSplitter", + "DocumentJoiner", ] diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py new file mode 100644 index 000000000..d712f7d4a --- /dev/null +++ b/nemo_curator/modules/splitter.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd + +from nemo_curator.datasets import DocumentDataset + + +class DocumentSplitter: + """ + Splits documents into segments based on a separator. + Each segment is a new document with an additional column + indicating the segment id. + + To restore the original document, ensure that each document + has a unique id prior to splitting. + """ + + def __init__( + self, + separator: str, + text_field: str = "text", + segment_id_field: str = "segment_id", + ): + """ + Args: + separator (str): The separator to split the documents on. + text_field (str): The name of the column containing the text to split. + segment_id_field (str): The name of the column to add to indicate the segment id. + """ + self.separator = separator + self.text_field = text_field + self.segment_id_field = segment_id_field + + def _split_partition(self, df: pd.DataFrame) -> pd.DataFrame: + # Work on a copy to avoid modifying the original dataframe in place. + df = df.copy() + # Split the text field into segments using the separator. + df["split_text"] = df[self.text_field].str.split(self.separator) + # Explode the list so that each segment becomes a separate row. + df = df.explode("split_text") + # For each original document (grouped by the original index), assign a segment id. + df[self.segment_id_field] = df.groupby(level=0).cumcount() + # Replace the original text field with the split segment. + df[self.text_field] = df["split_text"] + # Drop the temporary column. + df = df.drop(columns="split_text") + return df + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Splits the documents into segments based on the separator and + adds a column indicating the segment id. + """ + + # Construct meta information for the transformed dataframe. + meta = dataset.df._meta.copy() + if self.segment_id_field not in meta.columns: + meta[self.segment_id_field] = pd.Series(dtype="int64") + + # Apply the partition-wise splitting transformation using Dask's map_partitions. + dataset.df = dataset.df.map_partitions(self._split_partition, meta=meta) + return dataset + + +class DocumentJoiner: + """ + Joins documents that have a common id back into a single document. + The order of the documents is dictated by an additional segment_id column. + + The joined documents are joined by a separator. + """ + + def __init__( + self, + separator: str, + text_field: str = "text", + segment_id_field: str = "segment_id", + ): + """ + Args: + separator (str): The separator to join the documents on. + text_field (str): The name of the column containing the text to join. + segment_id_field (str): The name of the column containing the segment id. + """ + self.separator = separator + self.text_field = text_field + self.segment_id_field = segment_id_field + + def _join_partition(self, df: pd.DataFrame) -> pd.DataFrame: + if df.empty: + return df + # Sort the segments so that they are in the correct order. + df_sorted = df.sort_values(self.segment_id_field) + # Group by the original document index (level 0) and join the segments using the separator. + joined = df_sorted.groupby(level=0)[self.text_field].apply( + lambda texts: self.separator.join(texts) + ) + # Convert the result back to a DataFrame. + return joined.to_frame(name=self.text_field) + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Joins the documents back into a single document. + """ + # Construct meta information for the transformed dataframe. + meta = dataset.df._meta.copy() + if self.text_field not in meta.columns: + meta[self.text_field] = pd.Series(dtype="object") + + # Apply the join operation partition-wise. + dataset.df = dataset.df.map_partitions(self._join_partition, meta=meta) + return dataset From cbf0ab07b52953b53130f4c8efdd0b306e1615cc Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 5 Feb 2025 17:55:48 -0800 Subject: [PATCH 02/22] Add support for id field in joiner Signed-off-by: Ryan Wolf --- nemo_curator/modules/splitter.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index d712f7d4a..3713c84aa 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -86,28 +86,31 @@ def __init__( separator: str, text_field: str = "text", segment_id_field: str = "segment_id", + document_id_field: str = "id", ): """ Args: separator (str): The separator to join the documents on. text_field (str): The name of the column containing the text to join. segment_id_field (str): The name of the column containing the segment id. + document_id_field (str): The name of the column containing the document id. """ self.separator = separator self.text_field = text_field self.segment_id_field = segment_id_field + self.document_id_field = document_id_field def _join_partition(self, df: pd.DataFrame) -> pd.DataFrame: if df.empty: return df - # Sort the segments so that they are in the correct order. + # Sort the segments by the segment_id_field to maintain the proper order. df_sorted = df.sort_values(self.segment_id_field) - # Group by the original document index (level 0) and join the segments using the separator. - joined = df_sorted.groupby(level=0)[self.text_field].apply( + # Group by the document_id_field and join the segments using the separator. + joined = df_sorted.groupby(self.document_id_field)[self.text_field].apply( lambda texts: self.separator.join(texts) ) - # Convert the result back to a DataFrame. - return joined.to_frame(name=self.text_field) + # Convert the joined result back to a DataFrame and reset the index to include document_id_field. + return joined.to_frame(name=self.text_field).reset_index() def __call__(self, dataset: DocumentDataset) -> DocumentDataset: """ From 56de7c64f9f9eab7228ece626a13dfeb2fa958f5 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 5 Feb 2025 18:11:47 -0800 Subject: [PATCH 03/22] Fix splitter and joiner Signed-off-by: Ryan Wolf --- nemo_curator/modules/splitter.py | 45 +++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index 3713c84aa..67ccb1d2d 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import pandas as pd from nemo_curator.datasets import DocumentDataset @@ -87,6 +89,7 @@ def __init__( text_field: str = "text", segment_id_field: str = "segment_id", document_id_field: str = "id", + drop_segment_id_field: bool = True, ): """ Args: @@ -94,33 +97,55 @@ def __init__( text_field (str): The name of the column containing the text to join. segment_id_field (str): The name of the column containing the segment id. document_id_field (str): The name of the column containing the document id. + drop_segment_id_field (bool): Whether to drop the segment_id_field after joining. """ self.separator = separator self.text_field = text_field self.segment_id_field = segment_id_field self.document_id_field = document_id_field + self.drop_segment_id_field = drop_segment_id_field - def _join_partition(self, df: pd.DataFrame) -> pd.DataFrame: + def _join_partition( + self, df: pd.DataFrame, expected_cols: List[str] + ) -> pd.DataFrame: if df.empty: return df - # Sort the segments by the segment_id_field to maintain the proper order. + # Sort the segments by the segment_id_field to maintain proper order before aggregating. df_sorted = df.sort_values(self.segment_id_field) - # Group by the document_id_field and join the segments using the separator. - joined = df_sorted.groupby(self.document_id_field)[self.text_field].apply( - lambda texts: self.separator.join(texts) + # Build aggregation functions to preserve all original columns: + # - For self.text_field, join all segments using the separator. + # - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence. + agg_funcs = {} + for col in df_sorted.columns: + if col == self.text_field: + agg_funcs[col] = lambda texts: self.separator.join(texts.astype(str)) + elif col != self.document_id_field: + agg_funcs[col] = "first" + # Group by document_id_field while keeping the key as a column. + joined = df_sorted.groupby(self.document_id_field, as_index=False).agg( + agg_funcs ) - # Convert the joined result back to a DataFrame and reset the index to include document_id_field. - return joined.to_frame(name=self.text_field).reset_index() + + if self.drop_segment_id_field: + joined = joined.drop(columns=self.segment_id_field) + # Reorder the columns to match the expected metadata order. + joined = joined[expected_cols] + return joined def __call__(self, dataset: DocumentDataset) -> DocumentDataset: """ - Joins the documents back into a single document. + Joins the documents back into a single document while preserving all the original fields. """ # Construct meta information for the transformed dataframe. meta = dataset.df._meta.copy() if self.text_field not in meta.columns: meta[self.text_field] = pd.Series(dtype="object") - + # If dropping the segment id field, remove it from the metadata to prevent mismatches. + if self.drop_segment_id_field: + meta = meta.drop(columns=self.segment_id_field) + expected_cols = list(meta.columns) # Apply the join operation partition-wise. - dataset.df = dataset.df.map_partitions(self._join_partition, meta=meta) + dataset.df = dataset.df.map_partitions( + self._join_partition, expected_cols=expected_cols, meta=meta + ) return dataset From 6729a68ff48d4931b0e0f8fc1238ecfd40ca0472 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 5 Feb 2025 18:22:00 -0800 Subject: [PATCH 04/22] Add token count filter Signed-off-by: Ryan Wolf --- nemo_curator/filters/__init__.py | 2 ++ nemo_curator/filters/heuristic_filter.py | 29 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 9905c8370..935e1f993 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -50,6 +50,7 @@ RepeatingDuplicateNGramsFilter, RepeatingTopNGramsFilter, SymbolsToWordsFilter, + TokenCountFilter, UrlsFilter, WhiteSpaceFilter, WordCountFilter, @@ -98,4 +99,5 @@ "QualityEstimationFilter", "AnswerabilityFilter", "EasinessFilter", + "TokenCountFilter", ] diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index c17e4e9a3..24759ba1c 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -17,6 +17,7 @@ import requests from platformdirs import user_cache_dir +from transformers import AutoTokenizer from nemo_curator.filters.bitext_filter import BitextFilter from nemo_curator.filters.doc_filter import DocumentFilter, import_filter @@ -671,6 +672,34 @@ def keep_document(self, score): return score != 1 +class TokenCountFilter(DocumentFilter): + """ + If the document contains more or less than a specified number of tokens, then discard. + """ + + def __init__(self, tokenizer: AutoTokenizer, min_tokens=10, max_tokens=100000): + """ + Args: + tokenizer (AutoTokenizer): The tokenizer to use to count the tokens. + min_tokens (int): The minimum number of tokens the document must contain. + Set to 0 to disable the minimum token count filter. + max_tokens (int): The maximum number of tokens the document can contain. + Set to infinity to disable the maximum token count filter. + """ + super().__init__() + self._tokenizer = tokenizer + self._min_tokens = min_tokens + self._max_tokens = max_tokens + self._name = "token_count" + + def score_document(self, text): + tokens = self._tokenizer.encode(text) + return len(tokens) + + def keep_document(self, score): + return self._min_tokens <= score <= self._max_tokens + + class HistogramFilter(DocumentFilter): """Histogram filter used by the NLLB paper (https://arxiv.org/pdf/2207.04672). See p30 for details. From fe37942b6b2f1029f77aed316b1c33c1f6de1dba Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 09:26:18 -0800 Subject: [PATCH 05/22] Add postprocessing steps for nemotron cc sdg Signed-off-by: Ryan Wolf --- nemo_curator/filters/__init__.py | 2 + nemo_curator/filters/heuristic_filter.py | 33 +++++++++ nemo_curator/modifiers/__init__.py | 8 ++ nemo_curator/modifiers/line_remover.py | 36 +++++++++ nemo_curator/modifiers/markdown_remover.py | 43 +++++++++++ nemo_curator/modifiers/quotation_remover.py | 38 ++++++++++ nemo_curator/modifiers/slicer.py | 81 +++++++++++++++++++++ 7 files changed, 241 insertions(+) create mode 100644 nemo_curator/modifiers/line_remover.py create mode 100644 nemo_curator/modifiers/markdown_remover.py create mode 100644 nemo_curator/modifiers/quotation_remover.py create mode 100644 nemo_curator/modifiers/slicer.py diff --git a/nemo_curator/filters/__init__.py b/nemo_curator/filters/__init__.py index 935e1f993..cda290fbb 100644 --- a/nemo_curator/filters/__init__.py +++ b/nemo_curator/filters/__init__.py @@ -49,6 +49,7 @@ RepeatedParagraphsFilter, RepeatingDuplicateNGramsFilter, RepeatingTopNGramsFilter, + SubstringFilter, SymbolsToWordsFilter, TokenCountFilter, UrlsFilter, @@ -100,4 +101,5 @@ "AnswerabilityFilter", "EasinessFilter", "TokenCountFilter", + "SubstringFilter", ] diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index 24759ba1c..938906302 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -14,6 +14,7 @@ import os.path import tarfile +from typing import Literal import requests from platformdirs import user_cache_dir @@ -700,6 +701,38 @@ def keep_document(self, score): return self._min_tokens <= score <= self._max_tokens +class SubstringFilter(DocumentFilter): + """ + Keeps documents that contain a substring in a given position. + Gives a score of 1 if the substring is found in the given position, otherwise 0. + """ + + def __init__(self, substring: str, position: Literal["prefix", "suffix", "any"]): + """ + Args: + substring (str): The substring to check for. + position (Literal["prefix", "suffix", "any"]): The position of the substring. + """ + super().__init__() + self._substring = substring + if position not in ["prefix", "suffix", "any"]: + raise ValueError( + f"Invalid position: {position}. Must be one of: prefix, suffix, any." + ) + self._position = position + + def score_document(self, text: str) -> int: + if self._position == "prefix": + return int(text.startswith(self._substring)) + elif self._position == "suffix": + return int(text.endswith(self._substring)) + elif self._position == "any": + return int(self._substring in text) + + def keep_document(self, score: int) -> bool: + return score == 1 + + class HistogramFilter(DocumentFilter): """Histogram filter used by the NLLB paper (https://arxiv.org/pdf/2207.04672). See p30 for details. diff --git a/nemo_curator/modifiers/__init__.py b/nemo_curator/modifiers/__init__.py index f6511fdb0..bd9cdd27e 100644 --- a/nemo_curator/modifiers/__init__.py +++ b/nemo_curator/modifiers/__init__.py @@ -15,7 +15,11 @@ from .c4 import BoilerPlateStringModifier from .doc_modifier import DocumentModifier from .fasttext import FastTextLabelModifier +from .line_remover import LineRemover +from .markdown_remover import MarkdownRemover from .pii_modifier import PiiModifier +from .quotation_remover import QuotationRemover +from .slicer import Slicer from .unicode_reformatter import UnicodeReformatter __all__ = [ @@ -23,5 +27,9 @@ "BoilerPlateStringModifier", "FastTextLabelModifier", "UnicodeReformatter", + "QuotationRemover", + "LineRemover", + "MarkdownRemover", "PiiModifier", + "Slicer", ] diff --git a/nemo_curator/modifiers/line_remover.py b/nemo_curator/modifiers/line_remover.py new file mode 100644 index 000000000..eab763ad6 --- /dev/null +++ b/nemo_curator/modifiers/line_remover.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from nemo_curator.modifiers import DocumentModifier + + +class LineRemover(DocumentModifier): + """ + Removes lines from a document if the content of the line matches a given string. + """ + + def __init__(self, patterns: List[str]): + """ + Args: + patterns (List[str]): The patterns to check + """ + super().__init__() + self._patterns = patterns + + def modify_document(self, text: str) -> str: + lines = text.split("\n") + new_lines = [line for line in lines if line not in self._patterns] + return "\n".join(new_lines) diff --git a/nemo_curator/modifiers/markdown_remover.py b/nemo_curator/modifiers/markdown_remover.py new file mode 100644 index 000000000..cda29cd5f --- /dev/null +++ b/nemo_curator/modifiers/markdown_remover.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from nemo_curator.modifiers import DocumentModifier + +MARKDOWN_BOLD_REGEX = r"\*\*(.*?)\*\*" +MARKDOWN_ITALIC_REGEX = r"\*(.*?)\*" +MARKDOWN_UNDERLINE_REGEX = r"_(.*?)_" +MARKDOWN_LINK_REGEX = r"\[.*?\]\((.*?)\)" + + +class MarkdownRemover(DocumentModifier): + """ + Removes Markdown formatting in a document including bold, italic, and URL text. + """ + + def __init__(self): + super().__init__() + + def modify_document(self, text: str) -> str: + lines = text.split("\n") + new_lines = [] + for line in lines: + line = re.sub(MARKDOWN_BOLD_REGEX, r"\1", line) # **text** + line = re.sub(MARKDOWN_ITALIC_REGEX, r"\1", line) # *text* + line = re.sub(MARKDOWN_UNDERLINE_REGEX, r"\1", line) # _text_ + line = re.sub(MARKDOWN_LINK_REGEX, r"\1", line) # [text](url) + new_lines.append(line) + + return "\n".join(new_lines) diff --git a/nemo_curator/modifiers/quotation_remover.py b/nemo_curator/modifiers/quotation_remover.py new file mode 100644 index 000000000..3e36dfbcd --- /dev/null +++ b/nemo_curator/modifiers/quotation_remover.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.modifiers import DocumentModifier + + +class QuotationRemover(DocumentModifier): + """ + Removes quotations from a document following a few rules: + - If the document is less than 2 characters, it is returned unchanged. + - If the document starts and ends with a quotation mark and there are + no newlines in the document, the quotation marks are removed. + - If the document starts and ends with a quotation mark and there are + newlines in the document, the quotation marks are removed only if + the first line does not end with a quotation mark. + """ + + def __init__(self): + super().__init__() + + def modify_document(self, text: str) -> str: + if len(text.strip()) > 2 and text[0] == '"' and text[-1] == '"': + if "\n" not in text.strip(): + text = text[1:-1] + elif "\n" in text.strip() and text.split("\n")[0][-1] != '"': + text = text[1:-1] + return text diff --git a/nemo_curator/modifiers/slicer.py b/nemo_curator/modifiers/slicer.py new file mode 100644 index 000000000..9d8843b35 --- /dev/null +++ b/nemo_curator/modifiers/slicer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from nemo_curator.modifiers import DocumentModifier + + +class Slicer(DocumentModifier): + """ + Slices a document based on indices or strings + """ + + def __init__( + self, + left: Union[int, str], + right: Union[int, str], + include_left: bool = True, + include_right: bool = True, + strip: bool = True, + ): + """ + Args: + left (Union[int, str]): If the provided value is an int, slice the string from this index (inclusive). If the provided value is a str, slice the string from the first occurence of this substring. + right (Union[int, str]): If the provided value is an int, slice the string to this index (exclusive). If the provided value is a str, slice the string to the last occurence of this substring. + include_left (bool): Only used if `left` is a string. If True, the value of `left` is included in the slicing result. Defaults to False. + include_right (bool): Only used if `right` is a string. If True, the value of `right` is included in the slicing result. Defaults to False. + strip (bool): If True, strip the resulting string. + """ + super().__init__() + self._left = left + self._right = right + self._include_left = include_left + self._include_right = include_right + self._strip = strip + + def modify_document(self, text: str) -> str: + # Determine start index based on left type + if isinstance(self._left, int): + left_index = self._left + elif isinstance(self._left, str): + left_index_found = text.find(self._left) + if left_index_found == -1: + return "" + left_index = ( + left_index_found + if self._include_left + else left_index_found + len(self._left) + ) + else: + left_index = 0 # default if neither int nor str + + # Determine end index based on right type + if isinstance(self._right, int): + right_index = self._right + elif isinstance(self._right, str): + right_index_found = text.rfind(self._right) + if right_index_found == -1: + return "" + right_index = ( + right_index_found + len(self._right) + if self._include_right + else right_index_found + ) + else: + right_index = len(text) # default if neither int nor str + + result = text[left_index:right_index] + if self._strip: + result = result.strip() + return result From 95c939b8168b5abfb9b6815ef09936445af8dc12 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 09:38:53 -0800 Subject: [PATCH 06/22] Make left and right bounds optional Signed-off-by: Ryan Wolf --- nemo_curator/modifiers/slicer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_curator/modifiers/slicer.py b/nemo_curator/modifiers/slicer.py index 9d8843b35..9ce68a6e6 100644 --- a/nemo_curator/modifiers/slicer.py +++ b/nemo_curator/modifiers/slicer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union from nemo_curator.modifiers import DocumentModifier @@ -23,16 +23,16 @@ class Slicer(DocumentModifier): def __init__( self, - left: Union[int, str], - right: Union[int, str], + left: Optional[Union[int, str]] = None, + right: Optional[Union[int, str]] = None, include_left: bool = True, include_right: bool = True, strip: bool = True, ): """ Args: - left (Union[int, str]): If the provided value is an int, slice the string from this index (inclusive). If the provided value is a str, slice the string from the first occurence of this substring. - right (Union[int, str]): If the provided value is an int, slice the string to this index (exclusive). If the provided value is a str, slice the string to the last occurence of this substring. + left (Union[int, str], optional): If the provided value is an int, slice the string from this index (inclusive). If the provided value is a str, slice the string from the first occurence of this substring. + right (Union[int, str], optional): If the provided value is an int, slice the string to this index (exclusive). If the provided value is a str, slice the string to the last occurence of this substring. include_left (bool): Only used if `left` is a string. If True, the value of `left` is included in the slicing result. Defaults to False. include_right (bool): Only used if `right` is a string. If True, the value of `right` is included in the slicing result. Defaults to False. strip (bool): If True, strip the resulting string. From 8146cc2166ec314d6afa3785952c85961bc695f9 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 13:19:22 -0800 Subject: [PATCH 07/22] Add wikipedia rephrasing pipeline Signed-off-by: Ryan Wolf --- nemo_curator/filters/heuristic_filter.py | 2 +- nemo_curator/modules/splitter.py | 104 +++++++++++++++--- .../services/huggingface_formatter.py | 33 ++++++ nemo_curator/synthetic/__init__.py | 2 + nemo_curator/synthetic/nemotron_cc.py | 76 +++++++++++++ nemo_curator/synthetic/prompts.py | 9 ++ 6 files changed, 209 insertions(+), 17 deletions(-) create mode 100644 nemo_curator/services/huggingface_formatter.py create mode 100644 nemo_curator/synthetic/nemotron_cc.py diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index 938906302..182a76672 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -678,7 +678,7 @@ class TokenCountFilter(DocumentFilter): If the document contains more or less than a specified number of tokens, then discard. """ - def __init__(self, tokenizer: AutoTokenizer, min_tokens=10, max_tokens=100000): + def __init__(self, tokenizer: AutoTokenizer, min_tokens=0, max_tokens=float("inf")): """ Args: tokenizer (AutoTokenizer): The tokenizer to use to count the tokens. diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index 67ccb1d2d..71a260629 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import pandas as pd @@ -79,6 +79,7 @@ class DocumentJoiner: """ Joins documents that have a common id back into a single document. The order of the documents is dictated by an additional segment_id column. + A maximum length can be specified to limit the size of the joined documents. The joined documents are joined by a separator. """ @@ -90,6 +91,8 @@ def __init__( segment_id_field: str = "segment_id", document_id_field: str = "id", drop_segment_id_field: bool = True, + max_length: Optional[int] = None, + length_field: Optional[str] = None, ): """ Args: @@ -98,33 +101,102 @@ def __init__( segment_id_field (str): The name of the column containing the segment id. document_id_field (str): The name of the column containing the document id. drop_segment_id_field (bool): Whether to drop the segment_id_field after joining. + max_length (int, optional): The maximum length of the joined documents. + Both max_length and length_field must be specified or neither can be specified. + length_field (str, optional): The name of the column containing the length of the documents. + Both max_length and length_field must be specified or neither can be specified. """ + if max_length is not None and length_field is None: + raise ValueError("max_length is specified but length_field is not") + if max_length is None and length_field is not None: + raise ValueError("length_field is specified but max_length is not") + self.separator = separator self.text_field = text_field self.segment_id_field = segment_id_field self.document_id_field = document_id_field self.drop_segment_id_field = drop_segment_id_field + self.max_length = max_length + self.length_field = length_field + + def _join_segments(self, group): + # Ensure segments are processed in order. + group = group.sort_values(self.segment_id_field) + joined_rows = [] + current_seg_id = 0 + accumulator_text = None + accumulator_length = 0 + accumulator_row = None + + for _, row in group.iterrows(): + if accumulator_row is None: + # Start a new accumulation with the first segment. + accumulator_text = row[self.text_field] + accumulator_length = row[self.length_field] + accumulator_row = row + else: + # Calculate what the new length would be if we joined this segment. + proposed_length = accumulator_length + row[self.length_field] + 1 + if proposed_length <= self.max_length: + accumulator_text = ( + accumulator_text + self.separator + row[self.text_field] + ) + accumulator_length = proposed_length + else: + # Commit the current accumulation as one joined segment. + new_row = accumulator_row.copy() + new_row[self.text_field] = accumulator_text + new_row[self.length_field] = accumulator_length + new_row[self.segment_id_field] = current_seg_id + joined_rows.append(new_row) + current_seg_id += 1 + # Start a new accumulation with the current row. + accumulator_text = row[self.text_field] + accumulator_length = row[self.length_field] + accumulator_row = row + + # Commit the last accumulated segment. + if accumulator_row is not None: + new_row = accumulator_row.copy() + new_row[self.text_field] = accumulator_text + new_row[self.length_field] = accumulator_length + new_row[self.segment_id_field] = current_seg_id + joined_rows.append(new_row) + if joined_rows: + return pd.concat( + [group.iloc[0:0], pd.DataFrame(joined_rows)], ignore_index=True + ) + else: + return group.iloc[0:0] def _join_partition( self, df: pd.DataFrame, expected_cols: List[str] ) -> pd.DataFrame: if df.empty: return df - # Sort the segments by the segment_id_field to maintain proper order before aggregating. - df_sorted = df.sort_values(self.segment_id_field) - # Build aggregation functions to preserve all original columns: - # - For self.text_field, join all segments using the separator. - # - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence. - agg_funcs = {} - for col in df_sorted.columns: - if col == self.text_field: - agg_funcs[col] = lambda texts: self.separator.join(texts.astype(str)) - elif col != self.document_id_field: - agg_funcs[col] = "first" - # Group by document_id_field while keeping the key as a column. - joined = df_sorted.groupby(self.document_id_field, as_index=False).agg( - agg_funcs - ) + + if self.max_length is None: + # Sort the segments by the segment_id_field to maintain proper order before aggregating. + df_sorted = df.sort_values(self.segment_id_field) + # Build aggregation functions to preserve all original columns: + # - For self.text_field, join all segments using the separator. + # - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence. + agg_funcs = {} + for col in df_sorted.columns: + if col == self.text_field: + agg_funcs[col] = lambda texts: self.separator.join( + texts.astype(str) + ) + elif col != self.document_id_field: + agg_funcs[col] = "first" + # Group by document_id_field while keeping the key as a column. + joined = df_sorted.groupby(self.document_id_field, as_index=False).agg( + agg_funcs + ) + else: + joined = df.groupby(self.document_id_field, group_keys=False).apply( + self._join_segments + ) if self.drop_segment_id_field: joined = joined.drop(columns=self.segment_id_field) diff --git a/nemo_curator/services/huggingface_formatter.py b/nemo_curator/services/huggingface_formatter.py new file mode 100644 index 000000000..71967a536 --- /dev/null +++ b/nemo_curator/services/huggingface_formatter.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +from transformers import AutoTokenizer + +from nemo_curator.services import ConversationFormatter + + +class HuggingFaceFormatter(ConversationFormatter): + """ + A formatter that uses a Hugging Face tokenizer to format a conversation. + """ + + def __init__(self, tokenizer: AutoTokenizer) -> None: + self.tokenizer = tokenizer + + def format_conversation(self, conversation: List[dict]) -> str: + """ + Format a conversation between a user, assistant, and potentially system into a string. + """ + return self.tokenizer.apply_chat_template(conversation) diff --git a/nemo_curator/synthetic/__init__.py b/nemo_curator/synthetic/__init__.py index 44a4b6c12..bef5123e0 100644 --- a/nemo_curator/synthetic/__init__.py +++ b/nemo_curator/synthetic/__init__.py @@ -15,6 +15,7 @@ from .error import YamlConversionError from .mixtral import Mixtral8x7BFormatter from .nemotron import NemotronFormatter, NemotronGenerator +from .nemotron_cc import NemotronCC from .no_format import NoFormat from .prompts import ( DEFAULT_CLOSED_QA_PROMPT_TEMPLATE, @@ -45,6 +46,7 @@ "NemotronGenerator", "AsyncNemotronGenerator", "NemotronFormatter", + "NemotronCC", "Mixtral8x7BFormatter", "NoFormat", "DEFAULT_MACRO_TOPICS_PROMPT_TEMPLATE", diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py new file mode 100644 index 000000000..5d3273fff --- /dev/null +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from transformers import AutoTokenizer + +from nemo_curator.services import LLMClient +from nemo_curator.synthetic.prompts import ( + NEMOTRON_CC_SYSTEM_PROMPT, + WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, +) + + +class NemotronCC: + """ + Provides a collection of methods for generating synthetic data + described in the Nemotron-CC paper (https://arxiv.org/abs/2412.02595). + """ + + def __init__(self, llm_client: LLMClient, tokenizer: AutoTokenizer) -> None: + self.client = llm_client + self.tokenizer = tokenizer + + def _prompt( + self, + model: str, + prompt_template: str, + system_prompt: str, + prompt_kwargs: dict, + model_kwargs: dict, + ) -> List[str]: + prompt = prompt_template.format(**prompt_kwargs) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + + return self.client.query_model(messages=messages, model=model, **model_kwargs) + + def get_wikipedia_prefix_token_count(self) -> int: + user_prompt = WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE.format( + **{"document": "placeholder"} + ) + messages = [ + {"role": "system", "content": NEMOTRON_CC_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + prefix = self.tokenizer.apply_chat_template(messages) + + return len(self.tokenizer.encode(prefix)) + + def rewrite_to_wikipedia_style( + self, + document: str, + model: str, + prompt_template: str = WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> str: + prompt_kwargs["document"] = document + return self._prompt( + model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) diff --git a/nemo_curator/synthetic/prompts.py b/nemo_curator/synthetic/prompts.py index fbe7e026a..8c034673b 100644 --- a/nemo_curator/synthetic/prompts.py +++ b/nemo_curator/synthetic/prompts.py @@ -56,3 +56,12 @@ DIALOGUE_COMPLEX_USER_TURN_PROMPT_TEMPLATE = "Here is a conversation between a user and an assistant.\n<|The Start of Assistant's Conversation with User|>\n{conversation_history}\n<|The End of Assistant's Conversation with User|>\n\nGiven the conversation above, generate a followup request or question in the tone of User. Make sure the question is complex and diverse enough and suitable as a followup question. Directly give me the question without extraneous words." DIALOGUE_CONCISE_USER_TURN_PROMPT_TEMPLATE = "Here is a conversation between a user and an assistant.\n<|The Start of Assistant's Conversation with User|>\n{conversation_history}\n<|The End of Assistant's Conversation with User|>\n\nGiven the conversation above, generate a followup request or question in the toneof User. Be critical. Make sure the question is concise and has a real-life tone. Directly give me the question without extraneous words." + + +# Nemotron-CC prompts + +NEMOTRON_CC_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the questions." + +WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE = """For the following paragraph give me a diverse paraphrase of the same in high quality English language as in sentences on Wikipedia. Begin your answer on a separate line with "Here is a paraphrased version:". + +Text: {document}""" From 53c53b3aadfe2c93d2632ea914dfc3108c5bb238 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 15:42:47 -0800 Subject: [PATCH 08/22] Add diverse QA stages Signed-off-by: Ryan Wolf --- .../services/huggingface_formatter.py | 2 +- nemo_curator/synthetic/__init__.py | 3 +- nemo_curator/synthetic/nemotron_cc.py | 112 +++++++++++++++++- nemo_curator/synthetic/prompts.py | 28 +++++ 4 files changed, 139 insertions(+), 6 deletions(-) diff --git a/nemo_curator/services/huggingface_formatter.py b/nemo_curator/services/huggingface_formatter.py index 71967a536..728d1aa73 100644 --- a/nemo_curator/services/huggingface_formatter.py +++ b/nemo_curator/services/huggingface_formatter.py @@ -30,4 +30,4 @@ def format_conversation(self, conversation: List[dict]) -> str: """ Format a conversation between a user, assistant, and potentially system into a string. """ - return self.tokenizer.apply_chat_template(conversation) + return self.tokenizer.apply_chat_template(conversation, tokenize=False) diff --git a/nemo_curator/synthetic/__init__.py b/nemo_curator/synthetic/__init__.py index bef5123e0..acc1121b1 100644 --- a/nemo_curator/synthetic/__init__.py +++ b/nemo_curator/synthetic/__init__.py @@ -15,7 +15,7 @@ from .error import YamlConversionError from .mixtral import Mixtral8x7BFormatter from .nemotron import NemotronFormatter, NemotronGenerator -from .nemotron_cc import NemotronCC +from .nemotron_cc import NemotronCC, NemotronCCDiverseQAPostprocessor from .no_format import NoFormat from .prompts import ( DEFAULT_CLOSED_QA_PROMPT_TEMPLATE, @@ -47,6 +47,7 @@ "AsyncNemotronGenerator", "NemotronFormatter", "NemotronCC", + "NemotronCCDiverseQAPostprocessor", "Mixtral8x7BFormatter", "NoFormat", "DEFAULT_MACRO_TOPICS_PROMPT_TEMPLATE", diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index 5d3273fff..ee0498118 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +import random +from typing import Any, List, Optional from transformers import AutoTokenizer +from nemo_curator.datasets import DocumentDataset from nemo_curator.services import LLMClient from nemo_curator.synthetic.prompts import ( + DIVERSE_QA_PROMPT_TEMPLATE, NEMOTRON_CC_SYSTEM_PROMPT, WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, ) @@ -51,15 +54,16 @@ def _prompt( def get_wikipedia_prefix_token_count(self) -> int: user_prompt = WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE.format( - **{"document": "placeholder"} + document="placeholder" ) messages = [ {"role": "system", "content": NEMOTRON_CC_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ] - prefix = self.tokenizer.apply_chat_template(messages) - return len(self.tokenizer.encode(prefix)) + prefix_tokens = self.tokenizer.apply_chat_template(messages) + + return len(prefix_tokens) def rewrite_to_wikipedia_style( self, @@ -74,3 +78,103 @@ def rewrite_to_wikipedia_style( return self._prompt( model, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) + + def generate_diverse_qa( + self, + document: str, + model: str, + prompt_template: str = DIVERSE_QA_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> str: + prompt_kwargs["document"] = document + return self._prompt( + model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + + +class NemotronCCDiverseQAPostprocessor: + """ + Postprocesses the output of the Nemotron-CC Diverse QA generation pipeline. + This postprocessor will sample a random number of QA pairs up to max_num_pairs. + If a tokenizer is provided, the number of QA pairs will be sampled from at least 1 and at most floor(max_num_pairs * num_tokens / 150). + Otherwise, the number of QA pairs will be sampled randomly strictly up to max_num_pairs. + + The generated QA pairs are shuffled and then appended to the original text. + """ + + def __init__( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_field: str = "text", + response_field: str = "response", + max_num_pairs: int = 1, + prefix: str = "Here are the questions and answers based on the provided text:", + ) -> None: + """ + Args: + tokenizer (Optional[AutoTokenizer]): The tokenizer to use for tokenization. + If specified, the number of QA pairs will be sampled based on the token count of the text. + If not specified, the number of QA pairs will be sampled randomly up to max_num_pairs. + text_field (str): The field in the dataset that contains the text used to generate QA pairs. + response_field (str): The field in the dataset that contains the response from the LLM. + max_num_pairs (int): The maximum number of QA pairs to sample. + prefix (str): The prefix of the response from the LLM. + """ + self.tokenizer = tokenizer + self.text_field = text_field + self.response_field = response_field + self.max_num_pairs = max_num_pairs + self.prefix = prefix + + def _postprocess_llm_response(self, text: str, llm_response: str) -> str: + lines = [line.strip() for line in llm_response.split("\n") if line.strip()] + if not lines: + return "" + + # Remove the "- " prefix + lines = [line[2:].strip() if line.startswith("- ") else line for line in lines] + + if lines[0] == self.prefix: + lines = lines[1:] + + # Merge question and answer lines + qa_pairs = [] + for line in lines: + if line.startswith("Question:"): + qa_pairs.append(line) + else: + if qa_pairs: + qa_pairs[-1] += "\n" + line + else: + return "" + + if len(qa_pairs) == 0: + return "" + + # Shuffle the QA pairs and sample up to max_num_pairs + random.shuffle(qa_pairs) + if self.tokenizer is not None: + num_tokens = len(self.tokenizer.tokenize(text)) + qa_pairs = qa_pairs[ + : random.randint(1, max(1, int(self.max_num_pairs * num_tokens / 150))) + ] + else: + qa_pairs = qa_pairs[: random.randint(1, self.max_num_pairs)] + qa_pairs_str = "\n\n".join(qa_pairs) + + # Concatenate the document and the QA pairs + return f"{text}\n\n{qa_pairs_str}" + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + df = dataset.df + df[self.response_field] = df.apply( + lambda row: self._postprocess_llm_response( + row[self.text_field], row[self.response_field] + ), + axis=1, + ) + df = df[df[self.response_field] != ""] + + return DocumentDataset(df) diff --git a/nemo_curator/synthetic/prompts.py b/nemo_curator/synthetic/prompts.py index 8c034673b..941db8abc 100644 --- a/nemo_curator/synthetic/prompts.py +++ b/nemo_curator/synthetic/prompts.py @@ -65,3 +65,31 @@ WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE = """For the following paragraph give me a diverse paraphrase of the same in high quality English language as in sentences on Wikipedia. Begin your answer on a separate line with "Here is a paraphrased version:". Text: {document}""" + +DIVERSE_QA_PROMPT_TEMPLATE = """Task: +Read the text, ask questions and answer them. + +Follow these instructions: +1. Ask diverse questions that require different cognitive skills or cover different aspects of the text. +2. Ask questions in various forms such as: + - Yes/No questions that require determining whether a statement is true or false. + - Open-ended questions that begin with words like what, how, when, where, why and who. + - Multi-choice questions that offers two or more options to choose from. Include the options in the question. + - Comparison questions that compare two quantities or objects and determine the relationship between them. + - Reading comprehension questions that test the ability to understand and analyze the text. + - Problem-solving questions that test the ability to solve mathematical, physical, or logical problems. +3. Focus on asking questions about factual information, important knowledge, or concrete details in the text. +4. Write questions and answers using clear and concise language. +5. Use plain text. Do not use Markdown. +6. Each question and answer pair should be on a separate line. Tag the question with "Question:" and the answer with "Answer:". + +Text: +{document} + +Task: +After reading the above text, ask up to 8 questions and provide the correct answers following the instructions. Give your response in this format: + +Here are the questions and answers based on the provided text: +- Question: [first question] Answer: [first answer] +- Question: [second question] Answer: [second answer] +....""" From 4ddecbc2bbda3e7517187067dc86633b645abe56 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Thu, 6 Feb 2025 16:30:45 -0800 Subject: [PATCH 09/22] Add distillation Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/nemotron_cc.py | 16 ++++++++++++++++ nemo_curator/synthetic/prompts.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index ee0498118..e09636d37 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -20,6 +20,7 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.services import LLMClient from nemo_curator.synthetic.prompts import ( + DISTILL_PROMPT_TEMPLATE, DIVERSE_QA_PROMPT_TEMPLATE, NEMOTRON_CC_SYSTEM_PROMPT, WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, @@ -93,6 +94,20 @@ def generate_diverse_qa( model, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) + def distill( + self, + document: str, + model: str, + prompt_template: str = DISTILL_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> str: + prompt_kwargs["document"] = document + return self._prompt( + model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + class NemotronCCDiverseQAPostprocessor: """ @@ -174,6 +189,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: row[self.text_field], row[self.response_field] ), axis=1, + meta=(None, "object"), ) df = df[df[self.response_field] != ""] diff --git a/nemo_curator/synthetic/prompts.py b/nemo_curator/synthetic/prompts.py index 941db8abc..465e1c79c 100644 --- a/nemo_curator/synthetic/prompts.py +++ b/nemo_curator/synthetic/prompts.py @@ -93,3 +93,17 @@ - Question: [first question] Answer: [first answer] - Question: [second question] Answer: [second answer] ....""" + +DISTILL_PROMPT_TEMPLATE = """Your task is to read and paraphrase the provided text following these instructions: +- Aim to create a condensed but accurate and informative version of the original text, not a simplistic summary. +- Capture and preserve the crucial information, key concepts, important values, factual details in the original text, while making it more readable and accessible. +- Retain technical terms, specialized vocabulary, and complex concepts. +- Retain examples, explanations of reasoning processes, and supporting evidence to maintain the text's depth and context. +- Only include information that is present in the original text. Do not adding new or unsubstantiated claims. +- Write the text in plain text without formatting. + +Here is the text: +{document} + +Task: +After thoroughly reading the above text, paraphrase it in high-quality and clear English following the instructions. Begin your response with "Paraphrased Text:".""" From 46971aaf7dbf41bcdeab512baf17ca820208eb87 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 08:42:41 -0800 Subject: [PATCH 10/22] Add extract knowledge prompt Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/nemotron_cc.py | 56 ++++++++++++++++++++------- nemo_curator/synthetic/prompts.py | 17 ++++++++ 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index e09636d37..f791952fc 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -22,6 +22,7 @@ from nemo_curator.synthetic.prompts import ( DISTILL_PROMPT_TEMPLATE, DIVERSE_QA_PROMPT_TEMPLATE, + EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE, NEMOTRON_CC_SYSTEM_PROMPT, WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, ) @@ -53,19 +54,6 @@ def _prompt( return self.client.query_model(messages=messages, model=model, **model_kwargs) - def get_wikipedia_prefix_token_count(self) -> int: - user_prompt = WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE.format( - document="placeholder" - ) - messages = [ - {"role": "system", "content": NEMOTRON_CC_SYSTEM_PROMPT}, - {"role": "user", "content": user_prompt}, - ] - - prefix_tokens = self.tokenizer.apply_chat_template(messages) - - return len(prefix_tokens) - def rewrite_to_wikipedia_style( self, document: str, @@ -108,6 +96,20 @@ def distill( model, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) + def extract_knowledge( + self, + document: str, + model: str, + prompt_template: str = EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> str: + prompt_kwargs["document"] = document + return self._prompt( + model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + class NemotronCCDiverseQAPostprocessor: """ @@ -194,3 +196,31 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: df = df[df[self.response_field] != ""] return DocumentDataset(df) + + +# Although this could be implemented as a DocumentModifier, +# I have kept it separate to match the other postprocessors. +class NemotronCCKnowledgeListPostprocessor: + """ + Postprocesses the output of the Nemotron-CC Knowledge List generation pipeline. + """ + + def __init__(self, text_field: str = "text") -> None: + self.text_field = text_field + + def _postprocess_llm_response(self, text: str) -> str: + lines = [] + for idx, line in enumerate(text.split("\n")): + if idx == 0 and not line.startswith("-"): + continue + + if line.startswith(" ") or line.startswith("- "): + lines.append(line[2:].strip()) + else: + lines.append(line) + return "\n".join(lines) + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + df = dataset.df + df[self.text_field] = df[self.text_field].apply(self._postprocess_llm_response) + return DocumentDataset(df) diff --git a/nemo_curator/synthetic/prompts.py b/nemo_curator/synthetic/prompts.py index 465e1c79c..fd7072d43 100644 --- a/nemo_curator/synthetic/prompts.py +++ b/nemo_curator/synthetic/prompts.py @@ -62,6 +62,8 @@ NEMOTRON_CC_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the questions." +NEMOTRON_CC_DISTILL_SYSTEM_PROMPT = "You are an artificial intelligence assistant. You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning." + WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE = """For the following paragraph give me a diverse paraphrase of the same in high quality English language as in sentences on Wikipedia. Begin your answer on a separate line with "Here is a paraphrased version:". Text: {document}""" @@ -107,3 +109,18 @@ Task: After thoroughly reading the above text, paraphrase it in high-quality and clear English following the instructions. Begin your response with "Paraphrased Text:".""" + +EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE = """Your task is to rewrite knowledge from the provided text following these instructions. +- Rewrite the text as a passage or passages using easy-to-understand and high-quality English like sentences in textbooks and Wikipedia. +- Focus on content in disciplines such as humanities, social sciences, natural sciences, technology, engineering, math, law and legal, business, management, art, education, agricultural sciences, politics, and history. +- Disregard content that does not contain useful facts or knowledge. +- Retain examples, explanations of reasoning processes, and supporting evidence to maintain the text's depth and context. +- Do not add or alter details. Only restate what is already in the text. +- Write in plain text. +- Do not add titles, subtitles, note, or comment. + +Text: +{document} + +Task: +Rewrite facts and knowledge from the above text as a passage or passages following the instructions.""" From 042150f1f128a3091ef454fea1fc3e04084f277a Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 09:10:36 -0800 Subject: [PATCH 11/22] Add knowledge list prompt template Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/__init__.py | 7 ++++++- nemo_curator/synthetic/nemotron_cc.py | 18 +++++++++++++++++- nemo_curator/synthetic/prompts.py | 12 ++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/nemo_curator/synthetic/__init__.py b/nemo_curator/synthetic/__init__.py index acc1121b1..dbcaffdb1 100644 --- a/nemo_curator/synthetic/__init__.py +++ b/nemo_curator/synthetic/__init__.py @@ -15,7 +15,11 @@ from .error import YamlConversionError from .mixtral import Mixtral8x7BFormatter from .nemotron import NemotronFormatter, NemotronGenerator -from .nemotron_cc import NemotronCC, NemotronCCDiverseQAPostprocessor +from .nemotron_cc import ( + NemotronCC, + NemotronCCDiverseQAPostprocessor, + NemotronCCKnowledgeListPostprocessor, +) from .no_format import NoFormat from .prompts import ( DEFAULT_CLOSED_QA_PROMPT_TEMPLATE, @@ -48,6 +52,7 @@ "NemotronFormatter", "NemotronCC", "NemotronCCDiverseQAPostprocessor", + "NemotronCCKnowledgeListPostprocessor", "Mixtral8x7BFormatter", "NoFormat", "DEFAULT_MACRO_TOPICS_PROMPT_TEMPLATE", diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index f791952fc..a2b7ab410 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -23,6 +23,8 @@ DISTILL_PROMPT_TEMPLATE, DIVERSE_QA_PROMPT_TEMPLATE, EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE, + KNOWLEDGE_LIST_PROMPT_TEMPLATE, + NEMOTRON_CC_DISTILL_SYSTEM_PROMPT, NEMOTRON_CC_SYSTEM_PROMPT, WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, ) @@ -87,7 +89,7 @@ def distill( document: str, model: str, prompt_template: str = DISTILL_PROMPT_TEMPLATE, - system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + system_prompt: str = NEMOTRON_CC_DISTILL_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, ) -> str: @@ -110,6 +112,20 @@ def extract_knowledge( model, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) + def generate_knowledge_list( + self, + document: str, + model: str, + prompt_template: str = KNOWLEDGE_LIST_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> str: + prompt_kwargs["document"] = document + return self._prompt( + model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + class NemotronCCDiverseQAPostprocessor: """ diff --git a/nemo_curator/synthetic/prompts.py b/nemo_curator/synthetic/prompts.py index fd7072d43..1cdf53ed5 100644 --- a/nemo_curator/synthetic/prompts.py +++ b/nemo_curator/synthetic/prompts.py @@ -124,3 +124,15 @@ Task: Rewrite facts and knowledge from the above text as a passage or passages following the instructions.""" + +KNOWLEDGE_LIST_PROMPT_TEMPLATE = """Review the text and extract the key information. Follow these instructions: +- Carefully read the above text and provide a concise and organized list of factual information, concrete details, key concepts, and important numbers and statistics extracted from the text. +- Ensure each point is clear, specific, and supported by the original text. +- Ensure the extract text is information-dense and easier to learn from. +- Do not add titles or headings. + +Text: +{document} + +Task: +Extract the factual information, concrete details, and key concepts from the above text following the instructions.""" From 868a2b1c846c16322c4e97cbff0c692db2f29018 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 09:14:40 -0800 Subject: [PATCH 12/22] Add metadata to knowledge list postprocessor Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/nemotron_cc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index a2b7ab410..98e19d7d2 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -238,5 +238,7 @@ def _postprocess_llm_response(self, text: str) -> str: def __call__(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df - df[self.text_field] = df[self.text_field].apply(self._postprocess_llm_response) + df[self.text_field] = df[self.text_field].apply( + self._postprocess_llm_response, meta=(self.text_field, "object") + ) return DocumentDataset(df) From b0946fac7b93ff99175f5c5790ab7873d31ec1a1 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 09:28:05 -0800 Subject: [PATCH 13/22] Remove tokenizer from nemotron cc and add docstrings Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/nemotron_cc.py | 116 +++++++++++++++++++++----- 1 file changed, 97 insertions(+), 19 deletions(-) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index 98e19d7d2..565302cb4 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -36,19 +36,25 @@ class NemotronCC: described in the Nemotron-CC paper (https://arxiv.org/abs/2412.02595). """ - def __init__(self, llm_client: LLMClient, tokenizer: AutoTokenizer) -> None: + def __init__(self, llm_client: LLMClient) -> None: + """ + Initialize the NemotronCC instance. + + Args: + llm_client (LLMClient): The language model client used for querying the model. + """ self.client = llm_client - self.tokenizer = tokenizer def _prompt( self, model: str, + document: str, prompt_template: str, system_prompt: str, prompt_kwargs: dict, model_kwargs: dict, ) -> List[str]: - prompt = prompt_template.format(**prompt_kwargs) + prompt = prompt_template.format(document=document, **prompt_kwargs) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, @@ -64,10 +70,23 @@ def rewrite_to_wikipedia_style( system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, - ) -> str: - prompt_kwargs["document"] = document + ) -> List[str]: + """ + Rewrites a document into a Wikipedia-style narrative. + + Args: + document (str): The input document text to rewrite. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for rewriting. Defaults to WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ return self._prompt( - model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) def generate_diverse_qa( @@ -78,10 +97,23 @@ def generate_diverse_qa( system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, - ) -> str: - prompt_kwargs["document"] = document + ) -> List[str]: + """ + Generates diverse QA pairs from the provided document. + + Args: + document (str): The input document text used to generate QA pairs. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for generating QA pairs. Defaults to DIVERSE_QA_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ return self._prompt( - model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) def distill( @@ -92,10 +124,23 @@ def distill( system_prompt: str = NEMOTRON_CC_DISTILL_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, - ) -> str: - prompt_kwargs["document"] = document + ) -> List[str]: + """ + Distills the essential content from a document. + + Args: + document (str): The input document text to distill. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for distillation. Defaults to DISTILL_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_DISTILL_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ return self._prompt( - model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) def extract_knowledge( @@ -106,10 +151,23 @@ def extract_knowledge( system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, - ) -> str: - prompt_kwargs["document"] = document + ) -> List[str]: + """ + Extracts knowledge from the provided document. + + Args: + document (str): The input document text from which to extract knowledge. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for knowledge extraction. Defaults to EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ return self._prompt( - model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) def generate_knowledge_list( @@ -120,10 +178,23 @@ def generate_knowledge_list( system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, prompt_kwargs: dict = {}, model_kwargs: dict = {}, - ) -> str: - prompt_kwargs["document"] = document + ) -> List[str]: + """ + Generates a list of knowledge items from the provided document. + + Args: + document (str): The input document text to process. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for generating a knowledge list. Defaults to KNOWLEDGE_LIST_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ return self._prompt( - model, prompt_template, system_prompt, prompt_kwargs, model_kwargs + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs ) @@ -218,7 +289,14 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: # I have kept it separate to match the other postprocessors. class NemotronCCKnowledgeListPostprocessor: """ - Postprocesses the output of the Nemotron-CC Knowledge List generation pipeline. + Processes and cleans the output generated by the Nemotron-CC Knowledge List pipeline. + + This class is responsible for postprocessing raw text responses produced by the + Nemotron-CC Knowledge List generation pipeline. It removes formatting artifacts + such as bullet point prefixes ("- ") and extra indentation from each line, ensuring + that the final output is a clean, uniformly formatted list of knowledge items. + The processing includes skipping any initial non-bullet line and merging related lines + to reconstruct multi-line questions or answers. """ def __init__(self, text_field: str = "text") -> None: From a8f12ddad8ee8fbbec332ff5d0994c5537dfeef6 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 09:45:05 -0800 Subject: [PATCH 14/22] Add API docs and make modules use base class Signed-off-by: Ryan Wolf --- docs/user-guide/api/filters.rst | 8 +++++ docs/user-guide/api/misc.rst | 6 ++++ docs/user-guide/api/modifiers.rst | 19 +++++++++++ nemo_curator/filters/heuristic_filter.py | 4 +-- nemo_curator/modifiers/markdown_remover.py | 2 +- nemo_curator/modifiers/slicer.py | 14 +++++--- .../services/huggingface_formatter.py | 33 ------------------- nemo_curator/synthetic/nemotron_cc.py | 7 ++-- 8 files changed, 50 insertions(+), 43 deletions(-) delete mode 100644 nemo_curator/services/huggingface_formatter.py diff --git a/docs/user-guide/api/filters.rst b/docs/user-guide/api/filters.rst index 55b78ed7b..24678b73e 100644 --- a/docs/user-guide/api/filters.rst +++ b/docs/user-guide/api/filters.rst @@ -152,6 +152,14 @@ Heuristic Filters :members: :member-order: bysource +.. autoclass:: nemo_curator.filters.TokenCountFilter + :members: + :member-order: bysource + +.. autoclass:: nemo_curator.filters.SubstringFilter + :members: + :member-order: bysource + ------------------------------ Code Filters ------------------------------ diff --git a/docs/user-guide/api/misc.rst b/docs/user-guide/api/misc.rst index b4785f022..9872cb858 100644 --- a/docs/user-guide/api/misc.rst +++ b/docs/user-guide/api/misc.rst @@ -15,3 +15,9 @@ Miscellaneous .. autoclass:: nemo_curator.Shuffle :members: + +.. autoclass:: nemo_curator.DocumentSplitter + :members: + +.. autoclass:: nemo_curator.DocumentJoiner + :members: diff --git a/docs/user-guide/api/modifiers.rst b/docs/user-guide/api/modifiers.rst index 6e5f506ed..252803a24 100644 --- a/docs/user-guide/api/modifiers.rst +++ b/docs/user-guide/api/modifiers.rst @@ -32,3 +32,22 @@ Modifiers .. autoclass:: nemo_curator.modifiers.PiiModifier :members: + +.. autoclass:: nemo_curator.modifiers.LineRemover + :members: + +.. autoclass:: nemo_curator.modifiers.MarkdownRemover + :members: + +.. autoclass:: nemo_curator.modifiers.NewlineNormalizer + :members: + +.. autoclass:: nemo_curator.modifiers.UrlRemover + :members: + +.. autoclass:: nemo_curator.modifiers.Slicer + :members: + +.. autoclass:: nemo_curator.modifiers.QuotationRemover + :members: + diff --git a/nemo_curator/filters/heuristic_filter.py b/nemo_curator/filters/heuristic_filter.py index 182a76672..26617bd60 100644 --- a/nemo_curator/filters/heuristic_filter.py +++ b/nemo_curator/filters/heuristic_filter.py @@ -693,11 +693,11 @@ def __init__(self, tokenizer: AutoTokenizer, min_tokens=0, max_tokens=float("inf self._max_tokens = max_tokens self._name = "token_count" - def score_document(self, text): + def score_document(self, text: str) -> int: tokens = self._tokenizer.encode(text) return len(tokens) - def keep_document(self, score): + def keep_document(self, score: int) -> bool: return self._min_tokens <= score <= self._max_tokens diff --git a/nemo_curator/modifiers/markdown_remover.py b/nemo_curator/modifiers/markdown_remover.py index cda29cd5f..be060fd48 100644 --- a/nemo_curator/modifiers/markdown_remover.py +++ b/nemo_curator/modifiers/markdown_remover.py @@ -24,7 +24,7 @@ class MarkdownRemover(DocumentModifier): """ - Removes Markdown formatting in a document including bold, italic, and URL text. + Removes Markdown formatting in a document including bold, italic, underline, and URL text. """ def __init__(self): diff --git a/nemo_curator/modifiers/slicer.py b/nemo_curator/modifiers/slicer.py index 9ce68a6e6..d88070388 100644 --- a/nemo_curator/modifiers/slicer.py +++ b/nemo_curator/modifiers/slicer.py @@ -18,7 +18,7 @@ class Slicer(DocumentModifier): """ - Slices a document based on indices or strings + Slices a document based on indices or strings. """ def __init__( @@ -31,10 +31,14 @@ def __init__( ): """ Args: - left (Union[int, str], optional): If the provided value is an int, slice the string from this index (inclusive). If the provided value is a str, slice the string from the first occurence of this substring. - right (Union[int, str], optional): If the provided value is an int, slice the string to this index (exclusive). If the provided value is a str, slice the string to the last occurence of this substring. - include_left (bool): Only used if `left` is a string. If True, the value of `left` is included in the slicing result. Defaults to False. - include_right (bool): Only used if `right` is a string. If True, the value of `right` is included in the slicing result. Defaults to False. + left (Union[int, str], optional): If the provided value is an int, slice the string from this index (inclusive). + If the provided value is a str, slice the string from the first occurence of this substring. + right (Union[int, str], optional): If the provided value is an int, slice the string to this index (exclusive). + If the provided value is a str, slice the string to the last occurence of this substring. + include_left (bool): Only used if `left` is a string. If True, the value of `left` is included in the + slicing result. Defaults to False. + include_right (bool): Only used if `right` is a string. If True, the value of `right` is included in the + slicing result. Defaults to False. strip (bool): If True, strip the resulting string. """ super().__init__() diff --git a/nemo_curator/services/huggingface_formatter.py b/nemo_curator/services/huggingface_formatter.py deleted file mode 100644 index 728d1aa73..000000000 --- a/nemo_curator/services/huggingface_formatter.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List - -from transformers import AutoTokenizer - -from nemo_curator.services import ConversationFormatter - - -class HuggingFaceFormatter(ConversationFormatter): - """ - A formatter that uses a Hugging Face tokenizer to format a conversation. - """ - - def __init__(self, tokenizer: AutoTokenizer) -> None: - self.tokenizer = tokenizer - - def format_conversation(self, conversation: List[dict]) -> str: - """ - Format a conversation between a user, assistant, and potentially system into a string. - """ - return self.tokenizer.apply_chat_template(conversation, tokenize=False) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index 565302cb4..0850eb94a 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -17,6 +17,7 @@ from transformers import AutoTokenizer +from nemo_curator import BaseModule from nemo_curator.datasets import DocumentDataset from nemo_curator.services import LLMClient from nemo_curator.synthetic.prompts import ( @@ -198,7 +199,7 @@ def generate_knowledge_list( ) -class NemotronCCDiverseQAPostprocessor: +class NemotronCCDiverseQAPostprocessor(BaseModule): """ Postprocesses the output of the Nemotron-CC Diverse QA generation pipeline. This postprocessor will sample a random number of QA pairs up to max_num_pairs. @@ -226,6 +227,7 @@ def __init__( max_num_pairs (int): The maximum number of QA pairs to sample. prefix (str): The prefix of the response from the LLM. """ + super().__init__(input_backend="pandas") self.tokenizer = tokenizer self.text_field = text_field self.response_field = response_field @@ -287,7 +289,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: # Although this could be implemented as a DocumentModifier, # I have kept it separate to match the other postprocessors. -class NemotronCCKnowledgeListPostprocessor: +class NemotronCCKnowledgeListPostprocessor(BaseModule): """ Processes and cleans the output generated by the Nemotron-CC Knowledge List pipeline. @@ -300,6 +302,7 @@ class NemotronCCKnowledgeListPostprocessor: """ def __init__(self, text_field: str = "text") -> None: + super().__init__(input_backend="pandas") self.text_field = text_field def _postprocess_llm_response(self, text: str) -> str: From d1d1c0ce159df2782ac8f9ae12929d0cb42f0f20 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 11:33:56 -0800 Subject: [PATCH 15/22] Add tests for new modules Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/nemotron_cc.py | 4 +- tests/test_cleaning.py | 334 +++++++++++++++++++++++++- tests/test_filters.py | 180 ++++++++++++++ tests/test_nemotron_cc.py | 287 ++++++++++++++++++++++ tests/test_splitter.py | 280 +++++++++++++++++++++ 5 files changed, 1081 insertions(+), 4 deletions(-) create mode 100644 tests/test_nemotron_cc.py create mode 100644 tests/test_splitter.py diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index 0850eb94a..1009699ec 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -273,7 +273,7 @@ def _postprocess_llm_response(self, text: str, llm_response: str) -> str: # Concatenate the document and the QA pairs return f"{text}\n\n{qa_pairs_str}" - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df df[self.response_field] = df.apply( lambda row: self._postprocess_llm_response( @@ -317,7 +317,7 @@ def _postprocess_llm_response(self, text: str) -> str: lines.append(line) return "\n".join(lines) - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: df = dataset.df df[self.text_field] = df[self.text_field].apply( self._postprocess_llm_response, meta=(self.text_field, "object") diff --git a/tests/test_cleaning.py b/tests/test_cleaning.py index 906da3919..539fe49ba 100644 --- a/tests/test_cleaning.py +++ b/tests/test_cleaning.py @@ -14,17 +14,26 @@ import dask.dataframe as dd import pandas as pd +from dask.dataframe.utils import assert_eq from nemo_curator import Modify from nemo_curator.datasets import DocumentDataset -from nemo_curator.modifiers import NewlineNormalizer, UnicodeReformatter, UrlRemover +from nemo_curator.modifiers import ( + LineRemover, + MarkdownRemover, + NewlineNormalizer, + QuotationRemover, + Slicer, + UnicodeReformatter, + UrlRemover, +) def list_to_dataset(documents, col_name="text", npartitions=2): data = {col_name: documents} pdf = pd.DataFrame(data) - return DocumentDataset(dd.from_pandas(pdf, npartitions=npartitions)) + return DocumentDataset.from_pandas(pdf, npartitions=npartitions) class TestUnicodeReformatter: @@ -149,3 +158,324 @@ def test_urls(self): assert ( expected_results == actual_results ), f"Expected: {expected_results}, but got: {actual_results}" + + +class TestLineRemover: + def test_remove_exact_match(self): + text = "Keep this\nRemove me\nAlso keep this\nRemove me" + patterns = ["Remove me"] + remover = LineRemover(patterns) + result = remover.modify_document(text) + expected = "Keep this\nAlso keep this" + assert result == expected + + def test_no_removal_when_partial_match(self): + text = ( + "Keep this line\nThis line contains Remove me as a part of it\nAnother line" + ) + patterns = ["Remove me"] + remover = LineRemover(patterns) + # Only lines that exactly match "Remove me" are removed. + assert remover.modify_document(text) == text + + def test_empty_input(self): + text = "" + patterns = ["Remove me"] + remover = LineRemover(patterns) + result = remover.modify_document(text) + assert result == "" + + def test_multiple_patterns(self): + text = "Line one\nDelete\nLine two\nRemove\nLine three\nDelete" + patterns = ["Delete", "Remove"] + remover = LineRemover(patterns) + result = remover.modify_document(text) + expected = "Line one\nLine two\nLine three" + assert result == expected + + def test_whitespace_sensitivity(self): + # Exact match requires identical string content. + text = "Remove me \nRemove me\n Remove me" + patterns = ["Remove me"] + remover = LineRemover(patterns) + result = remover.modify_document(text) + # Only the line that exactly equals "Remove me" is removed. + expected = "Remove me \n Remove me" + assert result == expected + + def test_dataset_modification(self): + docs = [ + "Keep this\nRemove me\nKeep that", + "Remove me\nDon't remove\nRemove me", + "No removal here", + "Remove me", + ] + expected_results = [ + "Keep this\nKeep that", + "Don't remove", + "No removal here", + "", + ] + dataset = list_to_dataset(docs) + modifier = Modify(LineRemover(["Remove me"])) + fixed_dataset = modifier(dataset) + expected_dataset = list_to_dataset(expected_results) + assert_eq(fixed_dataset.df, expected_dataset.df) + + +class TestQuotationRemover: + def test_remove_quotes_no_newline(self): + text = '"Hello, World!"' + remover = QuotationRemover() + result = remover.modify_document(text) + expected = "Hello, World!" + assert result == expected + + def test_no_removal_when_quotes_not_enclosing(self): + text = 'Hello, "World!"' + remover = QuotationRemover() + result = remover.modify_document(text) + # The text does not start and end with a quotation mark. + assert result == text + + def test_remove_quotes_with_newline_removal(self): + text = '"Hello,\nWorld!"' + remover = QuotationRemover() + result = remover.modify_document(text) + # Since there is a newline and the first line does not end with a quote, + # the quotes are removed. + expected = "Hello,\nWorld!" + assert result == expected + + def test_no_removal_with_newline_preserved(self): + text = '"Hello,"\nWorld!"' + remover = QuotationRemover() + result = remover.modify_document(text) + # The first line ends with a quote so the removal does not occur. + assert result == text + + def test_short_text_no_removal(self): + text = '""' + remover = QuotationRemover() + result = remover.modify_document(text) + # With text length not greater than 2 (after stripping), nothing changes. + assert result == text + + def test_extra_whitespace_prevents_removal(self): + # If leading/trailing whitespace prevents the text from starting with a quote, + # nothing is changed. + text = ' "Test Message" ' + remover = QuotationRemover() + result = remover.modify_document(text) + assert result == text + + def test_dataset_modification(self): + import pandas as pd + from dask.dataframe.utils import assert_eq + + docs = ['"Document one"', 'Start "Document two" End', '"Document\nthree"', '""'] + expected_results = [ + "Document one", + 'Start "Document two" End', + "Document\nthree", + '""', + ] + dataset = list_to_dataset(docs) + modifier = Modify(QuotationRemover()) + fixed_dataset = modifier(dataset) + expected_dataset = list_to_dataset(expected_results) + assert_eq(fixed_dataset.df, expected_dataset.df) + + +class TestSlicer: + def test_integer_indices(self): + text = "Hello, world!" + slicer = Slicer(left=7, right=12) + result = slicer.modify_document(text) + expected = "world" + assert result == expected + + def test_left_string_including(self): + text = "abcXYZdef" + slicer = Slicer(left="XYZ", include_left=True) + result = slicer.modify_document(text) + expected = "XYZdef" + assert result == expected + + def test_left_string_excluding(self): + text = "abcXYZdef" + slicer = Slicer(left="XYZ", include_left=False) + result = slicer.modify_document(text) + expected = "def" + assert result == expected + + def test_right_string_including(self): + text = "abcXYZdef" + slicer = Slicer(right="XYZ", include_right=True) + result = slicer.modify_document(text) + expected = "abcXYZ" + assert result == expected + + def test_right_string_excluding(self): + text = "abcXYZdef" + slicer = Slicer(right="XYZ", include_right=False) + result = slicer.modify_document(text) + expected = "abc" + assert result == expected + + def test_both_left_and_right_with_strings(self): + text = "start middle end" + slicer = Slicer( + left="start", right="end", include_left=False, include_right=False + ) + result = slicer.modify_document(text) + # "start" is removed and "end" is excluded; extra spaces are stripped. + expected = "middle" + assert result == expected + + def test_non_existing_left(self): + text = "abcdef" + slicer = Slicer(left="nonexistent") + result = slicer.modify_document(text) + assert result == "" + + def test_non_existing_right(self): + text = "abcdef" + slicer = Slicer(right="nonexistent") + result = slicer.modify_document(text) + assert result == "" + + def test_no_left_no_right(self): + text = " some text with spaces " + slicer = Slicer() + result = slicer.modify_document(text) + # With no boundaries specified, the entire text is returned (stripped). + expected = "some text with spaces" + assert result == expected + + def test_integer_out_of_range(self): + text = "short" + slicer = Slicer(left=10) + result = slicer.modify_document(text) + # Slicing starting beyond the text length yields an empty string. + assert result == "" + + def test_multiple_occurrences(self): + text = "abc__def__ghi" + # Testing when markers appear multiple times. + slicer = Slicer(left="__", right="__", include_left=True, include_right=True) + result = slicer.modify_document(text) + # left: first occurrence at index 3; right: last occurrence at index 8, include_right adds len("__") + expected = "__def__" + assert result == expected + + def test_dataset_modification(self): + import pandas as pd + from dask.dataframe.utils import assert_eq + + docs = ["abcdef", "0123456789", "Hello", "Slicer"] + expected_results = [ + "cde", # "abcdef" sliced from index 2 to 5 + "234", # "0123456789" sliced from index 2 to 5 + "llo", # "Hello" sliced from index 2 to 5 + "ice", # "Slicer" sliced from index 2 to 5 + ] + dataset = list_to_dataset(docs) + modifier = Modify(Slicer(left=2, right=5)) + fixed_dataset = modifier(dataset) + expected_dataset = list_to_dataset(expected_results) + assert_eq(fixed_dataset.df, expected_dataset.df) + + +class TestMarkdownRemover: + def test_bold_removal(self): + text = "This is **bold** text." + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "This is bold text." + assert result == expected + + def test_italic_removal(self): + text = "This is *italic* text." + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "This is italic text." + assert result == expected + + def test_underline_removal(self): + text = "This is _underlined_ text." + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "This is underlined text." + assert result == expected + + def test_link_removal(self): + text = "Link: [Google](https://google.com)" + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "Link: https://google.com" + assert result == expected + + def test_multiple_markdown(self): + text = "This is **bold**, *italic*, and _underline_, check [Example](https://example.com)" + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "This is bold, italic, and underline, check https://example.com" + assert result == expected + + def test_no_markdown(self): + text = "This line has no markdown." + remover = MarkdownRemover() + result = remover.modify_document(text) + assert result == text + + def test_incomplete_markdown(self): + text = "This is *italic text" + remover = MarkdownRemover() + result = remover.modify_document(text) + # Without a closing '*', the text remains unchanged. + assert result == text + + def test_nested_markdown(self): + text = "This is **bold and *italic* inside** text." + remover = MarkdownRemover() + result = remover.modify_document(text) + # Bold formatting is removed first, then italics in the resulting string. + expected = "This is bold and italic inside text." + assert result == expected + + def test_multiple_lines(self): + text = "**Bold line**\n*Italic line*\n_Normal line_" + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "Bold line\nItalic line\nNormal line" + assert result == expected + + def test_adjacent_markdown(self): + text = "**Bold****MoreBold**" + remover = MarkdownRemover() + result = remover.modify_document(text) + expected = "BoldMoreBold" + assert result == expected + + def test_dataset_modification(self): + import pandas as pd + from dask.dataframe.utils import assert_eq + + docs = [ + "This is **bold**", + "This is *italic*", + "Check [Link](https://example.com)", + "No markdown here", + ] + expected_results = [ + "This is bold", + "This is italic", + "Check https://example.com", + "No markdown here", + ] + dataset = list_to_dataset(docs) + modifier = Modify(MarkdownRemover()) + fixed_dataset = modifier(dataset) + expected_dataset = list_to_dataset(expected_results) + assert_eq(fixed_dataset.df, expected_dataset.df) diff --git a/tests/test_filters.py b/tests/test_filters.py index 2f8bd00d1..aa15a1f00 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -19,6 +19,7 @@ import pandas as pd import pytest from dask import dataframe as dd +from dask.dataframe.utils import assert_eq from nemo_curator.datasets import DocumentDataset from nemo_curator.datasets.parallel_dataset import ParallelDataset @@ -49,7 +50,9 @@ RepeatedParagraphsFilter, RepeatingDuplicateNGramsFilter, RepeatingTopNGramsFilter, + SubstringFilter, SymbolsToWordsFilter, + TokenCountFilter, TokenizerFertilityFilter, UrlsFilter, WhiteSpaceFilter, @@ -110,6 +113,13 @@ def keep_document(self, scores): return min_threshold & max_threshold +# A simple dummy tokenizer for our tests. +class DummyTokenizer: + def encode(self, text): + # Simply splits the text on whitespace. + return text.split() + + def all_equal(left_dataset, right_dataset): return all(left_dataset.df.compute() == right_dataset.df.compute()) @@ -767,6 +777,176 @@ def test_histogram(self): ), f"Expected {expected_data2} but got {filtered_data2}" +class TestTokenCountFilter: + def test_score_document(self): + tokenizer = DummyTokenizer() + token_filter = TokenCountFilter(tokenizer, min_tokens=2, max_tokens=3) + text = "another test case" # Should yield 3 tokens. + score = token_filter.score_document(text) + assert score == 3 + + def test_keep_document(self): + tokenizer = DummyTokenizer() + token_filter = TokenCountFilter(tokenizer, min_tokens=2, max_tokens=3) + # Check that a score of 1 (too few) and 4 (too many) are rejected, + # while scores of 2 and 3 are accepted. + assert token_filter.keep_document(2) + assert token_filter.keep_document(3) + assert not token_filter.keep_document(1) + assert not token_filter.keep_document(4) + + def test_filter_dataset(self): + # Create a dataset of documents with different word counts. + docs = [ + "hello", # 1 token + "hello world", # 2 tokens + "this is a test", # 4 tokens + "another test case", # 3 tokens + ] + dataset = list_to_dataset(docs, col_name="text") + + tokenizer = DummyTokenizer() + token_filter = TokenCountFilter(tokenizer, min_tokens=2, max_tokens=3) + filter_step = ScoreFilter(token_filter, text_field="text") + filtered_dataset = filter_step(dataset) + # Reset indices for filtered dataset to ensure identical labeling for comparison. + filtered_dataset.df = filtered_dataset.df.reset_index(drop=True) + + # We expect to keep only the documents with exactly 2 or 3 tokens. + expected_docs = [ + "hello world", # 2 tokens + "another test case", # 3 tokens + ] + expected_dataset = list_to_dataset(expected_docs, col_name="text") + # Reset indices for expected dataset to ensure identical labeling. + expected_dataset.df = expected_dataset.df.reset_index(drop=True) + assert all_equal(expected_dataset, filtered_dataset) + + def test_filter_dataset_default(self): + # Create a dataset of documents with different word counts. + docs = [ + "hello", # 1 token + "hello world", # 2 tokens + "this is a test", # 4 tokens + "another test case", # 3 tokens + ] + dataset = list_to_dataset(docs, col_name="text") + + tokenizer = DummyTokenizer() + # Using default settings: min_tokens=0 and max_tokens=inf, so all documents pass. + token_filter = TokenCountFilter(tokenizer) + filter_step = ScoreFilter(token_filter, text_field="text") + filtered_dataset = filter_step(dataset) + + # We expect to keep all documents. + expected_dataset = list_to_dataset(docs, col_name="text") + assert all_equal(expected_dataset, filtered_dataset) + + +class TestSubstringFilter: + def test_invalid_position(self): + # Creating a SubstringFilter with an invalid position should raise a ValueError. + with pytest.raises(ValueError): + SubstringFilter("foo", "middle") + + def test_prefix_mode(self): + filter_prefix = SubstringFilter("Hello", "prefix") + # Positive example: text starts with "Hello". + text = "Hello world" + score = filter_prefix.score_document(text) + assert score == 1 + assert filter_prefix.keep_document(score) + # Negative example: text does not start with "Hello". + text2 = "world Hello" + score2 = filter_prefix.score_document(text2) + assert score2 == 0 + assert not filter_prefix.keep_document(score2) + + def test_suffix_mode(self): + filter_suffix = SubstringFilter("end", "suffix") + # Positive example: text ends with "end". + text = "This is the end" + score = filter_suffix.score_document(text) + assert score == 1 + assert filter_suffix.keep_document(score) + # Negative example: text does not end with "end". + text2 = "The end is near" + score2 = filter_suffix.score_document(text2) + assert score2 == 0 + assert not filter_suffix.keep_document(score2) + + def test_any_mode(self): + filter_any = SubstringFilter("test", "any") + # Positive example: text contains "test". + text = "this is a test string" + score = filter_any.score_document(text) + assert score == 1 + assert filter_any.keep_document(score) + # Negative example: text does not contain "test". + text2 = "this is a string" + score2 = filter_any.score_document(text2) + assert score2 == 0 + assert not filter_any.keep_document(score2) + + def test_filter_dataset_prefix(self): + docs = ["Hello world", "world Hello", "Hello everyone", "Not matching"] + dataset = list_to_dataset(docs, col_name="text") + filter_prefix = SubstringFilter("Hello", "prefix") + filter_step = ScoreFilter(filter_prefix, text_field="text") + filtered_dataset = filter_step(dataset) + + # Expect only those records where the text starts with "Hello". + expected_docs = ["Hello world", "Hello everyone"] + expected_dataset = list_to_dataset(expected_docs, col_name="text") + + # Reset indices to ensure both DataFrames are identically labeled + filtered_dataset = DocumentDataset(filtered_dataset.df.reset_index(drop=True)) + expected_dataset = DocumentDataset(expected_dataset.df.reset_index(drop=True)) + assert all_equal(expected_dataset, filtered_dataset) + + def test_filter_dataset_suffix(self): + docs = [ + "This is the end", # ends with "end" + "end of story", # does not end with "end" + "ending is good", # does not end with "end" + "Not matching end", # ends with "end" + "The end", # ends with "end" + ] + dataset = list_to_dataset(docs, col_name="text") + filter_suffix = SubstringFilter("end", "suffix") + filter_step = ScoreFilter(filter_suffix, text_field="text") + filtered_dataset = filter_step(dataset) + + # Expect only those records that end with "end". + expected_docs = [ + "Not matching end", + "The end", + "This is the end", + ] + expected_dataset = list_to_dataset(expected_docs, col_name="text") + + # Compare only the 'text' column values to avoid index label issues. + filtered_dataset = DocumentDataset(filtered_dataset.df.reset_index(drop=True)) + expected_dataset = DocumentDataset(expected_dataset.df.reset_index(drop=True)) + assert_eq(expected_dataset.df["text"], filtered_dataset.df["text"]) + + def test_filter_dataset_any(self): + docs = ["test case", "This is a testcase", "no match here", "another test"] + dataset = list_to_dataset(docs, col_name="text") + filter_any = SubstringFilter("test", "any") + filter_step = ScoreFilter(filter_any, text_field="text") + filtered_dataset = filter_step(dataset) + + # Expect documents that contain "test" anywhere. + expected_docs = ["test case", "This is a testcase", "another test"] + expected_dataset = list_to_dataset(expected_docs, col_name="text") + + # Reset indices to ensure both DataFrames are identically labeled + filtered_dataset = DocumentDataset(filtered_dataset.df.reset_index(drop=True)) + expected_dataset = DocumentDataset(expected_dataset.df.reset_index(drop=True)) + assert all_equal(expected_dataset, filtered_dataset) + + class TestCodeFilters: def test_python_comment_to_code(self): doc_1 = "# Good code\nprint('hello world')" diff --git a/tests/test_nemotron_cc.py b/tests/test_nemotron_cc.py new file mode 100644 index 000000000..a267779e6 --- /dev/null +++ b/tests/test_nemotron_cc.py @@ -0,0 +1,287 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import dask.dataframe as dd +import pandas as pd +import pytest + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.synthetic.nemotron_cc import ( + NemotronCCDiverseQAPostprocessor, + NemotronCCKnowledgeListPostprocessor, +) + + +# A dummy tokenizer that simply splits text by whitespace. +class DummyTokenizer: + def tokenize(self, text): + return text.split() + + +# Helper function to create a DocumentDataset from provided data. +def create_dataset(data): + pdf = pd.DataFrame(data) + return DocumentDataset.from_pandas(pdf) + + +class TestDiverseQAPostprocessor: + def test_valid_response_without_tokenizer(self, monkeypatch): + # Patch randomness so that the ordering and sampling is deterministic. + monkeypatch.setattr(random, "shuffle", lambda x: None) + # In the branch without a tokenizer, random.randint(1, max_num_pairs) + # will be forced to return the upper bound. + monkeypatch.setattr(random, "randint", lambda lo, hi: hi) + + text = "Document text" + llm_response = ( + "Here are the questions and answers based on the provided text:\n" + "- Question: What is this?\n" + "Answer: It is a test.\n" + "- Question: How does it work?\n" + "Answer: By magic." + ) + # Create a dataset with one row containing both the document and the LLM response. + ds = create_dataset({"text": [text], "response": [llm_response]}) + + # Use no tokenizer so that the branch using max_num_pairs (here, 2) is used. + processor = NemotronCCDiverseQAPostprocessor(tokenizer=None, max_num_pairs=2) + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # Expected processing: + # 1. Split into lines and remove the leading "- " prefix. + # 2. Remove the prefix line ("Here are...") if it matches. + # 3. Merge lines: the first QA pair becomes: + # "Question: What is this?\nAnswer: It is a test." + # and the second: + # "Question: How does it work?\nAnswer: By magic." + # 4. With our patched randint, both QA pairs are kept. + expected_qa = ( + "Question: What is this?\nAnswer: It is a test.\n\n" + "Question: How does it work?\nAnswer: By magic." + ) + expected_response = f"{text}\n\n{expected_qa}" + + assert not result_df.empty, "Expected non-empty dataset" + actual_response = result_df.iloc[0]["response"] + assert ( + actual_response == expected_response + ), f"Expected: {expected_response}, got: {actual_response}" + + def test_valid_response_with_tokenizer(self, monkeypatch): + # Using a dummy tokenizer. + dummy_tokenizer = DummyTokenizer() + monkeypatch.setattr(random, "shuffle", lambda x: None) + # For the branch with a tokenizer, the number of tokens is determined by: + # num_tokens = len(dummy_tokenizer.tokenize(text)). For "Document text" this yields 2. + # Then max_num = max(1, int(max_num_pairs * num_tokens / 150)) becomes max(1, int(4/150)) -> 1. + monkeypatch.setattr(random, "randint", lambda lo, hi: hi) + + text = "Document text" + llm_response = ( + "Here are the questions and answers based on the provided text:\n" + "- Question: What is this?\n" + "Answer: It is a test.\n" + "- Question: How does it work?\n" + "Answer: By magic." + ) + ds = create_dataset({"text": [text], "response": [llm_response]}) + processor = NemotronCCDiverseQAPostprocessor( + tokenizer=dummy_tokenizer, max_num_pairs=2 + ) + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # In the tokenizer branch only one QA pair is selected (the first one). + expected_qa = "Question: What is this?\nAnswer: It is a test." + expected_response = f"{text}\n\n{expected_qa}" + + assert not result_df.empty, "Expected non-empty dataset" + actual_response = result_df.iloc[0]["response"] + assert ( + actual_response == expected_response + ), f"Expected: {expected_response}, got: {actual_response}" + + def test_invalid_response_format(self, monkeypatch): + # Test a response with an invalid QA format (missing a "Question:" line). + monkeypatch.setattr(random, "shuffle", lambda x: None) + monkeypatch.setattr(random, "randint", lambda lo, hi: hi) + + text = "Doc" + # The response only has an answer line. + llm_response = ( + "Here are the questions and answers based on the provided text:\n" + "- Answer: Missing question." + ) + ds = create_dataset({"text": [text], "response": [llm_response]}) + processor = NemotronCCDiverseQAPostprocessor(tokenizer=None, max_num_pairs=2) + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # Since the response format is invalid (no "Question:" to start a QA pair), + # the postprocessing should return an empty string; the __call__ method then + # drops that row. + assert ( + result_df.empty + ), "Expected dataset to be empty due to invalid response format" + + def test_empty_response(self): + # Test when the LLM response is empty. + text = "Doc" + llm_response = "" + ds = create_dataset({"text": [text], "response": [llm_response]}) + processor = NemotronCCDiverseQAPostprocessor(tokenizer=None, max_num_pairs=2) + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # The empty LLM response should lead to an empty processed text and get filtered out. + assert result_df.empty, "Expected dataset to be empty for an empty LLM response" + + def test_more_qa_than_max(self, monkeypatch): + # Test when there are more QA pairs than max_num_pairs. + monkeypatch.setattr(random, "shuffle", lambda x: None) + monkeypatch.setattr(random, "randint", lambda lo, hi: hi) + + text = "Document text" + llm_response = ( + "Here are the questions and answers based on the provided text:\n" + "- Question: Q1?\n" + "Answer: A1.\n" + "- Question: Q2?\n" + "Answer: A2.\n" + "- Question: Q3?\n" + "Answer: A3.\n" + "- Question: Q4?\n" + "Answer: A4." + ) + ds = create_dataset({"text": [text], "response": [llm_response]}) + processor = NemotronCCDiverseQAPostprocessor(tokenizer=None, max_num_pairs=2) + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # With max_num_pairs set to 2 and patched randint returning the upper bound, + # only the first two QA pairs should be selected. + expected_qa = "Question: Q1?\nAnswer: A1.\n\n" "Question: Q2?\nAnswer: A2." + expected_response = f"{text}\n\n{expected_qa}" + + assert not result_df.empty, "Expected non-empty dataset" + actual_response = result_df.iloc[0]["response"] + assert ( + actual_response == expected_response + ), f"Expected: {expected_response}, got: {actual_response}" + + +class TestKnowledgeListPostprocessor: + def test_basic_formatting(self): + # Test that a response with an initial non-bullet line (to skip) and bullet lines + # is correctly cleaned. + input_response = ( + "Not a bullet line to skip\n" + "- Fact one: This is the first fact.\n" + " Continued fact one.\n" + "- Fact two: This is the second fact." + ) + ds = create_dataset({"text": [input_response]}) + processor = NemotronCCKnowledgeListPostprocessor(text_field="text") + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # Expected: + # - First line is skipped (since it does not start with "-"). + # - Bullet lines have the leading "- " or " " removed. + expected_output = ( + "Fact one: This is the first fact.\n" + "Continued fact one.\n" + "Fact two: This is the second fact." + ) + actual_output = result_df.iloc[0]["text"] + assert ( + actual_output == expected_output + ), f"Expected: {expected_output}, got: {actual_output}" + + def test_all_bullet_lines(self): + # Test when every line starts with a bullet prefix. + input_response = "- Item one\n" "- Item two\n" "- Item three" + ds = create_dataset({"text": [input_response]}) + processor = NemotronCCKnowledgeListPostprocessor(text_field="text") + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # Each line should be cleaned by removing the leading bullet. + expected_output = "Item one\nItem two\nItem three" + actual_output = result_df.iloc[0]["text"] + assert ( + actual_output == expected_output + ), f"Expected: {expected_output}, got: {actual_output}" + + def test_no_bullet_lines(self): + # If the response contains no bullet lines, then the first line is + # skipped and no text remains. + input_response = "This is just plain text without any bullet." + ds = create_dataset({"text": [input_response]}) + processor = NemotronCCKnowledgeListPostprocessor(text_field="text") + result_ds = processor(ds) + result_df = result_ds.df.compute() + + expected_output = "" + actual_output = result_df.iloc[0]["text"] + assert ( + actual_output == expected_output + ), f"Expected an empty string, got: {actual_output}" + + def test_mixed_indentation(self): + # Test mixed bullet prefixes and additional non-bullet lines. + input_response = ( + "- Bullet one\n" + "Some extra text\n" + " Indented line\n" + "- Bullet two\n" + " Continuation of bullet two\n" + "Another standalone line" + ) + ds = create_dataset({"text": [input_response]}) + processor = NemotronCCKnowledgeListPostprocessor(text_field="text") + result_ds = processor(ds) + result_df = result_ds.df.compute() + + # Note: Only the very first line is conditionally skipped if it doesn't start with '-'. + # Here, since the first line starts with "-", nothing is skipped. + # Each line that starts with "- " or " " should have those two characters removed. + expected_output = ( + "Bullet one\n" + "Some extra text\n" + "Indented line\n" + "Bullet two\n" + "Continuation of bullet two\n" + "Another standalone line" + ) + actual_output = result_df.iloc[0]["text"] + assert ( + actual_output == expected_output + ), f"Expected: {expected_output}, got: {actual_output}" + + def test_empty_input(self): + # Test that an empty input returns an empty string. + input_response = "" + ds = create_dataset({"text": [input_response]}) + processor = NemotronCCKnowledgeListPostprocessor(text_field="text") + result_ds = processor(ds) + result_df = result_ds.df.compute() + + expected_output = "" + actual_output = result_df.iloc[0]["text"] + assert ( + actual_output == expected_output + ), f"Expected empty string, got: {actual_output}" diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 000000000..7eadc054c --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,280 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dask.dataframe as dd +import pandas as pd +from dask.dataframe.utils import assert_eq + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.splitter import DocumentJoiner, DocumentSplitter + + +class TestDocumentSplitter: + def test_basic_split_default(self): + # Use default text_field "text" and segment_id_field "segment_id" + # Four examples: + # "a|b|c" → splits to ["a", "b", "c"] + # "nosplit" → ["nosplit"] + # "start|middle" → ["start", "middle"] + # "end|" → ["end", ""] + docs = ["a|b|c", "nosplit", "start|middle", "end|"] + pdf = pd.DataFrame({"text": docs}) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + splitter = DocumentSplitter(separator="|") + result_dataset = splitter(dataset) + + expected_df = pd.DataFrame( + { + "text": ["a", "b", "c", "nosplit", "start", "middle", "end", ""], + "segment_id": [0, 1, 2, 0, 0, 1, 0, 1], + } + ) + # Compare without considering the index order. + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + def test_split_custom_fields(self): + # Use a custom text field name ("content") and segment id field ("seg_id") + # with a different separator. + # Examples: + # "x;y" → ["x", "y"] + # "single" → ["single"] + # "first;second;third" → ["first", "second", "third"] + # ";leading" → ["", "leading"] + docs = ["x;y", "single", "first;second;third", ";leading"] + pdf = pd.DataFrame({"content": docs}) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + splitter = DocumentSplitter( + separator=";", text_field="content", segment_id_field="seg_id" + ) + result_dataset = splitter(dataset) + + expected_df = pd.DataFrame( + { + "content": [ + "x", + "y", + "single", + "first", + "second", + "third", + "", + "leading", + ], + "seg_id": [0, 1, 0, 0, 1, 2, 0, 1], + } + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + +class TestDocumentJoiner: + def test_join_default(self): + # Input represents documents already split. + # For example, a document with id=1 split as "a", "b", "c" becomes joined to "a|b|c". + # Four documents are used. + data = { + "id": [1, 1, 1, 2, 3, 3, 4, 4], + "text": ["a", "b", "c", "nosplit", "start", "middle", "end", ""], + "segment_id": [0, 1, 2, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + {"id": [1, 2, 3, 4], "text": ["a|b|c", "nosplit", "start|middle", "end|"]} + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + def test_join_custom_fields(self): + # Use custom field names: + # document id field: "doc" + # text field: "content" + # segment id field: "s_id" + # Also keep the segment id field (drop_segment_id_field=False) + data = { + "doc": [101, 101, 102, 103, 103, 104, 104], + "content": ["first", "second", "only", "hello", "world", "baz", ""], + "s_id": [0, 1, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="~", + text_field="content", + segment_id_field="s_id", + document_id_field="doc", + drop_segment_id_field=False, + ) + result_dataset = joiner(dataset) + + # Expected: each document is joined by "~". The segment id becomes the first segment's id. + expected_df = pd.DataFrame( + { + "doc": [101, 102, 103, 104], + "content": ["first~second", "only", "hello~world", "baz~"], + "s_id": [0, 0, 0, 0], + } + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + def test_join_max_length(self): + # Here we test joining when a maximum length is specified. + # Each segment carries a precomputed "length" value. + # The joiner should accumulate segments until adding the next one (plus separator) + # would exceed max_length=5. + # + # For document 1: + # segments: "ab"(2), "cd"(2), "ef"(2), "gh"(2) + # - "ab" then "cd": 2+2+1 = 5 → join as "ab-cd" (length 5) + # - then "ef" then "gh": 2+2+1 = 5 → join as "ef-gh" (length 5) + # + # For document 2: + # segments: "a"(1), "b"(1) → join as "a-b" (length 3) + # + # For document 3: + # segment: "hello"(5) → remains "hello" + # + # For document 4: + # segments: "x"(1), "yz"(2), "0"(1) + # - "x" then "yz": 1+2+1 = 4 → "x-yz" (length 4) + # - "0" remains alone. + data = { + "id": [1, 1, 1, 1, 2, 2, 3, 4, 4, 4], + "text": ["ab", "cd", "ef", "gh", "a", "b", "hello", "x", "yz", "0"], + "segment_id": [0, 1, 2, 3, 0, 1, 0, 0, 1, 2], + "length": [2, 2, 2, 2, 1, 1, 5, 1, 2, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="-", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + max_length=5, + length_field="length", + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + [ + {"id": 1, "text": "ab-cd", "length": 5}, + {"id": 1, "text": "ef-gh", "length": 5}, + {"id": 2, "text": "a-b", "length": 3}, + {"id": 3, "text": "hello", "length": 5}, + {"id": 4, "text": "x-yz", "length": 4}, + {"id": 4, "text": "0", "length": 1}, + ] + ) + # Sort by id and text to ensure consistent order + expected_sorted = expected_df.sort_values(by=["id", "text"]).reset_index( + drop=True + ) + result_sorted = ( + result_dataset.df.compute() + .sort_values(by=["id", "text"]) + .reset_index(drop=True) + ) + assert_eq(result_sorted, expected_sorted, check_index=False) + + def test_join_with_string_ids(self): + # Test join functionality when document id field is a string. + data = { + "doc": ["doc1", "doc1", "doc2", "doc3", "doc3", "doc4", "doc4"], + "text": ["a", "b", "nosplit", "start", "middle", "end", ""], + "segment_id": [0, 1, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="doc", + drop_segment_id_field=True, + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + { + "doc": ["doc1", "doc2", "doc3", "doc4"], + "text": ["a|b", "nosplit", "start|middle", "end|"], + } + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + +class TestSplitJoinReconstruction: + def test_reconstruction_default(self): + # Create an original dataset with a unique "id" column and text examples. + # Four examples include edge cases: + # "a|b|c" → multiple splits + # "nosplit" → no separator present + # "a||b|" → consecutive separators yield empty strings + # "" → empty document + docs = ["a|b|c", "nosplit", "a||b|", ""] + pdf = pd.DataFrame({"id": [1, 2, 3, 4], "text": docs}) + original_dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + + # First, split using "|" as separator. + splitter = DocumentSplitter(separator="|") + split_dataset = splitter(original_dataset) + + # Then, rejoin using the same separator. + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + ) + reconstructed_dataset = joiner(split_dataset) + + # The reconstructed "text" column should match the original. + original_sorted = ( + original_dataset.df.compute().sort_values(by="id").reset_index(drop=True) + ) + reconstructed_sorted = ( + reconstructed_dataset.df.compute() + .sort_values(by="id") + .reset_index(drop=True) + ) + assert_eq(reconstructed_sorted, original_sorted, check_index=False) From 7022a39b66d8634d853bf53f3c638d322aa1fe51 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Fri, 7 Feb 2025 11:39:29 -0800 Subject: [PATCH 16/22] Add async nemotron cc and rename classes Signed-off-by: Ryan Wolf --- nemo_curator/synthetic/__init__.py | 8 +- nemo_curator/synthetic/async_nemotron_cc.py | 196 ++++++++++++++++++++ nemo_curator/synthetic/nemotron_cc.py | 6 +- 3 files changed, 204 insertions(+), 6 deletions(-) create mode 100644 nemo_curator/synthetic/async_nemotron_cc.py diff --git a/nemo_curator/synthetic/__init__.py b/nemo_curator/synthetic/__init__.py index dbcaffdb1..1efb30430 100644 --- a/nemo_curator/synthetic/__init__.py +++ b/nemo_curator/synthetic/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from .async_nemotron import AsyncNemotronGenerator +from .async_nemotron_cc import AsyncNemotronCCGenerator from .error import YamlConversionError from .mixtral import Mixtral8x7BFormatter from .nemotron import NemotronFormatter, NemotronGenerator from .nemotron_cc import ( - NemotronCC, NemotronCCDiverseQAPostprocessor, + NemotronCCGenerator, NemotronCCKnowledgeListPostprocessor, ) from .no_format import NoFormat @@ -50,7 +51,8 @@ "NemotronGenerator", "AsyncNemotronGenerator", "NemotronFormatter", - "NemotronCC", + "NemotronCCGenerator", + "AsyncNemotronCCGenerator", "NemotronCCDiverseQAPostprocessor", "NemotronCCKnowledgeListPostprocessor", "Mixtral8x7BFormatter", diff --git a/nemo_curator/synthetic/async_nemotron_cc.py b/nemo_curator/synthetic/async_nemotron_cc.py new file mode 100644 index 000000000..8c16e11a6 --- /dev/null +++ b/nemo_curator/synthetic/async_nemotron_cc.py @@ -0,0 +1,196 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from nemo_curator.services import AsyncLLMClient +from nemo_curator.synthetic.prompts import ( + DISTILL_PROMPT_TEMPLATE, + DIVERSE_QA_PROMPT_TEMPLATE, + EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE, + KNOWLEDGE_LIST_PROMPT_TEMPLATE, + NEMOTRON_CC_DISTILL_SYSTEM_PROMPT, + NEMOTRON_CC_SYSTEM_PROMPT, + WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, +) + + +class AsyncNemotronCCGenerator: + """ + Provides a collection of methods for generating synthetic data + described in the Nemotron-CC paper (https://arxiv.org/abs/2412.02595). + """ + + def __init__(self, llm_client: AsyncLLMClient) -> None: + """ + Initialize the AsyncNemotronCCGenerator instance. + + Args: + llm_client (LLMClient): The language model client used for querying the model. + """ + self.client = llm_client + + async def _prompt( + self, + model: str, + document: str, + prompt_template: str, + system_prompt: str, + prompt_kwargs: dict, + model_kwargs: dict, + ) -> List[str]: + prompt = prompt_template.format(document=document, **prompt_kwargs) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + + return await self.client.query_model( + messages=messages, model=model, **model_kwargs + ) + + async def rewrite_to_wikipedia_style( + self, + document: str, + model: str, + prompt_template: str = WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> List[str]: + """ + Rewrites a document into a Wikipedia-style narrative. + + Args: + document (str): The input document text to rewrite. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for rewriting. Defaults to WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ + return await self._prompt( + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + + async def generate_diverse_qa( + self, + document: str, + model: str, + prompt_template: str = DIVERSE_QA_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> List[str]: + """ + Generates diverse QA pairs from the provided document. + + Args: + document (str): The input document text used to generate QA pairs. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for generating QA pairs. Defaults to DIVERSE_QA_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ + return await self._prompt( + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + + async def distill( + self, + document: str, + model: str, + prompt_template: str = DISTILL_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_DISTILL_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> List[str]: + """ + Distills the essential content from a document. + + Args: + document (str): The input document text to distill. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for distillation. Defaults to DISTILL_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_DISTILL_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ + return await self._prompt( + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + + async def extract_knowledge( + self, + document: str, + model: str, + prompt_template: str = EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> List[str]: + """ + Extracts knowledge from the provided document. + + Args: + document (str): The input document text from which to extract knowledge. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for knowledge extraction. Defaults to EXTRACT_KNOWLEDGE_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ + return await self._prompt( + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) + + async def generate_knowledge_list( + self, + document: str, + model: str, + prompt_template: str = KNOWLEDGE_LIST_PROMPT_TEMPLATE, + system_prompt: str = NEMOTRON_CC_SYSTEM_PROMPT, + prompt_kwargs: dict = {}, + model_kwargs: dict = {}, + ) -> List[str]: + """ + Generates a list of knowledge items from the provided document. + + Args: + document (str): The input document text to process. + model (str): The model identifier to use. + prompt_template (str, optional): The prompt template for generating a knowledge list. Defaults to KNOWLEDGE_LIST_PROMPT_TEMPLATE. + system_prompt (str, optional): The system prompt to use. Defaults to NEMOTRON_CC_SYSTEM_PROMPT. + prompt_kwargs (dict, optional): Additional keyword arguments for the prompt. Defaults to {}. + model_kwargs (dict, optional): Additional keyword arguments for the model invocation. Defaults to {}. + + Returns: + List[str]: A list of responses from the LLM. The list is only greater than length 1 if n > 1 is set in model_kwargs. + """ + return await self._prompt( + model, document, prompt_template, system_prompt, prompt_kwargs, model_kwargs + ) diff --git a/nemo_curator/synthetic/nemotron_cc.py b/nemo_curator/synthetic/nemotron_cc.py index 1009699ec..24629945e 100644 --- a/nemo_curator/synthetic/nemotron_cc.py +++ b/nemo_curator/synthetic/nemotron_cc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ ) -class NemotronCC: +class NemotronCCGenerator: """ Provides a collection of methods for generating synthetic data described in the Nemotron-CC paper (https://arxiv.org/abs/2412.02595). @@ -39,7 +39,7 @@ class NemotronCC: def __init__(self, llm_client: LLMClient) -> None: """ - Initialize the NemotronCC instance. + Initialize the NemotronCCGenerator instance. Args: llm_client (LLMClient): The language model client used for querying the model. From c53fe7d67b570c6fe63b0561aa42981f74554cee Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Mon, 10 Feb 2025 17:14:06 -0800 Subject: [PATCH 17/22] Add rst section and API docs Signed-off-by: Ryan Wolf --- docs/user-guide/api/synthetic.rst | 12 ++ docs/user-guide/syntheticdata.rst | 265 ++++++++++++++++++++++++++++++ 2 files changed, 277 insertions(+) diff --git a/docs/user-guide/api/synthetic.rst b/docs/user-guide/api/synthetic.rst index 685656b41..4e13e64b8 100644 --- a/docs/user-guide/api/synthetic.rst +++ b/docs/user-guide/api/synthetic.rst @@ -8,6 +8,18 @@ Synthetic Data .. autoclass:: nemo_curator.synthetic.AsyncNemotronGenerator :members: +.. autoclass:: nemo_curator.synthetic.NemotronCCGenerator + :members: + +.. autoclass:: nemo_curator.synthetic.NemotronCCDiverseQAPostprocessor + :members: + +.. autoclass:: nemo_curator.synthetic.NemotronCCKnowledgeListPostprocessor + :members: + +.. autoclass:: nemo_curator.synthetic.AsyncNemotronGenerator + :members: + .. autoclass:: nemo_curator.synthetic.NemotronFormatter :members: diff --git a/docs/user-guide/syntheticdata.rst b/docs/user-guide/syntheticdata.rst index d082ae5fe..2778bdc4f 100644 --- a/docs/user-guide/syntheticdata.rst +++ b/docs/user-guide/syntheticdata.rst @@ -15,6 +15,7 @@ Furthermore, NeMo Curator can also interface with `NeMo's Export and Deploy `_. +It also now supports the pipelines used in generating `Nemotron-CC `_. Additionally, you can seamlessly integrate filtering and deduplication steps in your synthetic data pipeline with the other modules available in NeMo Curator. Connect to an LLM Service @@ -690,6 +691,270 @@ All of the code so far has been sending requests to the LLM service synchronousl As you can see, the asynchronous modules have the same interface as the synchronous modules. The only exception is that a ``max_concurrent_requests`` parameter can be supplied to the constructor of ``AsyncNemotronGenerator`` as a form of rate limiting if your service is rate limited. +Customize the Nemotron-CC Pipeline +----------------------------------- + +Nemotron-CC used a collection of pipelines focused on rephrasing reference documents into different formats/styles. +NeMo Curator provides a synchronous and asynchronous version of each pipeline with ``nemo_curator.synthetic.NemotronCCGenerator`` and ``nemo_curator.synthetic.AsyncNemotronCCGenerator``. + +Rewrite to Wikipedia Style +########################## + +The ``NemotronCCGenerator.rewrite_to_wikipedia_style`` method rewrites a document into a style that is similar to Wikipedia. + +.. code-block:: python + + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.synthetic import NemotronCCGenerator + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 512, + } + + responses = generator.rewrite_to_wikipedia_style( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # The lunar surface has a high albedo, which means it reflects a significant amount of sunlight. + + +Generate Diverse QA Pairs +######################### + +The ``NemotronCCGenerator.generate_diverse_qa`` method generates a list of diverse QA pairs from a document. + +.. code-block:: python + + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.synthetic import NemotronCCGenerator + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 600, + } + + responses = generator.generate_diverse_qa( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # Question: What is the moon made of? + # Answer: The moon is made of rock and dust. + + +To help with cleaning the output, the ``NemotronCCDiverseQAPostprocessor`` class is provided. + +.. code-block:: python + + import pandas as pd + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.datasets import DocumentDataset + from nemo_curator.synthetic import NemotronCCGenerator, NemotronCCDiverseQAPostprocessor + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 600, + } + responses = generator.generate_diverse_qa(document=document, model=model, model_kwargs=model_kwargs) + postprocessor = NemotronCCDiverseQAPostprocessor(text_field="text", response_field="diverse_qa_response") + dataset = DocumentDataset.from_pandas(pd.DataFrame({"text": document, "diverse_qa_response": responses})) + cleaned_dataset = postprocessor(dataset) + + first_entry = cleaned_dataset.df.head(1) + print(first_entry["diverse_qa_response"]) + # Output: + # The moon is bright. It shines at night. Question: What is the moon made of? Answer: The moon is made of rock and dust. + + +Generate Knowledge List +####################### + +The ``NemotronCCGenerator.generate_knowledge_list`` method generates a list of knowledge from a document. + +.. code-block:: python + + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.synthetic import NemotronCCGenerator + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 600, + } + + responses = generator.generate_knowledge_list( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # - The moon is made of rock and dust. + # - The moon is the only natural satellite of the Earth. + # ... + +To help with cleaning the output, the ``NemotronCCKnowledgeListPostprocessor`` class is provided. + +.. code-block:: python + + import pandas as pd + from openai import OpenAI + + from nemo_curator import OpenAIClient + from nemo_curator.datasets import DocumentDataset + from nemo_curator.synthetic import NemotronCCGenerator, NemotronCCKnowledgeListPostprocessor + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 600, + } + + responses = generator.generate_knowledge_list( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # - The moon is made of rock and dust. + # - The moon is the only natural satellite of the Earth. + # ... + + postprocessor = NemotronCCKnowledgeListPostprocessor(text_field="knowledge_list_response") + dataset = DocumentDataset.from_pandas(pd.DataFrame({"knowledge_list_response": responses})) + cleaned_dataset = postprocessor(dataset) + + first_entry = cleaned_dataset.df.head(1) + print(first_entry["knowledge_list_response"]) + # Output: + # The moon is made of rock and dust. + # The moon is the only natural satellite of the Earth. + +Distill Document +################# + +The ``NemotronCCGenerator.distill`` method distills a document into a more concise form. + +.. code-block:: python + + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.synthetic import NemotronCCGenerator + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 1600, + } + + responses = generator.distill( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # The moon is bright at night. + + +Extract Knowledge +################ + +The ``NemotronCCGenerator.extract_knowledge`` method extracts knowledge from a document. + +.. code-block:: python + + from openai import OpenAI + from nemo_curator import OpenAIClient + from nemo_curator.synthetic import NemotronCCGenerator + + openai_client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="" + ) + client = OpenAIClient(openai_client) + generator = NemotronCCGenerator(client) + + document = "The moon is bright. It shines at night." + model = "nv-mistralai/mistral-nemo-12b-instruct" + model_kwargs = { + "temperature": 0.5, + "top_p": 0.9, + "max_tokens": 1400, + } + + responses = generator.extract_knowledge( + document=document, model=model, model_kwargs=model_kwargs + ) + + print(responses[0]) + # Output: + # The moon is a reflective body visible from the Earth at night. + + Combine Synthetic Data Generation with other NeMo Curator Modules ----------------------------------------------------------------- Synthetic data generation, unlike the rest of NeMo Curator, operates independently of Dask. From dcf9d900df2d593847fd9aa6b17ec4e32cb705e9 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 11 Feb 2025 11:02:30 -0800 Subject: [PATCH 18/22] Address Vibhu and Praateek's reviews Signed-off-by: Ryan Wolf --- nemo_curator/modifiers/quotation_remover.py | 2 +- nemo_curator/modifiers/slicer.py | 5 +- nemo_curator/modules/__init__.py | 3 +- nemo_curator/modules/joiner.py | 168 ++++++++++++++ nemo_curator/modules/splitter.py | 156 +------------ tests/test_joiner.py | 176 +++++++++++++++ tests/test_splitter.py | 232 +++----------------- tests/test_splitter_joiner.py | 56 +++++ 8 files changed, 437 insertions(+), 361 deletions(-) create mode 100644 nemo_curator/modules/joiner.py create mode 100644 tests/test_joiner.py create mode 100644 tests/test_splitter_joiner.py diff --git a/nemo_curator/modifiers/quotation_remover.py b/nemo_curator/modifiers/quotation_remover.py index 3e36dfbcd..02f5bda1e 100644 --- a/nemo_curator/modifiers/quotation_remover.py +++ b/nemo_curator/modifiers/quotation_remover.py @@ -33,6 +33,6 @@ def modify_document(self, text: str) -> str: if len(text.strip()) > 2 and text[0] == '"' and text[-1] == '"': if "\n" not in text.strip(): text = text[1:-1] - elif "\n" in text.strip() and text.split("\n")[0][-1] != '"': + elif text.split("\n")[0][-1] != '"': text = text[1:-1] return text diff --git a/nemo_curator/modifiers/slicer.py b/nemo_curator/modifiers/slicer.py index d88070388..d267b8314 100644 --- a/nemo_curator/modifiers/slicer.py +++ b/nemo_curator/modifiers/slicer.py @@ -23,7 +23,7 @@ class Slicer(DocumentModifier): def __init__( self, - left: Optional[Union[int, str]] = None, + left: Optional[Union[int, str]] = 0, right: Optional[Union[int, str]] = None, include_left: bool = True, include_right: bool = True, @@ -34,7 +34,8 @@ def __init__( left (Union[int, str], optional): If the provided value is an int, slice the string from this index (inclusive). If the provided value is a str, slice the string from the first occurence of this substring. right (Union[int, str], optional): If the provided value is an int, slice the string to this index (exclusive). - If the provided value is a str, slice the string to the last occurence of this substring. + If the provided value is a str, slice the string to the last occurence of this substring. If None, + right is set to the length of the string. include_left (bool): Only used if `left` is a string. If True, the value of `left` is included in the slicing result. Defaults to False. include_right (bool): Only used if `right` is a string. If True, the value of `right` is included in the diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index bd86756f2..6273c88e5 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -28,7 +28,8 @@ from .exact_dedup import ExactDuplicates from .meta import Sequential from .modify import Modify -from .splitter import DocumentSplitter, DocumentJoiner +from .splitter import DocumentSplitter +from .joiner import DocumentJoiner from .task import TaskDecontamination from .to_backend import ToBackend diff --git a/nemo_curator/modules/joiner.py b/nemo_curator/modules/joiner.py new file mode 100644 index 000000000..8adad7779 --- /dev/null +++ b/nemo_curator/modules/joiner.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import pandas as pd + +from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule + + +class DocumentJoiner(BaseModule): + """ + Joins documents that have a common id back into a single document. + The order of the documents is dictated by an additional segment_id column. + A maximum length can be specified to limit the size of the joined documents. + + The joined documents are joined by a separator. + """ + + def __init__( + self, + separator: str, + text_field: str = "text", + segment_id_field: str = "segment_id", + document_id_field: str = "id", + drop_segment_id_field: bool = True, + max_length: Optional[int] = None, + length_field: Optional[str] = None, + ): + """ + Args: + separator (str): The separator to join the documents on. + text_field (str): The name of the column containing the text to join. + segment_id_field (str): The name of the column containing the segment id. + document_id_field (str): The name of the column containing the document id. + drop_segment_id_field (bool): Whether to drop the segment_id_field after joining. + max_length (int, optional): The maximum length of the joined documents. + Both max_length and length_field must be specified or neither can be specified. + length_field (str, optional): The name of the column containing the length of the documents. + Both max_length and length_field must be specified or neither can be specified. + """ + if max_length is not None and length_field is None: + raise ValueError("max_length is specified but length_field is not") + if max_length is None and length_field is not None: + raise ValueError("length_field is specified but max_length is not") + + super().__init__(input_backend="pandas") + self.separator = separator + self.text_field = text_field + self.segment_id_field = segment_id_field + self.document_id_field = document_id_field + self.drop_segment_id_field = drop_segment_id_field + self.max_length = max_length + self.length_field = length_field + + def _join_segments(self, group): + # Ensure segments are processed in order. + group = group.sort_values(self.segment_id_field) + joined_rows = [] + current_seg_id = 0 + accumulator_text = None + accumulator_length = 0 + accumulator_row = None + + for _, row in group.iterrows(): + if accumulator_row is None: + # Start a new accumulation with the first segment. + accumulator_text = row[self.text_field] + accumulator_length = row[self.length_field] + accumulator_row = row + else: + # Calculate what the new length would be if we joined this segment. + proposed_length = accumulator_length + row[self.length_field] + 1 + if proposed_length <= self.max_length: + accumulator_text = ( + accumulator_text + self.separator + row[self.text_field] + ) + accumulator_length = proposed_length + else: + # Commit the current accumulation as one joined segment. + new_row = accumulator_row.copy() + new_row[self.text_field] = accumulator_text + new_row[self.length_field] = accumulator_length + new_row[self.segment_id_field] = current_seg_id + joined_rows.append(new_row) + current_seg_id += 1 + # Start a new accumulation with the current row. + accumulator_text = row[self.text_field] + accumulator_length = row[self.length_field] + accumulator_row = row + + # Commit the last accumulated segment. + if accumulator_row is not None: + new_row = accumulator_row.copy() + new_row[self.text_field] = accumulator_text + new_row[self.length_field] = accumulator_length + new_row[self.segment_id_field] = current_seg_id + joined_rows.append(new_row) + if joined_rows: + return pd.concat( + [group.iloc[0:0], pd.DataFrame(joined_rows)], ignore_index=True + ) + else: + return group.iloc[0:0] + + def _join_partition( + self, df: pd.DataFrame, expected_cols: List[str] + ) -> pd.DataFrame: + if df.empty: + return df + + if self.max_length is None: + # Sort the segments by the segment_id_field to maintain proper order before aggregating. + df_sorted = df.sort_values(self.segment_id_field) + # Build aggregation functions to preserve all original columns: + # - For self.text_field, join all segments using the separator. + # - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence. + agg_funcs = {} + for col in df_sorted.columns: + if col == self.text_field: + agg_funcs[col] = lambda texts: self.separator.join( + texts.astype(str) + ) + elif col != self.document_id_field: + agg_funcs[col] = "first" + # Group by document_id_field while keeping the key as a column. + joined = df_sorted.groupby(self.document_id_field, as_index=False).agg( + agg_funcs + ) + else: + joined = df.groupby(self.document_id_field, group_keys=False).apply( + self._join_segments + ) + + if self.drop_segment_id_field: + joined = joined.drop(columns=self.segment_id_field) + # Reorder the columns to match the expected metadata order. + joined = joined[expected_cols] + return joined + + def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + """ + Joins the documents back into a single document while preserving all the original fields. + """ + # Construct meta information for the transformed dataframe. + meta = dataset.df._meta.copy() + if self.text_field not in meta.columns: + meta[self.text_field] = pd.Series(dtype="object") + # If dropping the segment id field, remove it from the metadata to prevent mismatches. + if self.drop_segment_id_field: + meta = meta.drop(columns=self.segment_id_field) + expected_cols = list(meta.columns) + # Apply the join operation partition-wise. + dataset.df = dataset.df.map_partitions( + self._join_partition, expected_cols=expected_cols, meta=meta + ) + return dataset diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index 71a260629..4002468ff 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - import pandas as pd from nemo_curator.datasets import DocumentDataset +from nemo_curator.modules.base import BaseModule -class DocumentSplitter: +class DocumentSplitter(BaseModule): """ Splits documents into segments based on a separator. Each segment is a new document with an additional column @@ -40,13 +39,12 @@ def __init__( text_field (str): The name of the column containing the text to split. segment_id_field (str): The name of the column to add to indicate the segment id. """ + super().__init__(input_backend="any") self.separator = separator self.text_field = text_field self.segment_id_field = segment_id_field def _split_partition(self, df: pd.DataFrame) -> pd.DataFrame: - # Work on a copy to avoid modifying the original dataframe in place. - df = df.copy() # Split the text field into segments using the separator. df["split_text"] = df[self.text_field].str.split(self.separator) # Explode the list so that each segment becomes a separate row. @@ -73,151 +71,3 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset: # Apply the partition-wise splitting transformation using Dask's map_partitions. dataset.df = dataset.df.map_partitions(self._split_partition, meta=meta) return dataset - - -class DocumentJoiner: - """ - Joins documents that have a common id back into a single document. - The order of the documents is dictated by an additional segment_id column. - A maximum length can be specified to limit the size of the joined documents. - - The joined documents are joined by a separator. - """ - - def __init__( - self, - separator: str, - text_field: str = "text", - segment_id_field: str = "segment_id", - document_id_field: str = "id", - drop_segment_id_field: bool = True, - max_length: Optional[int] = None, - length_field: Optional[str] = None, - ): - """ - Args: - separator (str): The separator to join the documents on. - text_field (str): The name of the column containing the text to join. - segment_id_field (str): The name of the column containing the segment id. - document_id_field (str): The name of the column containing the document id. - drop_segment_id_field (bool): Whether to drop the segment_id_field after joining. - max_length (int, optional): The maximum length of the joined documents. - Both max_length and length_field must be specified or neither can be specified. - length_field (str, optional): The name of the column containing the length of the documents. - Both max_length and length_field must be specified or neither can be specified. - """ - if max_length is not None and length_field is None: - raise ValueError("max_length is specified but length_field is not") - if max_length is None and length_field is not None: - raise ValueError("length_field is specified but max_length is not") - - self.separator = separator - self.text_field = text_field - self.segment_id_field = segment_id_field - self.document_id_field = document_id_field - self.drop_segment_id_field = drop_segment_id_field - self.max_length = max_length - self.length_field = length_field - - def _join_segments(self, group): - # Ensure segments are processed in order. - group = group.sort_values(self.segment_id_field) - joined_rows = [] - current_seg_id = 0 - accumulator_text = None - accumulator_length = 0 - accumulator_row = None - - for _, row in group.iterrows(): - if accumulator_row is None: - # Start a new accumulation with the first segment. - accumulator_text = row[self.text_field] - accumulator_length = row[self.length_field] - accumulator_row = row - else: - # Calculate what the new length would be if we joined this segment. - proposed_length = accumulator_length + row[self.length_field] + 1 - if proposed_length <= self.max_length: - accumulator_text = ( - accumulator_text + self.separator + row[self.text_field] - ) - accumulator_length = proposed_length - else: - # Commit the current accumulation as one joined segment. - new_row = accumulator_row.copy() - new_row[self.text_field] = accumulator_text - new_row[self.length_field] = accumulator_length - new_row[self.segment_id_field] = current_seg_id - joined_rows.append(new_row) - current_seg_id += 1 - # Start a new accumulation with the current row. - accumulator_text = row[self.text_field] - accumulator_length = row[self.length_field] - accumulator_row = row - - # Commit the last accumulated segment. - if accumulator_row is not None: - new_row = accumulator_row.copy() - new_row[self.text_field] = accumulator_text - new_row[self.length_field] = accumulator_length - new_row[self.segment_id_field] = current_seg_id - joined_rows.append(new_row) - if joined_rows: - return pd.concat( - [group.iloc[0:0], pd.DataFrame(joined_rows)], ignore_index=True - ) - else: - return group.iloc[0:0] - - def _join_partition( - self, df: pd.DataFrame, expected_cols: List[str] - ) -> pd.DataFrame: - if df.empty: - return df - - if self.max_length is None: - # Sort the segments by the segment_id_field to maintain proper order before aggregating. - df_sorted = df.sort_values(self.segment_id_field) - # Build aggregation functions to preserve all original columns: - # - For self.text_field, join all segments using the separator. - # - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence. - agg_funcs = {} - for col in df_sorted.columns: - if col == self.text_field: - agg_funcs[col] = lambda texts: self.separator.join( - texts.astype(str) - ) - elif col != self.document_id_field: - agg_funcs[col] = "first" - # Group by document_id_field while keeping the key as a column. - joined = df_sorted.groupby(self.document_id_field, as_index=False).agg( - agg_funcs - ) - else: - joined = df.groupby(self.document_id_field, group_keys=False).apply( - self._join_segments - ) - - if self.drop_segment_id_field: - joined = joined.drop(columns=self.segment_id_field) - # Reorder the columns to match the expected metadata order. - joined = joined[expected_cols] - return joined - - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: - """ - Joins the documents back into a single document while preserving all the original fields. - """ - # Construct meta information for the transformed dataframe. - meta = dataset.df._meta.copy() - if self.text_field not in meta.columns: - meta[self.text_field] = pd.Series(dtype="object") - # If dropping the segment id field, remove it from the metadata to prevent mismatches. - if self.drop_segment_id_field: - meta = meta.drop(columns=self.segment_id_field) - expected_cols = list(meta.columns) - # Apply the join operation partition-wise. - dataset.df = dataset.df.map_partitions( - self._join_partition, expected_cols=expected_cols, meta=meta - ) - return dataset diff --git a/tests/test_joiner.py b/tests/test_joiner.py new file mode 100644 index 000000000..844beb1cd --- /dev/null +++ b/tests/test_joiner.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +from dask.dataframe.utils import assert_eq + +from nemo_curator import DocumentJoiner +from nemo_curator.datasets import DocumentDataset + + +class TestDocumentJoiner: + def test_join_default(self): + # Input represents documents already split. + # For example, a document with id=1 split as "a", "b", "c" becomes joined to "a|b|c". + # Four documents are used. + data = { + "id": [1, 1, 1, 2, 3, 3, 4, 4], + "text": ["a", "b", "c", "nosplit", "start", "middle", "end", ""], + "segment_id": [0, 1, 2, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + {"id": [1, 2, 3, 4], "text": ["a|b|c", "nosplit", "start|middle", "end|"]} + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + def test_join_custom_fields(self): + # Use custom field names: + # document id field: "doc" + # text field: "content" + # segment id field: "s_id" + # Also keep the segment id field (drop_segment_id_field=False) + data = { + "doc": [101, 101, 102, 103, 103, 104, 104], + "content": ["first", "second", "only", "hello", "world", "baz", ""], + "s_id": [0, 1, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="~", + text_field="content", + segment_id_field="s_id", + document_id_field="doc", + drop_segment_id_field=False, + ) + result_dataset = joiner(dataset) + + # Expected: each document is joined by "~". The segment id becomes the first segment's id. + expected_df = pd.DataFrame( + { + "doc": [101, 102, 103, 104], + "content": ["first~second", "only", "hello~world", "baz~"], + "s_id": [0, 0, 0, 0], + } + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) + + def test_join_max_length(self): + # Here we test joining when a maximum length is specified. + # Each segment carries a precomputed "length" value. + # The joiner should accumulate segments until adding the next one (plus separator) + # would exceed max_length=5. + # + # For document 1: + # segments: "ab"(2), "cd"(2), "ef"(2), "gh"(2) + # - "ab" then "cd": 2+2+1 = 5 → join as "ab-cd" (length 5) + # - then "ef" then "gh": 2+2+1 = 5 → join as "ef-gh" (length 5) + # + # For document 2: + # segments: "a"(1), "b"(1) → join as "a-b" (length 3) + # + # For document 3: + # segment: "hello"(5) → remains "hello" + # + # For document 4: + # segments: "x"(1), "yz"(2), "0"(1) + # - "x" then "yz": 1+2+1 = 4 → "x-yz" (length 4) + # - "0" remains alone. + data = { + "id": [1, 1, 1, 1, 2, 2, 3, 4, 4, 4], + "text": ["ab", "cd", "ef", "gh", "a", "b", "hello", "x", "yz", "0"], + "segment_id": [0, 1, 2, 3, 0, 1, 0, 0, 1, 2], + "length": [2, 2, 2, 2, 1, 1, 5, 1, 2, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="-", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + max_length=5, + length_field="length", + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + [ + {"id": 1, "text": "ab-cd", "length": 5}, + {"id": 1, "text": "ef-gh", "length": 5}, + {"id": 2, "text": "a-b", "length": 3}, + {"id": 3, "text": "hello", "length": 5}, + {"id": 4, "text": "x-yz", "length": 4}, + {"id": 4, "text": "0", "length": 1}, + ] + ) + # Sort by id and text to ensure consistent order + expected_sorted = expected_df.sort_values(by=["id", "text"]).reset_index( + drop=True + ) + result_sorted = ( + result_dataset.df.compute() + .sort_values(by=["id", "text"]) + .reset_index(drop=True) + ) + assert_eq(result_sorted, expected_sorted, check_index=False) + + def test_join_with_string_ids(self): + # Test join functionality when document id field is a string. + data = { + "doc": ["doc1", "doc1", "doc2", "doc3", "doc3", "doc4", "doc4"], + "text": ["a", "b", "nosplit", "start", "middle", "end", ""], + "segment_id": [0, 1, 0, 0, 1, 0, 1], + } + pdf = pd.DataFrame(data) + dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="doc", + drop_segment_id_field=True, + ) + result_dataset = joiner(dataset) + + expected_df = pd.DataFrame( + { + "doc": ["doc1", "doc2", "doc3", "doc4"], + "text": ["a|b", "nosplit", "start|middle", "end|"], + } + ) + assert_eq( + result_dataset.df.compute().reset_index(drop=True), + expected_df, + check_index=False, + ) diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 7eadc054c..6d987b86e 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -11,16 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import dask.dataframe as dd import pandas as pd +import pytest from dask.dataframe.utils import assert_eq +from nemo_curator import DocumentSplitter, ToBackend from nemo_curator.datasets import DocumentDataset -from nemo_curator.modules.splitter import DocumentJoiner, DocumentSplitter class TestDocumentSplitter: - def test_basic_split_default(self): + @pytest.mark.parametrize( + "backend", ["pandas", pytest.param("cudf", marks=pytest.mark.gpu)] + ) + def test_basic_split_default(self, backend): # Use default text_field "text" and segment_id_field "segment_id" # Four examples: # "a|b|c" → splits to ["a", "b", "c"] @@ -30,9 +33,16 @@ def test_basic_split_default(self): docs = ["a|b|c", "nosplit", "start|middle", "end|"] pdf = pd.DataFrame({"text": docs}) dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + to_backend = ToBackend(backend) + dataset = to_backend(dataset) + splitter = DocumentSplitter(separator="|") result_dataset = splitter(dataset) + result_df = result_dataset.df.compute() + if backend == "cudf": + result_df = result_df.to_pandas() + expected_df = pd.DataFrame( { "text": ["a", "b", "c", "nosplit", "start", "middle", "end", ""], @@ -41,12 +51,15 @@ def test_basic_split_default(self): ) # Compare without considering the index order. assert_eq( - result_dataset.df.compute().reset_index(drop=True), - expected_df, + result_df.reset_index(drop=True), + expected_df.reset_index(drop=True), check_index=False, ) - def test_split_custom_fields(self): + @pytest.mark.parametrize( + "backend", ["pandas", pytest.param("cudf", marks=pytest.mark.gpu)] + ) + def test_split_custom_fields(self, backend): # Use a custom text field name ("content") and segment id field ("seg_id") # with a different separator. # Examples: @@ -57,11 +70,18 @@ def test_split_custom_fields(self): docs = ["x;y", "single", "first;second;third", ";leading"] pdf = pd.DataFrame({"content": docs}) dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + to_backend = ToBackend(backend) + dataset = to_backend(dataset) + splitter = DocumentSplitter( separator=";", text_field="content", segment_id_field="seg_id" ) result_dataset = splitter(dataset) + result_df = result_dataset.df.compute() + if backend == "cudf": + result_df = result_df.to_pandas() + expected_df = pd.DataFrame( { "content": [ @@ -78,203 +98,7 @@ def test_split_custom_fields(self): } ) assert_eq( - result_dataset.df.compute().reset_index(drop=True), - expected_df, - check_index=False, - ) - - -class TestDocumentJoiner: - def test_join_default(self): - # Input represents documents already split. - # For example, a document with id=1 split as "a", "b", "c" becomes joined to "a|b|c". - # Four documents are used. - data = { - "id": [1, 1, 1, 2, 3, 3, 4, 4], - "text": ["a", "b", "c", "nosplit", "start", "middle", "end", ""], - "segment_id": [0, 1, 2, 0, 0, 1, 0, 1], - } - pdf = pd.DataFrame(data) - dataset = DocumentDataset.from_pandas(pdf, npartitions=1) - joiner = DocumentJoiner( - separator="|", - text_field="text", - segment_id_field="segment_id", - document_id_field="id", - drop_segment_id_field=True, - ) - result_dataset = joiner(dataset) - - expected_df = pd.DataFrame( - {"id": [1, 2, 3, 4], "text": ["a|b|c", "nosplit", "start|middle", "end|"]} - ) - assert_eq( - result_dataset.df.compute().reset_index(drop=True), - expected_df, + result_df.reset_index(drop=True), + expected_df.reset_index(drop=True), check_index=False, ) - - def test_join_custom_fields(self): - # Use custom field names: - # document id field: "doc" - # text field: "content" - # segment id field: "s_id" - # Also keep the segment id field (drop_segment_id_field=False) - data = { - "doc": [101, 101, 102, 103, 103, 104, 104], - "content": ["first", "second", "only", "hello", "world", "baz", ""], - "s_id": [0, 1, 0, 0, 1, 0, 1], - } - pdf = pd.DataFrame(data) - dataset = DocumentDataset.from_pandas(pdf, npartitions=1) - joiner = DocumentJoiner( - separator="~", - text_field="content", - segment_id_field="s_id", - document_id_field="doc", - drop_segment_id_field=False, - ) - result_dataset = joiner(dataset) - - # Expected: each document is joined by "~". The segment id becomes the first segment's id. - expected_df = pd.DataFrame( - { - "doc": [101, 102, 103, 104], - "content": ["first~second", "only", "hello~world", "baz~"], - "s_id": [0, 0, 0, 0], - } - ) - assert_eq( - result_dataset.df.compute().reset_index(drop=True), - expected_df, - check_index=False, - ) - - def test_join_max_length(self): - # Here we test joining when a maximum length is specified. - # Each segment carries a precomputed "length" value. - # The joiner should accumulate segments until adding the next one (plus separator) - # would exceed max_length=5. - # - # For document 1: - # segments: "ab"(2), "cd"(2), "ef"(2), "gh"(2) - # - "ab" then "cd": 2+2+1 = 5 → join as "ab-cd" (length 5) - # - then "ef" then "gh": 2+2+1 = 5 → join as "ef-gh" (length 5) - # - # For document 2: - # segments: "a"(1), "b"(1) → join as "a-b" (length 3) - # - # For document 3: - # segment: "hello"(5) → remains "hello" - # - # For document 4: - # segments: "x"(1), "yz"(2), "0"(1) - # - "x" then "yz": 1+2+1 = 4 → "x-yz" (length 4) - # - "0" remains alone. - data = { - "id": [1, 1, 1, 1, 2, 2, 3, 4, 4, 4], - "text": ["ab", "cd", "ef", "gh", "a", "b", "hello", "x", "yz", "0"], - "segment_id": [0, 1, 2, 3, 0, 1, 0, 0, 1, 2], - "length": [2, 2, 2, 2, 1, 1, 5, 1, 2, 1], - } - pdf = pd.DataFrame(data) - dataset = DocumentDataset.from_pandas(pdf, npartitions=1) - joiner = DocumentJoiner( - separator="-", - text_field="text", - segment_id_field="segment_id", - document_id_field="id", - drop_segment_id_field=True, - max_length=5, - length_field="length", - ) - result_dataset = joiner(dataset) - - expected_df = pd.DataFrame( - [ - {"id": 1, "text": "ab-cd", "length": 5}, - {"id": 1, "text": "ef-gh", "length": 5}, - {"id": 2, "text": "a-b", "length": 3}, - {"id": 3, "text": "hello", "length": 5}, - {"id": 4, "text": "x-yz", "length": 4}, - {"id": 4, "text": "0", "length": 1}, - ] - ) - # Sort by id and text to ensure consistent order - expected_sorted = expected_df.sort_values(by=["id", "text"]).reset_index( - drop=True - ) - result_sorted = ( - result_dataset.df.compute() - .sort_values(by=["id", "text"]) - .reset_index(drop=True) - ) - assert_eq(result_sorted, expected_sorted, check_index=False) - - def test_join_with_string_ids(self): - # Test join functionality when document id field is a string. - data = { - "doc": ["doc1", "doc1", "doc2", "doc3", "doc3", "doc4", "doc4"], - "text": ["a", "b", "nosplit", "start", "middle", "end", ""], - "segment_id": [0, 1, 0, 0, 1, 0, 1], - } - pdf = pd.DataFrame(data) - dataset = DocumentDataset.from_pandas(pdf, npartitions=1) - joiner = DocumentJoiner( - separator="|", - text_field="text", - segment_id_field="segment_id", - document_id_field="doc", - drop_segment_id_field=True, - ) - result_dataset = joiner(dataset) - - expected_df = pd.DataFrame( - { - "doc": ["doc1", "doc2", "doc3", "doc4"], - "text": ["a|b", "nosplit", "start|middle", "end|"], - } - ) - assert_eq( - result_dataset.df.compute().reset_index(drop=True), - expected_df, - check_index=False, - ) - - -class TestSplitJoinReconstruction: - def test_reconstruction_default(self): - # Create an original dataset with a unique "id" column and text examples. - # Four examples include edge cases: - # "a|b|c" → multiple splits - # "nosplit" → no separator present - # "a||b|" → consecutive separators yield empty strings - # "" → empty document - docs = ["a|b|c", "nosplit", "a||b|", ""] - pdf = pd.DataFrame({"id": [1, 2, 3, 4], "text": docs}) - original_dataset = DocumentDataset.from_pandas(pdf, npartitions=1) - - # First, split using "|" as separator. - splitter = DocumentSplitter(separator="|") - split_dataset = splitter(original_dataset) - - # Then, rejoin using the same separator. - joiner = DocumentJoiner( - separator="|", - text_field="text", - segment_id_field="segment_id", - document_id_field="id", - drop_segment_id_field=True, - ) - reconstructed_dataset = joiner(split_dataset) - - # The reconstructed "text" column should match the original. - original_sorted = ( - original_dataset.df.compute().sort_values(by="id").reset_index(drop=True) - ) - reconstructed_sorted = ( - reconstructed_dataset.df.compute() - .sort_values(by="id") - .reset_index(drop=True) - ) - assert_eq(reconstructed_sorted, original_sorted, check_index=False) diff --git a/tests/test_splitter_joiner.py b/tests/test_splitter_joiner.py new file mode 100644 index 000000000..9b7b5cec6 --- /dev/null +++ b/tests/test_splitter_joiner.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +from dask.dataframe.utils import assert_eq + +from nemo_curator import DocumentJoiner, DocumentSplitter +from nemo_curator.datasets import DocumentDataset + + +class TestSplitJoinReconstruction: + def test_reconstruction_default(self): + # Create an original dataset with a unique "id" column and text examples. + # Four examples include edge cases: + # "a|b|c" → multiple splits + # "nosplit" → no separator present + # "a||b|" → consecutive separators yield empty strings + # "" → empty document + docs = ["a|b|c", "nosplit", "a||b|", ""] + pdf = pd.DataFrame({"id": [1, 2, 3, 4], "text": docs}) + original_dataset = DocumentDataset.from_pandas(pdf, npartitions=1) + + # First, split using "|" as separator. + splitter = DocumentSplitter(separator="|") + split_dataset = splitter(original_dataset) + + # Then, rejoin using the same separator. + joiner = DocumentJoiner( + separator="|", + text_field="text", + segment_id_field="segment_id", + document_id_field="id", + drop_segment_id_field=True, + ) + reconstructed_dataset = joiner(split_dataset) + + # The reconstructed "text" column should match the original. + original_sorted = ( + original_dataset.df.compute().sort_values(by="id").reset_index(drop=True) + ) + reconstructed_sorted = ( + reconstructed_dataset.df.compute() + .sort_values(by="id") + .reset_index(drop=True) + ) + assert_eq(reconstructed_sorted, original_sorted, check_index=False) From f28141dce2f8ecb0bc25bbdfc516917c425a5d64 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 11 Feb 2025 11:08:57 -0800 Subject: [PATCH 19/22] Fix splitter and joiner call method Signed-off-by: Ryan Wolf --- nemo_curator/modules/joiner.py | 2 +- nemo_curator/modules/splitter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/joiner.py b/nemo_curator/modules/joiner.py index 8adad7779..2ecdfc80b 100644 --- a/nemo_curator/modules/joiner.py +++ b/nemo_curator/modules/joiner.py @@ -149,7 +149,7 @@ def _join_partition( joined = joined[expected_cols] return joined - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Joins the documents back into a single document while preserving all the original fields. """ diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index 4002468ff..a51f766ac 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -57,7 +57,7 @@ def _split_partition(self, df: pd.DataFrame) -> pd.DataFrame: df = df.drop(columns="split_text") return df - def __call__(self, dataset: DocumentDataset) -> DocumentDataset: + def call(self, dataset: DocumentDataset) -> DocumentDataset: """ Splits the documents into segments based on the separator and adds a column indicating the segment id. From 3e9e0a20bd09490291e928856d360d86fdef7b02 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 11 Feb 2025 14:33:14 -0800 Subject: [PATCH 20/22] Add type hint for cudf Signed-off-by: Ryan Wolf --- nemo_curator/modules/splitter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index a51f766ac..b48872bbd 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -11,10 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import pandas as pd from nemo_curator.datasets import DocumentDataset from nemo_curator.modules.base import BaseModule +from nemo_curator.utils.import_utils import gpu_only_import + +cudf = gpu_only_import("cudf") class DocumentSplitter(BaseModule): @@ -44,7 +49,9 @@ def __init__( self.text_field = text_field self.segment_id_field = segment_id_field - def _split_partition(self, df: pd.DataFrame) -> pd.DataFrame: + def _split_partition( + self, df: Union[pd.DataFrame, cudf.DataFrame] + ) -> Union[pd.DataFrame, cudf.DataFrame]: # Split the text field into segments using the separator. df["split_text"] = df[self.text_field].str.split(self.separator) # Explode the list so that each segment becomes a separate row. From d9ecba492f35e330b46cb64979c3682fab326564 Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Tue, 11 Feb 2025 15:04:30 -0800 Subject: [PATCH 21/22] Fix typing for cudf Signed-off-by: Ryan Wolf --- nemo_curator/modules/splitter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_curator/modules/splitter.py b/nemo_curator/modules/splitter.py index b48872bbd..623b4d4a0 100644 --- a/nemo_curator/modules/splitter.py +++ b/nemo_curator/modules/splitter.py @@ -50,8 +50,8 @@ def __init__( self.segment_id_field = segment_id_field def _split_partition( - self, df: Union[pd.DataFrame, cudf.DataFrame] - ) -> Union[pd.DataFrame, cudf.DataFrame]: + self, df: Union[pd.DataFrame, "cudf.DataFrame"] + ) -> Union[pd.DataFrame, "cudf.DataFrame"]: # Split the text field into segments using the separator. df["split_text"] = df[self.text_field].str.split(self.separator) # Explode the list so that each segment becomes a separate row. From 50e53e3068d63e7c3bce51b221e5fa4e7e5f996e Mon Sep 17 00:00:00 2001 From: Ryan Wolf Date: Wed, 12 Feb 2025 10:06:56 -0800 Subject: [PATCH 22/22] Address Lawrence's review Signed-off-by: Ryan Wolf --- docs/user-guide/syntheticdata.rst | 45 ++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/docs/user-guide/syntheticdata.rst b/docs/user-guide/syntheticdata.rst index 2778bdc4f..43c017c2d 100644 --- a/docs/user-guide/syntheticdata.rst +++ b/docs/user-guide/syntheticdata.rst @@ -694,13 +694,19 @@ The only exception is that a ``max_concurrent_requests`` parameter can be suppli Customize the Nemotron-CC Pipeline ----------------------------------- -Nemotron-CC used a collection of pipelines focused on rephrasing reference documents into different formats/styles. -NeMo Curator provides a synchronous and asynchronous version of each pipeline with ``nemo_curator.synthetic.NemotronCCGenerator`` and ``nemo_curator.synthetic.AsyncNemotronCCGenerator``. +Nemotron-CC is an open, large, high-quality English Common Crawl dataset that enables pretraining highly accurate LLMs over both short and long token horizons. + +You can use the Nemotron-CC pipeline collection to rewrite reference documents into different formats and styles. For example, you can rephrase short sentences with simple diction into technical, scholarly prose (like Wikipedia) or distill wandering paragraphs into condensed bulleted lists. + +NeMo Curator provides two versions of each pipeline: + +* **Synchronous**: ``nemo_curator.synthetic.NemotronCCGenerator`` +* **Asynchronous**: ``nemo_curator.synthetic.AsyncNemotronCCGenerator`` Rewrite to Wikipedia Style ########################## -The ``NemotronCCGenerator.rewrite_to_wikipedia_style`` method rewrites a document into a style that is similar to Wikipedia. +Use the ``NemotronCCGenerator.rewrite_to_wikipedia_style`` method to rewrite a document into a style that is similar to Wikipedia in terms of line spacing, punctuation, and style. .. code-block:: python @@ -735,7 +741,7 @@ The ``NemotronCCGenerator.rewrite_to_wikipedia_style`` method rewrites a documen Generate Diverse QA Pairs ######################### -The ``NemotronCCGenerator.generate_diverse_qa`` method generates a list of diverse QA pairs from a document. +Use the ``NemotronCCGenerator.generate_diverse_qa`` method to generate a list of diverse QA pairs from a document. .. code-block:: python @@ -768,7 +774,10 @@ The ``NemotronCCGenerator.generate_diverse_qa`` method generates a list of diver # Answer: The moon is made of rock and dust. -To help with cleaning the output, the ``NemotronCCDiverseQAPostprocessor`` class is provided. +Postprocessor +^^^^^^^^^^^^^ + +You can optionally use the ``NemotronCCDiverseQAPostprocessor`` class to reformat the output. .. code-block:: python @@ -795,6 +804,12 @@ To help with cleaning the output, the ``NemotronCCDiverseQAPostprocessor`` class responses = generator.generate_diverse_qa(document=document, model=model, model_kwargs=model_kwargs) postprocessor = NemotronCCDiverseQAPostprocessor(text_field="text", response_field="diverse_qa_response") dataset = DocumentDataset.from_pandas(pd.DataFrame({"text": document, "diverse_qa_response": responses})) + + # This postprocessor will sample a random number of QA pairs up to max_num_pairs. + # If a tokenizer is provided, the number of QA pairs will be sampled from at least + # 1 and at most floor(max_num_pairs * num_tokens / 150). + # Otherwise, the number of QA pairs will be sampled randomly strictly up to max_num_pairs. + # The generated QA pairs are shuffled and then appended to the original text. cleaned_dataset = postprocessor(dataset) first_entry = cleaned_dataset.df.head(1) @@ -806,7 +821,7 @@ To help with cleaning the output, the ``NemotronCCDiverseQAPostprocessor`` class Generate Knowledge List ####################### -The ``NemotronCCGenerator.generate_knowledge_list`` method generates a list of knowledge from a document. +Use the ``NemotronCCGenerator.generate_knowledge_list`` method to generate a list of knowledge from a document. .. code-block:: python @@ -839,7 +854,10 @@ The ``NemotronCCGenerator.generate_knowledge_list`` method generates a list of k # - The moon is the only natural satellite of the Earth. # ... -To help with cleaning the output, the ``NemotronCCKnowledgeListPostprocessor`` class is provided. +Postprocessor +^^^^^^^^^^^^^ + +You can optionally use the ``NemotronCCKnowledgeListPostprocessor`` class to reformat the output. .. code-block:: python @@ -877,6 +895,12 @@ To help with cleaning the output, the ``NemotronCCKnowledgeListPostprocessor`` c postprocessor = NemotronCCKnowledgeListPostprocessor(text_field="knowledge_list_response") dataset = DocumentDataset.from_pandas(pd.DataFrame({"knowledge_list_response": responses})) + + # This postprocessor removes formatting artifacts + # such as bullet point prefixes ("- ") and extra indentation from each line, + # ensuring that the final output is a clean, uniformly formatted list of knowledge items. + # The processing includes skipping any initial non-bullet line and merging related lines + # to reconstruct multi-line questions or answers. cleaned_dataset = postprocessor(dataset) first_entry = cleaned_dataset.df.head(1) @@ -888,7 +912,7 @@ To help with cleaning the output, the ``NemotronCCKnowledgeListPostprocessor`` c Distill Document ################# -The ``NemotronCCGenerator.distill`` method distills a document into a more concise form. +Use the ``NemotronCCGenerator.distill`` method to make a document more concise. .. code-block:: python @@ -923,7 +947,7 @@ The ``NemotronCCGenerator.distill`` method distills a document into a more conci Extract Knowledge ################ -The ``NemotronCCGenerator.extract_knowledge`` method extracts knowledge from a document. +Use the ``NemotronCCGenerator.extract_knowledge`` method to extract knowledge from a document. .. code-block:: python @@ -938,7 +962,8 @@ The ``NemotronCCGenerator.extract_knowledge`` method extracts knowledge from a d client = OpenAIClient(openai_client) generator = NemotronCCGenerator(client) - document = "The moon is bright. It shines at night." + document = ("The moon is bright. It shines at night. I love the moon. I first saw it up" + " close through a telescope in 1999 at a sleepover.") model = "nv-mistralai/mistral-nemo-12b-instruct" model_kwargs = { "temperature": 0.5,