Skip to content

Commit

Permalink
Merge pull request #185 from danmcp/answersunittests
Browse files Browse the repository at this point in the history
Add reorg answer file test
  • Loading branch information
mergify[bot] authored Dec 4, 2024
2 parents e5d89c6 + 894e658 commit ce8880f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/instructlab/eval/mt_bench_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@
def reorg_answer_file(answer_file):
"""Sort by question id and de-duplication"""
logger.debug(locals())
answers = {}
with open(answer_file, "r", encoding="utf-8") as fin:
for l in fin:
with open(answer_file, "r+", encoding="utf-8") as f:
answers = {}
for l in f:
qid = json.loads(l)["question_id"]
answers[qid] = l

qids = sorted(list(answers.keys()))
with open(answer_file, "w", encoding="utf-8") as fout:
# Reset to the beginning of the file and clear it
f.seek(0)
f.truncate()

qids = sorted(list(answers.keys()))
for qid in qids:
fout.write(answers[qid])
f.write(answers[qid])


def get_answer(
Expand Down
65 changes: 65 additions & 0 deletions tests/test_mt_bench_answers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import json
import os
import random
import shutil
import tempfile

# First Party
from instructlab.eval.mt_bench_answers import reorg_answer_file


def test_reorg_answer_file():
answer_file = os.path.join(
os.path.dirname(__file__),
"..",
"src",
"instructlab",
"eval",
"data",
"mt_bench",
"reference_answer",
"gpt-4.jsonl",
)

# Create a temporary file
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
temp_answer_file = temp_file.name

# Copy the original file to the temp file
shutil.copy(answer_file, temp_answer_file)

orig_length = 0
with open(temp_answer_file, "r+", encoding="utf-8") as f:
answers = {}
for l in f:
orig_length += 1
qid = json.loads(l)["question_id"]
answers[qid] = l

# Reset to the beginning of the file and clear it
f.seek(0)
f.truncate()

# Randomize the values
qids = sorted(list(answers.keys()), key=lambda answer: random.random())
for qid in qids:
f.write(answers[qid])
# Write each answer twice
f.write(answers[qid])

# Run the reorg which should sort and dedup the file in place
reorg_answer_file(temp_answer_file)

new_length = 0
with open(temp_answer_file, "r", encoding="utf-8") as fin:
previous_question_id = -1
for l in fin:
new_length += 1
qid = json.loads(l)["question_id"]
assert qid > previous_question_id
previous_question_id = qid

assert new_length == orig_length

0 comments on commit ce8880f

Please sign in to comment.