Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable rocm and vulkan build in CI workflow for PJRT plugin #19279

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Nov 23, 2024

This PR trys to enable rocm and vulkan PJRT plugin in CI workflow (via the runner nodai-amdgpu-w7900-x86-64, mentioned here: #19222 (comment)).

Currently it's marked as a draft PR for potentially CI debuging.

ci-exactly: build_packages, test_pjrt

@PragmaTwice PragmaTwice force-pushed the pjrt-ci-rocm branch 5 times, most recently from 311bd17 to f39168f Compare November 24, 2024 07:08
.github/workflows/pkgci_test_pjrt.yml Show resolved Hide resolved
Comment on lines -16 to +17
iree::experimental::rocm
iree::experimental::rocm::registration
iree::hal::drivers::hip
iree::hal::drivers::hip::registration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof sorry that this code has bit-rotted so much, and thanks for finding what needed to be updated.

@PragmaTwice
Copy link
Member Author

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: ROCM driver created
[IREE-PJRT] DEBUG: HIP target detected: gfx1100
iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; 
[ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6))
build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted                 (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

Hmm the error in rocm plugin comes with a hint (mismatched target chip? missing/wrong bitcode directory?) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via --iree-hip-target) before the compilation phase.

I'll investigate these errors in these days.

Comment on lines +48 to +65
// TODO: here we just use the target name of the first available device,
// but ideally we should find the device which will run the program
if (device_info_count > 0) {
hipDeviceProp_tR0000 props;
IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties(
*out_driver, device_infos->device_id, &props));

// `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully
// specified target. However the IREE target-chip flag only expects the
// prefix. refer to
// https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495
std::string_view target = props.gcnArchName;
if (auto pos = target.find(':'); pos != target.npos) {
target = target.substr(0, pos);
}

hip_target_ = target;
logger().debug("HIP target detected: " + hip_target_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: ROCM driver created
[IREE-PJRT] DEBUG: HIP target detected: gfx1100
iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; 
[ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6))
build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted                 (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out

Hmm the error in rocm plugin comes with a hint (mismatched target chip? missing/wrong bitcode directory?) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via --iree-hip-target) before the compilation phase.

I'll investigate these errors in these days.

For test workflows, we typically set an environment variable explicitly, matched with the hardware installed into the runner chosen using the runs-on property:

IREE_HIP_TEST_TARGET_CHIP: "gfx1100"

In PJRT, matching the code from iree-turbine that detects available devices makes sense to me, instead of that environment variable approach. This is a JIT scenario, where the compiler is being used to generate code for the current device, not an arbitrary deployment target.

Can you have the PJRT client log the compile command used, including all flags?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll dump more information for debugging.

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Nov 28, 2024

+ python test/test_add.py
WARNING:jax._src.xla_bridge:Platform 'iree_vulkan' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044860?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: Vulkan driver created
Traceback (most recent call last):
  File "/home/esaimana/actions-runner-2/_work/iree/iree/integrations/pjrt/test/test_add.py", line 7, in <module>
    a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5592, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 5426, in array
    out_array: Array = lax_internal._convert_element_type(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 587, in _convert_element_type
    return convert_element_type_p.bind(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 2981, in _convert_element_type_bind
    operand = core.Primitive.bind(convert_element_type_p, operand,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 4[38](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044860?pr=19279#step:7:39), in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 442, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 955, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Expected args to execute_sharded_on_local_devices to have 4 shards, got: [1]

Some finding about the vulkan PJRT plugin: I read the relevant code in jax and debugged it locally. It seems that the reason of the failure is that 4 GPU devices are enumerated in IREE HAL vulkan driver, while in jax the sharding object (GSPMDSharding) keeps only one device inside (I haven't checked why). This causes jax to not correctly multiply the args by the number of devices when replicating the execution (_shard_np_array, execute_sharded), resulting in the mismatch of the number of parameters passed to execute_sharded_on_local_devices.

A very simple workaround is to make the iree PJRT plugin return only one addressable device, but this is obviously not a good solution. Also not sure if it's related to the partitioning part in the PJRT plugin, still in investigating.

@ScottTodd
Copy link
Member

Some finding about the vulkan PJRT plugin: I read the relevant code in jax and debugged it locally. It seems that the reason of the failure is that 4 GPU devices are enumerated in IREE HAL vulkan driver, while in jax the sharding object (GSPMDSharding) keeps only one device inside (I haven't checked why). This causes jax to not correctly multiply the args by the number of devices when replicating the execution (_shard_np_array, execute_sharded), resulting in the mismatch of the number of parameters passed to execute_sharded_on_local_devices.

A very simple workaround is to make the iree PJRT plugin return only one addressable device, but this is obviously not a good solution. Also not sure if it's related to the partitioning part in the PJRT plugin, still in investigating.

cc-ing some folks who may have suggestions there: @benvanik @AWoloszyn @antiagainst @sogartar

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Dec 4, 2024

About the vulkan failure (actually not about vulkan, it can also fail on cuda/rocm if multiple GPUs are provided):

After comparing the behavior between IREE PJRT plugin and the official JAX cuda plugin, I finally got some clue. I think the difference is that, on JAX cuda, there's only one device listed in the loaded executable, but multiple (acutually all) devices will be returned in IREE PJRT. And the reason is that in IREE PJRT's impl of PJRT_Client_Compile, args->compile_options is ignored. We should read the device id/assignment from the compile options and then pass it to the returned executable (rather than forward all devices to it). And compile_options is encoded in protobuf (the schema is here), we need to decode it first (maybe via protobuf c API generated from protoc? so that we don't need to import any XLA source files, except the schema file).

I think I should split this PR to multiple ones, including the fix to the issue above.

ScottTodd added a commit that referenced this pull request Dec 6, 2024
…#19369)

It closes #19366, and blocks #19279.

After this PR, `ClientOptions::Compile` will first check the device
assignment in the compile options, and then return the corresponding
device list with the loaded executable.

To achieve this, we introduce protobuf via `FetchContent` in this PR,
which is scoped to the PJRT plugin. Compile options will be passed by
the PJRT client encoded in protobuf, and in this plugin we decode it
first and then retrieve some interesting fields.

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <[email protected]>
Co-authored-by: Scott Todd <[email protected]>
@AWoloszyn
Copy link
Contributor

Regarding A very simple workaround is to make the iree PJRT plugin return only one addressable device, but this is obviously not a good solution.. With the new hip changes I have in #18790 you can expose a number of devices as a single logical device and use queue affinities to interact with each individual device.

@PragmaTwice
Copy link
Member Author

Regarding A very simple workaround is to make the iree PJRT plugin return only one addressable device, but this is obviously not a good solution.. With the new hip changes I have in #18790 you can expose a number of devices as a single logical device and use queue affinities to interact with each individual device.

Ahh thanks! I've opened a PR #19369 to fix it and it's now merged, but thank you for your information. Maybe in the future we can utilize such feauture via a configuration as well.

@PragmaTwice
Copy link
Member Author

PragmaTwice commented Dec 7, 2024

After merging #19369 into this branch, I think the Vulkan PJRT plugin starts to work and execute some computations. But it still aborts, due to another failure:

Traceback (most recent call last):
  File "/home/esaimana/actions-runner-2/_work/iree/iree/integrations/pjrt/test/test_add.py", line 9, in <module>
    print(a + a)
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/array.py", line 284, in __str__
    return str(self._value)
               ^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/esaimana/actions-runner-2/_work/iree/iree/.venv/lib/python3.11/site-packages/jax/_src/array.py", line 628, in _value
    self._npy_value = self._single_device_array_to_np_array()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: iree/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc:631: INVALID_ARGUMENT; cannot import mapped memory

I believe this implies that, the computation is actually done, but in IREE runtime it fails to copy the computation result into host memory to show the result (via print). Here is related code:

// First check if the memory is importable.
// Some drivers incorrectly succeed when attempting to import already-mapped
// memory: https://gitlab.freedesktop.org/mesa/mesa/-/issues/9251.
//
// Attempt to synchronize the file with its memory map.
// If the memory is not mapped from a file, attempting to synchronize it with
// its memory map should fail fast and we can import the buffer. If the memory
// *is* mapped, import may fail on some drivers (this may also be slow).
// TODO(scotttodd): Further restrict this slow path to buggy drivers only?
// We'd need to plumb some driver information through to here
errno = 0;
(void)msync(external_buffer->handle.host_allocation.ptr,
external_buffer->size, MS_SYNC);
if (errno != ENOMEM) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"cannot import mapped memory");
}

(if this path is only for mesa (or some of its underlying drivers?), maybe we can do some check and allow other drivers to pass?)

Also, I think it's better to split this PR to a ROCm one and a Vulkan one, to make us have a clearer understanding of the current status of different platforms, and also avoid them to block each other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants