Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
Refactored warmup, increased dataset size for MLP (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-Krivov committed Feb 8, 2024
1 parent 1b06c2d commit 8029719
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 24 deletions.
6 changes: 3 additions & 3 deletions dl_bench/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(self, params) -> None:

batch_size = int(params.get("batch_size", 1024))

min_batches = 10
DATASET_SIZE = max(10_240, batch_size * min_batches)
min_batches = 20
DATASET_SIZE = max(102_400, batch_size * min_batches)
dataset = RandomInfDataset(DATASET_SIZE, in_shape)

name = params.get("name", "size5")
Expand All @@ -92,5 +92,5 @@ def __init__(self, params) -> None:

super().__init__(
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\
min_batches=min_batches, min_seconds=min_seconds
min_batches=min_batches, min_seconds=min_seconds, warmup_batches=10
)
30 changes: 9 additions & 21 deletions dl_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,13 @@ def _get_device(device_name):

class Benchmark:
def __init__(
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10, warmup_batches=3,
) -> None:
self.net = net
self.in_shape = in_shape
self.dataset = dataset
self.batch_size = batch_size
self.warmup_batches = warmup_batches
self.min_batches = min_batches
self.min_seconds = min_seconds

Expand Down Expand Up @@ -379,24 +380,6 @@ def inference(self, backend: Backend):
sample = next(iter(test_loader))
self.compile(sample, backend)

print("Warmup started")
with torch.no_grad(), tm.timeit("warmup_s"):
self.net.eval()
sample = backend.to_device(sample)
if backend.dtype != torch.float32:
with torch.autocast(
device_type=backend.device_name,
dtype=backend.dtype,
):
self.net(sample)
self.net(sample)
self.net(sample)
else:
self.net(sample)
self.net(sample)
self.net(sample)
print("Warmup done")

n_items = 0

self.net.eval()
Expand All @@ -417,15 +400,19 @@ def inference(self, backend: Backend):
y = self.net(x)
else:
y = self.net(x)
if i < 3: continue

if i < self.warmup_batches:
start = time.perf_counter()
continue

fw_times.append(get_time() - s)
n_items += len(x)
outputs.append(y)

# early stopping if we have 10+ batches and were running for 10+ seconds
if (
(time.perf_counter() - start) > self.min_seconds
and n_items > self.batch_size * self.min_batches
and n_items >= self.batch_size * self.min_batches
):
break

Expand All @@ -437,6 +424,7 @@ def inference(self, backend: Backend):
)

results = tm.get_results()
results["duration_s"] = get_time() - start
results["samples_per_s"] = n_items / sum(fw_times)
results["flops_per_sample"] = self.flops_per_sample

Expand Down

0 comments on commit 8029719

Please sign in to comment.