diff options
author | Giuseppe Rossini <giuseppe.rossini@amd.com> | 2024-07-02 17:12:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-02 17:12:33 +0100 |
commit | 6c3897d90eda4c39789ac9f4efa51db46734a249 (patch) | |
tree | 60ae5326a5b02b9d45f12a5e630a7ca0755fdce4 | |
parent | 123beb7926651217024e5db58b93ab9e8f3c77c7 (diff) | |
download | llvm-6c3897d90eda4c39789ac9f4efa51db46734a249.zip llvm-6c3897d90eda4c39789ac9f4efa51db46734a249.tar.gz llvm-6c3897d90eda4c39789ac9f4efa51db46734a249.tar.bz2 |
Fix block merging (#96871)
With this PR I am trying to address:
https://github.com/llvm/llvm-project/issues/63230.
What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.
**Note**: many tests are still not passing. But I wanted to submit the
code before changing all the tests (and probably adding a couple), so
that we can agree in principle on the algorithm/design.
12 files changed, 289 insertions, 93 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 954485c..5227b22 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass SplitDeallocWhenNotAliasingAnyOther, RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(), analysis); + // We don't want that the block structure changes invalidating the + // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of + // region simplification + GreedyRewriteConfig config; + config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal; populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) signalPassFailure(); } }; diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 4c0f15b..412e245 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -16,11 +17,15 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include <deque> +#include <iterator> using namespace mlir; @@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { blockIterators.push_back(mergeBlock->begin()); // Update each of the predecessor terminators with the new arguments. - SmallVector<SmallVector<Value, 8>, 2> newArguments( - 1 + blocksToMerge.size(), - SmallVector<Value, 8>(operandsToMerge.size())); + SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(), + SmallVector<Value, 8>()); unsigned curOpIndex = 0; for (const auto &it : llvm::enumerate(operandsToMerge)) { unsigned nextOpOffset = it.value().first - curOpIndex; @@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { Block::iterator &blockIter = blockIterators[i]; std::advance(blockIter, nextOpOffset); auto &operand = blockIter->getOpOperand(it.value().second); - newArguments[i][it.index()] = operand.get(); - - // Update the operand and insert an argument if this is the leader. - if (i == 0) { - Value operandVal = operand.get(); - operand.set(leaderBlock->addArgument(operandVal.getType(), - operandVal.getLoc())); + Value operandVal = operand.get(); + Value *it = std::find(newArguments[i].begin(), newArguments[i].end(), + operandVal); + if (it == newArguments[i].end()) { + newArguments[i].push_back(operandVal); + // Update the operand and insert an argument if this is the leader. + if (i == 0) { + operand.set(leaderBlock->addArgument(operandVal.getType(), + operandVal.getLoc())); + } + } else if (i == 0) { + // If this is the leader, update the operand but do not insert a new + // argument. Instead, the opearand should point to one of the + // arguments we already passed (and that contained `operandVal`) + operand.set(leaderBlock->getArgument( + std::distance(newArguments[i].begin(), it))); } } } @@ -818,6 +831,109 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, return success(anyChanged); } +static LogicalResult dropRedundantArguments(RewriterBase &rewriter, + Block &block) { + SmallVector<size_t> argsToErase; + + // Go through the arguments of the block + for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) { + bool sameArg = true; + Value commonValue; + + // Go through the block predecessor and flag if they pass to the block + // different values for the same argument + for (auto predIt = block.pred_begin(), predE = block.pred_end(); + predIt != predE; ++predIt) { + auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator()); + if (!branch) { + sameArg = false; + break; + } + unsigned succIndex = predIt.getSuccessorIndex(); + SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); + auto operands = succOperands.getForwardedOperands(); + if (!commonValue) { + commonValue = operands[argIdx]; + } else { + if (operands[argIdx] != commonValue) { + sameArg = false; + break; + } + } + } + + // If they are passing the same value, drop the argument + if (commonValue && sameArg) { + argsToErase.push_back(argIdx); + + // Remove the argument from the block + Value argVal = block.getArgument(argIdx); + rewriter.replaceAllUsesWith(argVal, commonValue); + } + } + + // Remove the arguments + for (auto argIdx : llvm::reverse(argsToErase)) { + block.eraseArgument(argIdx); + + // Remove the argument from the branch ops + for (auto predIt = block.pred_begin(), predE = block.pred_end(); + predIt != predE; ++predIt) { + auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); + unsigned succIndex = predIt.getSuccessorIndex(); + SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); + succOperands.erase(argIdx); + } + } + return success(!argsToErase.empty()); +} + +/// This optimization drops redundant argument to blocks. I.e., if a given +/// argument to a block receives the same value from each of the block +/// predecessors, we can remove the argument from the block and use directly the +/// original value. This is a simple example: +/// +/// %cond = llvm.call @rand() : () -> i1 +/// %val0 = llvm.mlir.constant(1 : i64) : i64 +/// %val1 = llvm.mlir.constant(2 : i64) : i64 +/// %val2 = llvm.mlir.constant(3 : i64) : i64 +/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 +/// : i64) +/// +/// ^bb1(%arg0 : i64, %arg1 : i64): +/// llvm.call @foo(%arg0, %arg1) +/// +/// The previous IR can be rewritten as: +/// %cond = llvm.call @rand() : () -> i1 +/// %val0 = llvm.mlir.constant(1 : i64) : i64 +/// %val1 = llvm.mlir.constant(2 : i64) : i64 +/// %val2 = llvm.mlir.constant(3 : i64) : i64 +/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) +/// +/// ^bb1(%arg0 : i64): +/// llvm.call @foo(%val0, %arg0) +/// +static LogicalResult dropRedundantArguments(RewriterBase &rewriter, + MutableArrayRef<Region> regions) { + llvm::SmallSetVector<Region *, 1> worklist; + for (auto ®ion : regions) + worklist.insert(®ion); + bool anyChanged = false; + while (!worklist.empty()) { + Region *region = worklist.pop_back_val(); + + // Add any nested regions to the worklist. + for (Block &block : *region) { + anyChanged = succeeded(dropRedundantArguments(rewriter, block)); + + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + worklist.insert(&nestedRegion); + } + } + return success(anyChanged); +} + //===----------------------------------------------------------------------===// // Region Simplification //===----------------------------------------------------------------------===// @@ -832,8 +948,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); bool mergedIdenticalBlocks = false; - if (mergeBlocks) + bool droppedRedundantArguments = false; + if (mergeBlocks) { mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); + droppedRedundantArguments = + succeeded(dropRedundantArguments(rewriter, regions)); + } return success(eliminatedBlocks || eliminatedOpsOrArgs || - mergedIdenticalBlocks); + mergedIdenticalBlocks || droppedRedundantArguments); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir index 5e8104f..8e14990 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir @@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested( // CHECK-NEXT: ^bb1 // CHECK-NOT: bufferization.dealloc // CHECK-NOT: bufferization.clone -// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} : +// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} : // CHECK: ^bb2([[IDX:%.*]]:{{.*}}) // CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]]) // CHECK-NEXT: test.buffer_based @@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested( // CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]] // CHECK-NOT: bufferization.dealloc // CHECK-NOT: bufferization.clone -// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3 +// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4 // CHECK-NEXT: ^bb3: // CHECK-NOT: bufferization.dealloc // CHECK-NOT: bufferization.clone -// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]] -// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}}) +// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]] +// CHECK-NEXT: ^bb4: // CHECK-NOT: bufferization.dealloc // CHECK-NOT: bufferization.clone -// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]] -// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}}) +// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]] +// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}}) +// CHECK-NOT: bufferization.dealloc +// CHECK-NOT: bufferization.clone +// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]] +// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}}) // CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]] // CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] : -// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0 -// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}}) +// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0 +// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}}) // CHECK: test.copy // CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]] // CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]]) diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir index d1a8922..50a2d6b 100644 --- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir @@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> { // CHECK-LABEL: @main // CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32> // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32> -// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32) -// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32): -// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32> +// CHECK: cf.br ^{{.*}} +// CHECK: ^{{.*}}: +// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32> // CHECK: return %[[ELEMENTS]] : tensor<f32> diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir index 8d17763..c728ad2 100644 --- a/mlir/test/Dialect/Linalg/detensorize_if.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir @@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} { } // CHECK-LABEL: func @main() -// CHECK-DAG: arith.constant 0 -// CHECK-DAG: arith.constant 10 -// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32) -// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): -// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32) -// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) -// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32) -// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32> -// CHECK-NEXT: return %{{.*}} +// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0> +// CHECK-DAG: arith.constant true +// CHECK: cf.br +// CHECK-NEXT: ^[[bb1:.*]]: +// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3 +// CHECK-NEXT: ^[[bb2]] +// CHECK-NEXT: cf.br ^[[bb3:.*]] +// CHECK-NEXT: ^[[bb3]] +// CHECK-NEXT: return %[[cst]] // CHECK-NEXT: } // ----- @@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} { } // CHECK-LABEL: func @main() -// CHECK-DAG: arith.constant 0 -// CHECK-DAG: arith.constant 10 -// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32) -// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): -// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32) -// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) -// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32) -// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) -// CHECK-NEXT: cf.br ^[[bb4:.*]](%{{.*}} : i32) -// CHECK-NEXT: ^[[bb4]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32> -// CHECK-NEXT: return %{{.*}} +// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0> +// CHECK-DAG: arith.constant true +// CHECK: cf.br ^[[bb1:.*]] +// CHECK-NEXT: ^[[bb1:.*]]: +// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3 +// CHECK-NEXT: ^[[bb2]]: +// CHECK-NEXT: cf.br ^[[bb3:.*]] +// CHECK-NEXT: ^[[bb3]]: +// CHECK-NEXT: cf.br ^[[bb4:.*]] +// CHECK-NEXT: ^[[bb4]]: +// CHECK-NEXT: return %[[cst]] // CHECK-NEXT: } // ----- @@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} { } // CHECK-LABEL: func @main() -// CHECK-DAG: arith.constant 0 -// CHECK-DAG: arith.constant 10 -// CHECK: cf.br ^[[bb1:.*]](%{{.*}}: i32) -// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32): -// CHECK-NEXT: arith.cmpi slt, %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32) -// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) -// CHECK-NEXT: arith.addi %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.br ^[[bb3:.*]](%{{.*}} : i32) -// CHECK-NEXT: ^[[bb3]](%{{.*}}: i32) -// CHECK-NEXT: tensor.from_elements %{{.*}} : tensor<i32> -// CHECK-NEXT: return %{{.*}} +// CHECK-DAG: %[[cst:.*]] = arith.constant dense<10> +// CHECK-DAG: arith.constant true +// CHECK: cf.br ^[[bb1:.*]] +// CHECK-NEXT: ^[[bb1]]: +// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2 +// CHECK-NEXT: ^[[bb2]] +// CHECK-NEXT: cf.br ^[[bb3:.*]] +// CHECK-NEXT: ^[[bb3]] +// CHECK-NEXT: return %[[cst]] // CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir index aa30900..580a97d 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir @@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu // DET-ALL: cf.br ^[[bb1:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb1]](%{{.*}}: i32) // DET-ALL: arith.cmpi slt, {{.*}} -// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) -// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]] +// DET-ALL: ^[[bb2]] // DET-ALL: arith.addi {{.*}} // DET-ALL: cf.br ^[[bb1]](%{{.*}} : i32) -// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: ^[[bb3]]: // DET-ALL: tensor.from_elements {{.*}} // DET-ALL: return %{{.*}} : tensor<i32> @@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu // DET-CF: cf.br ^[[bb1:.*]](%{{.*}} : i32) // DET-CF: ^[[bb1]](%{{.*}}: i32) // DET-CF: arith.cmpi slt, {{.*}} -// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) -// DET-CF: ^[[bb2]](%{{.*}}: i32) +// DET-CF: cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]] +// DET-CF: ^[[bb2]]: // DET-CF: arith.addi {{.*}} // DET-CF: cf.br ^[[bb1]](%{{.*}} : i32) -// DET-CF: ^[[bb3]](%{{.*}}: i32) +// DET-CF: ^[[bb3]]: // DET-CF: tensor.from_elements %{{.*}} : tensor<i32> // DET-CF: return %{{.*}} : tensor<i32> diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir index 955c7be..414d9b9 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -74,8 +74,8 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr // DET-ALL: } -> tensor<i32> // DET-ALL: tensor.extract %{{.*}}[] : tensor<i32> // DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 -// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) -// DET-ALL: ^[[bb2]](%{{.*}}: i32) +// DET-ALL: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]] +// DET-ALL: ^[[bb2]]: // DET-ALL: tensor.from_elements %{{.*}} : tensor<i32> // DET-ALL: tensor.empty() : tensor<10xi32> // DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { @@ -83,7 +83,7 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr // DET-ALL: linalg.yield %{{.*}} : i32 // DET-ALL: } -> tensor<10xi32> // DET-ALL: cf.br ^[[bb1]](%{{.*}} : tensor<10xi32>) -// DET-ALL: ^[[bb3]](%{{.*}}: i32) +// DET-ALL: ^[[bb3]] // DET-ALL: tensor.from_elements %{{.*}} : tensor<i32> // DET-ALL: return %{{.*}} : tensor<i32> // DET-ALL: } @@ -95,10 +95,10 @@ func.func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attr // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) { // DET-CF: tensor.extract %{{.*}}[] : tensor<i32> // DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32 -// DET-CF: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>) -// DET-CF: ^bb2(%{{.*}}: tensor<i32>) +// DET-CF: cf.cond_br %{{.*}}, ^bb2, ^bb3 +// DET-CF: ^bb2: // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) { // DET-CF: cf.br ^bb1(%{{.*}} : tensor<10xi32>) -// DET-CF: ^bb3(%{{.*}}: tensor<i32>) +// DET-CF: ^bb3: // DET-CF: return %{{.*}} : tensor<i32> // DET-CF: } diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir index 6d8d5fe..913e782 100644 --- a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir @@ -49,8 +49,8 @@ func.func @main() -> () attributes {} { // CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb1]](%{{.*}}: i32) // CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} -// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]] -// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32) +// CHECK-NEXT: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb3:.*]] +// CHECK-NEXT: ^[[bb2]] // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} // CHECK-NEXT: cf.br ^[[bb1]](%{{.*}} : i32) // CHECK-NEXT: ^[[bb3]]: diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir index 3b8b1fc..92cfde8 100644 --- a/mlir/test/Transforms/canonicalize-block-merge.mlir +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -87,7 +87,7 @@ func.func @mismatch_operands_matching_arguments(%cond : i1, %arg0 : i32, %arg1 : // CHECK-LABEL: func @mismatch_argument_uses( func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, i32) { - // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + // CHECK: return {{.*}}, {{.*}} cf.cond_br %cond, ^bb1(%arg1 : i32), ^bb2(%arg0 : i32) @@ -101,7 +101,7 @@ func.func @mismatch_argument_uses(%cond : i1, %arg0 : i32, %arg1 : i32) -> (i32, // CHECK-LABEL: func @mismatch_argument_types( func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) { - // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2 cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg1 : i16) @@ -115,7 +115,7 @@ func.func @mismatch_argument_types(%cond : i1, %arg0 : i32, %arg1 : i16) { // CHECK-LABEL: func @mismatch_argument_count( func.func @mismatch_argument_count(%cond : i1, %arg0 : i32) { - // CHECK: cf.cond_br %{{.*}}, ^bb1(%{{.*}}), ^bb2 + // CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2 cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2 diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir index ac034d5..8463194 100644 --- a/mlir/test/Transforms/canonicalize-dce.mlir +++ b/mlir/test/Transforms/canonicalize-dce.mlir @@ -137,10 +137,10 @@ func.func @f(%arg0: f32) { // Test case: Test the mechanics of deleting multiple block arguments. // CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>, %arg3: tensor<4xf32>, %arg4: tensor<5xf32>) -// CHECK-NEXT: "test.br"(%arg1, %arg3)[^bb1] : (tensor<2xf32>, tensor<4xf32>) -// CHECK-NEXT: ^bb1([[VAL0:%.+]]: tensor<2xf32>, [[VAL1:%.+]]: tensor<4xf32>): -// CHECK-NEXT: "foo.print"([[VAL0]]) -// CHECK-NEXT: "foo.print"([[VAL1]]) +// CHECK-NEXT: "test.br"()[^bb1] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: "foo.print"(%arg1) +// CHECK-NEXT: "foo.print"(%arg3) // CHECK-NEXT: return diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir index 58f6cfb..a9d4325 100644 --- a/mlir/test/Transforms/make-isolated-from-above.mlir +++ b/mlir/test/Transforms/make-isolated-from-above.mlir @@ -78,9 +78,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] // CHECK: test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]] // CHECK-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index) -// CHECK-NEXT: cf.br ^bb1(%[[B0]] : index) -// CHECK: ^bb1(%[[B5:.+]]: index) -// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]]) +// CHECK-NEXT: cf.br ^bb1 +// CHECK: ^bb1: +// CHECK: "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B0]]) // CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks( // CLONE1-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index @@ -95,9 +95,9 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index // CLONE1-NEXT: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index) // CLONE1-DAG: %[[C0_0:.+]] = arith.constant 0 : index // CLONE1-DAG: %[[C1_0:.+]] = arith.constant 1 : index -// CLONE1-NEXT: cf.br ^bb1(%[[B0]] : index) -// CLONE1: ^bb1(%[[B3:.+]]: index) -// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]]) +// CLONE1-NEXT: cf.br ^bb1 +// CLONE1: ^bb1: +// CLONE1: "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B0]]) // CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks( // CLONE2-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index @@ -110,6 +110,6 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index // CLONE2-DAG: %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]]) // CLONE2-DAG: %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]] // CLONE2-DAG: %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]] -// CLONE2-NEXT: cf.br ^bb1(%[[B0]] : index) -// CLONE2: ^bb1(%[[B3:.+]]: index) -// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]]) +// CLONE2-NEXT: cf.br ^bb1 +// CLONE2: ^bb1: +// CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]]) diff --git a/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir new file mode 100644 index 0000000..570ff69 --- /dev/null +++ b/mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir @@ -0,0 +1,76 @@ + // RUN: mlir-opt -pass-pipeline='builtin.module(llvm.func(canonicalize{region-simplify=aggressive}))' %s | FileCheck %s + +llvm.func @foo(%arg0: i64) + +llvm.func @rand() -> i1 + +// CHECK-LABEL: func @large_merge_block( +llvm.func @large_merge_block(%arg0: i64) { + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i64) : i64 + + // CHECK: llvm.cond_br %5, ^bb1(%[[C1]], %[[C3]], %[[C4]], %[[C2]] : i64, i64, i64, i64), ^bb1(%[[C4]], %[[C2]], %[[C1]], %[[C3]] : i64, i64, i64, i64) + // CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64, %[[arg3:.*]]: i64): + // CHECK: llvm.cond_br %{{.*}}, ^bb2(%[[arg0]] : i64), ^bb2(%[[arg3]] : i64) + // CHECK: ^bb{{.*}}(%11: i64): + // CHECK: llvm.br ^bb{{.*}} + // CHECK: ^bb{{.*}}: + // CHECK: llvm.call + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}(%[[arg1]] : i64), ^bb{{.*}}(%[[arg2]] : i64) + // CHECK: ^bb{{.*}}: + // CHECK: llvm.call + // CHECK llvm.br ^bb{{.*}} + + %0 = llvm.mlir.constant(0 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(2 : i64) : i64 + %3 = llvm.mlir.constant(3 : i64) : i64 + %4 = llvm.mlir.constant(4 : i64) : i64 + %10 = llvm.icmp "eq" %arg0, %0 : i64 + llvm.cond_br %10, ^bb1, ^bb14 +^bb1: // pred: ^bb0 + %11 = llvm.call @rand() : () -> i1 + llvm.cond_br %11, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + llvm.call @foo(%1) : (i64) -> () + llvm.br ^bb4 +^bb3: // pred: ^bb1 + llvm.call @foo(%2) : (i64) -> () + llvm.br ^bb4 +^bb4: // 2 preds: ^bb2, ^bb3 + %14 = llvm.call @rand() : () -> i1 + llvm.cond_br %14, ^bb5, ^bb6 +^bb5: // pred: ^bb4 + llvm.call @foo(%3) : (i64) -> () + llvm.br ^bb13 +^bb6: // pred: ^bb4 + llvm.call @foo(%4) : (i64) -> () + llvm.br ^bb13 +^bb13: // 2 preds: ^bb11, ^bb12 + llvm.br ^bb27 +^bb14: // pred: ^bb0 + %23 = llvm.call @rand() : () -> i1 + llvm.cond_br %23, ^bb15, ^bb16 +^bb15: // pred: ^bb14 + llvm.call @foo(%4) : (i64) -> () + llvm.br ^bb17 +^bb16: // pred: ^bb14 + llvm.call @foo(%3) : (i64) -> () + llvm.br ^bb17 +^bb17: // 2 preds: ^bb15, ^bb16 + %26 = llvm.call @rand() : () -> i1 + llvm.cond_br %26, ^bb18, ^bb19 +^bb18: // pred: ^bb17 + llvm.call @foo(%2) : (i64) -> () + llvm.br ^bb26 +^bb19: // pred: ^bb17 + llvm.call @foo(%1) : (i64) -> () + llvm.br ^bb26 +^bb26: // 2 preds: ^bb24, ^bb25 + llvm.br ^bb27 +^bb27: // 2 preds: ^bb13, ^bb26 + llvm.return +} |