Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rwkv x eagle notebooks #75

Merged
merged 31 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4ddbf83
change logging format, to reduce confusion
pic-o Feb 3, 2024
b3a32c0
move to experiment folder
pic-o Feb 3, 2024
695545d
preparing experiment notebooks
pic-o Feb 3, 2024
d4e2d3e
Added notes on runnign the runner
pic-o Feb 3, 2024
fea8459
Merge remote-tracking branch 'origin/main' into rwkv-x-eagle-notebooks
pic-o Feb 3, 2024
d9a6211
include any option
pic-o Feb 3, 2024
3d5fc0f
wip benchmarks
PicoCreator Feb 4, 2024
9c7744c
tweak
PicoCreator Feb 4, 2024
59dc19e
Merge branch 'rwkv-x-eagle-notebooks' of https://github.com/RWKV/RWKV…
PicoCreator Feb 4, 2024
c42c7c0
Fixing the LR for batched
pic-o Feb 4, 2024
fe9bd2c
Merge branch 'rwkv-x-eagle-notebooks' of https://github.com/RWKV/RWKV…
pic-o Feb 4, 2024
475558d
fixing mask sum calc
pic-o Feb 4, 2024
0fd1387
wip calibration
PicoCreator Feb 4, 2024
a2443d7
enwiki 16k test
PicoCreator Feb 4, 2024
85ce0c1
Update notebook title in enwiki-16k-3e-5.ipynb
PicoCreator Feb 4, 2024
18f52d0
Fixing multipack
PicoCreator Feb 4, 2024
ade9b35
loss validation run
PicoCreator Feb 4, 2024
efced7f
WIP datapack fixing code
PicoCreator Feb 4, 2024
e5deb69
Update learning rate initialization and finalization values
PicoCreator Feb 4, 2024
5e1b715
WIP benchmarks
PicoCreator Feb 4, 2024
6d7ca1c
drop dataset_name and dataset_index
PicoCreator Feb 4, 2024
dd9c79d
Merge branch 'rwkv-x-eagle-notebooks' of https://github.com/RWKV/RWKV…
PicoCreator Feb 4, 2024
0ee162d
the datapack code
PicoCreator Feb 4, 2024
44fb0be
support custom dataset split
PicoCreator Feb 4, 2024
6f1e02e
config update
PicoCreator Feb 4, 2024
601986c
WIP tweaks
PicoCreator Feb 4, 2024
dfccbc7
wip MultiPack train
PicoCreator Feb 4, 2024
74deedb
tweaks
PicoCreator Feb 6, 2024
429c8c1
prototype train/test split swap - because sometimes you need that
pic-o Feb 6, 2024
0654172
Merge branch 'rwkv-x-eagle-notebooks' of https://github.com/RWKV/RWKV…
pic-o Feb 6, 2024
e02907a
fixing multi-gpu sync
pic-o Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RWKV-v5/config-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,10 @@ data:
# source: "teven/enwiki_00k" # Hugging face dataset
# source: text # Text mode, used with source_data_dir

# Dataset split to use from HF dataset
# ---
# source_dataset_split: train

# Additional source dataset params, used to grab subsets of the dataset
# ---
# source_dataset_params:
Expand Down Expand Up @@ -419,6 +423,7 @@ data:

# Custom text column to use, useful for dataset with alternative training columns labels
# This is checked before multi column merging, default is null (disabled)
# If set this takes priority
# eg: 'code'
# ---
# custom_text_key: 'code'
Expand Down
11 changes: 8 additions & 3 deletions RWKV-v5/datapack-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ datapack:

# Mixing mode to use, this is used to alternate between datasets
#
# - batch : Meaning one dataset worth per batch, partial batches are discarded
# - sample : Dataset is mixed on a per sample level
mixing_mode: "batch"
# - concat : Keep It Simple Silly, lets just concat the datasets together
# - shuffle : Dataset is mixed on a per sample level
#
# (@TODO: Advance operations)
# - batch : Meaning one dataset worth per batch, partial batches are discarded
mixing_mode: "shuffle"

# # Mixing distribution to use
# # - weighted : Dataset batches/mixture is distrbuted randomly, but weighted by dataset size
# # - uniform : Dataset batches/mixture is distrbuted randomly, but with uniform probability
# distribution: "weighted"

# (@TODO: Advance operations)
#
# Mixed batch percentage
#
# % of batches which will contain a mixture of records from multiple datasets
Expand Down
418 changes: 232 additions & 186 deletions RWKV-v5/src/data.py

Large diffs are not rendered by default.

83 changes: 43 additions & 40 deletions RWKV-v5/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,14 +812,14 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid
assert isinstance(seq, torch.Tensor) and seq.ndim == 2
ori_seq_mask = batch['attention_mask']

# Get the dataset index
dataset_index = 0
dataset_name = "dataset_0"
if "dataset_index" in batch:
dataset_index = batch["dataset_index"]
dataset_name = f"dataset_{dataset_index}"
if "dataset_name" in batch and dataset_name is not None:
dataset_name = batch["dataset_name"]
# # Get the dataset index
# dataset_index = 0
# dataset_name = "dataset_0"
# if "dataset_index" in batch:
# dataset_index = batch["dataset_index"]
# dataset_name = f"dataset_{dataset_index}"
# if "dataset_name" in batch and dataset_name is not None:
# dataset_name = batch["dataset_name"]

# Check if attent mask is set, if not initialize it
if ori_seq_mask is None or ori_seq_mask.ndim != 2:
Expand Down Expand Up @@ -913,21 +913,24 @@ def compute_loss(self, batch, batch_idx, is_training_run: bool = False, is_valid

# If total_mask_sum, we skip, as there is no tokens of value to learn from anyway
total_mask_sum = torch.sum(seq_mask)
# Do a quick return, if there is no tokens of value to learn from due to full masking
if num_devices > 1 and total_mask_sum == 0:
return 0
avg_mask_sum = ( total_mask_sum / B )

# # Do a quick return, if there is no tokens of value to learn from due to full masking
# # DO NOT DO THIS : This causes multi node / multi GPU to go out of sync
# if num_devices <= 1 and total_mask_sum == 0:
# return 0

# Checkpoint steps
def checkpointed_step(idx, targets, mask, last_shift_states,
last_wkv_states):
# Skip if there is no tokens of value to learn from
if idx.shape[1] == 0:
# Prepare dummy loss
train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
sample_loss = train_loss.clone().detach().requires_grad_(False)
# # Skip if there is no tokens of value to learn from
# if idx.shape[1] == 0:
# # Prepare dummy loss
# train_loss = torch.tensor(0, dtype=self.emb.weight.dtype).requires_grad_()
# sample_loss = train_loss.clone().detach().requires_grad_(False)

# Return the checkpoint values
return sample_loss, train_loss, last_shift_states, last_wkv_states, 0
# # Return the checkpoint values
# return sample_loss, train_loss, last_shift_states, last_wkv_states, 0

# Get the logits, and the new states
logits, new_shift_states, new_wkv_states = self(
Expand All @@ -947,7 +950,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,

# to encourage the logits to be close to 0
# factor_divisor is typically the total token count
L2Wrap_factor = 1e-4 / total_mask_sum
L2Wrap_factor = 1e-4 / avg_mask_sum

# Submask count
submask_count = torch.sum(submask)
Expand Down Expand Up @@ -983,7 +986,7 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
train_token_count = torch.sum(train_mask)

# Adjust the factor accordingly
L2Wrap_factor = L2Wrap_factor * (submask_count / train_token_count)
# L2Wrap_factor = L2Wrap_factor * (submask_count / train_token_count)

else:
train_loss = torch.sum(token_loss * submask) / total_mask_sum
Expand Down Expand Up @@ -1254,20 +1257,20 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# Log the line values
wandb.log({
# The original loss and ctx_len (averaged by batch size)
'train/ctx_len': ctx_len,
'train/data_ctxlen': ctx_len,
'train/data_loss': sampling_loss,
# "train/dataset_index": dataset_index,

# The selective training tokens, and loss
'train/tokens': tokens,
'train/loss': training_loss,
"train/dataset_index": dataset_index,
'train/learn_tokens': tokens,
'train/learn_loss': training_loss,

# Dataset based tracking
f'dataset/train/{dataset_index}.loss': training_loss,
f'dataset/train/{dataset_index}.data_loss': sampling_loss,
f'dataset/train/{dataset_index}.tokens': tokens,
f'dataset/train/{dataset_index}.ctx_len': ctx_len,
f'dataset/train/{dataset_index}.name': dataset_name,
# # Dataset based tracking (not working)
# f'dataset/train/{dataset_index}.loss': training_loss,
# f'dataset/train/{dataset_index}.data_loss': sampling_loss,
# f'dataset/train/{dataset_index}.tokens': tokens,
# f'dataset/train/{dataset_index}.ctx_len': ctx_len,
# f'dataset/train/{dataset_index}.name': dataset_name,

# Perf tracking
f'perf/kTokens_per_sec.gpu.{global_rank}': self._counting_tokens / max(time.time() - self._counting_time_start, 1),
Expand All @@ -1286,19 +1289,19 @@ def checkpointed_step(idx, targets, mask, last_shift_states,
# Log the line values
wandb.log({
# The original loss and ctx_len (averaged by batch size)
'validation/ctx_len': T,
'validation/data_ctxlen': T,
'validation/data_loss': sampling_loss,
# "validation/dataset_index": dataset_index,

# The selective training tokens, and loss
'validation/tokens': training_tokens,
'validation/loss': training_loss,
"validation/dataset_index": dataset_index,

# Dataset based tracking
f'dataset/validation/{dataset_index}.loss': training_loss,
f'dataset/validation/{dataset_index}.data_loss': sampling_loss,
f'dataset/validation/{dataset_index}.ctx_len': T,
f'dataset/validation/{dataset_index}.name': dataset_name,
'validation/learn_tokens': training_tokens,
'validation/learn_loss': training_loss,

# # Dataset based tracking (not working)
# f'dataset/validation/{dataset_index}.loss': training_loss,
# f'dataset/validation/{dataset_index}.data_loss': sampling_loss,
# f'dataset/validation/{dataset_index}.ctx_len': T,
# f'dataset/validation/{dataset_index}.name': dataset_name,

# Step and trainer tracking
'global_rank': global_rank,
Expand Down
6 changes: 3 additions & 3 deletions docker/github-worker-cuda-11-8/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}" \
--replace \
--labels "nolane,${CUDA_VER},${RUNNER_LABELS}"
--labels "nolane,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand All @@ -41,7 +41,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}-lane1" \
--replace \
--labels "lane1,${CUDA_VER},${RUNNER_LABELS}"
--labels "lane1,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand All @@ -55,7 +55,7 @@ else
--token "${RUNNER_TOKEN}" \
--name "${RUNNER_NAME}-lane2" \
--replace \
--labels "lane2,${CUDA_VER},${RUNNER_LABELS}"
--labels "lane2,any-gpu,gpu-count-any,${CUDA_VER},${RUNNER_LABELS}"

# Run it in background, and get the PID
./run.sh &
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-ALMA-prompt-completion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-capybara-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-openhermes1-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
4 changes: 2 additions & 2 deletions notebook/finetune-example/Eagle-x-textbooks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ model:
load_model: ../model/L6-D512-neox-init.pth

# Starting and ending learning rate
lr_init: 1e-5
lr_final: 1e-5
lr_init: 3e-5
lr_final: 3e-5

# Training context length, note that the dataset can be
# larger then the context size, in which the trainer
Expand Down
87 changes: 56 additions & 31 deletions notebook/finetune-example/Eagle-x-zMultipack-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,14 @@ datapack:
# secret: <example S3 secret>
# endpoint_url: <example S3 endpoint>

# Batch size to use to alternate between datasets
# This should be a multiple of the GPU and node count
#
# Uses, `8 * (3 * 4 * 5 * 6 * 7) = 20160` for default, as it should align across
# a large number of batch size combinations. This helps reduce the amount of
# misaligned batches, and thus reduce the amount of wasted training time
batchsize: 512

# Mixing mode to use, this is used to alternate between datasets
#
# - batch : Meaning one dataset worth per batch, partial batches are discarded
# - sample : Dataset is mixed on a per sample level
mixing_mode: "batch"

# # Mixing distribution to use
# # - weighted : Dataset batches/mixture is distrbuted randomly, but weighted by dataset size
# # - uniform : Dataset batches/mixture is distrbuted randomly, but with uniform probability
# distribution: "weighted"

# Mixed batch percentage
#
# % of batches which will contain a mixture of records from multiple datasets
# instad of limiting each batch to a single dataset
# - concat : Keep It Simple Silly, lets just concat the datasets together
# - shuffle : Dataset is mixed on a per sample level
#
# Use 0, to disable mixed batches, sampled mixing_mode is the equavalent of mixed batch 1.0
#
# NOTE: This is a guideline percentage, and is not guaranteed to be exact
# if a partial batch is built, it may get converted to a mixed batch
mixed_batch_percentage: 0.5
# (@TODO: Advance operations)
# - batch : Meaning one dataset worth per batch, partial batches are discarded
mixing_mode: "shuffle"

#
# Default settings used across all datasets in the datapack
Expand Down Expand Up @@ -115,7 +94,7 @@ default:
# If given an int value, the number of data sample is used.
#
# Due to the limitaitons in the trainer process, there is always a minimum of 1 test sample
test_split: 8 # Intentionally set to a low sample for test, cause the real eval is humans
test_split: 0.01 # Intentionally set to a low sample for test, cause the real eval is humans
test_split_shuffle: true

# Tokenizer to use, use either the inbuilt 'neox', or 'world' tokenizer
Expand Down Expand Up @@ -265,10 +244,14 @@ dataset:

- # Text book is all you need
# https://huggingface.co/datasets/TanvirOnHF/muse_textbooks
source: "TanvirOnHF/muse_textbooks"
source: "teven/enwiki_100k"

# Optional, provide a name for the dataset
name: "muse_textbooks"
name: "enwiki_100k"

# Minimum / Maximum token size of the dataset to use
min_token_size: 1024
max_token_size: -1

# Various over write settings
# ---
Expand All @@ -277,6 +260,29 @@ dataset:
packing_enable: False
max_token_size: -1

- # SuperWiki (Multi-lingual)
# https://huggingface.co/datasets/RyokoExtra/SuperWIKI-Cleaned
source: "RyokoExtra/SuperWIKI-Cleaned"

# Optional, provide a name for the dataset
name: "super_wiki"

# Various over write settings
# ---
text_rechunk_size: 32768
text_rechunk_force: true
packing_enable: False
max_token_size: -1

source_dataset_split: lang25

# Custom text column to use, useful for dataset with alternative training columns labels
# This is checked before multi column merging, default is null (disabled)
# If set this takes priority
# eg: 'code'
# ---
custom_text_key: 'text'

# All other settings found in default can be overriden here
# ---
# ...
Expand All @@ -297,8 +303,7 @@ dataset:
# https://huggingface.co/datasets/kristaller486/ALMA-prompt-completion
source: "kristaller486/ALMA-prompt-completion"
name: "ALMA-prompt-completion"

# Prompt completion, notiong else
# Prompt completion, nothing else else

- # Instruct, input, output format
# https://huggingface.co/datasets/teknium/openhermes
Expand Down Expand Up @@ -380,6 +385,26 @@ dataset:
conversation_input_key_mask: {'input': false, 'output': true}
conversation_sender_suffix: {'input': "", 'output': ""}

######################################################
# Note: We found the ML generated textbooks
# too low in perplexity that it hurts the model
# so we are using the original enwiki_100k & superwiki
######################################################
# - # Text book is all you need
# # https://huggingface.co/datasets/TanvirOnHF/muse_textbooks
# source: "TanvirOnHF/muse_textbooks"

# # Optional, provide a name for the dataset
# name: "muse_textbooks"

# # Various over write settings
# # ---
# text_rechunk_size: 32768
# text_rechunk_force: true
# packing_enable: False
# max_token_size: -1
######################################################




Expand Down
Loading
Loading