-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_base_po_like.py
120 lines (99 loc) · 3.18 KB
/
test_base_po_like.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import logging
import os
import warnings
import pandas as pd
import torch
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
)
from configs import base, seeds
from utils import (
calculate_metrics,
gen_prompt_po,
get_logger,
)
warnings.filterwarnings("ignore")
tqdm.pandas()
def test_base_po_like(args: argparse.Namespace, logger: logging.Logger) -> None:
"""Tests the base llm in po like form.
Args:
args (argparse.Namespace): Arguments.
logger (logging.Logger): Logger.
"""
# Prepare components for inference ---------------------------
model = AutoModelForCausalLM.from_pretrained(
args.base_model_repo,
token=args.hf_token,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
),
)
tokenizer = AutoTokenizer.from_pretrained(
args.base_model_repo,
token=args.hf_token,
add_bos_token=True,
add_eos_token=False,
)
# Prepare prompts ---------------------------
df = pd.read_csv(args.po_test_filepath, encoding="utf-8")
df["prompt"] = df.progress_apply(gen_prompt_po, axis=1)
# Inference ---------------------------
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
)
predictions_base: list[str] = []
for prompt in tqdm(df["prompt"].values.tolist(), leave=True):
sequences = pipe(
prompt,
do_sample=True,
max_new_tokens=args.max_new_tokens,
num_beams=3,
num_return_sequences=1,
)
gen_text = sequences[0]["generated_text"]
gen_text = gen_text.split("###RESPONSE:\n")[1].lower().strip()
predictions_base.append(gen_text)
# Save results ---------------------------
df["prediction_base"] = predictions_base
df.to_csv(f"{args.results_dir}/base_po_like_predictions.csv", index=False)
# Get results ---------------------------
logger.info("Results of llm-base-po-like:")
df = pd.read_csv(
f"{args.results_dir}/base_po_like_predictions.csv", encoding="utf-8"
)
preds: list[list[str]] = []
refs: list[list[str]] = []
for pred, ref in zip(
df["prediction_base"].values.tolist(), df["findings"].values.tolist()
):
preds.append([pred])
refs.append([ref])
calculate_metrics(preds, refs, logger)
def main() -> None:
"""The main flow of this file."""
args = base.get_base_args()
# Configure necessities ---------------------------
logger = get_logger(args.logging_dir)
if not os.path.isdir(args.results_dir):
os.makedirs(args.results_dir)
logger.info(f"EXPERIMENT SETTING:\n{args}")
seeds.seed_everything(args.seed)
# Test ---------------------------
test_base_po_like(args, logger)
if __name__ == "__main__":
main()