Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hip] Set the current device before calls into Hip. #19103

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"api.h"
SRCS
"api.h"
"context_util.h"
"event_pool.c"
"event_pool.h"
"event_semaphore.c"
Expand Down
34 changes: 34 additions & 0 deletions runtime/src/iree/hal/drivers/hip/context_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
#define IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_

#include "iree/base/api.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"

static inline iree_status_t iree_hal_hip_set_context(
const iree_hal_hip_dynamic_symbols_t* syms, hipCtx_t hip_context) {
if (!hip_context) {
return iree_ok_status();
}
IREE_TRACE({
hipCtx_t current_context = NULL;
IREE_HIP_RETURN_IF_ERROR(syms, hipCtxGetCurrent(&current_context),
"hipCtxGetCurrent");
if (current_context != hip_context) {
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch");
iree_status_t status =
IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
IREE_TRACE_ZONE_END(z0);
return status;
}
});
return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
}

#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// HIP symbols
//===----------------------------------------------------------------------===//

IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int)
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *,
Expand Down
27 changes: 20 additions & 7 deletions runtime/src/iree/hal/drivers/hip/event_pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/base/internal/atomics.h"
#include "iree/base/internal/synchronization.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"

Expand All @@ -36,6 +37,10 @@ struct iree_hal_hip_event_t {
// The event pool that owns this event. This cannot be NULL. We retain it to
// make sure the event outlive the pool.
iree_hal_hip_event_pool_t* pool;

// The context to use to free this event, it must be the same
// context as was used when allocating the event.
hipCtx_t hip_context;
// The underlying hipEvent_t object.
hipEvent_t hip_event;
};
Expand All @@ -48,6 +53,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {
iree_allocator_t host_allocator = event->host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(
iree_hal_hip_set_context(event->symbols, event->hip_context));

IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count);
IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event));
Expand All @@ -58,8 +65,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) {

static inline iree_status_t iree_hal_hip_event_create(
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator,
iree_hal_hip_event_t** out_event) {
iree_hal_hip_event_pool_t* pool, hipCtx_t context,
iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(pool);
IREE_ASSERT_ARGUMENT(out_event);
Expand All @@ -75,6 +82,7 @@ static inline iree_status_t iree_hal_hip_event_create(
event->symbols = symbols;
event->pool = pool;
event->hip_event = NULL;
event->hip_context = context;

iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
symbols,
Expand Down Expand Up @@ -122,6 +130,10 @@ struct iree_hal_hip_event_pool_t {
// The symbols used to create and destroy hipEvent_t objects.
const iree_hal_hip_dynamic_symbols_t* symbols;

// The context for this event pool to use to allocate
// events.
hipCtx_t hip_context;

// Guards event related fields in the pool. We don't expect a performant
// program to frequently allocate events for synchronization purposes; the
// traffic to this pool should be low. So it should be fine to use mutex to
Expand All @@ -142,7 +154,7 @@ struct iree_hal_hip_event_pool_t {
static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool);

iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool) {
IREE_ASSERT_ARGUMENT(symbols);
Expand All @@ -163,11 +175,12 @@ iree_status_t iree_hal_hip_event_pool_allocate(
iree_slim_mutex_initialize(&event_pool->event_mutex);
event_pool->available_capacity = available_capacity;
event_pool->available_count = 0;
event_pool->hip_context = hip_context;

iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < available_capacity; ++i) {
status = iree_hal_hip_event_create(
symbols, event_pool, host_allocator,
symbols, event_pool, hip_context, host_allocator,
&event_pool->available_list[event_pool->available_count++]);
if (!iree_status_is_ok(status)) break;
}
Expand Down Expand Up @@ -240,9 +253,9 @@ iree_status_t iree_hal_hip_event_pool_acquire(
IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire");
iree_status_t status = iree_ok_status();
for (iree_host_size_t i = 0; i < remaining_count; ++i) {
status = iree_hal_hip_event_create(event_pool->symbols, event_pool,
event_pool->host_allocator,
&out_events[from_pool_count + i]);
status = iree_hal_hip_event_create(
event_pool->symbols, event_pool, event_pool->hip_context,
event_pool->host_allocator, &out_events[from_pool_count + i]);
if (!iree_status_is_ok(status)) {
// Must release all events we've acquired so far.
iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i,
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t;
// Extra events requested beyond the capability are directly created and
// destroyed without pooling.
iree_status_t iree_hal_hip_event_pool_allocate(
const iree_hal_hip_dynamic_symbols_t* symbols,
const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context,
iree_host_size_t available_capacity, iree_allocator_t host_allocator,
iree_hal_hip_event_pool_t** out_event_pool);

Expand Down
11 changes: 10 additions & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "iree/base/internal/synchronization.h"
#include "iree/base/internal/wait_handle.h"
#include "iree/base/status.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/drivers/hip/timepoint_pool.h"
#include "iree/hal/utils/semaphore_base.h"

Expand All @@ -30,6 +32,8 @@ typedef struct iree_hal_hip_semaphore_t {
// new signaled values.
iree_hal_deferred_work_queue_t* work_queue;

hipCtx_t hip_context;

// Guards value and status. We expect low contention on semaphores and since
// iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler
// than trying to make the entire structure lock-free.
Expand All @@ -56,7 +60,7 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast(

iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore) {
IREE_ASSERT_ARGUMENT(symbols);
Expand All @@ -65,6 +69,8 @@ iree_status_t iree_hal_hip_event_semaphore_create(
IREE_ASSERT_ARGUMENT(out_semaphore);
IREE_TRACE_ZONE_BEGIN(z0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(symbols, hip_context));
iree_hal_hip_semaphore_t* semaphore = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore),
Expand All @@ -79,6 +85,7 @@ iree_status_t iree_hal_hip_event_semaphore_create(
iree_slim_mutex_initialize(&semaphore->mutex);
semaphore->current_value = initial_value;
semaphore->failure_status = iree_ok_status();
semaphore->hip_context = hip_context;

*out_semaphore = &semaphore->base;

Expand All @@ -92,6 +99,8 @@ static void iree_hal_hip_semaphore_destroy(
iree_hal_hip_semaphore_cast(base_semaphore);
iree_allocator_t host_allocator = semaphore->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_IGNORE_ERROR(
iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context));

iree_status_ignore(semaphore->failure_status);
iree_slim_mutex_deinitialize(&semaphore->mutex);
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/hal/drivers/hip/event_semaphore.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {
// Thread-safe; multiple threads may signal/wait values on the same semaphore.
iree_status_t iree_hal_hip_event_semaphore_create(
uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_timepoint_pool_t* timepoint_pool,
hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool,
iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator,
iree_hal_semaphore_t** out_semaphore);

Expand Down
26 changes: 24 additions & 2 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/hip/context_util.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
#include "iree/hal/drivers/hip/hip_buffer.h"
#include "iree/hal/drivers/hip/status_util.h"
Expand All @@ -29,6 +30,8 @@ typedef struct iree_hal_hip_allocator_t {
// The HIP stream that allocations should be used in.
hipStream_t stream;

hipCtx_t hip_context;

// NOTE: optional depending on device support.
iree_hal_hip_memory_pools_t* pools;

Expand All @@ -54,11 +57,14 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast(

iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) {
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(hip_symbols);
IREE_ASSERT_ARGUMENT(out_allocator);
IREE_TRACE_ZONE_BEGIN(z0);
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(hip_symbols, hip_context));

// To support device-local + host-visible memory we need concurrent managed
// access indicating that the host and devices can concurrently access the
Expand Down Expand Up @@ -94,6 +100,7 @@ iree_status_t iree_hal_hip_allocator_create(
allocator->host_allocator = host_allocator;
allocator->supports_concurrent_managed_access =
supports_concurrent_managed_access != 0;
allocator->hip_context = hip_context;
*out_allocator = (iree_hal_allocator_t*)allocator;

IREE_TRACE_ZONE_END(z0);
Expand Down Expand Up @@ -352,6 +359,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer(
void* host_ptr = NULL;
hipDeviceptr_t device_ptr = NULL;
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate");
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size);
if (iree_all_bits_set(compat_params.type,
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
Expand Down Expand Up @@ -431,6 +441,9 @@ static void iree_hal_hip_allocator_deallocate_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_IGNORE_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

const iree_hal_hip_buffer_type_t buffer_type =
iree_hal_hip_buffer_type(base_buffer);

Expand Down Expand Up @@ -466,6 +479,9 @@ static iree_status_t iree_hal_hip_allocator_import_buffer(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

// Coerce options into those required by the current device.
iree_hal_buffer_params_t compat_params = *params;
iree_device_size_t allocation_size = external_buffer->size;
Expand Down Expand Up @@ -600,6 +616,9 @@ iree_status_t iree_hal_hip_allocator_alloc_async(
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);

IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

hipDeviceptr_t ptr = NULL;
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
allocator->symbols,
Expand All @@ -625,6 +644,9 @@ iree_status_t iree_hal_hip_allocator_free_async(
iree_hal_buffer_t* buffer) {
iree_hal_hip_allocator_t* allocator =
iree_hal_hip_allocator_cast(base_allocator);
IREE_RETURN_IF_ERROR(
iree_hal_hip_set_context(allocator->symbols, allocator->hip_context));

hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer);
if (!device_ptr) {
return iree_ok_status();
Expand Down
7 changes: 4 additions & 3 deletions runtime/src/iree/hal/drivers/hip/hip_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ extern "C" {
#endif // __cplusplus

// Creates a HIP memory allocator.
// |device| and |stream| will be used for management operations.
// |device| |hip_context| and |stream| will be used for management operations.
// |pools| provides memory pools that may be shared across multiple allocators
// and the pointer must remain valid for the lifetime of the allocator.
iree_status_t iree_hal_hip_allocator_create(
const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device,
hipStream_t stream, iree_hal_hip_memory_pools_t* pools,
iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator);
hipCtx_t hip_context, hipStream_t stream,
iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator,
iree_hal_allocator_t** out_allocator);

bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value);

Expand Down
Loading
Loading