From cbe4a08b6be892fe142407e188eed82c55952cd2 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Mon, 11 Nov 2024 15:08:57 -0500 Subject: [PATCH 1/5] [hip] Set the current device before calls into Hip. This is a bit of a brute-force way to solve our main hip multi-device problems temporarily until the more complete fix is in place. Signed-off-by: Andrew Woloszyn --- runtime/src/iree/hal/drivers/hip/event_pool.c | 26 ++++++-- runtime/src/iree/hal/drivers/hip/event_pool.h | 2 +- .../iree/hal/drivers/hip/event_semaphore.c | 10 ++- .../iree/hal/drivers/hip/event_semaphore.h | 2 +- .../src/iree/hal/drivers/hip/hip_allocator.c | 25 +++++++- .../src/iree/hal/drivers/hip/hip_allocator.h | 7 +- runtime/src/iree/hal/drivers/hip/hip_device.c | 64 +++++++++++++++---- .../src/iree/hal/drivers/hip/memory_pools.c | 16 +++++ .../src/iree/hal/drivers/hip/memory_pools.h | 2 + .../iree/hal/drivers/hip/native_executable.c | 4 +- .../iree/hal/drivers/hip/native_executable.h | 2 +- .../hal/drivers/hip/nop_executable_cache.c | 7 +- .../hal/drivers/hip/nop_executable_cache.h | 2 +- .../src/iree/hal/drivers/hip/status_util.h | 3 + .../hal/drivers/hip/stream_command_buffer.c | 36 +++++++++-- .../hal/drivers/hip/stream_command_buffer.h | 2 +- 16 files changed, 171 insertions(+), 39 deletions(-) diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index 38c93113b839..b71a10de241d 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -36,6 +36,10 @@ struct iree_hal_hip_event_t { // The event pool that owns this event. This cannot be NULL. We retain it to // make sure the event outlive the pool. iree_hal_hip_event_pool_t* pool; + + // The context to use to free this event, it must be the same + // context as was used when allocating the event. + hipCtx_t hip_context; // The underlying hipEvent_t object. hipEvent_t hip_event; }; @@ -45,6 +49,8 @@ hipEvent_t iree_hal_hip_event_handle(const iree_hal_hip_event_t* event) { } static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context)); + iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); @@ -58,8 +64,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { static inline iree_status_t iree_hal_hip_event_create( const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator, - iree_hal_hip_event_t** out_event) { + iree_hal_hip_event_pool_t* pool, hipCtx_t context, + iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(pool); IREE_ASSERT_ARGUMENT(out_event); @@ -75,6 +81,7 @@ static inline iree_status_t iree_hal_hip_event_create( event->symbols = symbols; event->pool = pool; event->hip_event = NULL; + event->hip_context = context; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( symbols, @@ -122,6 +129,10 @@ struct iree_hal_hip_event_pool_t { // The symbols used to create and destroy hipEvent_t objects. const iree_hal_hip_dynamic_symbols_t* symbols; + // The context for this event pool to use to allocate + // events. + hipCtx_t hip_context; + // Guards event related fields in the pool. We don't expect a performant // program to frequently allocate events for synchronization purposes; the // traffic to this pool should be low. So it should be fine to use mutex to @@ -142,7 +153,7 @@ struct iree_hal_hip_event_pool_t { static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool); iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, + const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, iree_host_size_t available_capacity, iree_allocator_t host_allocator, iree_hal_hip_event_pool_t** out_event_pool) { IREE_ASSERT_ARGUMENT(symbols); @@ -163,11 +174,12 @@ iree_status_t iree_hal_hip_event_pool_allocate( iree_slim_mutex_initialize(&event_pool->event_mutex); event_pool->available_capacity = available_capacity; event_pool->available_count = 0; + event_pool->hip_context = hip_context; iree_status_t status = iree_ok_status(); for (iree_host_size_t i = 0; i < available_capacity; ++i) { status = iree_hal_hip_event_create( - symbols, event_pool, host_allocator, + symbols, event_pool, hip_context, host_allocator, &event_pool->available_list[event_pool->available_count++]); if (!iree_status_is_ok(status)) break; } @@ -240,9 +252,9 @@ iree_status_t iree_hal_hip_event_pool_acquire( IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire"); iree_status_t status = iree_ok_status(); for (iree_host_size_t i = 0; i < remaining_count; ++i) { - status = iree_hal_hip_event_create(event_pool->symbols, event_pool, - event_pool->host_allocator, - &out_events[from_pool_count + i]); + status = iree_hal_hip_event_create( + event_pool->symbols, event_pool, event_pool->hip_context, + event_pool->host_allocator, &out_events[from_pool_count + i]); if (!iree_status_is_ok(status)) { // Must release all events we've acquired so far. iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.h b/runtime/src/iree/hal/drivers/hip/event_pool.h index ea09b90cc7cf..0683714d9724 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.h +++ b/runtime/src/iree/hal/drivers/hip/event_pool.h @@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t; // Extra events requested beyond the capability are directly created and // destroyed without pooling. iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, + const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, iree_host_size_t available_capacity, iree_allocator_t host_allocator, iree_hal_hip_event_pool_t** out_event_pool); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index de10b09125ec..71765927d324 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -10,6 +10,7 @@ #include "iree/base/internal/wait_handle.h" #include "iree/base/status.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/drivers/hip/timepoint_pool.h" #include "iree/hal/utils/semaphore_base.h" @@ -30,6 +31,8 @@ typedef struct iree_hal_hip_semaphore_t { // new signaled values. iree_hal_deferred_work_queue_t* work_queue; + hipCtx_t hip_context; + // Guards value and status. We expect low contention on semaphores and since // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler // than trying to make the entire structure lock-free. @@ -56,9 +59,11 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast( iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_timepoint_pool_t* timepoint_pool, + hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, hip_context)); + IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(timepoint_pool); IREE_ASSERT_ARGUMENT(work_queue); @@ -79,6 +84,7 @@ iree_status_t iree_hal_hip_event_semaphore_create( iree_slim_mutex_initialize(&semaphore->mutex); semaphore->current_value = initial_value; semaphore->failure_status = iree_ok_status(); + semaphore->hip_context = hip_context; *out_semaphore = &semaphore->base; @@ -90,6 +96,8 @@ static void iree_hal_hip_semaphore_destroy( iree_hal_semaphore_t* base_semaphore) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); + IREE_IGNORE_ERROR( + HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context)); iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.h b/runtime/src/iree/hal/drivers/hip/event_semaphore.h index 88a75e01c436..a5d8ff95b369 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.h @@ -31,7 +31,7 @@ extern "C" { // Thread-safe; multiple threads may signal/wait values on the same semaphore. iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_timepoint_pool_t* timepoint_pool, + hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index 041e011536f8..add9ca080146 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -29,6 +29,8 @@ typedef struct iree_hal_hip_allocator_t { // The HIP stream that allocations should be used in. hipStream_t stream; + hipCtx_t hip_context; + // NOTE: optional depending on device support. iree_hal_hip_memory_pools_t* pools; @@ -54,8 +56,11 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast( iree_status_t iree_hal_hip_allocator_create( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipStream_t stream, iree_hal_hip_memory_pools_t* pools, - iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { + hipCtx_t hip_context, hipStream_t stream, + iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); @@ -94,6 +99,7 @@ iree_status_t iree_hal_hip_allocator_create( allocator->host_allocator = host_allocator; allocator->supports_concurrent_managed_access = supports_concurrent_managed_access != 0; + allocator->hip_context = hip_context; *out_allocator = (iree_hal_allocator_t*)allocator; IREE_TRACE_ZONE_END(z0); @@ -319,6 +325,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; iree_hal_buffer_compatibility_t compatibility = @@ -431,6 +440,9 @@ static void iree_hal_hip_allocator_deallocate_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_IGNORE_ERROR( + HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + const iree_hal_hip_buffer_type_t buffer_type = iree_hal_hip_buffer_type(base_buffer); @@ -466,6 +478,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; iree_device_size_t allocation_size = external_buffer->size; @@ -600,6 +615,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + hipDeviceptr_t ptr = NULL; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( allocator->symbols, @@ -625,6 +643,9 @@ iree_status_t iree_hal_hip_allocator_free_async( iree_hal_buffer_t* buffer) { iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (!device_ptr) { return iree_ok_status(); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.h b/runtime/src/iree/hal/drivers/hip/hip_allocator.h index bb83ea18e5ad..89ed0936be0c 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.h +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.h @@ -17,13 +17,14 @@ extern "C" { #endif // __cplusplus // Creates a HIP memory allocator. -// |device| and |stream| will be used for management operations. +// |device| |hip_context| and |stream| will be used for management operations. // |pools| provides memory pools that may be shared across multiple allocators // and the pointer must remain valid for the lifetime of the allocator. iree_status_t iree_hal_hip_allocator_create( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipStream_t stream, iree_hal_hip_memory_pools_t* pools, - iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); + hipCtx_t hip_context, hipStream_t stream, + iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator); bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value); diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index e48b5241a0b6..9ee9ab5c2940 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -538,13 +538,13 @@ static iree_status_t iree_hal_hip_device_create_internal( // Create memory pools first so that we can share them with the allocator. if (iree_status_is_ok(status) && device->supports_memory_pools) { status = iree_hal_hip_memory_pools_initialize( - symbols, hip_device, ¶ms->memory_pools, host_allocator, + symbols, hip_device, context, ¶ms->memory_pools, host_allocator, &device->memory_pools); } if (iree_status_is_ok(status)) { status = iree_hal_hip_allocator_create( - symbols, hip_device, dispatch_stream, + symbols, hip_device, context, dispatch_stream, device->supports_memory_pools ? &device->memory_pools : NULL, host_allocator, &device->device_allocator); } @@ -608,9 +608,9 @@ iree_status_t iree_hal_hip_device_create( iree_hal_hip_event_pool_t* device_event_pool = NULL; if (iree_status_is_ok(status)) { - status = - iree_hal_hip_event_pool_allocate(symbols, params->event_pool_capacity, - host_allocator, &device_event_pool); + status = iree_hal_hip_event_pool_allocate( + symbols, context, params->event_pool_capacity, host_allocator, + &device_event_pool); } iree_hal_hip_timepoint_pool_t* timepoint_pool = NULL; @@ -731,6 +731,8 @@ static void iree_hal_hip_replace_channel_provider( static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); iree_arena_block_pool_trim(&device->block_pool); IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); if (device->supports_memory_pools) { @@ -743,6 +745,8 @@ static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); int value = 0; IREE_HIP_RETURN_IF_ERROR( device->hip_symbols, @@ -756,6 +760,8 @@ static iree_status_t iree_hal_hip_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); *out_value = 0; if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { @@ -779,6 +785,9 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + if (!device->nccl_symbols || !device->nccl_symbols->dylib) { return iree_make_status( IREE_STATUS_UNAVAILABLE, @@ -865,11 +874,14 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->tracing_context, mode, command_categories, - binding_capacity, device->hip_dispatch_stream, &device->block_pool, - device->host_allocator, out_command_buffer); + device->nccl_symbols, device->hip_context, device->tracing_context, mode, + command_categories, binding_capacity, device->hip_dispatch_stream, + &device->block_pool, device->host_allocator, out_command_buffer); } static iree_status_t iree_hal_hip_device_create_command_buffer( @@ -878,6 +890,9 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + if (device->params.allow_inline_execution && iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION)) { @@ -887,9 +902,9 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( // directly route commands to a HIP stream and let it eagerly flush. return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->tracing_context, mode, command_categories, - binding_capacity, device->hip_dispatch_stream, &device->block_pool, - device->host_allocator, out_command_buffer); + device->nccl_symbols, device->hip_context, device->tracing_context, + mode, command_categories, binding_capacity, device->hip_dispatch_stream, + &device->block_pool, device->host_allocator, out_command_buffer); } switch (device->params.command_buffer_mode) { case IREE_HAL_HIP_COMMAND_BUFFER_MODE_GRAPH: @@ -930,6 +945,10 @@ static iree_status_t iree_hal_hip_device_import_file( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_memory_access_t access, iree_io_file_handle_t* handle, iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + if (iree_io_file_handle_type(handle) != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { return iree_make_status( @@ -945,8 +964,10 @@ static iree_status_t iree_hal_hip_device_create_executable_cache( iree_hal_device_t* base_device, iree_string_view_t identifier, iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); return iree_hal_hip_nop_executable_cache_create( - identifier, device->hip_symbols, device->hip_device, + identifier, device->hip_symbols, device->hip_device, device->hip_context, device->host_allocator, out_executable_cache); } @@ -954,9 +975,13 @@ static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_device_t* base_device, uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + return iree_hal_hip_event_semaphore_create( - initial_value, device->hip_symbols, device->timepoint_pool, - device->work_queue, device->host_allocator, out_semaphore); + initial_value, device->hip_symbols, device->hip_context, + device->timepoint_pool, device->work_queue, device->host_allocator, + out_semaphore); } static iree_hal_semaphore_compatibility_t @@ -1006,6 +1031,8 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); if (device->supports_memory_pools && !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { @@ -1083,6 +1110,9 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + if (iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device))) { iree_status_t status = iree_hal_deferred_work_queue_enqueue_dealloc( device->work_queue, wait_semaphore_list, signal_semaphore_list, buffer); @@ -1169,6 +1199,9 @@ static iree_status_t iree_hal_hip_device_queue_execute( iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_binding_table_t binding_table) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); iree_status_t status = iree_hal_deferred_work_queue_enqueue( @@ -1199,6 +1232,9 @@ static iree_status_t iree_hal_hip_device_wait_semaphores( iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + IREE_RETURN_IF_ERROR( + HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + return iree_hal_hip_semaphore_multi_wait(semaphore_list, wait_mode, timeout, &device->block_pool); } diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index 0258fa03198f..ff1a36d1458b 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -59,9 +59,12 @@ static iree_status_t iree_hal_hip_create_memory_pool( iree_status_t iree_hal_hip_memory_pools_initialize( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, + hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); @@ -70,6 +73,7 @@ iree_status_t iree_hal_hip_memory_pools_initialize( memset(out_pools, 0, sizeof(*out_pools)); out_pools->hip_symbols = hip_symbols; out_pools->host_allocator = host_allocator; + out_pools->hip_context = hip_context; iree_status_t status = iree_ok_status(); @@ -90,6 +94,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); if (pools->device_local) { @@ -149,6 +155,8 @@ static void iree_hal_hip_memory_pool_track_free( void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_STATISTICS({ statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); @@ -181,6 +189,8 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_status_t iree_hal_hip_memory_pools_trim( iree_hal_hip_memory_pools_t* pools, const iree_hal_hip_memory_pooling_params_t* pooling_params) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_HIP_RETURN_IF_ERROR( pools->hip_symbols, hipMemPoolTrimTo(pools->device_local, @@ -199,6 +209,8 @@ iree_status_t iree_hal_hip_memory_pools_trim( static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); @@ -213,6 +225,8 @@ static void iree_hal_hip_async_buffer_release_callback( iree_status_t iree_hal_hip_memory_pools_allocate_pointer( iree_hal_hip_memory_pools_t* pools, iree_hal_buffer_t* buffer, hipStream_t stream, iree_device_size_t allocation_size) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + // TODO: more pools and better selection; this is coarsely deciding between // only device local (variables, constants, transients) and other (staging, // external) but could use more buffer properties (including usage/export @@ -276,6 +290,8 @@ iree_status_t iree_hal_hip_memory_pools_prepare_buffer( iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_hip_memory_pools_t* pools, hipStream_t stream, iree_hal_buffer_t* buffer) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.h b/runtime/src/iree/hal/drivers/hip/memory_pools.h index c505bbe34e98..f95b76d9816b 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.h +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.h @@ -33,6 +33,7 @@ typedef struct iree_hal_hip_memory_pools_t { hipMemPool_t other; const iree_hal_hip_dynamic_symbols_t* hip_symbols; + hipCtx_t hip_context; iree_allocator_t host_allocator; IREE_STATISTICS(struct { @@ -46,6 +47,7 @@ typedef struct iree_hal_hip_memory_pools_t { // Initializes |out_pools| by configuring new HIP memory pools. iree_status_t iree_hal_hip_memory_pools_initialize( const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, + hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index 18b3d378ed9a..f2af0e580f44 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -207,8 +207,10 @@ static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify( iree_status_t iree_hal_hip_native_executable_create( const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - const iree_hal_executable_params_t* executable_params, + hipCtx_t context, const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, context)); + IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.h b/runtime/src/iree/hal/drivers/hip/native_executable.h index beb1e7cf92f5..b67f5e73e599 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.h +++ b/runtime/src/iree/hal/drivers/hip/native_executable.h @@ -49,7 +49,7 @@ typedef struct iree_hal_hip_kernel_params_t { // several kernels that can be extracted along with the associated block size. iree_status_t iree_hal_hip_native_executable_create( const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - const iree_hal_executable_params_t* executable_params, + hipCtx_t context, const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable); // Returns the kernel launch parameters for the given |entry_point| in the diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c index c85d06a4a3d7..9680e3bf9f90 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c @@ -23,6 +23,7 @@ typedef struct iree_hal_hip_nop_executable_cache_t { const iree_hal_hip_dynamic_symbols_t* symbols; hipDevice_t device; + hipCtx_t hip_context; } iree_hal_hip_nop_executable_cache_t; static const iree_hal_executable_cache_vtable_t @@ -38,7 +39,7 @@ iree_hal_hip_nop_executable_cache_cast( iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - iree_allocator_t host_allocator, + hipCtx_t hip_context, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(out_executable_cache); @@ -55,6 +56,7 @@ iree_status_t iree_hal_hip_nop_executable_cache_create( executable_cache->host_allocator = host_allocator; executable_cache->symbols = symbols; executable_cache->device = device; + executable_cache->hip_context = hip_context; *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; @@ -89,7 +91,8 @@ static iree_status_t iree_hal_hip_nop_executable_cache_prepare_executable( iree_hal_hip_nop_executable_cache_t* executable_cache = iree_hal_hip_nop_executable_cache_cast(base_executable_cache); return iree_hal_hip_native_executable_create( - executable_cache->symbols, executable_cache->device, executable_params, + executable_cache->symbols, executable_cache->device, + executable_cache->hip_context, executable_params, executable_cache->host_allocator, out_executable); } diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h index 8726f1ebfbe5..795aa21b53c9 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h @@ -22,7 +22,7 @@ extern "C" { iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - iree_allocator_t host_allocator, + hipCtx_t hip_context, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache); #ifdef __cplusplus diff --git a/runtime/src/iree/hal/drivers/hip/status_util.h b/runtime/src/iree/hal/drivers/hip/status_util.h index 221f55fe0214..eaac3067c293 100644 --- a/runtime/src/iree/hal/drivers/hip/status_util.h +++ b/runtime/src/iree/hal/drivers/hip/status_util.h @@ -56,6 +56,9 @@ extern "C" { IREE_IGNORE_ERROR(iree_hal_hip_result_to_status((syms), ((syms)->expr), \ __FILE__, __LINE__)) +#define HIP_SET_CONTEXT(syms, ctx) \ + IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(ctx)) + // Converts a hipError_t to an iree_status_t object. iree_status_t iree_hal_hip_result_to_status( const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index 0c7eb86e8448..711fcad8959c 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -27,6 +27,7 @@ typedef struct iree_hal_hip_stream_command_buffer_t { iree_hal_stream_tracing_context_event_list_t tracing_event_list; hipStream_t hip_stream; + hipCtx_t hip_context; // A resource set to maintain references to all resources used within the // command buffer. Reset on each begin. @@ -54,12 +55,14 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - iree_hal_stream_tracing_context_t* tracing_context, + hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, hipStream_t stream, iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_ASSERT_ARGUMENT(device_allocator); IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(nccl_symbols); @@ -94,6 +97,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( command_buffer->tracing_event_list.head = NULL; command_buffer->tracing_event_list.tail = NULL; command_buffer->hip_stream = stream; + command_buffer->hip_context = hip_context; iree_arena_initialize(block_pool, &command_buffer->arena); iree_status_t status = @@ -114,6 +118,8 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); @@ -172,7 +178,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -187,6 +194,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -219,7 +228,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( const iree_hal_label_location_t* location) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -235,7 +245,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - (void)command_buffer; + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -255,6 +266,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_execution_barrier( const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { @@ -318,6 +331,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -371,6 +387,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_update_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -414,6 +433,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_copy_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -446,6 +468,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); iree_hal_buffer_binding_t send_binding = { @@ -473,6 +498,9 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); + IREE_TRACE_ZONE_BEGIN(z0); // If any of the workgroup counts are zero, we can skip execution diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h index cc88c3a4b6b7..43820866d8cd 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h @@ -33,7 +33,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - iree_hal_stream_tracing_context_t* tracing_context, + hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, iree_host_size_t binding_capacity, hipStream_t stream, From 0065254975ef4880792e375cb09eba2679250989 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Mon, 11 Nov 2024 15:26:26 -0500 Subject: [PATCH 2/5] Move our context switches inside any zones that need them. Signed-off-by: Andrew Woloszyn --- runtime/src/iree/hal/drivers/hip/event_pool.c | 3 +- .../iree/hal/drivers/hip/event_semaphore.c | 7 ++- .../src/iree/hal/drivers/hip/hip_allocator.c | 10 ++--- runtime/src/iree/hal/drivers/hip/hip_device.c | 4 +- .../src/iree/hal/drivers/hip/memory_pools.c | 14 +++--- .../iree/hal/drivers/hip/native_executable.c | 3 +- .../hal/drivers/hip/stream_command_buffer.c | 43 ++++++++++--------- 7 files changed, 40 insertions(+), 44 deletions(-) diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index b71a10de241d..bfecc9945c9d 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -49,11 +49,10 @@ hipEvent_t iree_hal_hip_event_handle(const iree_hal_hip_event_t* event) { } static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context)); - iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context)); IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event)); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 71765927d324..7e7434729099 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -62,14 +62,13 @@ iree_status_t iree_hal_hip_event_semaphore_create( hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, hip_context)); - IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(timepoint_pool); IREE_ASSERT_ARGUMENT(work_queue); IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, hip_context)); iree_hal_hip_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), @@ -96,10 +95,10 @@ static void iree_hal_hip_semaphore_destroy( iree_hal_semaphore_t* base_semaphore) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_IGNORE_ERROR( - HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context)); iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR( + HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context)); iree_status_ignore(semaphore->failure_status); iree_slim_mutex_deinitialize(&semaphore->mutex); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index add9ca080146..3bd0389ac409 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -59,11 +59,11 @@ iree_status_t iree_hal_hip_allocator_create( hipCtx_t hip_context, hipStream_t stream, iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); - IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + HIP_SET_CONTEXT(hip_symbols, hip_context)); // To support device-local + host-visible memory we need concurrent managed // access indicating that the host and devices can concurrently access the @@ -325,9 +325,6 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); - IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); - // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; iree_hal_buffer_compatibility_t compatibility = @@ -361,6 +358,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( void* host_ptr = NULL; hipDeviceptr_t device_ptr = NULL; IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); if (iree_all_bits_set(compat_params.type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index 9ee9ab5c2940..a1ebf0c7478a 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -1199,10 +1199,10 @@ static iree_status_t iree_hal_hip_device_queue_execute( iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_binding_table_t binding_table) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); iree_status_t status = iree_hal_deferred_work_queue_enqueue( device->work_queue, iree_hal_hip_device_collect_tracing_context, diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index ff1a36d1458b..0a912c2daddd 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -63,12 +63,12 @@ iree_status_t iree_hal_hip_memory_pools_initialize( const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); - IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + HIP_SET_CONTEXT(hip_symbols, hip_context)); memset(out_pools, 0, sizeof(*out_pools)); out_pools->hip_symbols = hip_symbols; @@ -94,9 +94,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); if (pools->device_local) { IREE_HIP_IGNORE_ERROR(pools->hip_symbols, @@ -209,9 +208,8 @@ iree_status_t iree_hal_hip_memory_pools_trim( static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { @@ -290,9 +288,9 @@ iree_status_t iree_hal_hip_memory_pools_prepare_buffer( iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_hip_memory_pools_t* pools, hipStream_t stream, iree_hal_buffer_t* buffer) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index f2af0e580f44..0a11a0162553 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -209,12 +209,11 @@ iree_status_t iree_hal_hip_native_executable_create( const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, hipCtx_t context, const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, context)); - IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, context)); *out_executable = NULL; diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index 711fcad8959c..c151a7e9d101 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -61,8 +61,6 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_host_size_t binding_capacity, hipStream_t stream, iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context)); - IREE_ASSERT_ARGUMENT(device_allocator); IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(nccl_symbols); @@ -76,6 +74,8 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( } IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + HIP_SET_CONTEXT(hip_symbols, hip_context)); iree_hal_hip_stream_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -118,10 +118,10 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); + IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_stream_tracing_free(command_buffer->tracing_context, &command_buffer->tracing_event_list); @@ -194,9 +194,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -331,10 +332,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -387,10 +388,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_update_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -433,10 +434,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_copy_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -468,10 +469,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_buffer_binding_t send_binding = { .buffer = send_ref.buffer, @@ -498,10 +499,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); - IREE_TRACE_ZONE_BEGIN(z0); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, + command_buffer->hip_context)); // If any of the workgroup counts are zero, we can skip execution // of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when From b82bc2e2e4e7efa4dcf9aa7fbd8312999d11b3ba Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Mon, 11 Nov 2024 15:44:18 -0500 Subject: [PATCH 3/5] [hip] Log a message if we are going to do a context switch. Signed-off-by: Andrew Woloszyn --- .../src/iree/hal/drivers/hip/CMakeLists.txt | 1 + .../src/iree/hal/drivers/hip/context_util.h | 27 ++++++++++ .../hal/drivers/hip/dynamic_symbol_tables.h | 1 + runtime/src/iree/hal/drivers/hip/event_pool.c | 4 +- .../iree/hal/drivers/hip/event_semaphore.c | 6 ++- .../src/iree/hal/drivers/hip/hip_allocator.c | 15 +++--- runtime/src/iree/hal/drivers/hip/hip_device.c | 27 +++++----- .../src/iree/hal/drivers/hip/memory_pools.c | 22 ++++++--- .../iree/hal/drivers/hip/native_executable.c | 4 +- .../src/iree/hal/drivers/hip/status_util.h | 3 -- .../hal/drivers/hip/stream_command_buffer.c | 49 ++++++++++--------- 11 files changed, 100 insertions(+), 59 deletions(-) create mode 100644 runtime/src/iree/hal/drivers/hip/context_util.h diff --git a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt index ef9576d8d66c..7fd5abf07294 100644 --- a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( "api.h" SRCS "api.h" + "context_util.h" "event_pool.c" "event_pool.h" "event_semaphore.c" diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h new file mode 100644 index 000000000000..9b184fd39a5c --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -0,0 +1,27 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ +#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ + +#include "iree/base/api.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/status_util.h" + +static inline iree_status_t iree_hal_hip_set_context( + const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) { + IREE_TRACE({ + hipCtx_t current_context = NULL; + IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(¤t_context), + "hipCtxGetCurrent"); + if (current_context != hip_context) { + IREE_TRACE_MESSAGE(INFO, "Hip Context Switch"); + } + }); + return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); +} + +#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h index 963784d81dcc..85d33740c83d 100644 --- a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h @@ -8,6 +8,7 @@ // HIP symbols //===----------------------------------------------------------------------===// +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index bfecc9945c9d..010cfd67fd2e 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -14,6 +14,7 @@ #include "iree/base/internal/atomics.h" #include "iree/base/internal/synchronization.h" #include "iree/hal/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" @@ -52,7 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(event->symbols, event->hip_context)); IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event)); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 7e7434729099..686e23015b8f 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -9,6 +9,7 @@ #include "iree/base/internal/synchronization.h" #include "iree/base/internal/wait_handle.h" #include "iree/base/status.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/drivers/hip/timepoint_pool.h" @@ -68,7 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create( IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(symbols, hip_context)); iree_hal_hip_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), @@ -98,7 +100,7 @@ static void iree_hal_hip_semaphore_destroy( iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); IREE_IGNORE_ERROR( - HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context)); + iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context)); iree_status_ignore(semaphore->failure_status); iree_slim_mutex_deinitialize(&semaphore->mutex); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index 3bd0389ac409..95a04cd1812b 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -10,6 +10,7 @@ #include "iree/base/api.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -62,8 +63,8 @@ iree_status_t iree_hal_hip_allocator_create( IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); // To support device-local + host-visible memory we need concurrent managed // access indicating that the host and devices can concurrently access the @@ -359,7 +360,7 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( hipDeviceptr_t device_ptr = NULL; IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate"); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); if (iree_all_bits_set(compat_params.type, @@ -441,7 +442,7 @@ static void iree_hal_hip_allocator_deallocate_buffer( iree_hal_hip_allocator_cast(base_allocator); IREE_IGNORE_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); const iree_hal_hip_buffer_type_t buffer_type = iree_hal_hip_buffer_type(base_buffer); @@ -479,7 +480,7 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; @@ -616,7 +617,7 @@ iree_status_t iree_hal_hip_allocator_alloc_async( iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); hipDeviceptr_t ptr = NULL; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( @@ -644,7 +645,7 @@ iree_status_t iree_hal_hip_allocator_free_async( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (!device_ptr) { diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index a1ebf0c7478a..bc001d645c3a 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -14,6 +14,7 @@ #include "iree/base/internal/event_pool.h" #include "iree/base/internal/math.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/event_pool.h" #include "iree/hal/drivers/hip/event_semaphore.h" @@ -732,7 +733,7 @@ static void iree_hal_hip_replace_channel_provider( static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_arena_block_pool_trim(&device->block_pool); IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); if (device->supports_memory_pools) { @@ -746,7 +747,7 @@ static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); int value = 0; IREE_HIP_RETURN_IF_ERROR( device->hip_symbols, @@ -761,7 +762,7 @@ static iree_status_t iree_hal_hip_device_query_i64( iree_string_view_t key, int64_t* out_value) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); *out_value = 0; if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { @@ -786,7 +787,7 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (!device->nccl_symbols || !device->nccl_symbols->dylib) { return iree_make_status( @@ -875,7 +876,7 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, @@ -891,7 +892,7 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (device->params.allow_inline_execution && iree_all_bits_set(mode, @@ -947,7 +948,7 @@ static iree_status_t iree_hal_hip_device_import_file( iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (iree_io_file_handle_type(handle) != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { @@ -965,7 +966,7 @@ static iree_status_t iree_hal_hip_device_create_executable_cache( iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_nop_executable_cache_create( identifier, device->hip_symbols, device->hip_device, device->hip_context, device->host_allocator, out_executable_cache); @@ -976,7 +977,7 @@ static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_event_semaphore_create( initial_value, device->hip_symbols, device->hip_context, @@ -1032,7 +1033,7 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_buffer_t** IREE_RESTRICT out_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (device->supports_memory_pools && !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { @@ -1111,7 +1112,7 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( iree_hal_buffer_t* buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device))) { iree_status_t status = iree_hal_deferred_work_queue_enqueue_dealloc( @@ -1202,7 +1203,7 @@ static iree_status_t iree_hal_hip_device_queue_execute( IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + z0, iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_status_t status = iree_hal_deferred_work_queue_enqueue( device->work_queue, iree_hal_hip_device_collect_tracing_context, @@ -1233,7 +1234,7 @@ static iree_status_t iree_hal_hip_device_wait_semaphores( const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_semaphore_multi_wait(semaphore_list, wait_mode, timeout, &device->block_pool); diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index 0a912c2daddd..93a046b1f59e 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -6,6 +6,7 @@ #include "iree/hal/drivers/hip/memory_pools.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -67,8 +68,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); memset(out_pools, 0, sizeof(*out_pools)); out_pools->hip_symbols = hip_symbols; @@ -95,7 +96,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); if (pools->device_local) { IREE_HIP_IGNORE_ERROR(pools->hip_symbols, @@ -154,7 +156,8 @@ static void iree_hal_hip_memory_pool_track_free( void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_STATISTICS({ statistics->device_bytes_allocated = iree_atomic_load( @@ -188,7 +191,8 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_status_t iree_hal_hip_memory_pools_trim( iree_hal_hip_memory_pools_t* pools, const iree_hal_hip_memory_pooling_params_t* pooling_params) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_HIP_RETURN_IF_ERROR( pools->hip_symbols, @@ -209,7 +213,8 @@ static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { @@ -223,7 +228,8 @@ static void iree_hal_hip_async_buffer_release_callback( iree_status_t iree_hal_hip_memory_pools_allocate_pointer( iree_hal_hip_memory_pools_t* pools, iree_hal_buffer_t* buffer, hipStream_t stream, iree_device_size_t allocation_size) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); // TODO: more pools and better selection; this is coarsely deciding between // only device local (variables, constants, transients) and other (staging, @@ -290,7 +296,7 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_buffer_t* buffer) { IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + z0, iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index 0a11a0162553..e37aba578f02 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -9,6 +9,7 @@ #include #include "iree/base/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/utils/executable_debug_info.h" @@ -213,7 +214,8 @@ iree_status_t iree_hal_hip_native_executable_create( IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + iree_hal_hip_set_context(symbols, context)); *out_executable = NULL; diff --git a/runtime/src/iree/hal/drivers/hip/status_util.h b/runtime/src/iree/hal/drivers/hip/status_util.h index eaac3067c293..221f55fe0214 100644 --- a/runtime/src/iree/hal/drivers/hip/status_util.h +++ b/runtime/src/iree/hal/drivers/hip/status_util.h @@ -56,9 +56,6 @@ extern "C" { IREE_IGNORE_ERROR(iree_hal_hip_result_to_status((syms), ((syms)->expr), \ __FILE__, __LINE__)) -#define HIP_SET_CONTEXT(syms, ctx) \ - IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(ctx)) - // Converts a hipError_t to an iree_status_t object. iree_status_t iree_hal_hip_result_to_status( const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index c151a7e9d101..6a201ca2976c 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -7,6 +7,7 @@ #include "iree/hal/drivers/hip/stream_command_buffer.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/native_executable.h" #include "iree/hal/drivers/hip/rccl_channel.h" @@ -74,8 +75,8 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( } IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); iree_hal_hip_stream_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -120,8 +121,8 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_IGNORE_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_stream_tracing_free(command_buffer->tracing_context, &command_buffer->tracing_event_list); @@ -178,8 +179,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -196,8 +197,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -229,8 +230,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( const iree_hal_label_location_t* location) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -246,8 +247,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -267,8 +268,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_execution_barrier( const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { @@ -334,8 +335,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -390,8 +391,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -436,8 +437,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -471,8 +472,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_buffer_binding_t send_binding = { .buffer = send_ref.buffer, @@ -501,8 +502,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); // If any of the workgroup counts are zero, we can skip execution // of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when From f563521e8c3a72ddf1904744108bf00b697fd036 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Mon, 11 Nov 2024 15:48:50 -0500 Subject: [PATCH 4/5] Use a zone instead of an message so that we can count in the trace. Signed-off-by: Andrew Woloszyn --- runtime/src/iree/hal/drivers/hip/context_util.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h index 9b184fd39a5c..099da8fec5a9 100644 --- a/runtime/src/iree/hal/drivers/hip/context_util.h +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -18,7 +18,11 @@ static inline iree_status_t iree_hal_hip_set_context( IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(¤t_context), "hipCtxGetCurrent"); if (current_context != hip_context) { - IREE_TRACE_MESSAGE(INFO, "Hip Context Switch"); + IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch"); + iree_status_t status = + IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); + IREE_TRACE_ZONE_END(z0); + return status; } }); return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); From 8a8ad2ebd23baa3299d5ff18053bf3263c2ffaeb Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Tue, 12 Nov 2024 20:52:27 -0500 Subject: [PATCH 5/5] Only set the hip context if it is valid. Signed-off-by: Andrew Woloszyn --- runtime/src/iree/hal/drivers/hip/context_util.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h index 099da8fec5a9..1aa1d79b4c28 100644 --- a/runtime/src/iree/hal/drivers/hip/context_util.h +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -13,6 +13,9 @@ static inline iree_status_t iree_hal_hip_set_context( const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) { + if (!hip_context) { + return iree_ok_status(); + } IREE_TRACE({ hipCtx_t current_context = NULL; IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(¤t_context),