From d3d656de3938afc0902c7b1c49ba4c99357362e1 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Mon, 11 Nov 2024 15:44:18 -0500 Subject: [PATCH] [hip] Log a message if we are going to do a context switch. --- .../src/iree/hal/drivers/hip/CMakeLists.txt | 1 + .../src/iree/hal/drivers/hip/context_util.h | 27 ++++++++++ .../hal/drivers/hip/dynamic_symbol_tables.h | 1 + runtime/src/iree/hal/drivers/hip/event_pool.c | 4 +- .../iree/hal/drivers/hip/event_semaphore.c | 6 ++- .../src/iree/hal/drivers/hip/hip_allocator.c | 15 +++--- runtime/src/iree/hal/drivers/hip/hip_device.c | 27 +++++----- .../src/iree/hal/drivers/hip/memory_pools.c | 22 ++++++--- .../iree/hal/drivers/hip/native_executable.c | 4 +- .../src/iree/hal/drivers/hip/status_util.h | 3 -- .../hal/drivers/hip/stream_command_buffer.c | 49 ++++++++++--------- 11 files changed, 100 insertions(+), 59 deletions(-) create mode 100644 runtime/src/iree/hal/drivers/hip/context_util.h diff --git a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt index ef9576d8d66c5..7fd5abf072948 100644 --- a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( "api.h" SRCS "api.h" + "context_util.h" "event_pool.c" "event_pool.h" "event_semaphore.c" diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h new file mode 100644 index 0000000000000..9b184fd39a5c3 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -0,0 +1,27 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ +#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ + +#include "iree/base/api.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/status_util.h" + +static inline iree_status_t iree_hal_hip_set_context( + const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) { + IREE_TRACE({ + hipCtx_t current_context = NULL; + IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(¤t_context), + "hipCtxGetCurrent"); + if (current_context != hip_context) { + IREE_TRACE_MESSAGE(INFO, "Hip Context Switch"); + } + }); + return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); +} + +#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h index 963784d81dcc1..85d33740c83da 100644 --- a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h @@ -8,6 +8,7 @@ // HIP symbols //===----------------------------------------------------------------------===// +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index bfecc9945c9d0..010cfd67fd2e8 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -14,6 +14,7 @@ #include "iree/base/internal/atomics.h" #include "iree/base/internal/synchronization.h" #include "iree/hal/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" @@ -52,7 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(event->symbols, event->hip_context)); IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event)); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 7e7434729099e..686e23015b8f7 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -9,6 +9,7 @@ #include "iree/base/internal/synchronization.h" #include "iree/base/internal/wait_handle.h" #include "iree/base/status.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/drivers/hip/timepoint_pool.h" @@ -68,7 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create( IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(symbols, hip_context)); iree_hal_hip_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), @@ -98,7 +100,7 @@ static void iree_hal_hip_semaphore_destroy( iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); IREE_IGNORE_ERROR( - HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context)); + iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context)); iree_status_ignore(semaphore->failure_status); iree_slim_mutex_deinitialize(&semaphore->mutex); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index 3bd0389ac4092..95a04cd1812b2 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -10,6 +10,7 @@ #include "iree/base/api.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -62,8 +63,8 @@ iree_status_t iree_hal_hip_allocator_create( IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); // To support device-local + host-visible memory we need concurrent managed // access indicating that the host and devices can concurrently access the @@ -359,7 +360,7 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( hipDeviceptr_t device_ptr = NULL; IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate"); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); if (iree_all_bits_set(compat_params.type, @@ -441,7 +442,7 @@ static void iree_hal_hip_allocator_deallocate_buffer( iree_hal_hip_allocator_cast(base_allocator); IREE_IGNORE_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); const iree_hal_hip_buffer_type_t buffer_type = iree_hal_hip_buffer_type(base_buffer); @@ -479,7 +480,7 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; @@ -616,7 +617,7 @@ iree_status_t iree_hal_hip_allocator_alloc_async( iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); hipDeviceptr_t ptr = NULL; iree_status_t status = IREE_HIP_RESULT_TO_STATUS( @@ -644,7 +645,7 @@ iree_status_t iree_hal_hip_allocator_free_async( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context)); + iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (!device_ptr) { diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index a1ebf0c7478aa..bc001d645c3a3 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -14,6 +14,7 @@ #include "iree/base/internal/event_pool.h" #include "iree/base/internal/math.h" #include "iree/base/tracing.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/event_pool.h" #include "iree/hal/drivers/hip/event_semaphore.h" @@ -732,7 +733,7 @@ static void iree_hal_hip_replace_channel_provider( static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_arena_block_pool_trim(&device->block_pool); IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); if (device->supports_memory_pools) { @@ -746,7 +747,7 @@ static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); int value = 0; IREE_HIP_RETURN_IF_ERROR( device->hip_symbols, @@ -761,7 +762,7 @@ static iree_status_t iree_hal_hip_device_query_i64( iree_string_view_t key, int64_t* out_value) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); *out_value = 0; if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { @@ -786,7 +787,7 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (!device->nccl_symbols || !device->nccl_symbols->dylib) { return iree_make_status( @@ -875,7 +876,7 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_stream_command_buffer_create( iree_hal_device_allocator(base_device), device->hip_symbols, @@ -891,7 +892,7 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( iree_hal_command_buffer_t** out_command_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (device->params.allow_inline_execution && iree_all_bits_set(mode, @@ -947,7 +948,7 @@ static iree_status_t iree_hal_hip_device_import_file( iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (iree_io_file_handle_type(handle) != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { @@ -965,7 +966,7 @@ static iree_status_t iree_hal_hip_device_create_executable_cache( iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_nop_executable_cache_create( identifier, device->hip_symbols, device->hip_device, device->hip_context, device->host_allocator, out_executable_cache); @@ -976,7 +977,7 @@ static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_event_semaphore_create( initial_value, device->hip_symbols, device->hip_context, @@ -1032,7 +1033,7 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_buffer_t** IREE_RESTRICT out_buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (device->supports_memory_pools && !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { @@ -1111,7 +1112,7 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( iree_hal_buffer_t* buffer) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); if (iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device))) { iree_status_t status = iree_hal_deferred_work_queue_enqueue_dealloc( @@ -1202,7 +1203,7 @@ static iree_status_t iree_hal_hip_device_queue_execute( IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + z0, iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_status_t status = iree_hal_deferred_work_queue_enqueue( device->work_queue, iree_hal_hip_device_collect_tracing_context, @@ -1233,7 +1234,7 @@ static iree_status_t iree_hal_hip_device_wait_semaphores( const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_RETURN_IF_ERROR( - HIP_SET_CONTEXT(device->hip_symbols, device->hip_context)); + iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); return iree_hal_hip_semaphore_multi_wait(semaphore_list, wait_mode, timeout, &device->block_pool); diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index 0a912c2daddd5..93a046b1f59ec 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -6,6 +6,7 @@ #include "iree/hal/drivers/hip/memory_pools.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -67,8 +68,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); memset(out_pools, 0, sizeof(*out_pools)); out_pools->hip_symbols = hip_symbols; @@ -95,7 +96,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); if (pools->device_local) { IREE_HIP_IGNORE_ERROR(pools->hip_symbols, @@ -154,7 +156,8 @@ static void iree_hal_hip_memory_pool_track_free( void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_STATISTICS({ statistics->device_bytes_allocated = iree_atomic_load( @@ -188,7 +191,8 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_status_t iree_hal_hip_memory_pools_trim( iree_hal_hip_memory_pools_t* pools, const iree_hal_hip_memory_pooling_params_t* pooling_params) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_HIP_RETURN_IF_ERROR( pools->hip_symbols, @@ -209,7 +213,8 @@ static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_IGNORE_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { @@ -223,7 +228,8 @@ static void iree_hal_hip_async_buffer_release_callback( iree_status_t iree_hal_hip_memory_pools_allocate_pointer( iree_hal_hip_memory_pools_t* pools, iree_hal_buffer_t* buffer, hipStream_t stream, iree_device_size_t allocation_size) { - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + IREE_RETURN_IF_ERROR( + iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); // TODO: more pools and better selection; this is coarsely deciding between // only device local (variables, constants, transients) and other (staging, @@ -290,7 +296,7 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_buffer_t* buffer) { IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context)); + z0, iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index 0a11a01625537..e37aba578f02c 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -9,6 +9,7 @@ #include #include "iree/base/api.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/utils/executable_debug_info.h" @@ -213,7 +214,8 @@ iree_status_t iree_hal_hip_native_executable_create( IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + iree_hal_hip_set_context(symbols, context)); *out_executable = NULL; diff --git a/runtime/src/iree/hal/drivers/hip/status_util.h b/runtime/src/iree/hal/drivers/hip/status_util.h index eaac3067c2936..221f55fe02142 100644 --- a/runtime/src/iree/hal/drivers/hip/status_util.h +++ b/runtime/src/iree/hal/drivers/hip/status_util.h @@ -56,9 +56,6 @@ extern "C" { IREE_IGNORE_ERROR(iree_hal_hip_result_to_status((syms), ((syms)->expr), \ __FILE__, __LINE__)) -#define HIP_SET_CONTEXT(syms, ctx) \ - IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(ctx)) - // Converts a hipError_t to an iree_status_t object. iree_status_t iree_hal_hip_result_to_status( const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index c151a7e9d1014..6a201ca2976cb 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -7,6 +7,7 @@ #include "iree/hal/drivers/hip/stream_command_buffer.h" +#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/native_executable.h" #include "iree/hal/drivers/hip/rccl_channel.h" @@ -74,8 +75,8 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( } IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - HIP_SET_CONTEXT(hip_symbols, hip_context)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(hip_symbols, hip_context)); iree_hal_hip_stream_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -120,8 +121,8 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_IGNORE_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_stream_tracing_free(command_buffer->tracing_context, &command_buffer->tracing_event_list); @@ -178,8 +179,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -196,8 +197,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -229,8 +230,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( const iree_hal_label_location_t* location) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -246,8 +247,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -267,8 +268,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_execution_barrier( const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { @@ -334,8 +335,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -390,8 +391,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -436,8 +437,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -471,8 +472,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); iree_hal_buffer_binding_t send_binding = { .buffer = send_ref.buffer, @@ -501,8 +502,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, HIP_SET_CONTEXT(command_buffer->hip_symbols, - command_buffer->hip_context)); + z0, iree_hal_hip_set_context(command_buffer->hip_symbols, + command_buffer->hip_context)); // If any of the workgroup counts are zero, we can skip execution // of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when