Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwei715 committed Feb 20, 2024
1 parent eac3744 commit 72f509a
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion llm_judge/llm_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def _compute_bench_mark_response(self, question) -> str:

return response

def _compute_over_one_data(self, question, response) -> Dict[str, Any]:
def _compute_over_one_data(self, question, response, reference=None) -> Dict[str, Any]:
"""
Compute the metrics over one data point
:param question
Expand All @@ -531,6 +531,8 @@ def _compute_over_one_data(self, question, response) -> Dict[str, Any]:
logger.info(f"Computing the metrics over {question} and {response}")
self.prompt_config["question"] = question
self.prompt_config["answerA"] = response
if reference:
self.prompt_config["reference"] = reference
self.prompt_config["answerB"] = self._compute_bench_mark_response(question)
input_ids = self.tokenizer(self._fill_prompt(), return_tensors="pt").input_ids
outputs = self.model.generate(
Expand Down Expand Up @@ -560,6 +562,8 @@ def _compute_over_data(
"""
self._prepare_judge()
self._prepare_bench_mark_model()
cols = sample_df.columns

res_df = pd.DataFrame(
columns=[
"question",
Expand All @@ -573,9 +577,15 @@ def _compute_over_data(
)

for i in range(len(sample_df)):
if "reference" in cols:
ref = sample_df.loc[i, "reference"]
else:
ref = None

res_dic = self._compute_over_one_data(
sample_df.loc[i, "question"],
sample_df.loc[i, "answer"],
ref
)
res_df.loc[i] = [
sample_df.loc[i, "question"],
Expand Down

0 comments on commit 72f509a

Please sign in to comment.