-
Notifications
You must be signed in to change notification settings - Fork 230
/
run.py
424 lines (375 loc) · 19.8 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import torch
import torch.distributed as dist
from vlmeval.config import supported_VLM
from vlmeval.dataset.video_dataset_config import supported_video_datasets
from vlmeval.dataset import build_dataset
from vlmeval.inference import infer_data_job
from vlmeval.inference_video import infer_data_job_video
from vlmeval.inference_mt import infer_data_job_mt
from vlmeval.smp import *
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer
def build_model_from_config(cfg, model_name):
import vlmeval.api
import vlmeval.vlm
config = cp.deepcopy(cfg[model_name])
if config == {}:
return supported_VLM[model_name]()
assert 'class' in config
cls_name = config.pop('class')
if hasattr(vlmeval.api, cls_name):
return getattr(vlmeval.api, cls_name)(**config)
elif hasattr(vlmeval.vlm, cls_name):
return getattr(vlmeval.vlm, cls_name)(**config)
else:
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`')
def build_dataset_from_config(cfg, dataset_name):
import vlmeval.dataset
import inspect
config = cp.deepcopy(cfg[dataset_name])
if config == {}:
return supported_video_datasets[dataset_name]()
assert 'class' in config
cls_name = config.pop('class')
if hasattr(vlmeval.dataset, cls_name):
cls = getattr(vlmeval.dataset, cls_name)
sig = inspect.signature(cls.__init__)
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0:
raise ValueError('fps and nframe should not be set at the same time')
if valid_params.get('fps', 0) <= 0 and valid_params.get('nframe', 0) <= 0:
raise ValueError('fps and nframe should be set at least one valid value')
return cls(**valid_params)
else:
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`')
def parse_args():
help_msg = """\
You can launch the evaluation by setting either --data and --model or --config.
--data and --model:
Each Arg should be a list of strings, specifying the names of datasets and models.
To find all supported model names, please refer to the `vlmeval/config.py` of check the output of the command \
`vlmutil mlist all` in the terminal (you should first have vlmeval installed).
To find all supported dataset names, please refer to the `vlmeval/dataset/__init__.py` file. The python script \
to print all supported dataset names is as follows:
```python
from vlmeval.dataset import SUPPORTED_DATASETS
print(SUPPORTED_DATASETS)
```
or you can check the output of the command `vlmutil dlist all` in the terminal.
To find all supported video dataset default settings, please refer to the \
`vlmeval/dataset/video_dataset_config.py` file.
--config:
Launch the evaluation by specifying the path to the config json file. Sample Json Content:
```json
{
"model": {
"GPT4o_20240806_T00_HIGH": {
"class": "GPT4V",
"model": "gpt-4o-2024-08-06",
"temperature": 0,
"img_detail": "high"
},
"GPT4o_20240806_T10_Low": {
"class": "GPT4V",
"model": "gpt-4o-2024-08-06",
"temperature": 1.0,
"img_detail": "low"
},
"GPT4o_20241120": {}
},
"data": {
"MME-RealWorld-Lite": {
"class": "MMERealWorld",
"dataset": "MME-RealWorld-Lite"
},
"MMBench_DEV_EN_V11": {
"class": "ImageMCQDataset",
"dataset": "MMBench_DEV_EN_V11"
},
"MMBench_Video_8frame_nopack": {},
"Video-MME_16frame_subs": {
"class": "VideoMME",
"dataset": "Video-MME",
"nframe": 16,
"use_subtitle": true,
}
}
}
```
Currently, only `model` and `data` are supported fields. The content of each field is a dictionary.
For `model`, the key is the name of the model, and the value is a dictionary containing the following keys:
- `class`: The class name of the model, which should be a class in `vlmeval.vlm` or `vlmeval.api`.
- Other keys are specific to the model, please refer to the corresponding class.
- Tip: The defined model in the `supported_VLM` of `vlmeval/config.py` can be used as a shortcut.
For `data`, the key is the name of the dataset (should be the same as the `dataset` field in most cases, \
except for video datasets), and the value is a dictionary containing the following keys:
- `class`: The class name of the dataset, which should be a class in `vlmeval.dataset`.
- `dataset`: The name of the dataset, which should be a string that is accepted by the `dataset` argument of the \
corresponding class.
- Other keys are specific to the dataset, please refer to the corresponding class.
- Tip: The defined dataset in the `supported_video_datasets` of `vlmeval/dataset/video_dataset_config.py` \
can be used as a shortcut.
The keys in the `model` and `data` fields will be used for naming the prediction files and evaluation results.
When launching with `--config`, args for API VLMs, such as `--retry`, `--verbose`, will be ignored.
"""
parser = argparse.ArgumentParser(description=help_msg, formatter_class=argparse.RawTextHelpFormatter)
# Essential Args, Setting the Names of Datasets and Models
parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets')
parser.add_argument('--model', type=str, nargs='+', help='Names of Models')
parser.add_argument('--config', type=str, help='Path to the Config Json File')
# Work Dir
parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory')
# Infer + Eval or Infer Only
parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
# API Kwargs, Apply to API VLMs and Judge API LLMs
parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
# Explicitly Set the Judge Model
parser.add_argument('--judge', type=str, default=None)
# Logging Utils
parser.add_argument('--verbose', action='store_true')
# Configuration for Resume
# Ignore: will not rerun failed VLM inference
parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
# Reuse: will reuse the existing prediction files
parser.add_argument('--reuse', action='store_true')
args = parser.parse_args()
return args
def main():
logger = get_logger('RUN')
rank, world_size = get_rank_and_world_size()
args = parse_args()
use_config, cfg = False, None
if args.config is not None:
assert args.data is None and args.model is None, '--data and --model should not be set when using --config'
use_config, cfg = True, load(args.config)
args.model = list(cfg['model'].keys())
args.data = list(cfg['data'].keys())
else:
assert len(args.data), '--data should be a list of data files'
if rank == 0:
if not args.reuse:
logger.warning('--reuse is not set, will not reuse previous (before one day) temporary files')
else:
logger.warning('--reuse is set, will reuse the latest prediction & temporary pickle files')
if 'MMEVAL_ROOT' in os.environ:
args.work_dir = os.environ['MMEVAL_ROOT']
if not use_config:
for k, v in supported_VLM.items():
if hasattr(v, 'keywords') and 'retry' in v.keywords and args.retry is not None:
v.keywords['retry'] = args.retry
supported_VLM[k] = v
if hasattr(v, 'keywords') and 'verbose' in v.keywords and args.verbose is not None:
v.keywords['verbose'] = args.verbose
supported_VLM[k] = v
if world_size > 1:
local_rank = os.environ.get('LOCAL_RANK', 0)
torch.cuda.set_device(int(local_rank))
dist.init_process_group(
backend='nccl',
timeout=datetime.timedelta(seconds=int(os.environ.get('DIST_TIMEOUT', 3600)))
)
for _, model_name in enumerate(args.model):
model = None
date, commit_id = timestr('day'), githash(digits=8)
eval_id = f"T{date}_G{commit_id}"
pred_root = osp.join(args.work_dir, model_name, eval_id)
pred_root_meta = osp.join(args.work_dir, model_name)
os.makedirs(pred_root_meta, exist_ok=True)
prev_pred_roots = ls(osp.join(args.work_dir, model_name), mode='dir')
if len(prev_pred_roots) and args.reuse:
prev_pred_roots.sort()
if not osp.exists(pred_root):
os.makedirs(pred_root, exist_ok=True)
if use_config:
model = build_model_from_config(cfg['model'], model_name)
for _, dataset_name in enumerate(args.data):
try:
result_file_base = f'{model_name}_{dataset_name}.xlsx'
if use_config:
if world_size > 1:
if rank == 0:
dataset = build_dataset_from_config(cfg['data'], dataset_name)
dist.barrier()
dataset = build_dataset_from_config(cfg['data'], dataset_name)
if dataset is None:
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
continue
else:
dataset_kwargs = {}
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
dataset_kwargs['model'] = model_name
# If distributed, first build the dataset on the main process for doing preparation works
if world_size > 1:
if rank == 0:
dataset = build_dataset(dataset_name, **dataset_kwargs)
dist.barrier()
dataset = build_dataset(dataset_name, **dataset_kwargs)
if dataset is None:
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
continue
# Handling Multi-Turn Dataset
if dataset.TYPE == 'MT':
result_file_base = result_file_base.replace('.xlsx', '.tsv')
result_file = osp.join(pred_root, result_file_base)
# Reuse the previous prediction file if exists
if rank == 0 and len(prev_pred_roots):
prev_result_file = None
prev_pkl_file_list = []
for root in prev_pred_roots[::-1]:
if osp.exists(osp.join(root, result_file_base)):
prev_result_file = osp.join(root, result_file_base)
break
elif commit_id in root and len(ls(root)) and root != pred_root:
temp_files = ls(root, match=[dataset_name, '.pkl'])
if len(temp_files):
prev_pkl_file_list.extend(temp_files)
break
if not args.reuse:
prev_result_file = None
prev_pkl_file_list = []
if prev_result_file is not None:
logger.warning(
f'--reuse is set, will reuse the prediction file {prev_result_file}.')
if prev_result_file != result_file:
shutil.copy(prev_result_file, result_file)
elif len(prev_pkl_file_list):
for fname in prev_pkl_file_list:
target_path = osp.join(pred_root, osp.basename(fname))
if not osp.exists(target_path):
shutil.copy(fname, target_path)
logger.info(f'--reuse is set, will reuse the prediction pickle file {fname}.')
else:
logger.warning(f'File already exists: {target_path}')
if world_size > 1:
dist.barrier()
if model is None:
model = model_name # which is only a name
# Perform the Inference
if dataset.MODALITY == 'VIDEO':
model = infer_data_job_video(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
result_file_name=result_file_base,
verbose=args.verbose,
api_nproc=args.nproc)
elif dataset.TYPE == 'MT':
model = infer_data_job_mt(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
verbose=args.verbose,
api_nproc=args.nproc,
ignore_failed=args.ignore)
else:
model = infer_data_job(
model,
work_dir=pred_root,
model_name=model_name,
dataset=dataset,
verbose=args.verbose,
api_nproc=args.nproc,
ignore_failed=args.ignore)
# Set the judge kwargs first before evaluation or dumping
judge_kwargs = {
'nproc': args.nproc,
'verbose': args.verbose,
'retry': args.retry if args.retry is not None else 3
}
if args.retry is not None:
judge_kwargs['retry'] = args.retry
if args.judge is not None:
judge_kwargs['model'] = args.judge
else:
if dataset.TYPE in ['MCQ', 'Y/N']:
judge_kwargs['model'] = 'chatgpt-0125'
elif listinstr(['MMVet', 'LLaVABench', 'MMBench-Video'], dataset_name):
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o-mini'
elif listinstr(['MMLongBench', 'MMDU', 'DUDE', 'SLIDEVQA', 'MIA-Bench', 'WildVision'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o'
if rank == 0:
logger.info(judge_kwargs)
if world_size > 1:
dist.barrier()
# Only Rank 0 handles the evaluation part
if rank == 0:
# Prepare Submission Files for MMMU_TEST AND MMT-Bench_ALL
if dataset_name in ['MMMU_TEST']:
result_json = MMMU_result_transfer(result_file)
logger.info(f'Transfer MMMU_TEST result to json for official evaluation, '
f'json file saved in {result_json}')
continue
elif 'MMT-Bench_ALL' in dataset_name:
submission_file = MMTBench_result_transfer(result_file, **judge_kwargs)
logger.info(f'Extract options from prediction of MMT-Bench FULL split for official evaluation '
f'(https://eval.ai/web/challenges/challenge-page/2328/overview), '
f'submission file saved in {submission_file}')
continue
# Skip the evaluation part if only infer
if args.mode == 'infer':
continue
# Skip the evaluation part if the dataset evaluation is not supported or annotations are missing
if 'MLLMGuard_DS' in dataset_name:
logger.info('The evaluation of MLLMGuard_DS is not supported yet. ')
continue
elif 'AesBench_TEST' == dataset_name:
logger.info(f'The results are saved in {result_file}. '
f'Please send it to the AesBench Team via [email protected].')
continue
elif dataset_name in ['DocVQA_TEST', 'InfoVQA_TEST', 'Q-Bench1_TEST', 'A-Bench_TEST']:
logger.info(f'{dataset_name} is a test split without ground-truth. '
'Thus only the inference part is supported for those datasets. ')
continue
elif dataset_name in [
'MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMBench', 'MMBench_CN',
'MMBench_TEST_CN_V11', 'MMBench_TEST_EN_V11', 'MMBench_V11', 'MMBench_CN_V11'
] and not MMBenchOfficialServer(dataset_name):
logger.error(
f'Can not evaluate {dataset_name} on non-official servers, will skip the evaluation.')
continue
# Setup the proxy for the evaluation
eval_proxy = os.environ.get('EVAL_PROXY', None)
old_proxy = os.environ.get('HTTP_PROXY', '')
if eval_proxy is not None:
proxy_set(eval_proxy)
# Perform the Evaluation
eval_results = dataset.evaluate(result_file, **judge_kwargs)
# Display Evaluation Results in Terminal
if eval_results is not None:
assert isinstance(eval_results, dict) or isinstance(eval_results, pd.DataFrame)
logger.info(f'The evaluation of model {model_name} x dataset {dataset_name} has finished! ')
logger.info('Evaluation Results:')
if isinstance(eval_results, dict):
logger.info('\n' + json.dumps(eval_results, indent=4))
elif isinstance(eval_results, pd.DataFrame):
if len(eval_results) < len(eval_results.columns):
eval_results = eval_results.T
logger.info('\n' + tabulate(eval_results))
# Restore the proxy
if eval_proxy is not None:
proxy_set(old_proxy)
# Create the symbolic links for the prediction files
files = os.listdir(pred_root)
files = [x for x in files if (f'{model_name}_{dataset_name}' in x or "status.json" in x)]
for f in files:
cwd = os.getcwd()
file_addr = osp.join(cwd, pred_root, f)
link_addr = osp.join(cwd, pred_root_meta, f)
if osp.exists(link_addr) or osp.islink(link_addr):
os.remove(link_addr)
os.symlink(file_addr, link_addr)
except Exception as e:
logger.exception(f'Model {model_name} x Dataset {dataset_name} combination failed: {e}, '
'skipping this combination.')
continue
if world_size > 1:
dist.barrier()
if world_size > 1:
dist.destroy_process_group()
if __name__ == '__main__':
load_env()
main()