Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nemotron CC SDG Pipelines and Pre-processing/Post-Processing Stages #527

Merged
merged 23 commits into from
Feb 12, 2025

Conversation

ryantwolf
Copy link
Collaborator

@ryantwolf ryantwolf commented Feb 7, 2025

Description

  • Adds the preprocessing and postprocessing stages for the following SDG pipelines from Nemotron-CC
    • Wikipedia-style rephrasing
    • Diverse QA generation
    • Distillation
    • Knowledge extraction
    • Knowledge listing

Usage

The full pipelines for all of them are too much for a PR description, but here is the full pre and postprocessing pipeline for wikipedia style rephrasing.

from nemo_curator import (
    DocumentSplitter,
    DocumentJoiner,
    get_client,
    Sequential,
    ScoreFilter,
    Modify,
    Filter,
)
from nemo_curator.datasets import DocumentDataset
from nemo_curator.filters import TokenCountFilter, SubstringFilter
from nemo_curator.modifiers import (
    QuotationRemover,
    MarkdownRemover,
    Slicer,
)
from nemo_curator.services import OpenAIClient
from nemo_curator.synthetic import NemotronCCGenerator
from transformers import AutoTokenizer
from openai import OpenAI
from nemo_curator.synthetic.prompts import (
    NEMOTRON_CC_SYSTEM_PROMPT,
    WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE,
)


def get_prefix_token_count(
    tokenizer: AutoTokenizer, system_prompt: str, user_prompt_template: str
):
    user_prompt = user_prompt_template.format(document="placeholder")
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    prefix_tokens = tokenizer.apply_chat_template(messages)

    return len(prefix_tokens)


def build_preprocessing_pipeline(
    tokenizer: AutoTokenizer,
    text_field: str,
    system_prompt: str,
    user_prompt_template: str,
    min_document_tokens: int,
    min_segment_tokens: int,
    max_input_tokens: int,
):
    # Construct filters for document filtering
    document_token_count_filter = TokenCountFilter(
        tokenizer=tokenizer, min_tokens=min_document_tokens
    )

    # Construct filters for segment filtering
    prefix_token_count = get_prefix_token_count(
        tokenizer, system_prompt, user_prompt_template
    )
    max_segment_tokens = max_input_tokens - prefix_token_count - 2
    long_segment_token_count_filter = TokenCountFilter(
        tokenizer=tokenizer, max_tokens=max_segment_tokens
    )
    short_segment_token_count_filter = TokenCountFilter(
        tokenizer=tokenizer, min_tokens=min_segment_tokens
    )

    preprocessing_pipeline = Sequential(
        [
            # Filter out documents that are too short
            ScoreFilter(
                document_token_count_filter,
                text_field=text_field,
                score_field="document_token_count",
                score_type=int,
            ),
            # Split documents into segments
            DocumentSplitter(
                separator="\n", text_field=text_field, segment_id_field="segment_id"
            ),
            # Filter out segments that are too long
            ScoreFilter(
                long_segment_token_count_filter,
                text_field=text_field,
                score_field="segment_token_count",
                score_type=int,
            ),
            # Join adjacent short segments
            DocumentJoiner(
                separator="\n",
                text_field=text_field,
                segment_id_field="segment_id",
                document_id_field="id",
                max_length=max_segment_tokens,
                length_field="segment_token_count",
                drop_segment_id_field=False,
            ),
            # Filter out segments that are too short even after joining
            Filter(
                short_segment_token_count_filter.keep_document,
                filter_field="segment_token_count",
            ),
        ]
    )

    return preprocessing_pipeline


def build_wikipedia_postprocessing_pipeline(
    tokenizer: AutoTokenizer, rephrased_field: str
):
    MAX_REPHRASED_TOKENS = 510
    MIN_DOCUMENT_TOKENS = 50

    long_segment_token_count_filter = TokenCountFilter(
        tokenizer=tokenizer, max_tokens=MAX_REPHRASED_TOKENS
    )
    document_token_count_filter = TokenCountFilter(
        tokenizer=tokenizer, min_tokens=MIN_DOCUMENT_TOKENS
    )
    postprocessing_pipeline = Sequential(
        [
            # Filter by token count
            ScoreFilter(
                long_segment_token_count_filter,
                text_field=rephrased_field,
                score_field="rephrased_segment_token_count",
                score_type=int,
            ),
            # Remove markdown formatting
            Modify(MarkdownRemover(), text_field=rephrased_field),
            # Remove documents not starting with the specified prefix
            ScoreFilter(
                SubstringFilter(
                    substring="Here is a paraphrased version:", position="prefix"
                ),
                text_field=rephrased_field,
                score_field="substring",
                score_type=int,
            ),
            # Remove the paraphrase prefix
            Modify(
                Slicer(
                    left="Here is a paraphrased version:",
                    include_left=False,
                    strip=True,
                ),
                text_field=rephrased_field,
            ),
            # Remove quotation marks
            Modify(QuotationRemover(), text_field=rephrased_field),
            # Concat paragraphs belonging to the same document
            DocumentJoiner(
                separator="\n",
                text_field=rephrased_field,
                segment_id_field="segment_id",
                document_id_field="id",
            ),
            # Filter out documents that are too short
            ScoreFilter(
                document_token_count_filter,
                text_field=rephrased_field,
                score_field="rephrased_document_token_count",
                score_type=int,
            ),
        ]
    )

    return postprocessing_pipeline


def wikipedia_rephraser():
    _ = get_client()
    input_data_path = "/path/to/input/"
    output_path = "/path/to/output/"
    text_field = "text_field_name"
    rephrased_field = "rewritten_field_name"

    tokenizer = AutoTokenizer.from_pretrained("<insert-tokenizer>")

    openai_client = OpenAI(
        base_url="https://integrate.api.nvidia.com/v1",
        api_key="<insert-api-key>",  # NV DEV API KEY
    )
    client = OpenAIClient(openai_client)
    nemotron_cc = NemotronCCGenerator(client)
    api_model_name = "nv-mistralai/mistral-nemo-12b-instruct"

    dataset = DocumentDataset.read_json(input_data_path)

    MIN_DOCUMENT_TOKENS = 30
    MIN_SEGMENT_TOKENS = 10
    MAX_INPUT_TOKENS = 512
    preprocessing_pipeline = build_preprocessing_pipeline(
        tokenizer,
        text_field,
        NEMOTRON_CC_SYSTEM_PROMPT,
        WIKIPEDIA_REPHRASING_PROMPT_TEMPLATE,
        MIN_DOCUMENT_TOKENS,
        MIN_SEGMENT_TOKENS,
        MAX_INPUT_TOKENS,
    )

    dataset = preprocessing_pipeline(dataset)

    first_entries = dataset.df.head()
    print(first_entries)

    MAX_OUTPUT_TOKENS = 512
    TOP_K = 0
    TOP_P = 0.9
    END_STRINGS = "['</s>']"
    TEMPERATURE = 0.5
    rewritten_texts = []
    for text in first_entries["EN"]:
        rewritten_text = nemotron_cc.rewrite_to_wikipedia_style(
            text,
            api_model_name,
            model_kwargs={
                "top_k": TOP_K,
                "top_p": TOP_P,
                "stop": END_STRINGS,
                "max_tokens": MAX_OUTPUT_TOKENS,
                "temperature": TEMPERATURE,
            },
        )
        rewritten_texts.append(rewritten_text[0])

    first_entries[rephrased_field] = rewritten_texts

    rephrased_dataset = DocumentDataset.from_pandas(first_entries)
    postprocessed_pipeline = build_wikipedia_postprocessing_pipeline(
        tokenizer, rephrased_field
    )
    rephrased_dataset = postprocessed_pipeline(rephrased_dataset)

    rephrased_dataset.to_json(output_path)

Checklist

  • I am familiar with the Contributing Guide.
  • New or Existing tests cover these changes.
  • The documentation is up to date with these changes.

@ryantwolf ryantwolf added the gpuci Run GPU CI/CD on PR label Feb 7, 2025
Signed-off-by: Ryan Wolf <[email protected]>
@ryantwolf ryantwolf marked this pull request as ready for review February 7, 2025 19:34
Copy link
Collaborator

@VibhuJawa VibhuJawa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor feedback around modules, mostly looks good to me.

nemo_curator/filters/heuristic_filter.py Show resolved Hide resolved
nemo_curator/modules/splitter.py Outdated Show resolved Hide resolved
nemo_curator/modules/splitter.py Outdated Show resolved Hide resolved
Comment on lines +24 to +26
- If the document starts and ends with a quotation mark and there are
newlines in the document, the quotation marks are removed only if
the first line does not end with a quotation mark.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this lead in a weird result?

"""
'hello my name is bla,
what's your name?'
is your name 'xyz'?
"""
Here the quotations are quoting 'hello my name is bla, \n what's your name?' and xyz

But the removal logic will output quotations such that the quoted phrase is is your name

Copy link
Collaborator Author

@ryantwolf ryantwolf Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry the formatting in your example is a bit weird and I'm having trouble understanding your question. Is your example meant to be one document or multiple documents? The ending ? would cause no modifications to be made on this document. From how I am reading it I see:

example="""""hello my name is bla,
what's your name?"
is your name "xyz"?
"""

Please let me know if I have misread this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad on the formatting, I mistakenly also added the question mark

example=""""hello my name is bla,
what's your name?"
is your name "xyz"
"""

In this example first line first char and last line last char are quotation marks, so the output will be

example="""hello my name is bla,
what's your name?"
is your name "xyz

i.e initially the quoted phrase was "hello my name is bla, \n what's your name?" and "xyz", but once we remove the punctuation as per the algorithm, the output will be "is your name "

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha thanks for clarifying. You're correct that this probably will happen, but I don't think that format of response is common in the Nemotron-CC SDG pipelines this filter was applied in. A lot of these filters have weird data edgecases, but so long as folks selectively apply them and look at their data to see the results it should be fine.

Copy link
Collaborator

@VibhuJawa VibhuJawa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Just a minor nit around type hints

nemo_curator/modules/splitter.py Outdated Show resolved Hide resolved
Signed-off-by: Ryan Wolf <[email protected]>
Signed-off-by: Ryan Wolf <[email protected]>
Copy link
Contributor

@lbliii lbliii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second person + some subheader suggestions.

I think these could potentially benefit from use case examples too to help highlight their differences/when to use them.

docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
docs/user-guide/syntheticdata.rst Outdated Show resolved Hide resolved
Signed-off-by: Ryan Wolf <[email protected]>
@ryantwolf ryantwolf merged commit eb8b613 into main Feb 12, 2025
4 checks passed
@ryantwolf ryantwolf deleted the rywolf/nemotron-cc-sdg branch February 12, 2025 18:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gpuci Run GPU CI/CD on PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants