Skip to content

Commit

Permalink
[hip] Set the current device before calls into Hip.
Browse files Browse the repository at this point in the history
This is a bit of a brute-force way to solve our main
hip multi-device problems temporarily until
the more complete fix is in place.

Signed-off-by: Andrew Woloszyn <[email protected]>
  • Loading branch information
AWoloszyn committed Nov 12, 2024
1 parent 2bfc639 commit e50e07f
Show file tree
Hide file tree
Showing 16 changed files with 171 additions and 39 deletions.
26 changes: 19 additions & 7 deletions runtime/src/iree/hal/drivers/hip/event_pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct iree_hal_hip_event_t {
// The event pool that owns this event. This cannot be NULL. We retain it to
// make sure the event outlive the pool.
iree_hal_hip_event_pool_t* pool;

// The context to use to free this event, it must be the same
// context as was used when allocating the event.
hipCtx_t hip_context;
// The underlying hipEvent_t object.
hipEvent_t hip_event;
};
Expand All @@ -45,6 +49,8 @@ hipEvent_t iree_hal_hip_event_handle(const iree_hal_hip_event_t* event) {
}

static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
IREE_IGNORE_ERROR(HIP_SET_CONTEXT(event->symbols, event->hip_context));

iree_allocator_t host_allocator = event->host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
Expand All @@ -58,8 +64,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {

static inline iree_status_t iree_hal_hip_event_create(
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator,
iree_hal_hip_event_t** out_event) {
iree_hal_hip_event_pool_t* pool, hipCtx_t context,
iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(pool);
IREE_ASSERT_ARGUMENT(out_event);
Expand All @@ -75,6 +81,7 @@ static inline iree_status_t iree_hal_hip_event_create(
event->symbols = symbols;
event->pool = pool;
event->hip_event = NULL;
event->hip_context = context;

iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
symbols,
Expand Down Expand Up @@ -122,6 +129,10 @@ struct iree_hal_hip_event_pool_t {
// The symbols used to create and destroy hipEvent_t objects.
const iree_hal_hip_dynamic_symbols_t* symbols;

// The context for this event pool to use to allocate
// events.
hipCtx_t hip_context;

// Guards event related fields in the pool. We don't expect a performant
// program to frequently allocate events for synchronization purposes; the
// traffic to this pool should be low. So it should be fine to use mutex to
Expand All @@ -142,7 +153,7 @@ struct iree_hal_hip_event_pool_t {
static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool);

iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool) {
IREE_ASSERT_ARGUMENT(symbols);
Expand All @@ -163,11 +174,12 @@ iree_status_t iree_hal_hip_event_pool_allocate(
iree_slim_mutex_initialize(&event_pool->event_mutex);
event_pool->available_capacity = available_capacity;
event_pool->available_count = 0;
event_pool->hip_context = hip_context;

iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < available_capacity; ++i) {
status = iree_hal_hip_event_create(
symbols, event_pool, host_allocator,
symbols, event_pool, hip_context, host_allocator,
&event_pool->available_list[event_pool->available_count++]);
if (!iree_status_is_ok(status)) break;
}
Expand Down Expand Up @@ -240,9 +252,9 @@ iree_status_t iree_hal_hip_event_pool_acquire(
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire");
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
status = iree_hal_hip_event_create(event_pool->symbols, event_pool,
event_pool->host_allocator,
&out_events[from_pool_count + i]);
status = iree_hal_hip_event_create(
event_pool->symbols, event_pool, event_pool->hip_context,
event_pool->host_allocator, &out_events[from_pool_count + i]);
if (!iree_status_is_ok(status)) {
// Must release all events we've acquired so far.
iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i,
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t;
// Extra events requested beyond the capability are directly created and
// destroyed without pooling.
iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool);

Expand Down
10 changes: 9 additions & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/base/internal/wait_handle.h"
#include "iree/base/status.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/drivers/hip/timepoint_pool.h"
#include "iree/hal/utils/semaphore_base.h"

Expand All @@ -30,6 +31,8 @@ typedef struct iree_hal_hip_semaphore_t {
// new signaled values.
iree_hal_deferred_work_queue_t* work_queue;

hipCtx_t hip_context;

// Guards value and status. We expect low contention on semaphores and since
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
// than trying to make the entire structure lock-free.
Expand All @@ -56,9 +59,11 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast(

iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(symbols, hip_context));

IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(timepoint_pool);
IREE_ASSERT_ARGUMENT(work_queue);
Expand All @@ -79,6 +84,7 @@ iree_status_t iree_hal_hip_event_semaphore_create(
iree_slim_mutex_initialize(&semaphore->mutex);
semaphore->current_value = initial_value;
semaphore->failure_status = iree_ok_status();
semaphore->hip_context = hip_context;

*out_semaphore = &semaphore->base;

Expand All @@ -90,6 +96,8 @@ static void iree_hal_hip_semaphore_destroy(
iree_hal_semaphore_t* base_semaphore) {
iree_hal_hip_semaphore_t* semaphore =
iree_hal_hip_semaphore_cast(base_semaphore);
IREE_IGNORE_ERROR(
HIP_SET_CONTEXT(semaphore->symbols, semaphore->hip_context));
iree_allocator_t host_allocator = semaphore->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);

Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
// Thread-safe; multiple threads may signal/wait values on the same semaphore.
iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore);

Expand Down
25 changes: 23 additions & 2 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ typedef struct iree_hal_hip_allocator_t {
// The HIP stream that allocations should be used in.
hipStream_t stream;

hipCtx_t hip_context;

// NOTE: optional depending on device support.
iree_hal_hip_memory_pools_t* pools;

Expand All @@ -54,8 +56,11 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast(

iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator) {
IREE_RETURN_IF_ERROR(HIP_SET_CONTEXT(hip_symbols, hip_context));

IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(out_allocator);
IREE_TRACE_ZONE_BEGIN(z0);
Expand Down Expand Up @@ -94,6 +99,7 @@ iree_status_t iree_hal_hip_allocator_create(
allocator->host_allocator = host_allocator;
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;
allocator->hip_context = hip_context;
*out_allocator = (iree_hal_allocator_t*)allocator;

IREE_TRACE_ZONE_END(z0);
Expand Down Expand Up @@ -319,6 +325,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

// Coerce options into those required by the current device.
iree_hal_buffer_params_t compat_params = *params;
iree_hal_buffer_compatibility_t compatibility =
Expand Down Expand Up @@ -431,6 +440,9 @@ static void iree_hal_hip_allocator_deallocate_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_IGNORE_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

const iree_hal_hip_buffer_type_t buffer_type =
iree_hal_hip_buffer_type(base_buffer);

Expand Down Expand Up @@ -466,6 +478,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

// Coerce options into those required by the current device.
iree_hal_buffer_params_t compat_params = *params;
iree_device_size_t allocation_size = external_buffer->size;
Expand Down Expand Up @@ -600,6 +615,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

hipDeviceptr_t ptr = NULL;
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
allocator->symbols,
Expand All @@ -625,6 +643,9 @@ iree_status_t iree_hal_hip_allocator_free_async(
iree_hal_buffer_t* buffer) {
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);
IREE_RETURN_IF_ERROR(
HIP_SET_CONTEXT(allocator->symbols, allocator->hip_context));

hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer);
if (!device_ptr) {
return iree_ok_status();
Expand Down
7 changes: 4 additions & 3 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ extern "C" {
#endif // __cplusplus

// Creates a HIP memory allocator.
// |device| and |stream| will be used for management operations.
// |device| |hip_context| and |stream| will be used for management operations.
// |pools| provides memory pools that may be shared across multiple allocators
// and the pointer must remain valid for the lifetime of the allocator.
iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator);

bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value);

Expand Down
Loading

0 comments on commit e50e07f

Please sign in to comment.