Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
infwinston committed Dec 27, 2023
1 parent da9bc88 commit a139b56
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 24 deletions.
3 changes: 3 additions & 0 deletions fastchat/serve/monitor/basic_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@ def load_log_files(filename):
)
return data


def load_log_files_parallel(log_files, num_threads=8):
data_all = []
from multiprocessing import Pool

with Pool(num_threads) as p:
ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files)))
for ret in ret_all:
data_all.extend(ret)
return data_all


def get_anony_vote_df(df):
anony_vote_df = df[
df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
Expand Down
24 changes: 14 additions & 10 deletions fastchat/serve/monitor/clean_battle_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def replace_model_name(old_name, tstamp):
}
if old_name in ["gpt-4", "gpt-3.5-turbo"]:
if tstamp > 1687849200:
return old_name+"-0613"
return old_name + "-0613"
else:
return old_name+"-0314"
return old_name + "-0314"
if old_name in replace_dict:
return replace_dict[old_name]
return old_name
Expand All @@ -99,17 +99,21 @@ def read_file(filename):
time.sleep(2)
return data


def read_file_parallel(log_files, num_threads=8):
data_all = []
from multiprocessing import Pool

with Pool(num_threads) as p:
ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files)))
for ret in ret_all:
data_all.extend(ret)
return data_all


def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False):
def clean_battle_data(
log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False
):
data = read_file_parallel(log_files, num_threads=8)

convert_type = {
Expand Down Expand Up @@ -171,7 +175,9 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize
messages = ""
for i in range(2):
state = row["states"][i]
for turn_idx, (role, msg) in enumerate(state["messages"][state["offset"] :]):
for turn_idx, (role, msg) in enumerate(
state["messages"][state["offset"] :]
):
if msg:
messages += msg.lower()
for word in IDENTITY_WORDS:
Expand Down Expand Up @@ -201,11 +207,7 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize

ip = row["ip"]
if ip not in all_ips:
all_ips[ip] = {
"ip": ip,
"count": 0,
"sanitized_id": len(all_ips)
}
all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)}
all_ips[ip]["count"] += 1
if sanitize_ip:
user_id = f"arena_user_{all_ips[ip]['sanitized_id']}"
Expand Down Expand Up @@ -272,7 +274,9 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize
log_files = get_log_files(args.max_num_files)
ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None

battles = clean_battle_data(log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip)
battles = clean_battle_data(
log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip
)
last_updated_tstamp = battles[-1]["tstamp"]
cutoff_date = datetime.datetime.fromtimestamp(
last_updated_tstamp, tz=timezone("US/Pacific")
Expand Down
25 changes: 17 additions & 8 deletions fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_bootstrap_result(battles, func_compute_elo, num_round=1000):

def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000):
from sklearn.linear_model import LogisticRegression

models = pd.concat([df["model_a"], df["model_b"]]).unique()
models = pd.Series(np.arange(len(models)), index=models)

Expand All @@ -73,17 +74,17 @@ def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000):
# one tie => one A win + one B win
# find tie + tie (both bad) index
tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)")
tie_idx[len(tie_idx)//2:] = False
tie_idx[len(tie_idx) // 2 :] = False
Y[tie_idx] = 1.0

lr = LogisticRegression(fit_intercept=False)
lr.fit(X,Y)
lr.fit(X, Y)

elo_scores = SCALE * lr.coef_[0] + INIT_RATING
# calibrate llama-13b to 800 if applicable
if "llama-13b" in models.index:
elo_scores += (800-elo_scores[models["llama-13b"]])
return pd.Series(elo_scores, index = models.index).sort_values(ascending=False)
elo_scores += 800 - elo_scores[models["llama-13b"]]
return pd.Series(elo_scores, index=models.index).sort_values(ascending=False)


def get_median_elo_from_bootstrap(bootstrap_df):
Expand Down Expand Up @@ -260,10 +261,14 @@ def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap=
elo_rating_online = compute_elo(battles)

if rating_system == "bt":
bootstrap_df = get_bootstrap_result(battles, compute_elo_mle_with_tie, num_round=num_bootstrap)
bootstrap_df = get_bootstrap_result(
battles, compute_elo_mle_with_tie, num_round=num_bootstrap
)
elo_rating_final = compute_elo_mle_with_tie(battles)
elif rating_system == "elo":
bootstrap_df = get_bootstrap_result(battles, compute_elo, num_round=num_bootstrap)
bootstrap_df = get_bootstrap_result(
battles, compute_elo, num_round=num_bootstrap
)
elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
elo_rating_final = elo_rating_median

Expand Down Expand Up @@ -316,7 +321,9 @@ def pretty_print_elo_rating(rating):
parser.add_argument("--clean-battle-file", type=str)
parser.add_argument("--max-num-files", type=int)
parser.add_argument("--num-bootstrap", type=int, default=100)
parser.add_argument("--rating-system", type=str, choices=["bt", "elo"], default="bt")
parser.add_argument(
"--rating-system", type=str, choices=["bt", "elo"], default="bt"
)
parser.add_argument("--exclude-tie", action="store_true", default=False)
args = parser.parse_args()

Expand All @@ -330,7 +337,9 @@ def pretty_print_elo_rating(rating):
log_files = get_log_files(args.max_num_files)
battles = clean_battle_data(log_files)

results = report_elo_analysis_results(battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap)
results = report_elo_analysis_results(
battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap
)

print("# Online Elo")
pretty_print_elo_rating(results["elo_rating_online"])
Expand Down
28 changes: 22 additions & 6 deletions fastchat/serve/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@ def make_leaderboard_md_live(elo_results):
return leaderboard_md


def update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_model_names):
def update_elo_components(
max_num_files, elo_results_file, ban_ip_file, exclude_model_names
):
log_files = get_log_files(max_num_files)

# Leaderboard
if elo_results_file is None: # Do live update
ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None
battles = clean_battle_data(log_files, exclude_model_names, ban_ip_list=ban_ip_list)
battles = clean_battle_data(
log_files, exclude_model_names, ban_ip_list=ban_ip_list
)
elo_results = report_elo_analysis_results(battles)

leader_component_values[0] = make_leaderboard_md_live(elo_results)
Expand Down Expand Up @@ -93,10 +97,14 @@ def update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_
basic_component_values[5] = md4


def update_worker(max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names):
def update_worker(
max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names
):
while True:
tic = time.time()
update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_model_names)
update_elo_components(
max_num_files, elo_results_file, ban_ip_file, exclude_model_names
)
durtaion = time.time() - tic
print(f"update duration: {durtaion:.2f} s")
time.sleep(max(interval - durtaion, 0))
Expand Down Expand Up @@ -270,7 +278,9 @@ def build_demo(elo_results_file, leaderboard_table_file):
with gr.Tabs() as tabs:
with gr.Tab("Leaderboard", id=0):
leader_components = build_leaderboard_tab(
elo_results_file, leaderboard_table_file, show_plot=True,
elo_results_file,
leaderboard_table_file,
show_plot=True,
)

with gr.Tab("Basic Stats", id=1):
Expand Down Expand Up @@ -307,7 +317,13 @@ def build_demo(elo_results_file, leaderboard_table_file):
if args.elo_results_file is None: # Do live update
update_thread = threading.Thread(
target=update_worker,
args=(args.max_num_files, args.update_interval, args.elo_results_file, args.ban_ip_file, args.exclude_model_names),
args=(
args.max_num_files,
args.update_interval,
args.elo_results_file,
args.ban_ip_file,
args.exclude_model_names,
),
)
update_thread.start()

Expand Down

0 comments on commit a139b56

Please sign in to comment.