From adc2b8e527ad17e436039c15cd4d97f4744161fd Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sun, 28 Apr 2024 16:54:31 +0800 Subject: [PATCH] [Torch] emit aten.__contains__.str_list and add folder --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 31 ++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++++++++++ .../build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 40 +++++++++++++++---- 5 files changed, 113 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 281637f155e9..13ac111c35f6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13579,6 +13579,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ }]; } +def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchStringType:$l, + Torch_StringType:$item + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 4508518bf297..f49fef0721c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } +namespace detail { +/// Matches the constant strs stored in a `torch.ListConstruct`. +struct torch_list_of_constant_strs_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_strs_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + std::string str; + if (matchPattern(value, m_TorchConstantStr(str))) + bind_values.push_back(str); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant strs stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_strs_op_binder +m_TorchListOfConstantStrs(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_strs_op_binder(bind_values); +} + namespace detail { /// Matches the expected tensor and dim from `torch.aten.size.int`. struct torch_tensor_size_int_op_binder { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 33079e35fda1..376e7dd2e584 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2385,6 +2385,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// Aten__Contains__StrListOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { + StringAttr item = dyn_cast(adaptor.getItem()); + if (!item) + return nullptr; + + if (auto listConstruct = getL().getDefiningOp()) { + if (isListPotentiallyMutated(listConstruct)) + return nullptr; + } + llvm::SmallVector strs; + if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) { + for (const auto &str : strs) { + if (item.getValue().str() == str) + return getI1IntegerAttr(getContext(), true); + } + return getI1IntegerAttr(getContext(), false); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3ebd007535e5..39a4480fdf4d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -973,6 +973,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") emit("aten::warn : (str, int) -> ()") + emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True) # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 7fd4e9832394..a317e4011b3e 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[TRUE]] : !torch.bool func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool @@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool true -// CHECK: return %[[FALSE]] : !torch.bool +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool func.func @torch.aten.ne.str$different_value() -> !torch.bool { %str4 = torch.constant.str "4" %str5 = torch.constant.str "5" @@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.ne.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool } // CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool false -// CHECK: return %[[TRUE]] : !torch.bool +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_value() -> !torch.bool { %str4 = torch.constant.str "4" %str4_0 = torch.constant.str "4" @@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { + %str = torch.constant.str "c" + %str_0 = torch.constant.str "b" + %str_1 = torch.constant.str "a" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { + %str = torch.constant.str "aa" + %str_0 = torch.constant.str "aa" + %str_1 = torch.constant.str "ccc" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool