diff --git a/lib.py b/lib.py index becc77d..18c85d6 100644 --- a/lib.py +++ b/lib.py @@ -179,7 +179,7 @@ def make_test(name, problem, problem_spec, add_sizes=[], constraint=lambda d: d) for size in add_sizes: del example[size] example["target"] = tensor(out) - if yours is not None: + if torch.is_tensor(yours) and yours.ndim > 0: example["yours"] = yours examples.append(example)