aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiuseppe Rossini <giuseppe.rossini@amd.com>2024-07-02 17:12:33 +0100
committerGitHub <noreply@github.com>2024-07-02 17:12:33 +0100
commit6c3897d90eda4c39789ac9f4efa51db46734a249 (patch)
tree60ae5326a5b02b9d45f12a5e630a7ca0755fdce4
parent123beb7926651217024e5db58b93ab9e8f3c77c7 (diff)
downloadllvm-6c3897d90eda4c39789ac9f4efa51db46734a249.zip
llvm-6c3897d90eda4c39789ac9f4efa51db46734a249.tar.gz
llvm-6c3897d90eda4c39789ac9f4efa51db46734a249.tar.bz2
Fix block merging (#96871)
With this PR I am trying to address: https://github.com/llvm/llvm-project/issues/63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: many tests are still not passing. But I wanted to submit the code before changing all the tests (and probably adding a couple), so that we can agree in principle on the algorithm/design.
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp144
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir20
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_entry_block.mlir6
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_if.mlir67
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_while.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir12
-rw-r--r--mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir4
-rw-r--r--mlir/test/Transforms/canonicalize-block-merge.mlir6
-rw-r--r--mlir/test/Transforms/canonicalize-dce.mlir8
-rw-r--r--mlir/test/Transforms/make-isolated-from-above.mlir18
-rw-r--r--mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir76
12 files changed, 289 insertions, 93 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485c..5227b22 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
+ // We don't want that the block structure changes invalidating the
+ // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+ // region simplification
+ GreedyRewriteConfig config;
+ config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+ config)))
signalPassFailure();
}
};
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15b..412e245 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include <deque>
+#include <iterator>
using namespace mlir;
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());
// Update each of the predecessor terminators with the new arguments.
- SmallVector<SmallVector<Value, 8>, 2> newArguments(
- 1 + blocksToMerge.size(),
- SmallVector<Value, 8>(operandsToMerge.size()));
+ SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
+ SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
- newArguments[i][it.index()] = operand.get();
-
- // Update the operand and insert an argument if this is the leader.
- if (i == 0) {
- Value operandVal = operand.get();
- operand.set(leaderBlock->addArgument(operandVal.getType(),
- operandVal.getLoc()));
+ Value operandVal = operand.get();
+ Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
+ operandVal);
+ if (it == newArguments[i].end()) {
+ newArguments[i].push_back(operandVal);
+ // Update the operand and insert an argument if this is the leader.
+ if (i == 0) {
+ operand.set(leaderBlock->addArgument(operandVal.getType(),
+ operandVal.getLoc()));
+ }
+ } else if (i == 0) {
+ // If this is the leader, update the operand but do not insert a new
+ // argument. Instead, the opearand should point to one of the
+ // arguments we already passed (and that contained `operandVal`)
+ operand.set(leaderBlock->getArgument(
+ std::distance(newArguments[i].begin(), it)));
}
}
}
@@ -818,6 +831,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ Block &block) {
+ SmallVector<size_t> argsToErase;
+
+ // Go through the arguments of the block
+ for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
+ bool sameArg = true;
+ Value commonValue;
+
+ // Go through the block predecessor and flag if they pass to the block
+ // different values for the same argument
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+ if (!branch) {
+ sameArg = false;
+ break;
+ }
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ auto operands = succOperands.getForwardedOperands();
+ if (!commonValue) {
+ commonValue = operands[argIdx];
+ } else {
+ if (operands[argIdx] != commonValue) {
+ sameArg = false;
+ break;
+ }
+ }
+ }
+
+ // If they are passing the same value, drop the argument
+ if (commonValue && sameArg) {
+ argsToErase.push_back(argIdx);
+
+ // Remove the argument from the block
+ Value argVal = block.getArgument(argIdx);
+ rewriter.replaceAllUsesWith(argVal, commonValue);
+ }
+ }
+
+ // Remove the arguments
+ for (auto argIdx : llvm::reverse(argsToErase)) {
+ block.eraseArgument(argIdx);
+
+ // Remove the argument from the branch ops
+ for (auto predIt = block.pred_begin(), predE = block.pred_end();
+ predIt != predE; ++predIt) {
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+ unsigned succIndex = predIt.getSuccessorIndex();
+ SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+ succOperands.erase(argIdx);
+ }
+ }
+ return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+/// llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+/// llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+ MutableArrayRef<Region> regions) {
+ llvm::SmallSetVector<Region *, 1> worklist;
+ for (auto &region : regions)
+ worklist.insert(&region);
+ bool anyChanged = false;
+ while (!worklist.empty()) {
+ Region *region = worklist.pop_back_val();
+
+ // Add any nested regions to the worklist.
+ for (Block &block : *region) {
+ anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+ for (auto &op : block)
+ for (auto &nestedRegion : op.getRegions())
+ worklist.insert(&nestedRegion);
+ }
+ }
+ return success(anyChanged);
+}
+
//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
@@ -832,8 +948,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
- if (mergeBlocks)
+ bool droppedRedundantArguments = false;
+ if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+ droppedRedundantArguments =
+ succeeded(dropRedundantArguments(rewriter, regions));
+ }
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
- mergedIdenticalBlocks);
+ mergedIdenticalBlocks || droppedRedundantArguments);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f..8e14990 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: ^bb1
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
// CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb4:
// CHECK-NOT: bufferization.dealloc
// CHECK-NOT: bufferization.clone
-// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+// CHECK-NOT: bufferization.dealloc
+// CHECK-NOT: bufferization.clone
+// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
// CHECK: test.copy
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a8922..50a2d6b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763..c728ad2 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1:.*]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT: ^[[bb2]]:
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: cf.br ^[[bb4:.*]]
+// CHECK-NEXT: ^[[bb4]]:
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
// -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
}
// CHECK-LABEL: func @main()
-// CHECK-DAG: arith.constant 0
-// CHECK-DAG: arith.constant 10
-// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT: return %{{.*}}
+// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG: arith.constant true
+// CHECK: cf.br ^[[bb1:.*]]
+// CHECK-NEXT: ^[[bb1]]:
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT: ^[[bb2]]
+// CHECK-NEXT: cf.br ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb3]]
+// CHECK-NEXT: return %[[cst]]
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900..580a97d 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb1]](%{{.*}}: i32)
// DET-ALL: arith.cmpi slt, {{.*}}
-// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]
// DET-ALL: arith.addi {{.*}}
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]:
// DET-ALL: tensor.from_elements {{.*}}
// DET-ALL: return %{{.*}} : tensor<i32>
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
// DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// DET-CF: ^[[bb1]](%{{.*}}: i32)
// DET-CF: arith.cmpi slt, {{.*}}
-// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF: ^[[bb2]](%{{.*}}: i32)
+// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-CF: ^[[bb2]]:
// DET-CF: arith.addi {{.*}}
// DET-CF: cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-CF: ^[[bb3]](%{{.*}}: i32)
+// DET-CF: ^[[bb3]]:
// DET-CF: tensor.from_elements %{{.*}} : tensor<i32>
// DET-CF: return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 955c7be..414d9b9 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-ALL: } -> tensor<i32>
// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL: ^[[bb2]]:
// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32>
// DET-ALL: tensor.empty() : tensor<10xi32>
// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
@@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-ALL: linalg.yield %{{.*}} : i32
// DET-ALL: } -> tensor<10xi32>
// DET-ALL: cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>)
-// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: ^[[bb3]]
// DET-ALL: tensor.from_elements %{{.*}} : tensor<i32>
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
@@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-CF: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
-// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF: cf.cond_br %{{.*}}, ^bb2, ^bb3
+// DET-CF: ^bb2:
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
// DET-CF: cf.br ^bb1(%{{.*}} : tensor<10xi32>)
-// DET-CF: ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF: ^bb3:
// DET-CF: return %{{.*}} : tensor<i32>
// DET-CF: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
index 6d8d5fe..913e782 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -49,8 +49,8 @@ func.func @main() -> () attributes {} {
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32)
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
-// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]]
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}}
// CHECK-NEXT: cf.br ^[[bb1]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb3]]:
diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir
index 3b8b1fc..92cfde8 100644
--- a/mlir/test/Transforms/canonicalize-block-merge.mlir
+++ b/mlir/test/Transforms/canonicalize-block-merge.mlir
@@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 :
// CHECK-LABEL: func @mismatch_argument_uses(
func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: return {{.*}}, {{.*}}
cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32)
@@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32,
// CHECK-LABEL: func @mismatch_argument_types(
func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16)
@@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) {
// CHECK-LABEL: func @mismatch_argument_count(
func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) {
- // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2
+ // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index ac034d5..8463194 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -137,10 +137,10 @@ func.func @f(%arg0: f32) {
// Test case: Test the mechanics of deleting multiple block arguments.
// CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>)
-// CHECK-NEXT: "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>)
-// CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>):
-// CHECK-NEXT: "foo.print"([[VAL0]])
-// CHECK-NEXT: "foo.print"([[VAL1]])
+// CHECK-NEXT: "test.br"()[^bb1]
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: "foo.print"(%arg1)
+// CHECK-NEXT: "foo.print"(%arg3)
// CHECK-NEXT: return
diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir
index 58f6cfb..a9d4325 100644
--- a/mlir/test/Transforms/make-isolated-from-above.mlir
+++ b/mlir/test/Transforms/make-isolated-from-above.mlir
@@ -78,9 +78,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
// CHECK: test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]]
// CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index)
-// CHECK-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CHECK: ^bb1(%[[B5:.+]]: index)
-// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]])
+// CHECK-NEXT: cf.br ^bb1
+// CHECK: ^bb1:
+// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B0]])
// CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks(
// CLONE1-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -95,9 +95,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CLONE1-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
// CLONE1-DAG: %[[C0_0:.+]] = arith.constant 0 : index
// CLONE1-DAG: %[[C1_0:.+]] = arith.constant 1 : index
-// CLONE1-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CLONE1: ^bb1(%[[B3:.+]]: index)
-// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]])
+// CLONE1-NEXT: cf.br ^bb1
+// CLONE1: ^bb1:
+// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B0]])
// CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks(
// CLONE2-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
@@ -110,6 +110,6 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index
// CLONE2-DAG: %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]])
// CLONE2-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
// CLONE2-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
-// CLONE2-NEXT: cf.br ^bb1(%[[B0]] : index)
-// CLONE2: ^bb1(%[[B3:.+]]: index)
-// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]])
+// CLONE2-NEXT: cf.br ^bb1
+// CLONE2: ^bb1:
+// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]])
diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
new file mode 100644
index 0000000..570ff69
--- /dev/null
+++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir
@@ -0,0 +1,76 @@
+ // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s
+
+llvm.func @foo(%arg0: i64)
+
+llvm.func @rand() -> i1
+
+// CHECK-LABEL: func @large_merge_block(
+llvm.func @large_merge_block(%arg0: i64) {
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64
+
+ // CHECK: llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64)
+ // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64):
+ // CHECK: llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64)
+ // CHECK: ^bb{{.*}}(%11: i64):
+ // CHECK: llvm.br ^bb{{.*}}
+ // CHECK: ^bb{{.*}}:
+ // CHECK: llvm.call
+ // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64)
+ // CHECK: ^bb{{.*}}:
+ // CHECK: llvm.call
+ // CHECK llvm.br ^bb{{.*}}
+
+ %0 = llvm.mlir.constant(0 : i64) : i64
+ %1 = llvm.mlir.constant(1 : i64) : i64
+ %2 = llvm.mlir.constant(2 : i64) : i64
+ %3 = llvm.mlir.constant(3 : i64) : i64
+ %4 = llvm.mlir.constant(4 : i64) : i64
+ %10 = llvm.icmp "eq" %arg0, %0 : i64
+ llvm.cond_br %10, ^bb1, ^bb14
+^bb1: // pred: ^bb0
+ %11 = llvm.call @rand() : () -> i1
+ llvm.cond_br %11, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ llvm.call @foo(%1) : (i64) -> ()
+ llvm.br ^bb4
+^bb3: // pred: ^bb1
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.br ^bb4
+^bb4: // 2 preds: ^bb2, ^bb3
+ %14 = llvm.call @rand() : () -> i1
+ llvm.cond_br %14, ^bb5, ^bb6
+^bb5: // pred: ^bb4
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb13
+^bb6: // pred: ^bb4
+ llvm.call @foo(%4) : (i64) -> ()
+ llvm.br ^bb13
+^bb13: // 2 preds: ^bb11, ^bb12
+ llvm.br ^bb27
+^bb14: // pred: ^bb0
+ %23 = llvm.call @rand() : () -> i1
+ llvm.cond_br %23, ^bb15, ^bb16
+^bb15: // pred: ^bb14
+ llvm.call @foo(%4) : (i64) -> ()
+ llvm.br ^bb17
+^bb16: // pred: ^bb14
+ llvm.call @foo(%3) : (i64) -> ()
+ llvm.br ^bb17
+^bb17: // 2 preds: ^bb15, ^bb16
+ %26 = llvm.call @rand() : () -> i1
+ llvm.cond_br %26, ^bb18, ^bb19
+^bb18: // pred: ^bb17
+ llvm.call @foo(%2) : (i64) -> ()
+ llvm.br ^bb26
+^bb19: // pred: ^bb17
+ llvm.call @foo(%1) : (i64) -> ()
+ llvm.br ^bb26
+^bb26: // 2 preds: ^bb24, ^bb25
+ llvm.br ^bb27
+^bb27: // 2 preds: ^bb13, ^bb26
+ llvm.return
+}