Skip to content

Commit

Permalink
[hip] Set the current device before calls into Hip. (iree-org#19103)
Browse files Browse the repository at this point in the history
This is a bit of a brute-force way to solve our main hip multi-device
problems temporarily until the more complete fix is in place.

For the single-device case this has negligible performance implications
as `hipCtxSetCurrent` is a no-op in that case.
For the multi-device case this could cause more significant performance
problems if the user program swaps between devices within a thread.

---------

Signed-off-by: Andrew Woloszyn <[email protected]>
  • Loading branch information
AWoloszyn authored and Groverkss committed Nov 29, 2024
1 parent 5112038 commit d213b3b
Show file tree
Hide file tree
Showing 18 changed files with 215 additions and 39 deletions.
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"api.h"
SRCS
"api.h"
"context_util.h"
"event_pool.c"
"event_pool.h"
"event_semaphore.c"
Expand Down
34 changes: 34 additions & 0 deletions runtime/src/iree/hal/drivers/hip/context_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_

#include "iree/base/api.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"

static inline iree_status_t iree_hal_hip_set_context(
const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) {
if (!hip_context) {
return iree_ok_status();
}
IREE_TRACE({
hipCtx_t current_context = NULL;
IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(&current_context),
"hipCtxGetCurrent");
if (current_context != hip_context) {
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch");
iree_status_t status =
IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
IREE_TRACE_ZONE_END(z0);
return status;
}
});
return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
}

#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// HIP symbols
//===----------------------------------------------------------------------===//

IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *,
Expand Down
27 changes: 20 additions & 7 deletions runtime/src/iree/hal/drivers/hip/event_pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/base/internal/atomics.h"
#include "iree/base/internal/synchronization.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"

Expand All @@ -36,6 +37,10 @@ struct iree_hal_hip_event_t {
// The event pool that owns this event. This cannot be NULL. We retain it to
// make sure the event outlive the pool.
iree_hal_hip_event_pool_t* pool;

// The context to use to free this event, it must be the same
// context as was used when allocating the event.
hipCtx_t hip_context;
// The underlying hipEvent_t object.
hipEvent_t hip_event;
};
Expand All @@ -48,6 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
iree_allocator_t host_allocator = event->host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(
iree_hal_hip_set_context(event->symbols, event->hip_context));

IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count);
IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event));
Expand All @@ -58,8 +65,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {

static inline iree_status_t iree_hal_hip_event_create(
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator,
iree_hal_hip_event_t** out_event) {
iree_hal_hip_event_pool_t* pool, hipCtx_t context,
iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(pool);
IREE_ASSERT_ARGUMENT(out_event);
Expand All @@ -75,6 +82,7 @@ static inline iree_status_t iree_hal_hip_event_create(
event->symbols = symbols;
event->pool = pool;
event->hip_event = NULL;
event->hip_context = context;

iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
symbols,
Expand Down Expand Up @@ -122,6 +130,10 @@ struct iree_hal_hip_event_pool_t {
// The symbols used to create and destroy hipEvent_t objects.
const iree_hal_hip_dynamic_symbols_t* symbols;

// The context for this event pool to use to allocate
// events.
hipCtx_t hip_context;

// Guards event related fields in the pool. We don't expect a performant
// program to frequently allocate events for synchronization purposes; the
// traffic to this pool should be low. So it should be fine to use mutex to
Expand All @@ -142,7 +154,7 @@ struct iree_hal_hip_event_pool_t {
static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool);

iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool) {
IREE_ASSERT_ARGUMENT(symbols);
Expand All @@ -163,11 +175,12 @@ iree_status_t iree_hal_hip_event_pool_allocate(
iree_slim_mutex_initialize(&event_pool->event_mutex);
event_pool->available_capacity = available_capacity;
event_pool->available_count = 0;
event_pool->hip_context = hip_context;

iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < available_capacity; ++i) {
status = iree_hal_hip_event_create(
symbols, event_pool, host_allocator,
symbols, event_pool, hip_context, host_allocator,
&event_pool->available_list[event_pool->available_count++]);
if (!iree_status_is_ok(status)) break;
}
Expand Down Expand Up @@ -240,9 +253,9 @@ iree_status_t iree_hal_hip_event_pool_acquire(
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire");
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
status = iree_hal_hip_event_create(event_pool->symbols, event_pool,
event_pool->host_allocator,
&out_events[from_pool_count + i]);
status = iree_hal_hip_event_create(
event_pool->symbols, event_pool, event_pool->hip_context,
event_pool->host_allocator, &out_events[from_pool_count + i]);
if (!iree_status_is_ok(status)) {
// Must release all events we've acquired so far.
iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i,
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t;
// Extra events requested beyond the capability are directly created and
// destroyed without pooling.
iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool);

Expand Down
11 changes: 10 additions & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "iree/base/internal/synchronization.h"
#include "iree/base/internal/wait_handle.h"
#include "iree/base/status.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/drivers/hip/timepoint_pool.h"
#include "iree/hal/utils/semaphore_base.h"

Expand All @@ -30,6 +32,8 @@ typedef struct iree_hal_hip_semaphore_t {
// new signaled values.
iree_hal_deferred_work_queue_t* work_queue;

hipCtx_t hip_context;

// Guards value and status. We expect low contention on semaphores and since
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
// than trying to make the entire structure lock-free.
Expand All @@ -56,7 +60,7 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast(

iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore) {
IREE_ASSERT_ARGUMENT(symbols);
Expand All @@ -65,6 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create(
IREE_ASSERT_ARGUMENT(out_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(symbols, hip_context));
iree_hal_hip_semaphore_t* semaphore = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
Expand All @@ -79,6 +85,7 @@ iree_status_t iree_hal_hip_event_semaphore_create(
iree_slim_mutex_initialize(&semaphore->mutex);
semaphore->current_value = initial_value;
semaphore->failure_status = iree_ok_status();
semaphore->hip_context = hip_context;

*out_semaphore = &semaphore->base;

Expand All @@ -92,6 +99,8 @@ static void iree_hal_hip_semaphore_destroy(
iree_hal_hip_semaphore_cast(base_semaphore);
iree_allocator_t host_allocator = semaphore->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(
iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context));

iree_status_ignore(semaphore->failure_status);
iree_slim_mutex_deinitialize(&semaphore->mutex);
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
// Thread-safe; multiple threads may signal/wait values on the same semaphore.
iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore);

Expand Down
26 changes: 24 additions & 2 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/hip_buffer.h"
#include "iree/hal/drivers/hip/status_util.h"
Expand All @@ -29,6 +30,8 @@ typedef struct iree_hal_hip_allocator_t {
// The HIP stream that allocations should be used in.
hipStream_t stream;

hipCtx_t hip_context;

// NOTE: optional depending on device support.
iree_hal_hip_memory_pools_t* pools;

Expand All @@ -54,11 +57,14 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast(

iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(out_allocator);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(hip_symbols, hip_context));

// To support device-local + host-visible memory we need concurrent managed
// access indicating that the host and devices can concurrently access the
Expand Down Expand Up @@ -94,6 +100,7 @@ iree_status_t iree_hal_hip_allocator_create(
allocator->host_allocator = host_allocator;
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;
allocator->hip_context = hip_context;
*out_allocator = (iree_hal_allocator_t*)allocator;

IREE_TRACE_ZONE_END(z0);
Expand Down Expand Up @@ -352,6 +359,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
void* host_ptr = NULL;
hipDeviceptr_t device_ptr = NULL;
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate");
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size);
if (iree_all_bits_set(compat_params.type,
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
Expand Down Expand Up @@ -431,6 +441,9 @@ static void iree_hal_hip_allocator_deallocate_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_IGNORE_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

const iree_hal_hip_buffer_type_t buffer_type =
iree_hal_hip_buffer_type(base_buffer);

Expand Down Expand Up @@ -466,6 +479,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

// Coerce options into those required by the current device.
iree_hal_buffer_params_t compat_params = *params;
iree_device_size_t allocation_size = external_buffer->size;
Expand Down Expand Up @@ -600,6 +616,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

hipDeviceptr_t ptr = NULL;
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
allocator->symbols,
Expand All @@ -625,6 +644,9 @@ iree_status_t iree_hal_hip_allocator_free_async(
iree_hal_buffer_t* buffer) {
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);
IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer);
if (!device_ptr) {
return iree_ok_status();
Expand Down
7 changes: 4 additions & 3 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ extern "C" {
#endif // __cplusplus

// Creates a HIP memory allocator.
// |device| and |stream| will be used for management operations.
// |device| |hip_context| and |stream| will be used for management operations.
// |pools| provides memory pools that may be shared across multiple allocators
// and the pointer must remain valid for the lifetime of the allocator.
iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator);

bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value);

Expand Down
Loading

0 comments on commit d213b3b

Please sign in to comment.