diff --git a/examples/imagenet_val.py b/examples/imagenet_val.py index 1643b1a26..b8036f6f9 100644 --- a/examples/imagenet_val.py +++ b/examples/imagenet_val.py @@ -51,7 +51,7 @@ def main(): # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) - model.load_state_dict(checkpoint['state_dict'], strict=False) + model.load_state_dict(checkpoint, strict=False) valdir = os.path.join(args.imagenet_dir, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])