Skip to content

Commit

Permalink
Merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Feb 7, 2025
2 parents 360712f + f4157be commit 73a7b40
Show file tree
Hide file tree
Showing 38 changed files with 51,560 additions and 251 deletions.
4 changes: 2 additions & 2 deletions examples_deepspeed/finetune_hf_llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggin

#### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model
```bash
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert_hf2mds
```
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.```convert_mds2hf``` can convert a Megatron-Deepspeed model into the Hugging Face format

#### 2. Fine-tuning Process
```bash
Expand Down
5 changes: 5 additions & 0 deletions examples_deepspeed/finetune_hf_llama/ds_config_empty.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"train_batch_size" : 256,
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 100
}
33 changes: 26 additions & 7 deletions examples_deepspeed/finetune_hf_llama/finetune_llama.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json
DATASET_PATH=./alpaca_data.json
DATASET_PATH=./examples_deepspeed/finetune_hf_llama/alpaca_data.json
# dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json

HF_LLAMA_PATH=/data/llama-7b/
HF_LLAMA_PATH=/data/llama-2-7b-hf/
# weights link: https://huggingface.co/huggyllama/llama-7b

MICRO_BATCH_SIZE=16
Expand Down Expand Up @@ -43,12 +43,28 @@ cat <<EOT > $DS_CONFIG
}
EOT

if [ "$1" = "convert_hf2mds" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
elif [ "$1" = "convert_mds2hf" ]; then
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config_empty.json"
else
DS_CONFIG_PATH="./examples_deepspeed/finetune_hf_llama/ds_config.json"
fi

covert_args="deepspeed tools/hf2megads_weight_converter.py \
covert_hf2mds_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
--origin-hf-ckpt-dir $HF_LLAMA_PATH \
--hf-ckpt-dir $HF_LLAMA_PATH \
--load-mode auto \
--save $MEGA_DS_LLAMA_PATH"

covert_mds2hf_args="deepspeed tools/hf2megads_weight_converter.py \
--hf-ckpt-num-shards 2 \
--hf-ckpt-dir $HF_LLAMA_PATH \
--load-mode auto \
--to-hf-ckpt \
--load $MEGA_DS_LLAMA_PATH \
--save $HF_LLAMA_PATH'-hf-out' "

finetune_args="deepspeed finetune_llama.py \
--load $MEGA_DS_LLAMA_PATH"

Expand All @@ -60,6 +76,7 @@ comm_args="--tensor-model-parallel-size $TP \
--num-layers $NUM_LAYERS \
--hidden-size $HIDDEN_SIZE \
--num-attention-heads $NUM_HEADS \
--finetune \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--attention-dropout 0 \
--hidden-dropout 0 \
Expand Down Expand Up @@ -88,7 +105,7 @@ comm_args="--tensor-model-parallel-size $TP \
--zero-stage 0 \
--tokenizer-type HFTokenizer \
--tokenizer-model $HF_LLAMA_PATH \
--deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \
--deepspeed_config $DS_CONFIG_PATH \
--deepspeed \
--distributed-backend nccl \
--num-workers 0 \
Expand All @@ -98,8 +115,10 @@ comm_args="--tensor-model-parallel-size $TP \
--no-gradient-accumulation-fusion \
--repeated-dataloader"

if [ "$1" = "convert" ]; then
task_args="$covert_args"
if [ "$1" = "convert_hf2mds" ]; then
task_args="$covert_hf2mds_args"
elif [ "$1" = "convert_mds2hf" ]; then
task_args="$covert_mds2hf_args"
else
task_args="$finetune_args"
fi
Expand Down
Loading

0 comments on commit 73a7b40

Please sign in to comment.