Skip to content

Commit

Permalink
Move our context switches inside any zones that need them.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Woloszyn <[email protected]>
  • Loading branch information
AWoloszyn committed Nov 12, 2024
1 parent e50e07f commit d83fee6
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 44 deletions.
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/hip/event_pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ hipEvent_t iree_hal_hip_event_handle(const iree_hal_hip_event_t* event) {
}

static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context));

iree_allocator_t host_allocator = event->host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context));

IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count);
IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event));
Expand Down
7 changes: 3 additions & 4 deletions runtime/src/iree/hal/drivers/hip/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ iree_status_t iree_hal_hip_event_semaphore_create(
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, hip_context));

IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(timepoint_pool);
IREE_ASSERT_ARGUMENT(work_queue);
IREE_ASSERT_ARGUMENT(out_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, hip_context));
iree_hal_hip_semaphore_t* semaphore = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
Expand All @@ -96,10 +95,10 @@ static void iree_hal_hip_semaphore_destroy(
iree_hal_semaphore_t* base_semaphore) {
iree_hal_hip_semaphore_t* semaphore =
iree_hal_hip_semaphore_cast(base_semaphore);
IREE_IGNORE_ERROR(
HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context));
iree_allocator_t host_allocator = semaphore->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(
HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context));

iree_status_ignore(semaphore->failure_status);
iree_slim_mutex_deinitialize(&semaphore->mutex);
Expand Down
10 changes: 5 additions & 5 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ iree_status_t iree_hal_hip_allocator_create(
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context));

IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(out_allocator);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
HIP_SET_CONTEXT(hip_symbols, hip_context));

// To support device-local + host-visible memory we need concurrent managed
// access indicating that the host and devices can concurrently access the
Expand Down Expand Up @@ -325,9 +325,6 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

// Coerce options into those required by the current device.
iree_hal_buffer_params_t compat_params = *params;
iree_hal_buffer_compatibility_t compatibility =
Expand Down Expand Up @@ -361,6 +358,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
void* host_ptr = NULL;
hipDeviceptr_t device_ptr = NULL;
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate");
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size);
if (iree_all_bits_set(compat_params.type,
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/hip/hip_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -1199,10 +1199,10 @@ static iree_status_t iree_hal_hip_device_queue_execute(
iree_hal_command_buffer_t* command_buffer,
iree_hal_buffer_binding_table_t binding_table) {
iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(device->hip_symbols, device->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(device->hip_symbols, device->hip_context));

iree_status_t status = iree_hal_deferred_work_queue_enqueue(
device->work_queue, iree_hal_hip_device_collect_tracing_context,
Expand Down
14 changes: 6 additions & 8 deletions runtime/src/iree/hal/drivers/hip/memory_pools.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ iree_status_t iree_hal_hip_memory_pools_initialize(
const iree_hal_hip_memory_pooling_params_t* pooling_params,
iree_allocator_t host_allocator,
iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context));

IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(pooling_params);
IREE_ASSERT_ARGUMENT(out_pools);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
HIP_SET_CONTEXT(hip_symbols, hip_context));

memset(out_pools, 0, sizeof(*out_pools));
out_pools->hip_symbols = hip_symbols;
Expand All @@ -94,9 +94,8 @@ iree_status_t iree_hal_hip_memory_pools_initialize(

void iree_hal_hip_memory_pools_deinitialize(
iree_hal_hip_memory_pools_t* pools) {
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));

if (pools->device_local) {
IREE_HIP_IGNORE_ERROR(pools->hip_symbols,
Expand Down Expand Up @@ -209,9 +208,8 @@ iree_status_t iree_hal_hip_memory_pools_trim(
static void iree_hal_hip_async_buffer_release_callback(
void* user_data, iree_hal_buffer_t* buffer) {
iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data;
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));

hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer);
if (device_ptr) {
Expand Down Expand Up @@ -290,9 +288,9 @@ iree_status_t iree_hal_hip_memory_pools_prepare_buffer(
iree_status_t iree_hal_hip_memory_pools_deallocate(
iree_hal_hip_memory_pools_t* pools, hipStream_t stream,
iree_hal_buffer_t* buffer) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(pools->hip_symbols, pools->hip_context));
IREE_TRACE_ZONE_APPEND_VALUE_I64(
z0, (int64_t)iree_hal_buffer_allocation_size(buffer));

Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/hip/native_executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,11 @@ iree_status_t iree_hal_hip_native_executable_create(
const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device,
hipCtx_t context, const iree_hal_executable_params_t* executable_params,
iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, context));

IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, HIP_SET_CONTEXT(symbols, context));

*out_executable = NULL;

Expand Down
43 changes: 22 additions & 21 deletions runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ iree_status_t iree_hal_hip_stream_command_buffer_create(
iree_host_size_t binding_capacity, hipStream_t stream,
iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context));

IREE_ASSERT_ARGUMENT(device_allocator);
IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(nccl_symbols);
Expand All @@ -76,6 +74,8 @@ iree_status_t iree_hal_hip_stream_command_buffer_create(
}

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0,
HIP_SET_CONTEXT(hip_symbols, hip_context));

iree_hal_hip_stream_command_buffer_t* command_buffer = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
Expand Down Expand Up @@ -118,10 +118,10 @@ static void iree_hal_hip_stream_command_buffer_destroy(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));
iree_allocator_t host_allocator = command_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

iree_hal_stream_tracing_free(command_buffer->tracing_context,
&command_buffer->tracing_event_list);
Expand Down Expand Up @@ -194,9 +194,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -331,10 +332,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer(
iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -387,10 +388,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer(
iree_hal_update_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -433,10 +434,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer(
iree_hal_copy_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer));
Expand Down Expand Up @@ -468,10 +469,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective(
iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

iree_hal_buffer_binding_t send_binding = {
.buffer = send_ref.buffer,
Expand All @@ -498,10 +499,10 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch(
iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, HIP_SET_CONTEXT(command_buffer->hip_symbols,
command_buffer->hip_context));

// If any of the workgroup counts are zero, we can skip execution
// of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when
Expand Down

0 comments on commit d83fee6

Please sign in to comment.