Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-Krivov authored Dec 28, 2023
1 parent 587c08e commit cf902d1
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions dl_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,15 @@ def _compile_model(compile_mode: str, device, model: Module, sample_input, dtype
if compile_mode == "torch":
compiled_model = model
elif compile_mode == "torchscript":
enabled = dtype != torch.float32
with torch.cpu.amp.autocast(enabled=enabled, dtype=dtype), torch.no_grad():
compiled_model = torch.jit.trace(model, sample_input)
compiled_model = torch.jit.freeze(compiled_model)
compiled_model = torch.jit.trace(model, sample_input)
compiled_model = torch.jit.freeze(compiled_model)
print("Compiled with torchscript")
elif compile_mode == "torchscript_onednn":
# enable oneDNN graph fusion globally
torch.jit.enable_onednn_fusion(True)
enabled = dtype != torch.float32
with torch.cpu.amp.autocast(enabled=enabled, dtype=dtype), torch.no_grad():
compiled_model = torch.jit.trace(model, sample_input)
compiled_model = torch.jit.freeze(compiled_model)
print("Compiled with torchscript onednn")
compiled_model = torch.jit.trace(model, sample_input)
compiled_model = torch.jit.freeze(compiled_model)
print("Compiled with torchscript onednn")
elif compile_mode == "ipex":
import intel_extension_for_pytorch as ipex

Expand Down

0 comments on commit cf902d1

Please sign in to comment.