From 95ffa8a2ac1154d968c0af2d428ade541eddd828 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 19 Apr 2024 10:52:31 +0200 Subject: [mlir][emitc] Restrict types in EmitC (#88391) Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 3 ++ mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 52 ++++++++++++------------ mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td | 3 ++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 38 +++++++++++++++++ mlir/test/Dialect/EmitC/invalid_types.mlir | 48 ++++++++++++++++++++++ 5 files changed, 118 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index c039156..5d9531c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -31,6 +31,9 @@ namespace mlir { namespace emitc { void buildTerminatedBody(OpBuilder &builder, Location loc); +/// Determines whether \p type is valid in EmitC. +bool isSupportedEmitCType(mlir::Type type); + /// Determines whether \p type is a valid integer type in EmitC. bool isSupportedIntegerType(mlir::Type type); diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index e611fd2..c1a1e77 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -34,16 +34,16 @@ class EmitC_Op traits = []> // Base class for unary operations. class EmitC_UnaryOp traits = []> : EmitC_Op { - let arguments = (ins AnyType); - let results = (outs AnyType); + let arguments = (ins EmitCType); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } // Base class for binary operations. class EmitC_BinaryOp traits = []> : EmitC_Op { - let arguments = (ins AnyType:$lhs, AnyType:$rhs); - let results = (outs AnyType); + let arguments = (ins EmitCType:$lhs, EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { }]; let arguments = (ins Arg:$applicableOperator, - AnyType:$operand + EmitCType:$operand ); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; @@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { Arg:$callee, Arg, "the order of operands and further attributes">:$args, Arg, "template arguments">:$template_args, - Variadic:$operands + Variadic:$operands ); - let results = (outs Variadic); + let results = (outs Variadic); let builders = [ OpBuilder<(ins "::mlir::TypeRange":$resultTypes, @@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast", ``` }]; - let arguments = (ins AnyType:$source); - let results = (outs AnyType:$dest); + let arguments = (ins EmitCType:$source); + let results = (outs EmitCType:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } @@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { }]; let arguments = (ins EmitC_CmpPredicateAttr:$predicate, - AnyType:$lhs, - AnyType:$rhs); - let results = (outs AnyType); + EmitCType:$lhs, + EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)"; } @@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasFolder = 1; let hasVerifier = 1; @@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", }]; let arguments = (ins UnitAttr:$do_not_inline); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let regions = (region SizedRegion<1>:$region); let hasVerifier = 1; @@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call", %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); let builders = [ OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ @@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, } ``` }]; - let arguments = (ins Optional:$operand); + let arguments = (ins Optional:$operand); let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; let hasVerifier = 1; @@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { }]; let arguments = (ins StrAttr:$value); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let hasVerifier = 1; let assemblyFormat = "$value attr-dict `:` type($result)"; @@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional", int32_t v6 = v3 ? v4 : v5; ``` }]; - let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value); - let results = (outs AnyType:$result); + let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value); + let results = (outs EmitCType:$result); let assemblyFormat = "operands attr-dict `:` type($result)"; } @@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasVerifier = 1; } @@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { ``` }]; - let arguments = (ins AnyType:$var, AnyType:$value); + let arguments = (ins EmitCType:$var, EmitCType:$value); let results = (outs); let hasVerifier = 1; @@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield", value is yielded. }]; - let arguments = (ins Optional:$result); + let arguments = (ins Optional:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let hasVerifier = 1; @@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { EmitC_OpaqueType, EmitC_PointerType]>, "the value to subscript">:$value, - Variadic:$indices); - let results = (outs AnyType:$result); + Variadic:$indices); + let results = (outs EmitCType:$result); let builders = [ OpBuilder<(ins "TypedValue":$array, "ValueRange":$indices), [{ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index bce5807..444395b 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td" // EmitC type definitions //===----------------------------------------------------------------------===// +def EmitCType : Type, + "type supported by EmitC">; + def EmitCIntegerType : Type, "integer type supported by EmitC">; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 7cbf28b..b037ef3 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -10,11 +10,15 @@ #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Types.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::emitc; @@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create(loc); } +bool mlir::emitc::isSupportedEmitCType(Type type) { + if (llvm::isa(type)) + return true; + if (auto ptrType = llvm::dyn_cast(type)) + return isSupportedEmitCType(ptrType.getPointee()); + if (auto arrayType = llvm::dyn_cast(type)) { + auto elemType = arrayType.getElementType(); + return !llvm::isa(elemType) && + isSupportedEmitCType(elemType); + } + if (type.isIndex()) + return true; + if (llvm::isa(type)) + return isSupportedIntegerType(type); + if (llvm::isa(type)) + return isSupportedFloatType(type); + if (auto tensorType = llvm::dyn_cast(type)) { + if (!tensorType.hasStaticShape()) { + return false; + } + auto elemType = tensorType.getElementType(); + if (llvm::isa(elemType)) { + return false; + } + return isSupportedEmitCType(elemType); + } + if (auto tupleType = llvm::dyn_cast(type)) { + return llvm::all_of(tupleType.getTypes(), [](Type type) { + return !llvm::isa(type) && isSupportedEmitCType(type); + }); + } + return false; +} + bool mlir::emitc::isSupportedIntegerType(Type type) { if (auto intType = llvm::dyn_cast(type)) { switch (intType.getWidth()) { diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index f9d517b..0ad8d4e 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) { %mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80 return } + +// ----- + +func.func @illegal_pointee_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr + return +} + +// ----- + +func.func @illegal_non_static_tensor_shape_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor + return +} + +// ----- + +func.func @illegal_tensor_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor> + return +} + +// ----- + +func.func @illegal_tensor_integer_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11> + return +} + +// ----- + +func.func @illegal_tuple_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple, f32>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple, f32> + return +} + +// ----- + +func.func @illegal_tuple_float_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple + return +} -- cgit v1.1