From 385e488d0a62b07b7b296c9942891e2caadb633e Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 5 Nov 2024 08:29:08 -0800 Subject: [PATCH] Adding fill/update/copy HAL ops. (#19026) These ops use the newer style of 64-bit flags. TODOs were added to hal.imports.mlir for future cleanup to existing ops whenever we want to bump the version. Signed-off-by: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com> --- .../Conversion/HALToVM/ConvertDeviceOps.cpp | 50 ++++++ .../Conversion/HALToVM/test/device_ops.mlir | 142 ++++++++++++++++++ .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 12 ++ .../iree/compiler/Dialect/HAL/IR/HALOps.td | 132 ++++++++++++++++ .../Dialect/HAL/IR/test/device_ops.mlir | 103 +++++++++++++ .../compiler/Dialect/HAL/hal.imports.mlir | 64 +++++++- runtime/src/iree/modules/hal/exports.inl | 3 + runtime/src/iree/modules/hal/module.c | 71 +++++++++ runtime/src/iree/vm/shims.c | 2 + runtime/src/iree/vm/shims.h | 28 ++++ 10 files changed, 605 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp index 915e5f2a54137..3af63b32abfc6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp @@ -115,6 +115,50 @@ class DeviceQueryI64OpConversion mutable IREE::VM::ImportOp importOp; }; +class DeviceQueueFillOpConversion + : public OpConversionPattern { +public: + DeviceQueueFillOpConversion(MLIRContext *context, SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::DeviceQueueFillOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + auto i64Type = rewriter.getI64Type(); + auto patternLength = rewriter.create( + op.getLoc(), + llvm::divideCeil(op.getPattern().getType().getIntOrFloatBitWidth(), 8)); + auto flags = + rewriter.create(op.getLoc(), op.getFlags()); + std::array callOperands = { + adaptor.getDevice(), + castToImportType(adaptor.getQueueAffinity(), i64Type, rewriter), + adaptor.getWaitFence(), + adaptor.getSignalFence(), + adaptor.getTargetBuffer(), + castToImportType(adaptor.getTargetOffset(), i64Type, rewriter), + castToImportType(adaptor.getLength(), i64Type, rewriter), + castToImportType(adaptor.getPattern(), i64Type, rewriter), + patternLength, + flags, + }; + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + class DeviceQueueExecuteIndirectOpConversion : public OpConversionPattern { public: @@ -185,6 +229,12 @@ void populateHALDeviceToVMPatterns(MLIRContext *context, context, importSymbols, typeConverter, "hal.device.queue.alloca"); patterns.insert>( context, importSymbols, typeConverter, "hal.device.queue.dealloca"); + patterns.insert( + context, importSymbols, typeConverter, "hal.device.queue.fill"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.update"); + patterns.insert>( + context, importSymbols, typeConverter, "hal.device.queue.copy"); patterns.insert>( context, importSymbols, typeConverter, "hal.device.queue.read"); patterns.insert>( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir index a0052cbce09ae..608cacc1f278f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir @@ -141,6 +141,148 @@ util.func public @device_queue_dealloca( // ----- +// CHECK-LABEL: @device_queue_fill_i8 +util.func public @device_queue_fill_i8( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[PATTERN_I8_I32:.+]]: i32, + %pattern_i8: i8, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !vm.ref) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = vm.const.i64 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 300 + %length = arith.constant 300 : index + // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 1 + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK-DAG: %[[PATTERN_I8_I64:.+]] = vm.ext.i32.i64.s %[[PATTERN_I8_I32]] + // CHECK: vm.call @hal.device.queue.fill( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[TARGET_BUFFER]], %[[TARGET_OFFSET]], + // CHECK-SAME: %[[LENGTH]], + // CHECK-SAME: %[[PATTERN_I8_I64]], %[[PATTERN_LENGTH]], + // CHECK-SAME: %[[FLAGS]]) + hal.device.queue.fill<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + target(%target_buffer : !hal.buffer)[%target_offset] + length(%length) + pattern(%pattern_i8 : i8) + flags(0) + util.return +} + +// ----- + +// CHECK-LABEL: @device_queue_fill_i32 +util.func public @device_queue_fill_i32( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[PATTERN_I32:.+]]: i32, + %pattern_i32: i32, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !vm.ref) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = vm.const.i64 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 300 + %length = arith.constant 300 : index + // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4 + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK-DAG: %[[PATTERN_I32_I64:.+]] = vm.ext.i32.i64.s %[[PATTERN_I32]] + // CHECK: vm.call @hal.device.queue.fill( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[TARGET_BUFFER]], %[[TARGET_OFFSET]], + // CHECK-SAME: %[[LENGTH]], + // CHECK-SAME: %[[PATTERN_I32_I64]], %[[PATTERN_LENGTH]], + // CHECK-SAME: %[[FLAGS]]) + hal.device.queue.fill<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + target(%target_buffer : !hal.buffer)[%target_offset] + length(%length) + pattern(%pattern_i32 : i32) + flags(0) + util.return +} + +// ----- + +// CHECK-LABEL: @device_queue_update +util.func public @device_queue_update( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !vm.buffer, + %source_buffer: !util.buffer, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !vm.ref) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = vm.const.i64 100 + %source_offset = arith.constant 100 : index + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = vm.const.i64 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 300 + %length = arith.constant 300 : index + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.device.queue.update( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[SOURCE_BUFFER]], %[[SOURCE_OFFSET]], + // CHECK-SAME: %[[TARGET_BUFFER]], %[[TARGET_OFFSET]], + // CHECK-SAME: %[[LENGTH]], %[[FLAGS]]) + hal.device.queue.update<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + source(%source_buffer : !util.buffer)[%source_offset] + target(%target_buffer : !hal.buffer)[%target_offset] + length(%length) + flags(0) + util.return +} + +// ----- + +// CHECK-LABEL: @device_queue_copy +util.func public @device_queue_copy( + // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref, %[[SIGNAL_FENCE:.+]]: !vm.ref, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !vm.ref, + %source_buffer: !hal.buffer, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !vm.ref) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = vm.const.i64 100 + %source_offset = arith.constant 100 : index + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = vm.const.i64 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = vm.const.i64 300 + %length = arith.constant 300 : index + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.device.queue.copy( + // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]], + // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]], + // CHECK-SAME: %[[SOURCE_BUFFER]], %[[SOURCE_OFFSET]], + // CHECK-SAME: %[[TARGET_BUFFER]], %[[TARGET_OFFSET]], + // CHECK-SAME: %[[LENGTH]], %[[FLAGS]]) + hal.device.queue.copy<%device : !hal.device> + affinity(%affinity) + wait(%wait_fence) signal(%signal_fence) + source(%source_buffer : !hal.buffer)[%source_offset] + target(%target_buffer : !hal.buffer)[%target_offset] + length(%length) + flags(0) + util.return +} + +// ----- + // CHECK-LABEL: @device_queue_read util.func public @device_queue_read( // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 9f77e8e7f1561..7210d402598d5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -1319,6 +1319,18 @@ LogicalResult DeviceQueueDeallocaOp::verify() { return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); } +LogicalResult DeviceQueueFillOp::verify() { + return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); +} + +LogicalResult DeviceQueueUpdateOp::verify() { + return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); +} + +LogicalResult DeviceQueueCopyOp::verify() { + return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); +} + LogicalResult DeviceQueueReadOp::verify() { return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence()); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index c1d9a4ea5b561..25b9753070fd8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1851,6 +1851,138 @@ def HAL_DeviceQueueDeallocaOp : HAL_Op<"device.queue.dealloca"> { let hasVerifier = 1; } +def HAL_DeviceQueueFillOp : HAL_Op<"device.queue.fill"> { + let summary = [{fills a buffer with a repeating pattern}]; + let description = [{ + The target buffer must be visible to the device queue performing the update. + In most cases the queue affinity should be set to where the target buffer + will be consumed so that it has a chance of being cached. + + Note that individual queue transfer operations have a high overhead and they + should be batched with other operations in command buffers. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_Buffer:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length, + HAL_FillPatternType:$pattern, + I64Attr:$flags + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + `pattern` `(` $pattern `:` type($pattern) `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + +def HAL_DeviceQueueUpdateOp : HAL_Op<"device.queue.update"> { + let summary = [{updates a buffer with the contents of a host buffer}]; + let description = [{ + The provided host source buffer will be captured and need not remain live or + unchanged while the operation is queued. The target buffer must be visible + to the device queue performing the update. In most cases the queue affinity + should be set to where the target buffer will be consumed so that it has a + chance of being cached. + + Some implementations may have limits on the size of the update or may + perform poorly if the size is larger than an implementation-defined limit. + Updates should be kept as small and infrequent as possible. + + Note that individual queue transfer operations have a high overhead and they + should be batched with other operations in command buffers. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + Util_BufferType:$source_buffer, + HAL_DeviceSize:$source_offset, + HAL_Buffer:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length, + I64Attr:$flags + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `source` `(` $source_buffer `:` type($source_buffer) `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + +def HAL_DeviceQueueCopyOp : HAL_Op<"device.queue.copy"> { + let summary = [{copies one device-visible buffer to another}]; + let description = [{ + The source buffer and target buffer must both be visible to the device + queue performing the copy. In most cases the queue affinity should be set to + where the target buffer will be consumed so that it has a chance of being + cached. The source buffer must have transfer-source usage and the target + buffer must have transfer-target usage. + + Note that individual queue transfer operations have a high overhead and they + should be batched with other operations in command buffers. + }]; + + let arguments = (ins + HAL_Device:$device, + HAL_DeviceQueueAffinity:$queue_affinity, + HAL_Fence:$wait_fence, + HAL_Fence:$signal_fence, + HAL_Buffer:$source_buffer, + HAL_DeviceSize:$source_offset, + HAL_Buffer:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length, + I64Attr:$flags + ); + let results = (outs); + + let assemblyFormat = [{ + `<` $device `:` type($device) `>` + `affinity` `(` $queue_affinity `)` + `wait` `(` $wait_fence `)` + `signal` `(` $signal_fence `)` + `source` `(` $source_buffer `:` type($source_buffer) `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let hasVerifier = 1; +} + def HAL_DeviceQueueReadOp : HAL_Op<"device.queue.read"> { let summary = [{reads a segment from a file into a device buffer}]; let description = [{ diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index 93a31acbbe99b..fd590cd609403 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir @@ -67,6 +67,109 @@ util.func public @device_queue_dealloca( // ----- +// CHECK-LABEL: @device_queue_fill +util.func public @device_queue_fill( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[PATTERN_I8:.+]]: i8, + %pattern_i8: i8, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = arith.constant 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 300 + %length = arith.constant 300 : index + // CHECK: hal.device.queue.fill<%[[DEVICE]] : !hal.device> + hal.device.queue.fill<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%[[TARGET_OFFSET]]] + target(%target_buffer : !hal.buffer)[%target_offset] + // CHECK-SAME: length(%[[LENGTH]]) + length(%length) + // CHECK-SAME: pattern(%[[PATTERN_I8]] : i8) + pattern(%pattern_i8 : i8) + // CHECK-SAME: flags(0) + flags(0) + util.return +} + +// ----- + +// CHECK-LABEL: @device_queue_update +util.func public @device_queue_update( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !util.buffer, + %source_buffer: !util.buffer, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = arith.constant 100 + %source_offset = arith.constant 100 : index + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = arith.constant 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 300 + %length = arith.constant 300 : index + // CHECK: hal.device.queue.update<%[[DEVICE]] : !hal.device> + hal.device.queue.update<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !util.buffer)[%[[SOURCE_OFFSET]]] + source(%source_buffer : !util.buffer)[%source_offset] + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%[[TARGET_OFFSET]]] + target(%target_buffer : !hal.buffer)[%target_offset] + // CHECK-SAME: length(%[[LENGTH]]) + length(%length) + // CHECK-SAME: flags(0) + flags(0) + util.return +} + +// ----- + +// CHECK-LABEL: @device_queue_copy +util.func public @device_queue_copy( + // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, + %device: !hal.device, %affinity: i64, + // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence, + %wait_fence: !hal.fence, %signal_fence: !hal.fence, + // CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !hal.buffer, + %source_buffer: !hal.buffer, + // CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer) + %target_buffer: !hal.buffer) { + // CHECK-DAG: %[[SOURCE_OFFSET:.+]] = arith.constant 100 + %source_offset = arith.constant 100 : index + // CHECK-DAG: %[[TARGET_OFFSET:.+]] = arith.constant 200 + %target_offset = arith.constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = arith.constant 300 + %length = arith.constant 300 : index + // CHECK: hal.device.queue.copy<%[[DEVICE]] : !hal.device> + hal.device.queue.copy<%device : !hal.device> + // CHECK-SAME: affinity(%[[AFFINITY]]) + affinity(%affinity) + // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]]) + wait(%wait_fence) signal(%signal_fence) + // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !hal.buffer)[%[[SOURCE_OFFSET]]] + source(%source_buffer : !hal.buffer)[%source_offset] + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%[[TARGET_OFFSET]]] + target(%target_buffer : !hal.buffer)[%target_offset] + // CHECK-SAME: length(%[[LENGTH]]) + length(%length) + // CHECK-SAME: flags(0) + flags(0) + util.return +} + +// ----- + // CHECK-LABEL: @device_queue_read util.func public @device_queue_read( // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64, diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index a935732d8a6b9..13bdfb6023f50 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -228,9 +228,12 @@ vm.import private @command_buffer.execution_barrier( %command_buffer : !vm.ref, %source_stage_mask : i32, %target_stage_mask : i32, + // TODO(benvanik): make i64. %flags : i32 ) +// TODO(benvanik): add @command_buffer.advise_buffer. + // Fills the target buffer with the given repeating value. // NOTE: order slightly differs from op in order to get better arg alignment. vm.import private @command_buffer.fill_buffer( @@ -239,8 +242,10 @@ vm.import private @command_buffer.fill_buffer( %target_offset : i64, %length : i64, %target_buffer_slot : i32, + // TODO(benvanik): make pattern i64. %pattern : i32, - %pattern_length: i32 + %pattern_length : i32 + // TODO(benvanik): add %flags : i64. ) // Updates a device buffer with the captured contents of a host buffer. @@ -253,6 +258,7 @@ vm.import private @command_buffer.update_buffer( %target_offset : i64, %length : i64, %target_buffer_slot : i32 + // TODO(benvanik): add %flags : i64. ) // Copies a range of one buffer to another. @@ -266,6 +272,7 @@ vm.import private @command_buffer.copy_buffer( %target_buffer : !vm.ref, %target_offset : i64, %length : i64 + // TODO(benvanik): add %flags : i64. ) // Dispatches a collective operation defined by |op| using the given buffers. @@ -345,7 +352,9 @@ vm.import private @device.queue.alloca( %pool : i32, %memory_types : i32, %buffer_usage : i32, + // TODO(benvanik): add %reserved : i32 (padding). %allocation_size : i64 + // TODO(benvanik): add %flags : i64. ) -> !vm.ref // Deallocates a queue-ordered transient buffer. @@ -356,9 +365,52 @@ vm.import private @device.queue.dealloca( %queue_affinity : i64, %wait_fence : !vm.ref, %signal_fence : !vm.ref, + // TODO(benvanik): add %flags : i64. %buffer : !vm.ref ) +// Enqueues a single queue-ordered fill operation. +vm.import private @device.queue.fill( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %target_buffer : !vm.ref, + %target_offset : i64, + %length : i64, + %pattern : i64, + %pattern_length: i32, + %flags : i64 +) + +// Enqueues a single queue-ordered buffer update operation. +vm.import private @device.queue.update( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %source_buffer : !vm.buffer, + %source_offset : i64, + %target_buffer : !vm.ref, + %target_offset : i64, + %length : i64, + %flags : i64 +) + +// Enqueues a single queue-ordered buffer copy operation. +vm.import private @device.queue.copy( + %device : !vm.ref, + %queue_affinity : i64, + %wait_fence : !vm.ref, + %signal_fence : !vm.ref, + %source_buffer : !vm.ref, + %source_offset : i64, + %target_buffer : !vm.ref, + %target_offset : i64, + %length : i64, + %flags : i64 +) + // Reads a segment from a file into a device buffer. vm.import private @device.queue.read( %device : !vm.ref, @@ -370,6 +422,7 @@ vm.import private @device.queue.read( %target_buffer : !vm.ref, %target_offset : i64, %length : i64, + // TODO(benvanik): make i64. %flags : i32 ) @@ -384,6 +437,7 @@ vm.import private @device.queue.write( %target_file : !vm.ref, %target_offset : i64, %length : i64, + // TODO(benvanik): make i64. %flags : i32 ) @@ -396,7 +450,9 @@ vm.import private @device.queue.execute( %queue_affinity : i64, %wait_fence : !vm.ref, %signal_fence : !vm.ref, + // TODO(benvanik): make a single command buffer. %command_buffers : !vm.ref... + // TODO(benvanik): add %flags : i64. ) // Executes a command buffer on a device queue with the given binding table. @@ -410,6 +466,7 @@ vm.import private @device.queue.execute.indirect( %command_buffer : !vm.ref, // %binding_table : tuple, i64, i64>... + // TODO(benvanik): add %flags : i64. ) // Flushes any locally-pending submissions in the queue. @@ -437,6 +494,7 @@ attributes {nosideeffects} // Creates an executable for use with the specified device. vm.import private @executable.create( %device : !vm.ref, + // TODO(benvanik): add %queue_affinity : i64. %executable_format : !vm.buffer, %executable_data : !vm.buffer, %constants : !vm.buffer @@ -452,11 +510,13 @@ attributes { // Returns an unsignaled fence that defines a point in time. vm.import private @fence.create( %device : !vm.ref, + // TODO(benvanik): make i64. %flags : i32 ) -> !vm.ref // Returns a fence that joins the input fences as a wait-all operation. vm.import private @fence.join( + // TODO(benvanik): add %flags : i64? May be worth it to control access. %fences : !vm.ref ... ) -> !vm.ref attributes {nosideeffects} @@ -480,7 +540,7 @@ vm.import private @fence.fail( %status : i32 ) -// Yields the caller until all fences is reached. +// Yields the caller until all fences are reached. vm.import private @fence.await( %timeout_millis : i32, %fences : !vm.ref ... diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index 46e7f17358004..6c7ba56766ecd 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -61,11 +61,14 @@ EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_ EXPORT_FN("device.allocator", iree_hal_module_device_allocator, r, r) EXPORT_FN("device.query.i64", iree_hal_module_device_query_i64, rrr, iI) EXPORT_FN("device.queue.alloca", iree_hal_module_device_queue_alloca, rIrriiiI, r) +EXPORT_FN("device.queue.copy", iree_hal_module_device_queue_copy, rIrrrIrIII, v) EXPORT_FN("device.queue.dealloca", iree_hal_module_device_queue_dealloca, rIrrr, v) EXPORT_FN("device.queue.execute", iree_hal_module_device_queue_execute, rIrrCrD, v) EXPORT_FN("device.queue.execute.indirect", iree_hal_module_device_queue_execute_indirect, rIrrrCrIID, v) +EXPORT_FN("device.queue.fill", iree_hal_module_device_queue_fill, rIrrrIIIiI, v) EXPORT_FN("device.queue.flush", iree_hal_module_device_queue_flush, rI, v) EXPORT_FN("device.queue.read", iree_hal_module_device_queue_read, rIrrrIrIIi, v) +EXPORT_FN("device.queue.update", iree_hal_module_device_queue_update, rIrrrIrIII, v) EXPORT_FN("device.queue.write", iree_hal_module_device_queue_write, rIrrrIrIIi, v) EXPORT_FN("devices.count", iree_hal_module_devices_count, v, i) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 7bece016674db..3baa27dd91e43 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -1196,6 +1196,77 @@ IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_dealloca, // iree_hal_fence_semaphore_list(signal_fence), buffer); } +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_fill, // + iree_hal_module_state_t, // + rIrrrIIIiI, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i5); + iree_device_size_t length = iree_hal_cast_device_size(args->i6); + uint64_t pattern = args->i7; + iree_host_size_t pattern_length = iree_hal_cast_host_size(args->i8); + iree_hal_fill_flags_t flags = (iree_hal_fill_flags_t)args->i9; + return iree_hal_device_queue_fill( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), target_buffer, target_offset, + length, &pattern, pattern_length, flags); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_update, // + iree_hal_module_state_t, // + rIrrrIrIII, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_vm_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r4, &source_buffer)); + iree_host_size_t source_offset = iree_hal_cast_host_size(args->i5); + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r6, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i7); + iree_device_size_t length = iree_hal_cast_device_size(args->i8); + iree_hal_copy_flags_t flags = (iree_hal_copy_flags_t)args->i9; + iree_const_byte_span_t source_span = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(source_buffer, source_offset, + length, 1, &source_span)); + return iree_hal_device_queue_update( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), source_span.data, 0, + target_buffer, target_offset, length, flags); +} + +IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_copy, // + iree_hal_module_state_t, // + rIrrrIrIII, v) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i1; + iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2); + iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3); + iree_hal_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r4, &source_buffer)); + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i5); + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r6, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i7); + iree_device_size_t length = iree_hal_cast_device_size(args->i8); + iree_hal_copy_flags_t flags = (iree_hal_copy_flags_t)args->i9; + return iree_hal_device_queue_copy( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), source_buffer, source_offset, + target_buffer, target_offset, length, flags); +} + IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_read, // iree_hal_module_state_t, // rIrrrIrIIi, v) { diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 2509ffa7dc121..48d24d5d01277 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -70,6 +70,8 @@ IREE_VM_ABI_DEFINE_SHIM(riirIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrIii, v); IREE_VM_ABI_DEFINE_SHIM(rrrIii, v); IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DEFINE_SHIM(rIrrrIIIiI, v); +IREE_VM_ABI_DEFINE_SHIM(rIrrrIrIII, v); IREE_VM_ABI_DEFINE_SHIM(rIrrrIrIIi, v); IREE_VM_ABI_DEFINE_SHIM(rIrrrrrrr, v); IREE_VM_ABI_DEFINE_SHIM(rIrrrIiirrr, r); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index cd14a4645589b..e1da0a70d5edd 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -458,6 +458,32 @@ IREE_VM_ABI_FIXED_STRUCT(rIrriiiI, { int64_t i7; }); +IREE_VM_ABI_FIXED_STRUCT(rIrrrIIIiI, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + iree_vm_ref_t r4; + int64_t i5; + int64_t i6; + int64_t i7; + int32_t i8; + int64_t i9; +}); + +IREE_VM_ABI_FIXED_STRUCT(rIrrrIrIII, { + iree_vm_ref_t r0; + int64_t i1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; + iree_vm_ref_t r4; + int64_t i5; + iree_vm_ref_t r6; + int64_t i7; + int64_t i8; + int64_t i9; +}); + IREE_VM_ABI_FIXED_STRUCT(rIrrrIrIIi, { iree_vm_ref_t r0; int64_t i1; @@ -728,6 +754,8 @@ IREE_VM_ABI_DECLARE_SHIM(riirIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrIii, v); IREE_VM_ABI_DECLARE_SHIM(rrrIii, v); IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r); +IREE_VM_ABI_DECLARE_SHIM(rIrrrIIIiI, v); +IREE_VM_ABI_DECLARE_SHIM(rIrrrIrIII, v); IREE_VM_ABI_DECLARE_SHIM(rIrrrIrIIi, v); IREE_VM_ABI_DECLARE_SHIM(rIrrrrrrr, v); IREE_VM_ABI_DECLARE_SHIM(rIrrrIiirrr, r);