diff --git a/dl_bench/mlp.py b/dl_bench/mlp.py index c05ddc2..93845a8 100644 --- a/dl_bench/mlp.py +++ b/dl_bench/mlp.py @@ -81,8 +81,8 @@ def __init__(self, params) -> None: batch_size = int(params.get("batch_size", 1024)) - min_batches = 10 - DATASET_SIZE = max(10_240, batch_size * min_batches) + min_batches = 20 + DATASET_SIZE = max(102_400, batch_size * min_batches) dataset = RandomInfDataset(DATASET_SIZE, in_shape) name = params.get("name", "size5") @@ -92,5 +92,5 @@ def __init__(self, params) -> None: super().__init__( net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\ - min_batches=min_batches, min_seconds=min_seconds + min_batches=min_batches, min_seconds=min_seconds, warmup_batches=10 ) diff --git a/dl_bench/utils.py b/dl_bench/utils.py index 28776e5..9300fa2 100644 --- a/dl_bench/utils.py +++ b/dl_bench/utils.py @@ -343,12 +343,13 @@ def _get_device(device_name): class Benchmark: def __init__( - self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10 + self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10, warmup_batches=3, ) -> None: self.net = net self.in_shape = in_shape self.dataset = dataset self.batch_size = batch_size + self.warmup_batches = warmup_batches self.min_batches = min_batches self.min_seconds = min_seconds @@ -379,24 +380,6 @@ def inference(self, backend: Backend): sample = next(iter(test_loader)) self.compile(sample, backend) - print("Warmup started") - with torch.no_grad(), tm.timeit("warmup_s"): - self.net.eval() - sample = backend.to_device(sample) - if backend.dtype != torch.float32: - with torch.autocast( - device_type=backend.device_name, - dtype=backend.dtype, - ): - self.net(sample) - self.net(sample) - self.net(sample) - else: - self.net(sample) - self.net(sample) - self.net(sample) - print("Warmup done") - n_items = 0 self.net.eval() @@ -417,7 +400,11 @@ def inference(self, backend: Backend): y = self.net(x) else: y = self.net(x) - if i < 3: continue + + if i < self.warmup_batches: + start = time.perf_counter() + continue + fw_times.append(get_time() - s) n_items += len(x) outputs.append(y) @@ -425,7 +412,7 @@ def inference(self, backend: Backend): # early stopping if we have 10+ batches and were running for 10+ seconds if ( (time.perf_counter() - start) > self.min_seconds - and n_items > self.batch_size * self.min_batches + and n_items >= self.batch_size * self.min_batches ): break @@ -437,6 +424,7 @@ def inference(self, backend: Backend): ) results = tm.get_results() + results["duration_s"] = get_time() - start results["samples_per_s"] = n_items / sum(fw_times) results["flops_per_sample"] = self.flops_per_sample