From 894e658cbb950dcf0ce32adc253796372eb7859e Mon Sep 17 00:00:00 2001 From: Dan McPherson Date: Sun, 17 Nov 2024 23:54:00 -0500 Subject: [PATCH] Add reorg answer file test Signed-off-by: Dan McPherson --- src/instructlab/eval/mt_bench_answers.py | 15 +++--- tests/test_mt_bench_answers.py | 65 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) create mode 100644 tests/test_mt_bench_answers.py diff --git a/src/instructlab/eval/mt_bench_answers.py b/src/instructlab/eval/mt_bench_answers.py index f4337b46..b9b4852e 100644 --- a/src/instructlab/eval/mt_bench_answers.py +++ b/src/instructlab/eval/mt_bench_answers.py @@ -26,16 +26,19 @@ def reorg_answer_file(answer_file): """Sort by question id and de-duplication""" logger.debug(locals()) - answers = {} - with open(answer_file, "r", encoding="utf-8") as fin: - for l in fin: + with open(answer_file, "r+", encoding="utf-8") as f: + answers = {} + for l in f: qid = json.loads(l)["question_id"] answers[qid] = l - qids = sorted(list(answers.keys())) - with open(answer_file, "w", encoding="utf-8") as fout: + # Reset to the beginning of the file and clear it + f.seek(0) + f.truncate() + + qids = sorted(list(answers.keys())) for qid in qids: - fout.write(answers[qid]) + f.write(answers[qid]) def get_answer( diff --git a/tests/test_mt_bench_answers.py b/tests/test_mt_bench_answers.py new file mode 100644 index 00000000..7c082c73 --- /dev/null +++ b/tests/test_mt_bench_answers.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import json +import os +import random +import shutil +import tempfile + +# First Party +from instructlab.eval.mt_bench_answers import reorg_answer_file + + +def test_reorg_answer_file(): + answer_file = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "instructlab", + "eval", + "data", + "mt_bench", + "reference_answer", + "gpt-4.jsonl", + ) + + # Create a temporary file + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + temp_answer_file = temp_file.name + + # Copy the original file to the temp file + shutil.copy(answer_file, temp_answer_file) + + orig_length = 0 + with open(temp_answer_file, "r+", encoding="utf-8") as f: + answers = {} + for l in f: + orig_length += 1 + qid = json.loads(l)["question_id"] + answers[qid] = l + + # Reset to the beginning of the file and clear it + f.seek(0) + f.truncate() + + # Randomize the values + qids = sorted(list(answers.keys()), key=lambda answer: random.random()) + for qid in qids: + f.write(answers[qid]) + # Write each answer twice + f.write(answers[qid]) + + # Run the reorg which should sort and dedup the file in place + reorg_answer_file(temp_answer_file) + + new_length = 0 + with open(temp_answer_file, "r", encoding="utf-8") as fin: + previous_question_id = -1 + for l in fin: + new_length += 1 + qid = json.loads(l)["question_id"] + assert qid > previous_question_id + previous_question_id = qid + + assert new_length == orig_length