Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[post training] define llama stack post training dataset format #717

Merged
merged 19 commits into from
Jan 14, 2025
10 changes: 10 additions & 0 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,24 @@ class OptimizerType(Enum):
sgd = "sgd"


@json_schema_type
class DatasetFormat(Enum):
alpaca = "alpaca"
instruct = "instruct"
chat_sharegpt = "chat_sharegpt"
chat_openai = "chat_openai"


@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
column_map: Optional[Dict[str, str]] = None


@json_schema_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,37 @@
# the root directory of this source tree.

from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional

import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model

from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat

from pydantic import BaseModel
from torchtune.data._messages import (
AlpacaToMessages,
InputOutputToMessages,
OpenAIToMessages,
ShareGPTToMessages,
)

from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform


class ColumnName(Enum):
SLR722 marked this conversation as resolved.
Show resolved Hide resolved
instruction = "instruction"
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"


class ModelConfig(BaseModel):
Expand All @@ -41,6 +52,9 @@ class ModelConfig(BaseModel):

class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
instruct: List[Dict[str, ParamType]]
chat_sharegpt: List[Dict[str, ParamType]]
chat_openai: List[Dict[str, ParamType]]


MODEL_CONFIGS: Dict[str, ModelConfig] = {
Expand All @@ -56,6 +70,13 @@ class DatasetSchema(BaseModel):
),
}

DATA_FORMATS: Dict[str, Transform] = {
"alpaca": AlpacaToMessages,
"instruct": InputOutputToMessages,
"chat_sharegpt": ShareGPTToMessages,
"chat_openai": OpenAIToMessages,
}


EXPECTED_DATASET_SCHEMA = DatasetSchema(
SLR722 marked this conversation as resolved.
Show resolved Hide resolved
alpaca=[
Expand All @@ -74,7 +95,23 @@ class DatasetSchema(BaseModel):
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
],
instruct=[
{
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
}
],
chat_sharegpt=[
{
ColumnName.conversations.value: StringType(),
}
],
chat_openai=[
{
ColumnName.messages.value: StringType(),
}
],
)

BuildLoraModelCallable = Callable[..., torch.nn.Module]
Expand Down Expand Up @@ -122,10 +159,15 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type


async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]


async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
column_map: Optional[Dict[str, str]] = None,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
Expand All @@ -134,7 +176,21 @@ async def validate_input_dataset_schema(
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")

if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
dataset_schema = {}

if column_map:
for old_col_name in dataset_def.dataset_schema.keys():
if old_col_name in column_map.values():
new_col_name = next(
k for k, v in column_map.items() if v == old_col_name
)
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema = dataset_def.dataset_schema

if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.data import padded_collate_sft

from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
Expand Down Expand Up @@ -129,8 +129,10 @@ def model_checkpoint_dir(model) -> str:
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._column_map = training_config.data_config.column_map

# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
Expand Down Expand Up @@ -360,11 +362,15 @@ async def fetch_rows(dataset_id: str):
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
dataset_type=self._data_format.value,
column_map=self._column_map,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
message_transform=data_transform(
train_on_input=False, column_map=self._column_map
),
model_transform=tokenizer,
)

Expand Down
Loading