-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathrun.py
102 lines (95 loc) · 3.17 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright Sierra
import argparse
from tau_bench.types import RunConfig
from tau_bench.run import run
from litellm import provider_list
from tau_bench.envs.user import UserStrategy
def parse_args() -> RunConfig:
parser = argparse.ArgumentParser()
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument(
"--env", type=str, choices=["retail", "airline"], default="retail"
)
parser.add_argument(
"--model",
type=str,
help="The model to use for the agent",
)
parser.add_argument(
"--model-provider",
type=str,
choices=provider_list,
help="The model provider for the agent",
)
parser.add_argument(
"--user-model",
type=str,
default="gpt-4o",
help="The model to use for the user simulator",
)
parser.add_argument(
"--user-model-provider",
type=str,
choices=provider_list,
help="The model provider for the user simulator",
)
parser.add_argument(
"--agent-strategy",
type=str,
default="tool-calling",
choices=["tool-calling", "act", "react", "few-shot"],
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="The sampling temperature for the action model",
)
parser.add_argument(
"--task-split",
type=str,
default="test",
choices=["train", "test", "dev"],
help="The split of tasks to run (only applies to the retail domain for now",
)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1")
parser.add_argument("--task-ids", type=int, nargs="+", help="(Optional) run only the tasks with the given IDs")
parser.add_argument("--log-dir", type=str, default="results")
parser.add_argument(
"--max-concurrency",
type=int,
default=1,
help="Number of tasks to run in parallel",
)
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--shuffle", type=int, default=0)
parser.add_argument("--user-strategy", type=str, default="llm", choices=[item.value for item in UserStrategy])
parser.add_argument("--few-shot-displays-path", type=str, help="Path to a jsonlines file containing few shot displays")
args = parser.parse_args()
print(args)
return RunConfig(
model_provider=args.model_provider,
user_model_provider=args.user_model_provider,
model=args.model,
user_model=args.user_model,
num_trials=args.num_trials,
env=args.env,
agent_strategy=args.agent_strategy,
temperature=args.temperature,
task_split=args.task_split,
start_index=args.start_index,
end_index=args.end_index,
task_ids=args.task_ids,
log_dir=args.log_dir,
max_concurrency=args.max_concurrency,
seed=args.seed,
shuffle=args.shuffle,
user_strategy=args.user_strategy,
few_shot_displays_path=args.few_shot_displays_path,
)
def main():
config = parse_args()
run(config)
if __name__ == "__main__":
main()