Skip to content

Commit

Permalink
clean up old files for public
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Oct 14, 2024
1 parent f923ba9 commit 11dbe91
Show file tree
Hide file tree
Showing 113 changed files with 208 additions and 16,964 deletions.
211 changes: 109 additions & 102 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dataset:
max_train_samples: 50000
max_eval_num: 1000
max_length: 32768
min_length: 1048
min_length: 1024
chat_template: llama-3
chunk_size: 1024 # sequence length for distilling
seed: 42
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-70B"
cache_dir: "/home/mzhang/models/llama-3_1-70b" # Set this to where you want to save checkpoint weights
cache_dir: "/scr/mzhang/models/llama-3_1-70b" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
Expand Down
31 changes: 19 additions & 12 deletions distill_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ def main():
print_model=args.verbose,
merge_loras=False,
peft_gradient_checkpointing=not args.no_peft_grad_ckpt)
if args.verbose:
print_header(f'*** Trainable finetuning parameters ***')
for n, p in model.named_parameters():
if p.requires_grad:
print(f'├── {n} ({p.dtype})')

finetune_trainer = get_finetuner(model, finetune_config, args.device, args, wandb)
if args.verbose:
print_header('Finetune config')
Expand Down Expand Up @@ -357,18 +363,19 @@ def main():
finetune_trainer = get_evaluator(model, eval_config, args, args.device, wandb)

# Final eval
print_header('*** Distilled + Finetuned Final Eval ***')
final_metrics = finetune_trainer.evaluate(model, step=-1, max_batches=None, prefix='final')
print_header('*** Saved Checkpoints ***')
print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\')
print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\')
# print(f'--finetune_long_checkpoint_path {args.load_finetune_long_checkpoint} \\')

print(final_metrics)
for k, v in final_metrics.items():
print(f'├── {k}: {v:.4f}')
if wandb is not None:
wandb.log({f'final/{k}': v for k, v in final_metrics.items()})
if 'save10' not in args.distill_config and 'save10' not in args.finetune_config:
print_header('*** Distilled + Finetuned Final Eval ***')
final_metrics = finetune_trainer.evaluate(model, step=-1, max_batches=None, prefix='final')
print_header('*** Saved Checkpoints ***')
print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\')
print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\')
# print(f'--finetune_long_checkpoint_path {args.load_finetune_long_checkpoint} \\')

print(final_metrics)
for k, v in final_metrics.items():
print(f'├── {k}: {v:.4f}')
if wandb is not None:
wandb.log({f'final/{k}': v for k, v in final_metrics.items()})


# ------------------
Expand Down
507 changes: 0 additions & 507 deletions distill_llama_layer.py

This file was deleted.

Loading

0 comments on commit 11dbe91

Please sign in to comment.