Skip to content

Commit

Permalink
updates for interacting with newer torch-mlir (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
fifield authored Apr 24, 2024
1 parent 3856bbb commit 8afa722
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 36 deletions.
2 changes: 1 addition & 1 deletion python/air/backend/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from air.mlir.ir import Module
from air.ir import Module

# A type shared between the result of `AirBackend.compile` and the
# input to `AirBackend.load`. Each backend will likely have a
Expand Down
37 changes: 17 additions & 20 deletions python/air/backend/linalg_on_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
import torch_mlir.ir
import torch_mlir.passmanager
from torch_mlir.dynamo import make_simple_dynamo_backend
from torch_mlir import torchscript

import air.mlir.ir
import air.mlir.passmanager
import air.ir
import air.passmanager

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend

Expand All @@ -26,8 +26,6 @@
from pathlib import Path
from typing import List

path = Path(air.backend.__file__).resolve().parent

# First need to load the libhsa-runtime64.so.1 so we can load libairhost_shared
try:
ctypes.CDLL(f"{rocm_path}/../../libhsa-runtime64.so.1", mode=ctypes.RTLD_GLOBAL)
Expand All @@ -38,12 +36,12 @@

# After loading libhsa-runtime64.so we can load the AIR runtime functions
try:
ctypes.CDLL(f"{path}/../../../runtime_lib/airhost/libairhost_shared.so", mode=ctypes.RTLD_GLOBAL)
ctypes.CDLL(f"{install_path()}/runtime_lib/x86_64/airhost/libairhost_shared.so", mode=ctypes.RTLD_GLOBAL)
except Exception as e:
print("[WARNING] We were not able to load .so for libairhost_shared.so");
print(e)
pass
import air.mlir._mlir_libs._airRt as airrt
import air._mlir_libs._airRt as airrt

__all__ = [
"LinalgOnTensorsAirBackend",
Expand Down Expand Up @@ -112,25 +110,26 @@ def compile(self, imported_module: torch_mlir.ir.Module, pipeline=None,

if type(imported_module) is torch_mlir.ir.Module:
with imported_module.context:
imported_module = torchscript._lower_mlir_module(False, torchscript.OutputType.LINALG_ON_TENSORS, imported_module)
pm = torch_mlir.passmanager.PassManager.parse('builtin.module(refback-mlprogram-bufferize)')
pm.run(imported_module)
pm.run(imported_module.operation)

with air.mlir.ir.Context():
air_module = air.mlir.ir.Module.parse(str(imported_module))
pm = air.mlir.passmanager.PassManager.parse(
with air.ir.Context():
air_module = air.ir.Module.parse(str(imported_module))
pm = air.passmanager.PassManager.parse(
air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE)

if verbose:
print("Running MLIR pass pipeline: ",
air.compiler.util.LINALG_TENSOR_TO_MEMREF_PIPELINE)

pm.run(air_module)
pm.run(air_module.operation)

if verbose:
print("Running MLIR pass pipeline: ", pipeline)

pm = air.mlir.passmanager.PassManager.parse(pipeline)
pm.run(air_module)
pm = air.passmanager.PassManager.parse(pipeline)
pm.run(air_module.operation)

if verbose:
print("AIR Module:")
Expand Down Expand Up @@ -164,15 +163,15 @@ def load(self, module):
# Keeping the agent and queue as a part of the backend so Python doesn't delete them
self.a = a
self.q = q

return self.refbackend.load(module)

def unload(self):
"""Unload any loaded module and shutdown the air runtime."""
if self.handle:
airrt.host.module_unload(self.handle)
self.handle = None
airrt.host.shut_down()
self.q = None
self.a = None

def make_dynamo_backend(pipeline=None, verbose=False,
segment_offset=None, segment_size=None):
Expand All @@ -193,12 +192,10 @@ def make_dynamo_backend(pipeline=None, verbose=False,
@make_simple_dynamo_backend
def air_backend(fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor]):

# get the linalg mlir of the model from torch_mlir
mlir_module = torch_mlir.compile(
mlir_module = torchscript.compile(
fx_graph, example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)

output_type="linalg-on-tensors")
# compile the mlir model with aircc
compiled = backend.compile(mlir_module, pipeline=pipeline,
verbose=verbose, segment_offset=segment_offset,
Expand Down
7 changes: 7 additions & 0 deletions python/air/compiler/aircc/configure.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
#
# (c) Copyright 2023 Advanced Micro Devices Inc.

import os

air_link_with_xchesscc = @CONFIG_LINK_WITH_XCHESSCC@
air_compile_with_xchesscc = @CONFIG_COMPILE_WITH_XCHESSCC@
libxaie_path = "@XILINX_XAIE_DIR@"
rocm_path = "@hsa-runtime64_DIR@"

def install_path():
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '..', '..', '..', '..')
return os.path.realpath(path)
23 changes: 15 additions & 8 deletions python/test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
# -*- Python -*-

import os
import platform
import re
import subprocess
import tempfile

import lit.formats
import lit.util
Expand All @@ -26,9 +22,20 @@

config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
config.environment['PYTHONPATH'] \
= "{}".format(os.path.join(config.air_obj_root, "python"))
#os.environ['PYTHONPATH']
print("Running with PYTHONPATH",config.environment['PYTHONPATH'])
= "{}:{}".format(os.path.join(config.air_obj_root, "python"),
os.path.join(config.aie_obj_root, "python"))

try:
import torch_mlir
torch_mlir_path = os.path.join(torch_mlir.__path__[0], "..")
print ("found torch_mlir:", torch_mlir_path)
config.environment['PYTHONPATH'] = config.environment['PYTHONPATH'] + ":" + torch_mlir_path
config.available_features.add("torch_mlir")
except:
print("torch_mlir not found")
pass

print("Running with PYTHONPATH", config.environment['PYTHONPATH'])

# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.py']
Expand All @@ -52,7 +59,7 @@
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
# subdirectories contain auxiliary inputs for various tests in their parent
# directories.
config.excludes = ['lit.cfg.py', 'torch_mlir_e2e']
config.excludes = ['lit.cfg.py']

# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)
Expand Down
20 changes: 13 additions & 7 deletions python/test/torch_mlir_e2e/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
# Copyright (C) 2022, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASSED

import torch
import torch._dynamo as dynamo
import numpy
from air.backend import linalg_on_tensors as backend

air_backend = backend.make_dynamo_backend()
import air.backend.linalg_on_tensors as air_backend
from torch_mlir import fx

class model(torch.nn.Module):
def __init__(self):
Expand All @@ -23,17 +25,21 @@ def forward(self, a, b):
return x

def run_test(dtype, shape):
program = model()
dynamo_program = dynamo.optimize(air_backend)(program)
torch_program = model()

a = torch.randint(size = shape, low=1, high=100, dtype=dtype)
b = torch.randint(size = shape, low=1, high=100, dtype=dtype)
c = dynamo_program(a, b)
c_ref = program(a, b)
m = fx.export_and_import(torch_program, a, b, func_name="forward")

backend = air_backend.LinalgOnTensorsAirBackend()
air_program = backend.load(backend.compile(m))

c_ref = torch_program(a, b)
c = torch.tensor(air_program.forward(a.numpy(), b.numpy()))

print(f"input:\n{a}\n{b}\noutput:\n{c}")

if torch.allclose(c_ref,c):
if torch.allclose(c_ref, c):
print("PASS!")
return 1
else:
Expand Down
2 changes: 2 additions & 0 deletions python/test/torch_mlir_e2e/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright (C) 2022, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASS

Expand Down
2 changes: 2 additions & 0 deletions python/test/torch_mlir_e2e/matmul_mul_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright (C) 2022, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASS

Expand Down
2 changes: 2 additions & 0 deletions python/test/torch_mlir_e2e/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright (C) 2022, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASSED

Expand Down
2 changes: 2 additions & 0 deletions python/test/torch_mlir_e2e/mul_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASSED

Expand Down
2 changes: 2 additions & 0 deletions python/test/torch_mlir_e2e/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Copyright (C) 2022, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

# REQUIRES: torch_mlir, needs_update

# RUN: %PYTHON %s | FileCheck %s
# CHECK: PASSED

Expand Down

0 comments on commit 8afa722

Please sign in to comment.