diff --git a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt index ef9576d8d66c5..7fd5abf072948 100644 --- a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( "api.h" SRCS "api.h" + "context_util.h" "event_pool.c" "event_pool.h" "event_semaphore.c" diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h new file mode 100644 index 0000000000000..1aa1d79b4c28f --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -0,0 +1,34 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ +#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ + +#include "iree/base/api.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/status_util.h" + +static inline iree_status_t iree_hal_hip_set_context( + const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) { + if (!hip_context) { + return iree_ok_status(); + } + IREE_TRACE({ + hipCtx_t current_context = NULL; + IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(¤t_context), + "hipCtxGetCurrent"); + if (current_context != hip_context) { + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch"); + iree_status_t status = + IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); + IREE_TRACE_ZONE_END(z0); + return status; + } + }); + return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); +} + +#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h index 963784d81dcc1..85d33740c83da 100644 --- a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h @@ -8,6 +8,7 @@ // HIP symbols //===----------------------------------------------------------------------===// +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index 38c93113b8390..010cfd67fd2e8 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -14,6 +14,7 @@ #include "iree/base/internal/atomics.h" #include "iree/base/internal/synchronization.h" #include "iree/hal/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" @@ -36,6 +37,10 @@ struct iree_hal_hip_event_t { // The event pool that owns this event. This cannot be NULL. We retain it to // make sure the event outlive the pool. iree_hal_hip_event_pool_t* pool; + + // The context to use to free this event, it must be the same + // context as was used when allocating the event. + hipCtx_t hip_context; // The underlying hipEvent_t object. hipEvent_t hip_event; }; @@ -48,6 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(event->symbols, event->hip_context)); IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event)); @@ -58,8 +65,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { static inline iree_status_t iree_hal_hip_event_create( const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator, - iree_hal_hip_event_t** out_event) { + iree_hal_hip_event_pool_t* pool, hipCtx_t context, + iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(pool); IREE_ASSERT_ARGUMENT(out_event); @@ -75,6 +82,7 @@ static inline iree_status_t iree_hal_hip_event_create( event->symbols = symbols; event->pool = pool; event->hip_event = NULL; + event->hip_context = context; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( symbols, @@ -122,6 +130,10 @@ struct iree_hal_hip_event_pool_t { // The symbols used to create and destroy hipEvent_t objects. const iree_hal_hip_dynamic_symbols_t* symbols; + // The context for this event pool to use to allocate + // events. + hipCtx_t hip_context; + // Guards event related fields in the pool. We don't expect a performant // program to frequently allocate events for synchronization purposes; the // traffic to this pool should be low. So it should be fine to use mutex to @@ -142,7 +154,7 @@ struct iree_hal_hip_event_pool_t { static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool); iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, + const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, iree_host_size_t available_capacity, iree_allocator_t host_allocator, iree_hal_hip_event_pool_t** out_event_pool) { IREE_ASSERT_ARGUMENT(symbols); @@ -163,11 +175,12 @@ iree_status_t iree_hal_hip_event_pool_allocate( iree_slim_mutex_initialize(&event_pool->event_mutex); event_pool->available_capacity = available_capacity; event_pool->available_count = 0; + event_pool->hip_context = hip_context; iree_status_t status = iree_ok_status(); for (iree_host_size_t i = 0; i < available_capacity; ++i) { status = iree_hal_hip_event_create( - symbols, event_pool, host_allocator, + symbols, event_pool, hip_context, host_allocator, &event_pool->available_list[event_pool->available_count++]); if (!iree_status_is_ok(status)) break; } @@ -240,9 +253,9 @@ iree_status_t iree_hal_hip_event_pool_acquire( IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire"); iree_status_t status = iree_ok_status(); for (iree_host_size_t i = 0; i < remaining_count; ++i) { - status = iree_hal_hip_event_create(event_pool->symbols, event_pool, - event_pool->host_allocator, - &out_events[from_pool_count + i]); + status = iree_hal_hip_event_create( + event_pool->symbols, event_pool, event_pool->hip_context, + event_pool->host_allocator, &out_events[from_pool_count + i]); if (!iree_status_is_ok(status)) { // Must release all events we've acquired so far. iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.h b/runtime/src/iree/hal/drivers/hip/event_pool.h index ea09b90cc7cf2..0683714d97240 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.h +++ b/runtime/src/iree/hal/drivers/hip/event_pool.h @@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t; // Extra events requested beyond the capability are directly created and // destroyed without pooling. iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, + const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, iree_host_size_t available_capacity, iree_allocator_t host_allocator, iree_hal_hip_event_pool_t** out_event_pool); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index de10b09125ec8..686e23015b8f7 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -9,7 +9,9 @@ #include "iree/base/internal/synchronization.h" #include "iree/base/internal/wait_handle.h" #include "iree/base/status.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/drivers/hip/timepoint_pool.h" #include "iree/hal/utils/semaphore_base.h" @@ -30,6 +32,8 @@ typedef struct iree_hal_hip_semaphore_t { // new signaled values. iree_hal_deferred_work_queue_t* work_queue; + hipCtx_t hip_context; + // Guards value and status. We expect low contention on semaphores and since // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler // than trying to make the entire structure lock-free. @@ -56,7 +60,7 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast( iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_timepoint_pool_t* timepoint_pool, + hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(symbols); @@ -65,6 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create( IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(symbols, hip_context)); iree_hal_hip_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), @@ -79,6 +85,7 @@ iree_status_t iree_hal_hip_event_semaphore_create( iree_slim_mutex_initialize(&semaphore->mutex); semaphore->current_value = initial_value; semaphore->failure_status = iree_ok_status(); + semaphore->hip_context = hip_context; *out_semaphore = &semaphore->base; @@ -92,6 +99,8 @@ static void iree_hal_hip_semaphore_destroy( iree_hal_hip_semaphore_cast(base_semaphore); iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context)); iree_status_ignore(semaphore->failure_status); iree_slim_mutex_deinitialize(&semaphore->mutex); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.h b/runtime/src/iree/hal/drivers/hip/event_semaphore.h index 88a75e01c4361..a5d8ff95b3691 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.h @@ -31,7 +31,7 @@ extern "C" { // Thread-safe; multiple threads may signal/wait values on the same semaphore. iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_timepoint_pool_t* timepoint_pool, + hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index 041e011536f83..95a04cd1812b2 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -10,6 +10,7 @@ #include "iree/base/api.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -29,6 +30,8 @@ typedef struct iree_hal_hip_allocator_t { // The HIP stream that allocations should be used in. hipStream_t stream; + hipCtx_t hip_context; + // NOTE: optional depending on device support. iree_hal_hip_memory_pools_t* pools; @@ -54,11 +57,14 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast( iree_status_t iree_hal_hip_allocator_create( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipStream_t stream, iree_hal_hip_memory_pools_t* pools, - iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { + hipCtx_t hip_context, hipStream_t stream, + iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); // To support device-local + host-visible memory we need concurrent managed // access indicating that the host and devices can concurrently access the @@ -94,6 +100,7 @@ iree_status_t iree_hal_hip_allocator_create( allocator->host_allocator = host_allocator; allocator->supports_concurrent_managed_access = supports_concurrent_managed_access != 0; + allocator->hip_context = hip_context; *out_allocator = (iree_hal_allocator_t*)allocator; IREE_TRACE_ZONE_END(z0); @@ -352,6 +359,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( void* host_ptr = NULL; hipDeviceptr_t device_ptr = NULL; IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); if (iree_all_bits_set(compat_params.type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { @@ -431,6 +441,9 @@ static void iree_hal_hip_allocator_deallocate_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + const iree_hal_hip_buffer_type_t buffer_type = iree_hal_hip_buffer_type(base_buffer); @@ -466,6 +479,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; iree_device_size_t allocation_size = external_buffer->size; @@ -600,6 +616,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + hipDeviceptr_t ptr = NULL; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( allocator->symbols, @@ -625,6 +644,9 @@ iree_status_t iree_hal_hip_allocator_free_async( iree_hal_buffer_t* buffer) { iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (!device_ptr) { return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.h b/runtime/src/iree/hal/drivers/hip/hip_allocator.h index bb83ea18e5ad4..89ed0936be0c8 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.h +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.h @@ -17,13 +17,14 @@ extern "C" { #endif // __cplusplus // Creates a HIP memory allocator. -// |device| and |stream| will be used for management operations. +// |device| |hip_context| and |stream| will be used for management operations. // |pools| provides memory pools that may be shared across multiple allocators // and the pointer must remain valid for the lifetime of the allocator. iree_status_t iree_hal_hip_allocator_create( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipStream_t stream, iree_hal_hip_memory_pools_t* pools, - iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); + hipCtx_t hip_context, hipStream_t stream, + iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator); bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value); diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index e48b5241a0b6d..bc001d645c3a3 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -14,6 +14,7 @@ #include "iree/base/internal/event_pool.h" #include "iree/base/internal/math.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/event_pool.h" #include "iree/hal/drivers/hip/event_semaphore.h" @@ -538,13 +539,13 @@ static iree_status_t iree_hal_hip_device_create_internal( // Create memory pools first so that we can share them with the allocator. if (iree_status_is_ok(status) && device->supports_memory_pools) { status = iree_hal_hip_memory_pools_initialize( - symbols, hip_device, ¶ms->memory_pools, host_allocator, + symbols, hip_device, context, ¶ms->memory_pools, host_allocator, &device->memory_pools); } if (iree_status_is_ok(status)) { status = iree_hal_hip_allocator_create( - symbols, hip_device, dispatch_stream, + symbols, hip_device, context, dispatch_stream, device->supports_memory_pools ? &device->memory_pools : NULL, host_allocator, &device->device_allocator); } @@ -608,9 +609,9 @@ iree_status_t iree_hal_hip_device_create( iree_hal_hip_event_pool_t* device_event_pool = NULL; if (iree_status_is_ok(status)) { - status = - iree_hal_hip_event_pool_allocate(symbols, params->event_pool_capacity, - host_allocator, &device_event_pool); + status = iree_hal_hip_event_pool_allocate( + symbols, context, params->event_pool_capacity, host_allocator, + &device_event_pool); } iree_hal_hip_timepoint_pool_t* timepoint_pool = NULL; @@ -731,6 +732,8 @@ static void iree_hal_hip_replace_channel_provider( static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_arena_block_pool_trim(&device->block_pool); IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); if (device->supports_memory_pools) { @@ -743,6 +746,8 @@ static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); int value = 0; IREE_HIP_RETURN_IF_ERROR( device->hip_symbols, @@ -756,6 +761,8 @@ static iree_status_t iree_hal_hip_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); *out_value = 0; if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { @@ -779,6 +786,9 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + if (!device->nccl_symbols || !device->nccl_symbols->dylib) { return iree_make_status( IREE_STATUS_UNAVAILABLE, @@ -865,11 +875,14 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->tracing_context, mode, command_categories, - binding_capacity, device->hip_dispatch_stream, &device->block_pool, - device->host_allocator, out_command_buffer); + device->nccl_symbols, device->hip_context, device->tracing_context, mode, + command_categories, binding_capacity, device->hip_dispatch_stream, + &device->block_pool, device->host_allocator, out_command_buffer); } static iree_status_t iree_hal_hip_device_create_command_buffer( @@ -878,6 +891,9 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + if (device->params.allow_inline_execution && iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION)) { @@ -887,9 +903,9 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( // directly route commands to a HIP stream and let it eagerly flush. return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->tracing_context, mode, command_categories, - binding_capacity, device->hip_dispatch_stream, &device->block_pool, - device->host_allocator, out_command_buffer); + device->nccl_symbols, device->hip_context, device->tracing_context, + mode, command_categories, binding_capacity, device->hip_dispatch_stream, + &device->block_pool, device->host_allocator, out_command_buffer); } switch (device->params.command_buffer_mode) { case IREE_HAL_HIP_COMMAND_BUFFER_MODE_GRAPH: @@ -930,6 +946,10 @@ static iree_status_t iree_hal_hip_device_import_file( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_memory_access_t access, iree_io_file_handle_t* handle, iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + if (iree_io_file_handle_type(handle) != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { return iree_make_status( @@ -945,8 +965,10 @@ static iree_status_t iree_hal_hip_device_create_executable_cache( iree_hal_device_t* base_device, iree_string_view_t identifier, iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_nop_executable_cache_create( - identifier, device->hip_symbols, device->hip_device, + identifier, device->hip_symbols, device->hip_device, device->hip_context, device->host_allocator, out_executable_cache); } @@ -954,9 +976,13 @@ static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_device_t* base_device, uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + return iree_hal_hip_event_semaphore_create( - initial_value, device->hip_symbols, device->timepoint_pool, - device->work_queue, device->host_allocator, out_semaphore); + initial_value, device->hip_symbols, device->hip_context, + device->timepoint_pool, device->work_queue, device->host_allocator, + out_semaphore); } static iree_hal_semaphore_compatibility_t @@ -1006,6 +1032,8 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (device->supports_memory_pools && !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { @@ -1083,6 +1111,9 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + if (iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device))) { iree_status_t status = iree_hal_deferred_work_queue_enqueue_dealloc( device->work_queue, wait_semaphore_list, signal_semaphore_list, buffer); @@ -1169,7 +1200,10 @@ static iree_status_t iree_hal_hip_device_queue_execute( iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_binding_table_t binding_table) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_status_t status = iree_hal_deferred_work_queue_enqueue( device->work_queue, iree_hal_hip_device_collect_tracing_context, @@ -1199,6 +1233,9 @@ static iree_status_t iree_hal_hip_device_wait_semaphores( iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + return iree_hal_hip_semaphore_multi_wait(semaphore_list, wait_mode, timeout, &device->block_pool); } diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index 0258fa03198f4..93a046b1f59ec 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -6,6 +6,7 @@ #include "iree/hal/drivers/hip/memory_pools.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -59,6 +60,7 @@ static iree_status_t iree_hal_hip_create_memory_pool( iree_status_t iree_hal_hip_memory_pools_initialize( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, + hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools) { @@ -66,10 +68,13 @@ iree_status_t iree_hal_hip_memory_pools_initialize( IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); memset(out_pools, 0, sizeof(*out_pools)); out_pools->hip_symbols = hip_symbols; out_pools->host_allocator = host_allocator; + out_pools->hip_context = hip_context; iree_status_t status = iree_ok_status(); @@ -91,6 +96,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); if (pools->device_local) { IREE_HIP_IGNORE_ERROR(pools->hip_symbols, @@ -149,6 +156,9 @@ static void iree_hal_hip_memory_pool_track_free( void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); + IREE_STATISTICS({ statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); @@ -181,6 +191,9 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_status_t iree_hal_hip_memory_pools_trim( iree_hal_hip_memory_pools_t* pools, const iree_hal_hip_memory_pooling_params_t* pooling_params) { + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); + IREE_HIP_RETURN_IF_ERROR( pools->hip_symbols, hipMemPoolTrimTo(pools->device_local, @@ -200,6 +213,8 @@ static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { @@ -213,6 +228,9 @@ static void iree_hal_hip_async_buffer_release_callback( iree_status_t iree_hal_hip_memory_pools_allocate_pointer( iree_hal_hip_memory_pools_t* pools, iree_hal_buffer_t* buffer, hipStream_t stream, iree_device_size_t allocation_size) { + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); + // TODO: more pools and better selection; this is coarsely deciding between // only device local (variables, constants, transients) and other (staging, // external) but could use more buffer properties (including usage/export @@ -277,6 +295,8 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_hip_memory_pools_t* pools, hipStream_t stream, iree_hal_buffer_t* buffer) { IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.h b/runtime/src/iree/hal/drivers/hip/memory_pools.h index c505bbe34e98b..f95b76d9816b1 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.h +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.h @@ -33,6 +33,7 @@ typedef struct iree_hal_hip_memory_pools_t { hipMemPool_t other; const iree_hal_hip_dynamic_symbols_t* hip_symbols; + hipCtx_t hip_context; iree_allocator_t host_allocator; IREE_STATISTICS(struct { @@ -46,6 +47,7 @@ typedef struct iree_hal_hip_memory_pools_t { // Initializes |out_pools| by configuring new HIP memory pools. iree_status_t iree_hal_hip_memory_pools_initialize( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, + hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index 18b3d378ed9a3..e37aba578f02c 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -9,6 +9,7 @@ #include #include "iree/base/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/utils/executable_debug_info.h" @@ -207,12 +208,14 @@ static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify( iree_status_t iree_hal_hip_native_executable_create( const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - const iree_hal_executable_params_t* executable_params, + hipCtx_t context, const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + iree_hal_hip_set_context(symbols, context)); *out_executable = NULL; diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.h b/runtime/src/iree/hal/drivers/hip/native_executable.h index beb1e7cf92f5a..b67f5e73e5991 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.h +++ b/runtime/src/iree/hal/drivers/hip/native_executable.h @@ -49,7 +49,7 @@ typedef struct iree_hal_hip_kernel_params_t { // several kernels that can be extracted along with the associated block size. iree_status_t iree_hal_hip_native_executable_create( const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - const iree_hal_executable_params_t* executable_params, + hipCtx_t context, const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable); // Returns the kernel launch parameters for the given |entry_point| in the diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c index c85d06a4a3d70..9680e3bf9f90e 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c @@ -23,6 +23,7 @@ typedef struct iree_hal_hip_nop_executable_cache_t { const iree_hal_hip_dynamic_symbols_t* symbols; hipDevice_t device; + hipCtx_t hip_context; } iree_hal_hip_nop_executable_cache_t; static const iree_hal_executable_cache_vtable_t @@ -38,7 +39,7 @@ iree_hal_hip_nop_executable_cache_cast( iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - iree_allocator_t host_allocator, + hipCtx_t hip_context, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(out_executable_cache); @@ -55,6 +56,7 @@ iree_status_t iree_hal_hip_nop_executable_cache_create( executable_cache->host_allocator = host_allocator; executable_cache->symbols = symbols; executable_cache->device = device; + executable_cache->hip_context = hip_context; *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; @@ -89,7 +91,8 @@ static iree_status_t iree_hal_hip_nop_executable_cache_prepare_executable( iree_hal_hip_nop_executable_cache_t* executable_cache = iree_hal_hip_nop_executable_cache_cast(base_executable_cache); return iree_hal_hip_native_executable_create( - executable_cache->symbols, executable_cache->device, executable_params, + executable_cache->symbols, executable_cache->device, + executable_cache->hip_context, executable_params, executable_cache->host_allocator, out_executable); } diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h index 8726f1ebfbe5b..795aa21b53c92 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h @@ -22,7 +22,7 @@ extern "C" { iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - iree_allocator_t host_allocator, + hipCtx_t hip_context, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache); #ifdef __cplusplus diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index 0c7eb86e8448f..6a201ca2976cb 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -7,6 +7,7 @@ #include "iree/hal/drivers/hip/stream_command_buffer.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/native_executable.h" #include "iree/hal/drivers/hip/rccl_channel.h" @@ -27,6 +28,7 @@ typedef struct iree_hal_hip_stream_command_buffer_t { iree_hal_stream_tracing_context_event_list_t tracing_event_list; hipStream_t hip_stream; + hipCtx_t hip_context; // A resource set to maintain references to all resources used within the // command buffer. Reset on each begin. @@ -54,7 +56,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - iree_hal_stream_tracing_context_t* tracing_context, + hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, hipStream_t stream, @@ -73,6 +75,8 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( } IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); iree_hal_hip_stream_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -94,6 +98,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( command_buffer->tracing_event_list.head = NULL; command_buffer->tracing_event_list.tail = NULL; command_buffer->hip_stream = stream; + command_buffer->hip_context = hip_context; iree_arena_initialize(block_pool, &command_buffer->arena); iree_status_t status = @@ -116,6 +121,8 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_stream_tracing_free(command_buffer->tracing_context, &command_buffer->tracing_event_list); @@ -172,7 +179,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -188,6 +196,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -219,7 +230,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( const iree_hal_label_location_t* location) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -235,7 +247,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -255,6 +268,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_execution_barrier( const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { @@ -319,6 +334,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -372,6 +390,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -415,6 +436,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -447,6 +471,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_buffer_binding_t send_binding = { .buffer = send_ref.buffer, @@ -474,6 +501,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); // If any of the workgroup counts are zero, we can skip execution // of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h index cc88c3a4b6b76..43820866d8cd0 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h @@ -33,7 +33,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - iree_hal_stream_tracing_context_t* tracing_context, + hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, hipStream_t stream,