diff options
author | Matthias Springer <me@m-sp.org> | 2024-06-19 16:23:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-19 16:23:27 +0200 |
commit | 5b007582bb9142ccab47134e5d2de0dd451338af (patch) | |
tree | f6d1528e601c834681d3f84626aff9b7dcf2a851 /mlir/test | |
parent | 2c1ae801e1b66a09a15028ae4ba614e0911eec00 (diff) | |
download | llvm-5b007582bb9142ccab47134e5d2de0dd451338af.zip llvm-5b007582bb9142ccab47134e5d2de0dd451338af.tar.gz llvm-5b007582bb9142ccab47134e5d2de0dd451338af.tar.bz2 |
[mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (#95700)
This commit fixes a crash in `-reconcile-unrealized-casts` when cast ops
have multiple operands:
```
DialectConversion.cpp:1583: virtual void mlir::ConversionPatternRewriter::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
This commit also generalizes the pass such that more ops are folded. In
particular (letters indicate types):
```
A
/ \
B C
|
A
```
Previously, such IR was not folded at all. The `A -> B -> A` type cast
cycle is now folded away. (The `A -> C` cast stays in place.)
This commit also turns the pass from a dialect conversion into a simple
IR walk. The pattern and its `populate` function are removed. The
pattern was a (non-conversion) rewrite pattern, but used in a dialect
conversion, which is generally not safe. In particular, the rewrite
pattern may traverse IR that was already scheduled for erasure by the
dialect conversion.
Note: Some test cases changed slightly (NFC) because the new pass
implementation no longer attempts to fold ops.
Note for LLVM integration: If your pipeline uses the removed `populate`
function, try to simply remove that function call. Chances are you may
not need it at all. If it is in fact needed, run the
`-reconcile-unrealized-casts` pass right after the pass that used to
populate the pattern.
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
Diffstat (limited to 'mlir/test')
5 files changed, 110 insertions, 57 deletions
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir index 7cdb89e..18734d1 100644 --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) { // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)> - // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]] + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] @@ -159,14 +159,17 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] - // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]] + // CHECK: %[[RANK_EXTR:.*]] = llvm.extractvalue %[[DESC_2]][0] + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK_EXTR]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]]) - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA]], %[[ALLOC_SIZE]]) <{isVolatile = false}> + // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA_EXTRACTED]], %[[ALLOC_SIZE]]) <{isVolatile = false}> // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> - // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0] + // CHECK: %[[RANK_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][0] + // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK_EXTRACTED]], %[[NEW_DESC]][0] // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1] // CHECK: llvm.return %[[NEW_DESC_2]] return %0 : memref<*xf32> @@ -218,13 +221,15 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr // convention requires the caller to free them and the caller cannot know // whether they are the same value or not. // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}}) - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}> + // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}> // CHECK: %[[RES_1:.*]] = llvm.mlir.undef // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0] // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1] // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}}) - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}> + // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}> // CHECK: %[[RES_2:.*]] = llvm.mlir.undef // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0] // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1] @@ -265,7 +270,8 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : // CHECK: llvm.store %{{.*}}, %[[STOREPTR]] memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32> - // CHECK: llvm.return %[[ARG0]] + // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0] + // CHECK: llvm.return %[[EXTRACT_MEMREF]] return %arg0 : memref<4x3xf32> } @@ -298,9 +304,10 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]] %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32> + // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0] // CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)> // CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0] - // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1] + // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[EXTRACT_MEMREF]], %[[INSERT_RETURN0]][1] // CHECK: llvm.return %[[INSERT_RETURN1]] return %0, %arg0 : f32, memref<4x3xf32> } diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir deleted file mode 100644 index f5ceb29..0000000 --- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s - -// CHECK-LABEL: @liveSingleCast -// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 -// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 -// CHECK: return %[[liveCast]] : i32 - -func.func @liveSingleCast(%arg0: i64) -> i32 { - %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 - return %0 : i32 -} - -// ----- - -// CHECK-LABEL: @liveChain -// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 -// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1 -// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32 -// CHECK: return %[[cast1]] : i32 - -func.func @liveChain(%arg0: i64) -> i32 { - %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1 - %1 = builtin.unrealized_conversion_cast %0 : i1 to i32 - return %1 : i32 -} - -// ----- - -// CHECK-LABEL: @liveBifurcation -// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 -// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 -// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64 -// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1 -// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64 -// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64 -// CHECK: return %[[result]] : i64 - -func.func @liveBifurcation(%arg0: i64) -> i64 { - %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 - %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 - %2 = builtin.unrealized_conversion_cast %0 : i32 to i1 - %3 = arith.extsi %2 : i1 to i64 - %4 = arith.addi %1, %3 : i64 - return %4 : i64 -} diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir index d71cbba1..3573114 100644 --- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -103,3 +103,94 @@ func.func @unusedBifurcation(%arg0: i64) -> i64 { %4 = arith.addi %arg0, %3 : i64 return %4 : i64 } + +// ----- + +// CHECK-LABEL: @liveSingleCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: return %[[liveCast]] : i32 + +func.func @liveSingleCast(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @liveChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32 +// CHECK: return %[[cast1]] : i32 + +func.func @liveChain(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1 + %1 = builtin.unrealized_conversion_cast %0 : i1 to i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @liveBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1 +// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[extsi]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @liveBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %2 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %3 = arith.extsi %2 : i1 to i64 + %4 = arith.addi %1, %3 : i64 + return %4 : i64 +} + +// ----- + +// CHECK-LABEL: func @deadNToOneCast( +// CHECK-NEXT: return +func.func @deadNToOneCast(%arg0: index, %arg1: index) { + %0 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64 + return +} + +// ----- + +// CHECK-LABEL: func @swappingOperands( +// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index +// CHECK: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]] +// CHECK: %[[cast2:.*]]:2 = builtin.unrealized_conversion_cast %[[cast1]]#1, %[[cast1]]#0 +// CHECK: %[[cast3:.*]]:2 = builtin.unrealized_conversion_cast %[[cast2]]#0, %[[cast2]]#1 +// CHECK: return %[[cast3]]#0, %[[cast3]]#1 +func.func @swappingOperands(%arg0: index, %arg1: index) -> (index, index) { + %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64 + %1:2 = builtin.unrealized_conversion_cast %0#1, %0#0 : i64, i64 to i32, i32 + %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1 : i32, i32 to index, index + return %2#0, %2#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @matchingOperands( +// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index +// CHECK: return %[[arg0]], %[[arg1]] +func.func @matchingOperands(%arg0: index, %arg1: index) -> (index, index) { + %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64 + %1:3 = builtin.unrealized_conversion_cast %0#0, %0#1 : i64, i64 to i32, i32, i32 + %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1, %1#2 : i32, i32, i32 to index, index + return %2#0, %2#1 : index, index +} + +// ----- + +// CHECK-LABEL: func @emptyCast() +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast to index +// CHECK: return %[[cast]] +func.func @emptyCast() -> index { + %0 = builtin.unrealized_conversion_cast to index + return %0 : index +} diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir index cc9759f..f6d3387 100644 --- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir @@ -1,11 +1,11 @@ // Note: We run CSE here to make the pattern matching more direct. -// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s +// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s // RUN: mlir-opt %s \ // RUN: -transform-preload-library="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir" \ // RUN: -transform-interpreter="debug-payload-root-tag=payload" \ -// RUN: -test-transform-dialect-erase-schedule -cse \ +// RUN: -test-transform-dialect-erase-schedule -cse -canonicalize \ // RUN: | FileCheck %s module attributes {transform.target_tag="payload"} { diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir index ac4608e..a74553c 100644 --- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir @@ -1,12 +1,12 @@ // Note: We run CSE here to make the pattern matching more direct. -// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s +// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s // RUN: mlir-opt %s \ // RUN: -transform-preload-library=transform-library-paths=%p/../Transform/include/Library/lower-to-llvm.mlir \ // RUN: -transform-interpreter="entry-point=entry_point" \ // RUN: -test-transform-dialect-erase-schedule \ -// RUN: -cse \ +// RUN: -cse -canonicalize \ // RUN: | FileCheck %s // Check that we properly lower to llvm memref operations that require to be |