diff options
author | Tina Jung <tinamaria.jung@amd.com> | 2024-04-19 10:52:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-19 10:52:31 +0200 |
commit | 95ffa8a2ac1154d968c0af2d428ade541eddd828 (patch) | |
tree | 924e326dd931c16bde74ba1bb8264cecb2b373e4 | |
parent | e2a72fa583d9ccec7e996e15ea86f0ceddbfe63c (diff) | |
download | llvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.zip llvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.tar.gz llvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.tar.bz2 |
[mlir][emitc] Restrict types in EmitC (#88391)
Restrict the types which are valid for EmitC operations. Use what is
currently supported by the emitter as restriction. Define a utility
functions for valid types, such that they can be used to restrict the
operations in table gen as well as being available for reuse in dialect
conversions.
-rw-r--r-- | mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 3 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 52 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/EmitC/invalid_types.mlir | 48 |
5 files changed, 118 insertions, 26 deletions
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index c039156..5d9531c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -31,6 +31,9 @@ namespace mlir { namespace emitc { void buildTerminatedBody(OpBuilder &builder, Location loc); +/// Determines whether \p type is valid in EmitC. +bool isSupportedEmitCType(mlir::Type type); + /// Determines whether \p type is a valid integer type in EmitC. bool isSupportedIntegerType(mlir::Type type); diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index e611fd2..c1a1e77 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []> // Base class for unary operations. class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> : EmitC_Op<mnemonic, traits> { - let arguments = (ins AnyType); - let results = (outs AnyType); + let arguments = (ins EmitCType); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } // Base class for binary operations. class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> : EmitC_Op<mnemonic, traits> { - let arguments = (ins AnyType:$lhs, AnyType:$rhs); - let results = (outs AnyType); + let arguments = (ins EmitCType:$lhs, EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> { }]; let arguments = (ins Arg<StrAttr, "the operator to apply">:$applicableOperator, - AnyType:$operand + EmitCType:$operand ); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let assemblyFormat = [{ $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results) }]; @@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { Arg<StrAttr, "the C++ function to call">:$callee, Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args, Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args, - Variadic<AnyType>:$operands + Variadic<EmitCType>:$operands ); - let results = (outs Variadic<AnyType>); + let results = (outs Variadic<EmitCType>); let builders = [ OpBuilder<(ins "::mlir::TypeRange":$resultTypes, @@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast", ``` }]; - let arguments = (ins AnyType:$source); - let results = (outs AnyType:$dest); + let arguments = (ins EmitCType:$source); + let results = (outs EmitCType:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; } @@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { }]; let arguments = (ins EmitC_CmpPredicateAttr:$predicate, - AnyType:$lhs, - AnyType:$rhs); - let results = (outs AnyType); + EmitCType:$lhs, + EmitCType:$rhs); + let results = (outs EmitCType); let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)"; } @@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasFolder = 1; let hasVerifier = 1; @@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression", }]; let arguments = (ins UnitAttr:$do_not_inline); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let regions = (region SizedRegion<1>:$region); let hasVerifier = 1; @@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call", %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands); - let results = (outs Variadic<AnyType>); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands); + let results = (outs Variadic<EmitCType>); let builders = [ OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ @@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, } ``` }]; - let arguments = (ins Optional<AnyType>:$operand); + let arguments = (ins Optional<EmitCType>:$operand); let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?"; let hasVerifier = 1; @@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { }]; let arguments = (ins StrAttr:$value); - let results = (outs AnyType:$result); + let results = (outs EmitCType:$result); let hasVerifier = 1; let assemblyFormat = "$value attr-dict `:` type($result)"; @@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional", int32_t v6 = v3 ? v4 : v5; ``` }]; - let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value); - let results = (outs AnyType:$result); + let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value); + let results = (outs EmitCType:$result); let assemblyFormat = "operands attr-dict `:` type($result)"; } @@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> { }]; let arguments = (ins EmitC_OpaqueOrTypedAttr:$value); - let results = (outs AnyType); + let results = (outs EmitCType); let hasVerifier = 1; } @@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> { ``` }]; - let arguments = (ins AnyType:$var, AnyType:$value); + let arguments = (ins EmitCType:$var, EmitCType:$value); let results = (outs); let hasVerifier = 1; @@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield", value is yielded. }]; - let arguments = (ins Optional<AnyType>:$result); + let arguments = (ins Optional<EmitCType>:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let hasVerifier = 1; @@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { EmitC_OpaqueType, EmitC_PointerType]>, "the value to subscript">:$value, - Variadic<AnyType>:$indices); - let results = (outs AnyType:$result); + Variadic<EmitCType>:$indices); + let results = (outs EmitCType:$result); let builders = [ OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index bce5807..444395b 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td" // EmitC type definitions //===----------------------------------------------------------------------===// +def EmitCType : Type<CPred<"emitc::isSupportedEmitCType($_self)">, + "type supported by EmitC">; + def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">, "integer type supported by EmitC">; diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 7cbf28b..b037ef3 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -10,11 +10,15 @@ #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Types.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::emitc; @@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create<emitc::YieldOp>(loc); } +bool mlir::emitc::isSupportedEmitCType(Type type) { + if (llvm::isa<emitc::OpaqueType>(type)) + return true; + if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type)) + return isSupportedEmitCType(ptrType.getPointee()); + if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) { + auto elemType = arrayType.getElementType(); + return !llvm::isa<emitc::ArrayType>(elemType) && + isSupportedEmitCType(elemType); + } + if (type.isIndex()) + return true; + if (llvm::isa<IntegerType>(type)) + return isSupportedIntegerType(type); + if (llvm::isa<FloatType>(type)) + return isSupportedFloatType(type); + if (auto tensorType = llvm::dyn_cast<TensorType>(type)) { + if (!tensorType.hasStaticShape()) { + return false; + } + auto elemType = tensorType.getElementType(); + if (llvm::isa<emitc::ArrayType>(elemType)) { + return false; + } + return isSupportedEmitCType(elemType); + } + if (auto tupleType = llvm::dyn_cast<TupleType>(type)) { + return llvm::all_of(tupleType.getTypes(), [](Type type) { + return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type); + }); + } + return false; +} + bool mlir::emitc::isSupportedIntegerType(Type type) { if (auto intType = llvm::dyn_cast<IntegerType>(type)) { switch (intType.getWidth()) { diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index f9d517b..0ad8d4e 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) { %mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80 return } + +// ----- + +func.func @illegal_pointee_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11> + return +} + +// ----- + +func.func @illegal_non_static_tensor_shape_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<?xf32>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<?xf32> + return +} + +// ----- + +func.func @illegal_tensor_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<!emitc.array<9xi16>>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<!emitc.array<9xi16>> + return +} + +// ----- + +func.func @illegal_tensor_integer_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11> + return +} + +// ----- + +func.func @illegal_tuple_array_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<!emitc.array<9xf32>, f32>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<!emitc.array<9xf32>, f32> + return +} + +// ----- + +func.func @illegal_tuple_float_element_type() { + // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<i32, f80>'}} + %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<i32, f80> + return +} |