aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTina Jung <tinamaria.jung@amd.com>2024-04-19 10:52:31 +0200
committerGitHub <noreply@github.com>2024-04-19 10:52:31 +0200
commit95ffa8a2ac1154d968c0af2d428ade541eddd828 (patch)
tree924e326dd931c16bde74ba1bb8264cecb2b373e4
parente2a72fa583d9ccec7e996e15ea86f0ceddbfe63c (diff)
downloadllvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.zip
llvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.tar.gz
llvm-95ffa8a2ac1154d968c0af2d428ade541eddd828.tar.bz2
[mlir][emitc] Restrict types in EmitC (#88391)
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.h3
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.td52
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td3
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp38
-rw-r--r--mlir/test/Dialect/EmitC/invalid_types.mlir48
5 files changed, 118 insertions, 26 deletions
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index c039156..5d9531c 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -31,6 +31,9 @@ namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+/// Determines whether \p type is valid in EmitC.
+bool isSupportedEmitCType(mlir::Type type);
+
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index e611fd2..c1a1e77 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
// Base class for unary operations.
class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
- let arguments = (ins AnyType);
- let results = (outs AnyType);
+ let arguments = (ins EmitCType);
+ let results = (outs EmitCType);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
// Base class for binary operations.
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
- let arguments = (ins AnyType:$lhs, AnyType:$rhs);
- let results = (outs AnyType);
+ let arguments = (ins EmitCType:$lhs, EmitCType:$rhs);
+ let results = (outs EmitCType);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
@@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
}];
let arguments = (ins
Arg<StrAttr, "the operator to apply">:$applicableOperator,
- AnyType:$operand
+ EmitCType:$operand
);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let assemblyFormat = [{
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
}];
@@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
Arg<StrAttr, "the C++ function to call">:$callee,
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
- Variadic<AnyType>:$operands
+ Variadic<EmitCType>:$operands
);
- let results = (outs Variadic<AnyType>);
+ let results = (outs Variadic<EmitCType>);
let builders = [
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
@@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast",
```
}];
- let arguments = (ins AnyType:$source);
- let results = (outs AnyType:$dest);
+ let arguments = (ins EmitCType:$source);
+ let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}
@@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
}];
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
- AnyType:$lhs,
- AnyType:$rhs);
- let results = (outs AnyType);
+ EmitCType:$lhs,
+ EmitCType:$rhs);
+ let results = (outs EmitCType);
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
}
@@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs AnyType);
+ let results = (outs EmitCType);
let hasFolder = 1;
let hasVerifier = 1;
@@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
}];
let arguments = (ins UnitAttr:$do_not_inline);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
@@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call",
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
```
}];
- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
- let results = (outs Variadic<AnyType>);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
+ let results = (outs Variadic<EmitCType>);
let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
@@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
}
```
}];
- let arguments = (ins Optional<AnyType>:$operand);
+ let arguments = (ins Optional<EmitCType>:$operand);
let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
let hasVerifier = 1;
@@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
}];
let arguments = (ins StrAttr:$value);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let hasVerifier = 1;
let assemblyFormat = "$value attr-dict `:` type($result)";
@@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
int32_t v6 = v3 ? v4 : v5;
```
}];
- let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
- let results = (outs AnyType:$result);
+ let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value);
+ let results = (outs EmitCType:$result);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
@@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs AnyType);
+ let results = (outs EmitCType);
let hasVerifier = 1;
}
@@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
```
}];
- let arguments = (ins AnyType:$var, AnyType:$value);
+ let arguments = (ins EmitCType:$var, EmitCType:$value);
let results = (outs);
let hasVerifier = 1;
@@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
value is yielded.
}];
- let arguments = (ins Optional<AnyType>:$result);
+ let arguments = (ins Optional<EmitCType>:$result);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let hasVerifier = 1;
@@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
EmitC_OpaqueType,
EmitC_PointerType]>,
"the value to subscript">:$value,
- Variadic<AnyType>:$indices);
- let results = (outs AnyType:$result);
+ Variadic<EmitCType>:$indices);
+ let results = (outs EmitCType:$result);
let builders = [
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index bce5807..444395b 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//
+def EmitCType : Type<CPred<"emitc::isSupportedEmitCType($_self)">,
+ "type supported by EmitC">;
+
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
"integer type supported by EmitC">;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 7cbf28b..b037ef3 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -10,11 +10,15 @@
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Types.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::emitc;
@@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}
+bool mlir::emitc::isSupportedEmitCType(Type type) {
+ if (llvm::isa<emitc::OpaqueType>(type))
+ return true;
+ if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
+ return isSupportedEmitCType(ptrType.getPointee());
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
+ auto elemType = arrayType.getElementType();
+ return !llvm::isa<emitc::ArrayType>(elemType) &&
+ isSupportedEmitCType(elemType);
+ }
+ if (type.isIndex())
+ return true;
+ if (llvm::isa<IntegerType>(type))
+ return isSupportedIntegerType(type);
+ if (llvm::isa<FloatType>(type))
+ return isSupportedFloatType(type);
+ if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
+ if (!tensorType.hasStaticShape()) {
+ return false;
+ }
+ auto elemType = tensorType.getElementType();
+ if (llvm::isa<emitc::ArrayType>(elemType)) {
+ return false;
+ }
+ return isSupportedEmitCType(elemType);
+ }
+ if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
+ return llvm::all_of(tupleType.getTypes(), [](Type type) {
+ return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
+ });
+ }
+ return false;
+}
+
bool mlir::emitc::isSupportedIntegerType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
switch (intType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index f9d517b..0ad8d4e 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
return
}
+
+// -----
+
+func.func @illegal_pointee_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11>
+ return
+}
+
+// -----
+
+func.func @illegal_non_static_tensor_shape_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<?xf32>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @illegal_tensor_array_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<!emitc.array<9xi16>>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<!emitc.array<9xi16>>
+ return
+}
+
+// -----
+
+func.func @illegal_tensor_integer_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11>
+ return
+}
+
+// -----
+
+func.func @illegal_tuple_array_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<!emitc.array<9xf32>, f32>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<!emitc.array<9xf32>, f32>
+ return
+}
+
+// -----
+
+func.func @illegal_tuple_float_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<i32, f80>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<i32, f80>
+ return
+}