aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Gehre <matthias.gehre@amd.com>2024-03-26 22:27:11 +0100
committerGitHub <noreply@github.com>2024-03-26 22:27:11 +0100
commitc6d419c15bf836085392212b8ab7600f7402829b (patch)
treeabc6fad6a67b59d6b6aac7f9fb2c402f46d40826
parenta22bd00ce03a77a38be2911979a4c4f2ca01379d (diff)
downloadllvm-c6d419c15bf836085392212b8ab7600f7402829b.zip
llvm-c6d419c15bf836085392212b8ab7600f7402829b.tar.gz
llvm-c6d419c15bf836085392212b8ab7600f7402829b.tar.bz2
[TOSA] Allow all integer types in most ops (#86509)
As discussed in one of the previous TOSA community meetings, we would like to allow for more integer types in the TOSA dialect to enable more use cases. For strict standards conformance, the TosaValidation pass can be used. Follow up PRs will extend conversions from TOSA where needed.
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td2
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td27
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp51
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir16
4 files changed, 71 insertions, 25 deletions
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ecded7..306e4a4 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1942,7 +1942,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64]>]>:$output
);
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5a4d6ff..cff3de0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -38,29 +38,17 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
// Used to express accumulator results or compare results.
//===----------------------------------------------------------------------===//
-def Tosa_UInt8 : UI<8>;
-def Tosa_UInt16 : UI<16>;
-
def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
-def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
-def Tosa_Int48 : I<48>;
def Tosa_Int64 : I<64>;
-def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
- Tosa_Int16,
- Tosa_Int32,
- Tosa_Int48,
- Tosa_Int64]>;
-
-def Tosa_Bool : I<1>;
-
-// No unsigned unquantized int types.
-def Tosa_Int : AnyTypeOf<[Tosa_Bool,
- Tosa_UInt8,
- Tosa_UInt16,
- Tosa_SignedInt]>;
+// The TOSA dialect allows more types than the TOSA standard to allow for
+// experimentation. For historical reasons, signless is used in the place of
+// signed.
+// The TosaValidation pass can be used to check for standard conformance.
+def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
+ AnySignlessInteger]>;
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
@@ -172,9 +160,6 @@ class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
-def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
-def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
-def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 9677752..74ef638 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -410,6 +410,8 @@ private:
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
+ bool isValidElementType(Type type);
+
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
@@ -503,15 +505,58 @@ LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
return success();
}
+bool TosaValidation::isValidElementType(Type type) {
+ if ((profile == TosaProfileEnum::BaseInference) && isa<FloatType>(type)) {
+ return false;
+ }
+ if (type.isF64()) {
+ return false;
+ }
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
+ if (intTy.isUnsigned()) {
+ switch (intTy.getWidth()) {
+ case 8:
+ case 16:
+ return true;
+ default:
+ return false;
+ }
+ } else {
+ // Signless - treated as signed.
+ switch (intTy.getWidth()) {
+ case 1:
+ case 4:
+ case 8:
+ case 16:
+ case 32:
+ case 48:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+ }
+ return true;
+}
+
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
- if ((profile == TosaProfileEnum::BaseInference) &&
- isa<FloatType>(getElementTypeOrSelf(operand))) {
+ auto elementTy = getElementTypeOrSelf(operand);
+ if (!isValidElementType(elementTy)) {
+ op->emitOpError() << "is not profile-aligned: element type "
+ << elementTy << " is not legal";
return signalPassFailure();
}
- if (getElementTypeOrSelf(operand).isF64()) {
+ }
+ for (Type resultTy : op->getResultTypes()) {
+ auto elementTy = getElementTypeOrSelf(resultTy);
+ if (!isValidElementType(elementTy)) {
+ op->emitOpError() << "is not profile-aligned: element type "
+ << elementTy << " is not legal";
return signalPassFailure();
}
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 35ecbcc7..d8dd878 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -115,6 +115,22 @@ func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> {
// -----
+func.func @test_const_i2(%arg0 : tensor<1xi2>) {
+ // expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'i2' is not legal}}
+ %0 = "tosa.const"() {value = dense<0> : tensor<1xi2>} : () -> tensor<1xi2>
+ return
+}
+
+// -----
+
+func.func @test_const_ui32(%arg0 : tensor<1xui32>) {
+ // expected-error@+1 {{'tosa.const' op is not profile-aligned: element type 'ui32' is not legal}}
+ %0 = "tosa.const"() {value = dense<0> : tensor<1xui32>} : () -> tensor<1xui32>
+ return
+}
+
+// -----
+
func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
// expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}}
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 8193, 1>, pad = array<i64: 4, 4, 4, 4>, stride = array<i64: 1, 1>, acc_type = f32} :