diff --git a/fastedit/rome/rome_hparams.py b/fastedit/rome/rome_hparams.py index 2d0a4ea..d9253b2 100644 --- a/fastedit/rome/rome_hparams.py +++ b/fastedit/rome/rome_hparams.py @@ -56,6 +56,9 @@ def from_name(cls, name: str): if name == "gpj-j-6b": pass elif name == "llama-7b": + r""" + Supports: LLaMA-7B, Baichuan-7B, InternLM-7B... + """ data.update(dict( v_loss_layer=31, rewrite_module_tmp="model.layers.{}.mlp.down_proj", @@ -65,6 +68,9 @@ def from_name(cls, name: str): ln_f_module="model.norm" )) elif name == "llama-13b": + r""" + Supports LLaMA-13B, Baichuan-13B... + """ data.update(dict( layers=[10], v_loss_layer=39,