diff options
author | Luke Hutton <luke.hutton@arm.com> | 2025-07-16 07:33:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-16 07:33:40 +0100 |
commit | 5480fc6bb8ef6a6a895be7952d50d557116dcb38 (patch) | |
tree | 40ba61ba700dab07f9c1f77eca56746aa476b6f2 | |
parent | dbb6ed76317a52ada7045611649e50c2afe51496 (diff) | |
download | llvm-5480fc6bb8ef6a6a895be7952d50d557116dcb38.zip llvm-5480fc6bb8ef6a6a895be7952d50d557116dcb38.tar.gz llvm-5480fc6bb8ef6a6a895be7952d50d557116dcb38.tar.bz2 |
[mlir][tosa] Interpret boolean values correctly in cast folder (#147078)
Previously the cast folder would sign extend boolean values, leading
"true" to be casted to a value of -1 instead of 1. This change ensures
i1 values are zero extended, since i1 is used as a boolean value in
TOSA. According to the TOSA spec, the result of a boolean cast with
value "true" to another integer type should give a result of 1.
Fixes https://github.com/llvm/llvm-project/issues/57951
-rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/canonicalize.mlir | 11 |
2 files changed, 15 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 2dd45d27..5758d8d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1295,7 +1295,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { } if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) { - auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger(); + const auto inIntType = llvm::cast<IntegerType>(inETy); + auto unsignIn = inIntType.isUnsignedInteger(); bool trunc = inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth(); auto intVal = operand.getSplatValue<APInt>(); @@ -1303,7 +1304,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { if (trunc) { intVal = intVal.trunc(bitwidth); - } else if (unsignIn) { + // i1 types are boolean in TOSA + } else if (unsignIn || inIntType.isInteger(1)) { intVal = intVal.zext(bitwidth); } else { intVal = intVal.sext(bitwidth); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 2728080..11c8d54 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1338,3 +1338,14 @@ func.func @no_fold_mul_result_exceeds_i32() -> tensor<i32> { %3 = tosa.mul %0, %1, %2 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32> return %3 : tensor<i32> } + +// ----- + +// CHECK-LABEL: @test_fold_i1_to_i32_cast +// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<1> : tensor<i32>}> : () -> tensor<i32> +// CHECK: return %[[OUT]] : tensor<i32> +func.func @test_fold_i1_to_i32_cast() -> tensor<i32> { + %0 = "tosa.const"() <{values = dense<1> : tensor<i1>}> : () -> tensor<i1> + %1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32> + return %1 : tensor<i32> +} |