-
Notifications
You must be signed in to change notification settings - Fork 638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable rocm and vulkan build in CI workflow for PJRT plugin #19279
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: PragmaTwice <[email protected]>
Signed-off-by: PragmaTwice <[email protected]>
c30acce
to
02c6694
Compare
Signed-off-by: PragmaTwice <[email protected]>
02c6694
to
bf355b0
Compare
Signed-off-by: PragmaTwice <[email protected]>
Signed-off-by: PragmaTwice <[email protected]>
Signed-off-by: PragmaTwice <[email protected]>
Signed-off-by: PragmaTwice <[email protected]>
Signed-off-by: PragmaTwice <[email protected]>
311bd17
to
f39168f
Compare
Signed-off-by: PragmaTwice <[email protected]>
f39168f
to
f0258ad
Compare
iree::experimental::rocm | ||
iree::experimental::rocm::registration | ||
iree::hal::drivers::hip | ||
iree::hal::drivers::hip::registration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oof sorry that this code has bit-rotted so much, and thanks for finding what needed to be updated.
Signed-off-by: PragmaTwice <[email protected]>
Hmm the error in rocm plugin comes with a hint ( I'll investigate these errors in these days. |
// TODO: here we just use the target name of the first available device, | ||
// but ideally we should find the device which will run the program | ||
if (device_info_count > 0) { | ||
hipDeviceProp_tR0000 props; | ||
IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties( | ||
*out_driver, device_infos->device_id, &props)); | ||
|
||
// `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully | ||
// specified target. However the IREE target-chip flag only expects the | ||
// prefix. refer to | ||
// https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495 | ||
std::string_view target = props.gcnArchName; | ||
if (auto pos = target.find(':'); pos != target.npos) { | ||
target = target.substr(0, pos); | ||
} | ||
|
||
hip_target_ = target; | ||
logger().debug("HIP target detected: " + hip_target_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+ python test/test_add.py WARNING:jax._src.xla_bridge:Platform 'iree_rocm' is experimental and not all JAX functionality may be correctly supported! [IREE-PJRT] DEBUG: Using IREE compiler binary: /home/esaimana/actions-runner/_work/iree/iree/.venv/lib/python3.11/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so [IREE-PJRT] DEBUG: Compiler Version: 3.1.0.dev+da45650503b955e[35](https://github.com/iree-org/iree/actions/runs/12031928938/job/33543044454?pr=19279#step:7:36)a45ea8e34b5f10b30a2f912 @ da45650503b955e35a45ea8e34b5f10b30a2f912 (API version 1.4) [IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var) [IREE-PJRT] DEBUG: ROCM driver created [IREE-PJRT] DEBUG: HIP target detected: gfx1100 iree/runtime/src/iree/hal/drivers/hip/native_executable.c:306: UNKNOWN; HIP driver error 'hipErrorFileNotFound' (301): file not found; mismatched target chip? missing/wrong bitcode directory?; while invoking native function hal.executable.create; while calling import; [ 0] bytecode jit__add.__init:1284 "jit(_add)/jit(main)/add"("<module>"(/home/esaimana/actions-runner/_work/iree/iree/integrations/pjrt/test/test_add.py:9:6)) build_tools/testing/run_jax_tests.sh: line 31: 888283 Aborted (core dumped) JAX_PLATFORMS=$actual_jax_platform python $test_py_file > $actual_tmp_out
Hmm the error in rocm plugin comes with a hint (
mismatched target chip? missing/wrong bitcode directory?
) but it seems weird that I've already add some new logic to detect the HIP target and pass to IREE (via--iree-hip-target
) before the compilation phase.I'll investigate these errors in these days.
For test workflows, we typically set an environment variable explicitly, matched with the hardware installed into the runner chosen using the runs-on
property:
IREE_HIP_TEST_TARGET_CHIP: "gfx1100" |
In PJRT, matching the code from iree-turbine that detects available devices makes sense to me, instead of that environment variable approach. This is a JIT scenario, where the compiler is being used to generate code for the current device, not an arbitrary deployment target.
Can you have the PJRT client log the compile command used, including all flags?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I'll dump more information for debugging.
Some finding about the vulkan PJRT plugin: I read the relevant code in jax and debugged it locally. It seems that the reason of the failure is that 4 GPU devices are enumerated in IREE HAL vulkan driver, while in jax the sharding object ( A very simple workaround is to make the iree PJRT plugin return only one addressable device, but this is obviously not a good solution. Also not sure if it's related to the partitioning part in the PJRT plugin, still in investigating. |
cc-ing some folks who may have suggestions there: @benvanik @AWoloszyn @antiagainst @sogartar |
About the vulkan failure (actually not about vulkan, it can also fail on cuda/rocm if multiple GPUs are provided): After comparing the behavior between IREE PJRT plugin and the official JAX cuda plugin, I finally got some clue. I think the difference is that, on JAX cuda, there's only one device listed in the loaded executable, but multiple (acutually all) devices will be returned in IREE PJRT. And the reason is that in IREE PJRT's impl of PJRT_Client_Compile, I think I should split this PR to multiple ones, including the fix to the issue above. |
…#19369) It closes #19366, and blocks #19279. After this PR, `ClientOptions::Compile` will first check the device assignment in the compile options, and then return the corresponding device list with the loaded executable. To achieve this, we introduce protobuf via `FetchContent` in this PR, which is scoped to the PJRT plugin. Compile options will be passed by the PJRT client encoded in protobuf, and in this plugin we decode it first and then retrieve some interesting fields. ci-exactly: build_packages, test_pjrt --------- Signed-off-by: PragmaTwice <[email protected]> Co-authored-by: Scott Todd <[email protected]>
Regarding |
Ahh thanks! I've opened a PR #19369 to fix it and it's now merged, but thank you for your information. Maybe in the future we can utilize such feauture via a configuration as well. |
After merging #19369 into this branch, I think the Vulkan PJRT plugin starts to work and execute some computations. But it still aborts, due to another failure:
I believe this implies that, the computation is actually done, but in IREE runtime it fails to copy the computation result into host memory to show the result (via iree/runtime/src/iree/hal/drivers/vulkan/native_allocator.cc Lines 616 to 633 in cb59389
(if this path is only for mesa (or some of its underlying drivers?), maybe we can do some check and allow other drivers to pass?) Also, I think it's better to split this PR to a ROCm one and a Vulkan one, to make us have a clearer understanding of the current status of different platforms, and also avoid them to block each other. |
This PR trys to enable rocm and vulkan PJRT plugin in CI workflow (via the runner
nodai-amdgpu-w7900-x86-64
, mentioned here: #19222 (comment)).Currently it's marked as a draft PR for potentially CI debuging.
ci-exactly: build_packages, test_pjrt