aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-19 16:23:27 +0200
committerGitHub <noreply@github.com>2024-06-19 16:23:27 +0200
commit5b007582bb9142ccab47134e5d2de0dd451338af (patch)
treef6d1528e601c834681d3f84626aff9b7dcf2a851
parent2c1ae801e1b66a09a15028ae4ba614e0911eec00 (diff)
downloadllvm-5b007582bb9142ccab47134e5d2de0dd451338af.zip
llvm-5b007582bb9142ccab47134e5d2de0dd451338af.tar.gz
llvm-5b007582bb9142ccab47134e5d2de0dd451338af.tar.bz2
[mlir][Conversion] Generalize and fix crash in `reconcile-unrealized-casts` (#95700)
This commit fixes a crash in `-reconcile-unrealized-casts` when cast ops have multiple operands: ``` DialectConversion.cpp:1583: virtual void mlir::ConversionPatternRewriter::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed. ``` This commit also generalizes the pass such that more ops are folded. In particular (letters indicate types): ``` A / \ B C | A ``` Previously, such IR was not folded at all. The `A -> B -> A` type cast cycle is now folded away. (The `A -> C` cast stays in place.) This commit also turns the pass from a dialect conversion into a simple IR walk. The pattern and its `populate` function are removed. The pattern was a (non-conversion) rewrite pattern, but used in a dialect conversion, which is generally not safe. In particular, the rewrite pattern may traverse IR that was already scheduled for erasure by the dialect conversion. Note: Some test cases changed slightly (NFC) because the new pass implementation no longer attempts to fold ops. Note for LLVM integration: If your pipeline uses the removed `populate` function, try to simply remove that function call. Chances are you may not need it at all. If it is in fact needed, run the `-reconcile-unrealized-casts` pass right after the pass that used to populate the pattern. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com> Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
-rw-r--r--mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h4
-rw-r--r--mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp168
-rw-r--r--mlir/test/Conversion/FuncToLLVM/calling-convention.mlir23
-rw-r--r--mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir45
-rw-r--r--mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir91
-rw-r--r--mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir4
-rw-r--r--mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir4
7 files changed, 180 insertions, 159 deletions
diff --git a/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h b/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
index 9df1177..533a57a 100644
--- a/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
+++ b/mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h
@@ -21,10 +21,6 @@ class RewritePatternSet;
/// Creates a pass that eliminates noop `unrealized_conversion_cast` operation
/// sequences.
std::unique_ptr<Pass> createReconcileUnrealizedCastsPass();
-
-/// Populates `patterns` with rewrite patterns that eliminate noop
-/// `unrealized_conversion_cast` operation sequences.
-void populateReconcileUnrealizedCastsPatterns(RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_RECONCILEUNREALIZEDCASTS_RECONCILEUNREALIZEDCASTS_H_
diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
index 86a3d8b..12e0029 100644
--- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
+++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp
@@ -9,9 +9,7 @@
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -22,113 +20,87 @@ using namespace mlir;
namespace {
-/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
-/// the same as the input ones.
-/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
-/// represent a noop within the IR, and thus the initial input values can be
-/// propagated.
-/// The same does not hold for 'open' chains of casts, such as
-/// `A -> B -> C`. In this last case there is no cycle among the types and thus
-/// the conversion is incomplete. The same hold for 'closed' chains like
-/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
-/// operations.
-/// Bifurcations (that is when a chain starts in between of another one) are
-/// also taken into considerations, and all the above considerations remain
-/// valid.
-/// Special corner cases such as dead casts or single casts with same input and
-/// output types are also covered.
-struct UnrealizedConversionCastPassthrough
- : public OpRewritePattern<UnrealizedConversionCastOp> {
- using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
- PatternRewriter &rewriter) const override {
- // The nodes that either are not used by any operation or have at least
- // one user that is not an unrealized cast.
- DenseSet<UnrealizedConversionCastOp> exitNodes;
-
- // The nodes whose users are all unrealized casts
- DenseSet<UnrealizedConversionCastOp> intermediateNodes;
-
- // Stack used for the depth-first traversal of the use-def DAG.
- SmallVector<UnrealizedConversionCastOp, 2> visitStack;
- visitStack.push_back(op);
-
- while (!visitStack.empty()) {
- UnrealizedConversionCastOp current = visitStack.pop_back_val();
- auto users = current->getUsers();
- bool isLive = false;
-
- for (Operation *user : users) {
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
- if (other.getInputs() != current.getOutputs())
- return rewriter.notifyMatchFailure(
- op, "mismatching values propagation");
- } else {
- isLive = true;
- }
-
- // Continue traversing the DAG of unrealized casts
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
- visitStack.push_back(other);
- }
-
- // If the cast is live, then we need to check if the results of the last
- // cast have the same type of the root inputs. It this is the case (e.g.
- // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
- // no-op and the inputs can be forwarded. If it's not (e.g.
- // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
-
- bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
-
- if (isLive && !isCycle)
- return rewriter.notifyMatchFailure(op,
- "live unrealized conversion cast");
-
- bool isExitNode = users.empty() || isLive;
-
- if (isExitNode) {
- exitNodes.insert(current);
- } else {
- intermediateNodes.insert(current);
- }
- }
-
- // Replace the sink nodes with the root input values
- for (UnrealizedConversionCastOp exitNode : exitNodes)
- rewriter.replaceOp(exitNode, op.getInputs());
-
- // Erase all the other casts belonging to the DAG
- for (UnrealizedConversionCastOp castOp : intermediateNodes)
- rewriter.eraseOp(castOp);
-
- return success();
- }
-};
-
/// Pass to simplify and eliminate unrealized conversion casts.
+///
+/// This pass processes unrealized_conversion_cast ops in a worklist-driven
+/// fashion. For each matched cast op, if the chain of input casts eventually
+/// reaches a cast op where the input types match the output types of the
+/// matched op, replace the matched op with the inputs.
+///
+/// Example:
+/// %1 = unrealized_conversion_cast %0 : !A to !B
+/// %2 = unrealized_conversion_cast %1 : !B to !C
+/// %3 = unrealized_conversion_cast %2 : !C to !A
+///
+/// In the above example, %0 can be used instead of %3 and all cast ops are
+/// folded away.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateReconcileUnrealizedCastsPatterns(patterns);
- ConversionTarget target(getContext());
- target.addIllegalOp<UnrealizedConversionCastOp>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- signalPassFailure();
+ // Gather all unrealized_conversion_cast ops.
+ SetVector<UnrealizedConversionCastOp> worklist;
+ getOperation()->walk(
+ [&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
+
+ // Helper function that adds all operands to the worklist that are an
+ // unrealized_conversion_cast op result.
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+ for (Value v : castOp.getInputs())
+ if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ worklist.insert(inputCastOp);
+ };
+
+ // Helper function that return the unrealized_conversion_cast op that
+ // defines all inputs of the given op (in the same order). Return "nullptr"
+ // if there is no such op.
+ auto getInputCast =
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
+ if (castOp.getInputs().empty())
+ return {};
+ auto inputCastOp = castOp.getInputs()
+ .front()
+ .getDefiningOp<UnrealizedConversionCastOp>();
+ if (!inputCastOp)
+ return {};
+ if (inputCastOp.getOutputs() != castOp.getInputs())
+ return {};
+ return inputCastOp;
+ };
+
+ // Process ops in the worklist bottom-to-top.
+ while (!worklist.empty()) {
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+ if (castOp->use_empty()) {
+ // DCE: If the op has no users, erase it. Add the operands to the
+ // worklist to find additional DCE opportunities.
+ enqueueOperands(castOp);
+ castOp->erase();
+ continue;
+ }
+
+ // Traverse the chain of input cast ops to see if an op with the same
+ // input types can be found.
+ UnrealizedConversionCastOp nextCast = castOp;
+ while (nextCast) {
+ if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ // Found a cast where the input types match the output types of the
+ // matched op. We can directly use those inputs and the matched op can
+ // be removed.
+ enqueueOperands(castOp);
+ castOp.replaceAllUsesWith(nextCast.getInputs());
+ castOp->erase();
+ break;
+ }
+ nextCast = getInputCast(nextCast);
+ }
+ }
}
};
} // namespace
-void mlir::populateReconcileUnrealizedCastsPatterns(
- RewritePatternSet &patterns) {
- patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
-}
-
std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
return std::make_unique<ReconcileUnrealizedCasts>();
}
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index 7cdb89e..18734d1 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
// CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,14 +159,17 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
- // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
+ // CHECK: %[[RANK_EXTR:.*]] = llvm.extractvalue %[[DESC_2]][0]
+ // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK_EXTR]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
// CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]])
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
+ // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA_EXTRACTED]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
// CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
- // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0]
+ // CHECK: %[[RANK_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][0]
+ // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK_EXTRACTED]], %[[NEW_DESC]][0]
// CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1]
// CHECK: llvm.return %[[NEW_DESC_2]]
return %0 : memref<*xf32>
@@ -218,13 +221,15 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr
// convention requires the caller to free them and the caller cannot know
// whether they are the same value or not.
// CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}})
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
+ // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
// CHECK: %[[RES_1:.*]] = llvm.mlir.undef
// CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0]
// CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1]
// CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}})
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
+ // CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
// CHECK: %[[RES_2:.*]] = llvm.mlir.undef
// CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0]
// CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1]
@@ -265,7 +270,8 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 :
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
- // CHECK: llvm.return %[[ARG0]]
+ // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
+ // CHECK: llvm.return %[[EXTRACT_MEMREF]]
return %arg0 : memref<4x3xf32>
}
@@ -298,9 +304,10 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind
// CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
%0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
+ // CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
// CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)>
// CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0]
- // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1]
+ // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[EXTRACT_MEMREF]], %[[INSERT_RETURN0]][1]
// CHECK: llvm.return %[[INSERT_RETURN1]]
return %0, %arg0 : f32, memref<4x3xf32>
}
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
deleted file mode 100644
index f5ceb29..0000000
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir
+++ /dev/null
@@ -1,45 +0,0 @@
-// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s
-
-// CHECK-LABEL: @liveSingleCast
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
-// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
-// CHECK: return %[[liveCast]] : i32
-
-func.func @liveSingleCast(%arg0: i64) -> i32 {
- %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
- return %0 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @liveChain
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
-// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
-// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
-// CHECK: return %[[cast1]] : i32
-
-func.func @liveChain(%arg0: i64) -> i32 {
- %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
- %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
- return %1 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @liveBifurcation
-// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
-// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
-// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64
-// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
-// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
-// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64
-// CHECK: return %[[result]] : i64
-
-func.func @liveBifurcation(%arg0: i64) -> i64 {
- %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
- %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
- %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
- %3 = arith.extsi %2 : i1 to i64
- %4 = arith.addi %1, %3 : i64
- return %4 : i64
-}
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index d71cbba1..3573114 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -103,3 +103,94 @@ func.func @unusedBifurcation(%arg0: i64) -> i64 {
%4 = arith.addi %arg0, %3 : i64
return %4 : i64
}
+
+// -----
+
+// CHECK-LABEL: @liveSingleCast
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: return %[[liveCast]] : i32
+
+func.func @liveSingleCast(%arg0: i64) -> i32 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveChain
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1
+// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32
+// CHECK: return %[[cast1]] : i32
+
+func.func @liveChain(%arg0: i64) -> i32 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1
+ %1 = builtin.unrealized_conversion_cast %0 : i1 to i32
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @liveBifurcation
+// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64
+// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32
+// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1
+// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64
+// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[extsi]] : i64
+// CHECK: return %[[result]] : i64
+
+func.func @liveBifurcation(%arg0: i64) -> i64 {
+ %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32
+ %1 = builtin.unrealized_conversion_cast %0 : i32 to i64
+ %2 = builtin.unrealized_conversion_cast %0 : i32 to i1
+ %3 = arith.extsi %2 : i1 to i64
+ %4 = arith.addi %1, %3 : i64
+ return %4 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func @deadNToOneCast(
+// CHECK-NEXT: return
+func.func @deadNToOneCast(%arg0: index, %arg1: index) {
+ %0 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @swappingOperands(
+// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg0]], %[[arg1]]
+// CHECK: %[[cast2:.*]]:2 = builtin.unrealized_conversion_cast %[[cast1]]#1, %[[cast1]]#0
+// CHECK: %[[cast3:.*]]:2 = builtin.unrealized_conversion_cast %[[cast2]]#0, %[[cast2]]#1
+// CHECK: return %[[cast3]]#0, %[[cast3]]#1
+func.func @swappingOperands(%arg0: index, %arg1: index) -> (index, index) {
+ %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64
+ %1:2 = builtin.unrealized_conversion_cast %0#1, %0#0 : i64, i64 to i32, i32
+ %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1 : i32, i32 to index, index
+ return %2#0, %2#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @matchingOperands(
+// CHECK-SAME: %[[arg0:.*]]: index, %[[arg1:.*]]: index
+// CHECK: return %[[arg0]], %[[arg1]]
+func.func @matchingOperands(%arg0: index, %arg1: index) -> (index, index) {
+ %0:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : index, index to i64, i64
+ %1:3 = builtin.unrealized_conversion_cast %0#0, %0#1 : i64, i64 to i32, i32, i32
+ %2:2 = builtin.unrealized_conversion_cast %1#0, %1#1, %1#2 : i32, i32, i32 to index, index
+ return %2#0, %2#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @emptyCast()
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast to index
+// CHECK: return %[[cast]]
+func.func @emptyCast() -> index {
+ %0 = builtin.unrealized_conversion_cast to index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
index cc9759f..f6d3387 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
@@ -1,11 +1,11 @@
// Note: We run CSE here to make the pattern matching more direct.
-// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
+// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s
// RUN: mlir-opt %s \
// RUN: -transform-preload-library="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir" \
// RUN: -transform-interpreter="debug-payload-root-tag=payload" \
-// RUN: -test-transform-dialect-erase-schedule -cse \
+// RUN: -test-transform-dialect-erase-schedule -cse -canonicalize \
// RUN: | FileCheck %s
module attributes {transform.target_tag="payload"} {
diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
index ac4608e..a74553c 100644
--- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
+++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
@@ -1,12 +1,12 @@
// Note: We run CSE here to make the pattern matching more direct.
-// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s
+// RUN: mlir-opt %s -test-lower-to-llvm -cse -canonicalize | FileCheck %s
// RUN: mlir-opt %s \
// RUN: -transform-preload-library=transform-library-paths=%p/../Transform/include/Library/lower-to-llvm.mlir \
// RUN: -transform-interpreter="entry-point=entry_point" \
// RUN: -test-transform-dialect-erase-schedule \
-// RUN: -cse \
+// RUN: -cse -canonicalize \
// RUN: | FileCheck %s
// Check that we properly lower to llvm memref operations that require to be