Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plumb i1 datatype through the compilation pipeline #18483

Closed
lialan opened this issue Sep 10, 2024 · 7 comments · Fixed by #18713
Closed

Plumb i1 datatype through the compilation pipeline #18483

lialan opened this issue Sep 10, 2024 · 7 comments · Fixed by #18713
Assignees

Comments

@lialan
Copy link
Contributor

lialan commented Sep 10, 2024

Targeting a very simplistic function:

func.func @add_tensors(%arg0: tensor<2x32x32x32xi1>, %arg1: tensor<2x32x32x32xi1>) -> tensor<2x32x32x32xi1> {
  %result = arith.xori %arg0, %arg1 : tensor<2x32x32x32xi1>
  return %result : tensor<2x32x32x32xi1>
}

compile with:

iree-compile add.mlir --iree-hal-target-backends=llvm-cpu

Don't forget this patch:

diff --git a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
index a7fa1c3f7b..79445148e5 100644
--- a/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
+++ b/compiler/src/iree/compiler/Utils/ElementPackingUtils.cpp
@@ -20,7 +20,7 @@ bool needToPackSubByteElementBitWidth(unsigned bitWidth) {
   // trickiness and weirdness of packing and cross-byte access.
   // Also disallow boolean values for now--they may require separate interface
   // choices.
-  return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth) && bitWidth != 1;
+  return bitWidth < 8 && llvm::isPowerOf2_32(bitWidth);
 }
 
 bool needToPackSubByteElements(RankedTensorType shapedType) {
@lialan lialan self-assigned this Sep 10, 2024
@lialan
Copy link
Contributor Author

lialan commented Sep 11, 2024

LLVMCPUTileAndFusePass needs to change: it splits any i1 tensor loads into 4xi1, (specifically, --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-and-fuse{tiling-level=1}))"). This tile size was introduced in LLVMCPUSelectLoweringStrategyPass so just to update the lowering config is fine.

which looks okay but ConvertVectorLoad cannot handle sub-byte sizes.

@lialan
Copy link
Contributor Author

lialan commented Sep 11, 2024

Hacked a little bit the tile config, the pipeline can now emit bytecode without issues. Checked narrow byte emulation:
before:

  %8 = vector.load %1[%7] : memref<65536xi1>, vector<8xi1>
  %9 = arith.xori %6, %8 : vector<8xi1>

after:

  %9 = vector.load %1[%8] : memref<8192xi8>, vector<1xi8>
  %10 = vector.bitcast %9 : vector<1xi8> to vector<8xi1>
  %11 = arith.xori %7, %10 : vector<8xi1>

seems okay.

@lialan
Copy link
Contributor Author

lialan commented Sep 11, 2024

IR Dump after EmulateNarrowTypePass for the abovementioned simplistic input:

// -----// IR Dump After EmulateNarrowTypePass (iree-codegen-emulate-narrow-type) //----- //
func.func @add_tensors_dispatch_0_elementwise_65536_i1() attributes {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} {
  %c8 = arith.constant 8 : index
  %c4096 = arith.constant 4096 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<8192xi8>
  memref.assume_alignment %0, 64 : memref<8192xi8>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<8192xi8>
  memref.assume_alignment %1, 64 : memref<8192xi8>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<8192xi8>
  memref.assume_alignment %2, 64 : memref<8192xi8>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  cf.br ^bb1(%c0 : index)
^bb1(%3: index):  // 2 preds: ^bb0, ^bb2
  %4 = arith.cmpi slt, %3, %c4096 : index
  cf.cond_br %4, ^bb2, ^bb3
^bb2:  // pred: ^bb1
  %5 = affine.apply affine_map<()[s0, s1] -> (s1 * 512 + s0 floordiv 8)>()[%3, %workgroup_id_x]
  %6 = vector.load %0[%5] : memref<8192xi8>, vector<1xi8>
  %7 = vector.bitcast %6 : vector<1xi8> to vector<8xi1>
  %8 = affine.apply affine_map<()[s0, s1] -> (s1 * 512 + s0 floordiv 8)>()[%3, %workgroup_id_x]
  %9 = vector.load %1[%8] : memref<8192xi8>, vector<1xi8>
  %10 = vector.bitcast %9 : vector<1xi8> to vector<8xi1>
  %11 = arith.xori %7, %10 : vector<8xi1>
  %12 = affine.apply affine_map<()[s0, s1] -> (s1 * 512 + s0 floordiv 8)>()[%3, %workgroup_id_x]
  %13 = vector.bitcast %11 : vector<8xi1> to vector<1xi8>
  vector.store %13, %2[%12] : memref<8192xi8>, vector<1xi8>
  %14 = arith.addi %3, %c8 : index
  cf.br ^bb1(%14 : index)
^bb3:  // pred: ^bb1
  return
}

@lialan
Copy link
Contributor Author

lialan commented Sep 12, 2024

stream dialect requires that all tensors tohave integral number of total bytes:

error: 'stream.tensor.sizeof' op failed to calculate total byte count: 'tensor<2x3xi1>' does not have integral number of total bytes
func.func @dot_matmul_i1_i16_i32(%arg0: tensor<2x3xi1>,

Considering that before we lower to stream dialect, we ensure the size of total bytes in the case of i1, i2 and i4 or other sub byte data types. If it does not meet the criteria we will pad the tensor.

@benvanik just want to check with you if this is sound, and perhaps any remarks on where to make it happen?

@benvanik
Copy link
Collaborator

We need an encoding attribute that preserves the existing behavior (so that we can interop with frontends that expect i1 to be stored unpacked in i8). For all other cases we should be able to remove the special case bitWidth != 1 from needToPackSubByteElementBitWidth. calculateStorageElementCountInBytes takes the tensor type and the encoding can be checked on that. We have an encoding dialect and can add a new attribute in there #encoding.element_padding<8> (or something) that indicates what an element type should be padded to in bits and that way we can support other arbitrary formats. Ideally we end up with an attr interface that does it all so we can make up whatever we want for composite types but to start I suspect we just need element bit width padding (then i3 to i8 or i9 to i16 or whatever can be controlled).

@stellaraccident
Copy link
Collaborator

We could also change the frontends to lower to i8+casts for 8bit bools. I think this was just an error from the dawn of time to alias these things.

@benvanik
Copy link
Collaborator

Agreed :)
My concern is metadata checks in the various frontends/parameter formats/etc that aren't clean. If we could ensure consistency there and make it a hard rule I'd feel better.

@lialan lialan linked a pull request Oct 7, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants