Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[post training] define llama stack post training dataset format #717

Merged
merged 19 commits into from
Jan 14, 2025
6 changes: 6 additions & 0 deletions llama_stack/apis/common/type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class AgentTurnInputType(BaseModel):
type: Literal["agent_turn_input"] = "agent_turn_input"


class DialogType(BaseModel):
# expects List[Message] for messages
SLR722 marked this conversation as resolved.
Show resolved Hide resolved
# this type semantically contains the output label whereas ChatCompletionInputType does not
type: Literal["dialog"] = "dialog"


ParamType = register_schema(
Annotated[
Union[
Expand Down
7 changes: 7 additions & 0 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ class OptimizerType(Enum):
sgd = "sgd"


@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
dialog = "dialog"


@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
52 changes: 52 additions & 0 deletions llama_stack/providers/inline/post_training/common/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
StringType,
)
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)

EXPECTED_DATASET_SCHEMA = {
"instruct": [
{
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
ColumnName.expected_answer.value: StringType(),
}
],
"dialog": [
{
ColumnName.dialog.value: DialogType(),
}
],
}


async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")

validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict

import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model

from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages

from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform

from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets


class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
from llama_stack.apis.post_training import DatasetFormat


class ModelConfig(BaseModel):
Expand All @@ -41,10 +34,6 @@ class ModelConfig(BaseModel):
checkpoint_type: str


class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]


MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
Expand All @@ -58,26 +47,11 @@ class DatasetSchema(BaseModel):
),
}

DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}

EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)

BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
Expand Down Expand Up @@ -124,19 +98,5 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type


async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")

if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Mapping

from llama_stack.providers.utils.common.data_schema_validator import ColumnName


def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))

assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]

assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]

return {
"input": input,
"output": output,
}


def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value]))

assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)

assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"

return {"conversations": conversations}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform

from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
llama_stack_chat_to_torchtune_chat,
llama_stack_instruct_to_torchtune_instruct,
)


class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
self._dataset_type = dataset_type

def __len__(self):
return len(self._rows)
Expand All @@ -39,6 +46,12 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":
sample = llama_stack_chat_to_torchtune_chat(sample)
else:
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.data import padded_collate_sft

from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
Expand Down Expand Up @@ -47,6 +47,9 @@
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR

from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)

from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
Expand Down Expand Up @@ -129,8 +132,10 @@ def model_checkpoint_dir(model) -> str:
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._train_on_input = training_config.data_config.train_on_input

# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
Expand Down Expand Up @@ -354,18 +359,17 @@ async def fetch_rows(dataset_id: str):
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows

# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
dataset_type=self._data_format.value,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)

sampler = DistributedSampler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ColumnName(Enum):
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
dialog = "dialog"


VALID_SCHEMAS_FOR_SCORING = [
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/templates/experimental-post-training/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ distribution_spec:
post_training:
- inline::torchtune
datasetio:
- inline::localfs
- remote::huggingface
telemetry:
- inline::meta-reference
Expand All @@ -22,4 +23,6 @@ distribution_spec:
- inline::llama-guard
memory:
- inline::faiss
tool_runtime:
- remote::brave-search
image_type: conda
Loading
Loading