diff --git a/fastchat/serve/monitor/basic_stats.py b/fastchat/serve/monitor/basic_stats.py index 21917ca99..8bab33662 100644 --- a/fastchat/serve/monitor/basic_stats.py +++ b/fastchat/serve/monitor/basic_stats.py @@ -53,15 +53,18 @@ def load_log_files(filename): ) return data + def load_log_files_parallel(log_files, num_threads=8): data_all = [] from multiprocessing import Pool + with Pool(num_threads) as p: ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files))) for ret in ret_all: data_all.extend(ret) return data_all + def get_anony_vote_df(df): anony_vote_df = df[ df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) diff --git a/fastchat/serve/monitor/clean_battle_data.py b/fastchat/serve/monitor/clean_battle_data.py index 87f0f10ee..87328c94a 100644 --- a/fastchat/serve/monitor/clean_battle_data.py +++ b/fastchat/serve/monitor/clean_battle_data.py @@ -77,9 +77,9 @@ def replace_model_name(old_name, tstamp): } if old_name in ["gpt-4", "gpt-3.5-turbo"]: if tstamp > 1687849200: - return old_name+"-0613" + return old_name + "-0613" else: - return old_name+"-0314" + return old_name + "-0314" if old_name in replace_dict: return replace_dict[old_name] return old_name @@ -99,9 +99,11 @@ def read_file(filename): time.sleep(2) return data + def read_file_parallel(log_files, num_threads=8): data_all = [] from multiprocessing import Pool + with Pool(num_threads) as p: ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) for ret in ret_all: @@ -109,7 +111,9 @@ def read_file_parallel(log_files, num_threads=8): return data_all -def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False): +def clean_battle_data( + log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False +): data = read_file_parallel(log_files, num_threads=8) convert_type = { @@ -171,7 +175,9 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize messages = "" for i in range(2): state = row["states"][i] - for turn_idx, (role, msg) in enumerate(state["messages"][state["offset"] :]): + for turn_idx, (role, msg) in enumerate( + state["messages"][state["offset"] :] + ): if msg: messages += msg.lower() for word in IDENTITY_WORDS: @@ -201,11 +207,7 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize ip = row["ip"] if ip not in all_ips: - all_ips[ip] = { - "ip": ip, - "count": 0, - "sanitized_id": len(all_ips) - } + all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)} all_ips[ip]["count"] += 1 if sanitize_ip: user_id = f"arena_user_{all_ips[ip]['sanitized_id']}" @@ -272,7 +274,9 @@ def clean_battle_data(log_files, exclude_model_names, ban_ip_list=None, sanitize log_files = get_log_files(args.max_num_files) ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None - battles = clean_battle_data(log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip) + battles = clean_battle_data( + log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip + ) last_updated_tstamp = battles[-1]["tstamp"] cutoff_date = datetime.datetime.fromtimestamp( last_updated_tstamp, tz=timezone("US/Pacific") diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py index ee2991743..2841bcd12 100644 --- a/fastchat/serve/monitor/elo_analysis.py +++ b/fastchat/serve/monitor/elo_analysis.py @@ -54,6 +54,7 @@ def get_bootstrap_result(battles, func_compute_elo, num_round=1000): def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000): from sklearn.linear_model import LogisticRegression + models = pd.concat([df["model_a"], df["model_b"]]).unique() models = pd.Series(np.arange(len(models)), index=models) @@ -73,17 +74,17 @@ def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000): # one tie => one A win + one B win # find tie + tie (both bad) index tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)") - tie_idx[len(tie_idx)//2:] = False + tie_idx[len(tie_idx) // 2 :] = False Y[tie_idx] = 1.0 lr = LogisticRegression(fit_intercept=False) - lr.fit(X,Y) + lr.fit(X, Y) elo_scores = SCALE * lr.coef_[0] + INIT_RATING # calibrate llama-13b to 800 if applicable if "llama-13b" in models.index: - elo_scores += (800-elo_scores[models["llama-13b"]]) - return pd.Series(elo_scores, index = models.index).sort_values(ascending=False) + elo_scores += 800 - elo_scores[models["llama-13b"]] + return pd.Series(elo_scores, index=models.index).sort_values(ascending=False) def get_median_elo_from_bootstrap(bootstrap_df): @@ -260,10 +261,14 @@ def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap= elo_rating_online = compute_elo(battles) if rating_system == "bt": - bootstrap_df = get_bootstrap_result(battles, compute_elo_mle_with_tie, num_round=num_bootstrap) + bootstrap_df = get_bootstrap_result( + battles, compute_elo_mle_with_tie, num_round=num_bootstrap + ) elo_rating_final = compute_elo_mle_with_tie(battles) elif rating_system == "elo": - bootstrap_df = get_bootstrap_result(battles, compute_elo, num_round=num_bootstrap) + bootstrap_df = get_bootstrap_result( + battles, compute_elo, num_round=num_bootstrap + ) elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) elo_rating_final = elo_rating_median @@ -316,7 +321,9 @@ def pretty_print_elo_rating(rating): parser.add_argument("--clean-battle-file", type=str) parser.add_argument("--max-num-files", type=int) parser.add_argument("--num-bootstrap", type=int, default=100) - parser.add_argument("--rating-system", type=str, choices=["bt", "elo"], default="bt") + parser.add_argument( + "--rating-system", type=str, choices=["bt", "elo"], default="bt" + ) parser.add_argument("--exclude-tie", action="store_true", default=False) args = parser.parse_args() @@ -330,7 +337,9 @@ def pretty_print_elo_rating(rating): log_files = get_log_files(args.max_num_files) battles = clean_battle_data(log_files) - results = report_elo_analysis_results(battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap) + results = report_elo_analysis_results( + battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap + ) print("# Online Elo") pretty_print_elo_rating(results["elo_rating_online"]) diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index e8b0a4828..6728cafd3 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -54,13 +54,17 @@ def make_leaderboard_md_live(elo_results): return leaderboard_md -def update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_model_names): +def update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names +): log_files = get_log_files(max_num_files) # Leaderboard if elo_results_file is None: # Do live update ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None - battles = clean_battle_data(log_files, exclude_model_names, ban_ip_list=ban_ip_list) + battles = clean_battle_data( + log_files, exclude_model_names, ban_ip_list=ban_ip_list + ) elo_results = report_elo_analysis_results(battles) leader_component_values[0] = make_leaderboard_md_live(elo_results) @@ -93,10 +97,14 @@ def update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_ basic_component_values[5] = md4 -def update_worker(max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names): +def update_worker( + max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names +): while True: tic = time.time() - update_elo_components(max_num_files, elo_results_file, ban_ip_file, exclude_model_names) + update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names + ) durtaion = time.time() - tic print(f"update duration: {durtaion:.2f} s") time.sleep(max(interval - durtaion, 0)) @@ -270,7 +278,9 @@ def build_demo(elo_results_file, leaderboard_table_file): with gr.Tabs() as tabs: with gr.Tab("Leaderboard", id=0): leader_components = build_leaderboard_tab( - elo_results_file, leaderboard_table_file, show_plot=True, + elo_results_file, + leaderboard_table_file, + show_plot=True, ) with gr.Tab("Basic Stats", id=1): @@ -307,7 +317,13 @@ def build_demo(elo_results_file, leaderboard_table_file): if args.elo_results_file is None: # Do live update update_thread = threading.Thread( target=update_worker, - args=(args.max_num_files, args.update_interval, args.elo_results_file, args.ban_ip_file, args.exclude_model_names), + args=( + args.max_num_files, + args.update_interval, + args.elo_results_file, + args.ban_ip_file, + args.exclude_model_names, + ), ) update_thread.start()