diff options
-rw-r--r-- | mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h | 39 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/Passes.h | 3 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/Passes.td | 5 | ||||
-rw-r--r-- | mlir/lib/Transforms/LoopInvariantCodeMotion.cpp | 20 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp | 254 | ||||
-rw-r--r-- | mlir/test/Transforms/loop-invariant-subset-hoisting.mlir | 237 | ||||
-rw-r--r-- | utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 |
8 files changed, 556 insertions, 4 deletions
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h index c7b816e..5790540 100644 --- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h +++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h @@ -71,6 +71,45 @@ size_t moveLoopInvariantCode( /// methods provided by the interface. size_t moveLoopInvariantCode(LoopLikeOpInterface loopLike); +/// Hoist loop-invariant tensor subsets (subset extraction and subset insertion +/// ops) from loop-like ops. Extraction ops are moved before the loop. Insertion +/// ops are moved after the loop. The loop body operates on newly added region +/// iter_args (one per extraction-insertion pair). +/// +/// A subset extraction op (`SubsetExtractionOpInterface`) extracts from a +/// tensor value at a subset. The result of the op may have an arbitrary type, +/// i.e., not necessarily a tensor type. Example: "tensor.extract_slice". +/// +/// A subset insertion op (`SubsetInsertionOpInterface`) inserts into a tensor +/// value ("destination") at a subset. Example: "tensor.insert_slice". +/// +/// Matching extraction-insertion subset ops can be hoisted from a loop if there +/// are no other ops within the loop that operate on the same or on an +/// overlapping subset. In particular, non-subset ops can prevent hoisting +/// because the analysis does not know what subset they operate on. +/// +/// Example: +/// ``` +/// %r = scf.for ... iter_args(%t = %a) -> (tensor<?xf32>) { +/// %0 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> +/// %1 = "test.foo"(%0) : (tensor<5xf32>) -> (tensor<5xf32>) +/// %2 = tensor.insert_slice %1 into %t[0][5][1] +/// : tensor<5xf32> into tensor<?xf32> +/// scf.yield %2 : tensor<?xf32> +/// } +/// ``` +/// Is rewritten to: +/// ``` +/// %0 = tensor.extract_slice %a[0][5][1] : tensor<?xf32> to tensor<5xf32> +/// %new_loop:2 = scf.for ... iter_args(%t = %a, %h = %0) -> (tensor<?xf32>) { +/// %1 = "test.foo"(%h) : (tensor<5xf32>) -> (tensor<5xf32>) +/// scf.yield %t, %2 : tensor<?xf32>, tensor<5xf32> +/// } +/// %r = tensor.insert_slice %new_loop#1 into %new_loop#0 +/// : tensor<5xf32> into tensor<?xf32> +/// ``` +LoopLikeOpInterface hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOPINVARIANTCODEMOTIONUTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 320932b..11f5b23 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -78,6 +78,9 @@ std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(); /// instructions out of the loop. std::unique_ptr<Pass> createLoopInvariantCodeMotionPass(); +/// Creates a pass that hoists loop-invariant subset ops. +std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass(); + /// Creates a pass to strip debug information from a function. std::unique_ptr<Pass> createStripDebugInfoPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 26d2ff3..2d2d54fb 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -329,6 +329,11 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { let constructor = "mlir::createLoopInvariantCodeMotionPass()"; } +def LoopInvariantSubsetHoisting : Pass<"loop-invariant-subset-hoisting"> { + let summary = "Hoist loop invariant subset ops outside of the loop"; + let constructor = "mlir::createLoopInvariantSubsetHoistingPass()"; +} + def Mem2Reg : Pass<"mem2reg"> { let summary = "Promotes memory slots into values."; let description = [{ diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp index 854fde0..e6d8af8 100644 --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -18,6 +18,7 @@ namespace mlir { #define GEN_PASS_DEF_LOOPINVARIANTCODEMOTION +#define GEN_PASS_DEF_LOOPINVARIANTSUBSETHOISTING #include "mlir/Transforms/Passes.h.inc" } // namespace mlir @@ -29,6 +30,12 @@ struct LoopInvariantCodeMotion : public impl::LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> { void runOnOperation() override; }; + +struct LoopInvariantSubsetHoisting + : public impl::LoopInvariantSubsetHoistingBase< + LoopInvariantSubsetHoisting> { + void runOnOperation() override; +}; } // namespace void LoopInvariantCodeMotion::runOnOperation() { @@ -39,6 +46,19 @@ void LoopInvariantCodeMotion::runOnOperation() { [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); } +void LoopInvariantSubsetHoisting::runOnOperation() { + // Walk through all loops in a function in innermost-loop-first order. This + // way, we first hoist from the inner loop, and place the ops in the outer + // loop, which in turn can be further hoisted from. + getOperation()->walk([&](LoopLikeOpInterface loopLike) { + (void)hoistLoopInvariantSubsets(loopLike); + }); +} + std::unique_ptr<Pass> mlir::createLoopInvariantCodeMotionPass() { return std::make_unique<LoopInvariantCodeMotion>(); } + +std::unique_ptr<Pass> mlir::createLoopInvariantSubsetHoistingPass() { + return std::make_unique<LoopInvariantSubsetHoisting>(); +} diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index efc7a51..1c608e0 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -20,5 +20,6 @@ add_mlir_library(MLIRTransformUtils MLIRFunctionInterfaces MLIRLoopLikeInterface MLIRSideEffectInterfaces + MLIRSubsetOpInterface MLIRRewrite ) diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index 080492d..01318cf 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -11,9 +11,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" + #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/SubsetOpInterface.h" #include "llvm/Support/Debug.h" #include <queue> @@ -26,7 +29,7 @@ using namespace mlir; /// loop (by means of calling definedOutside). /// - the op has no side-effects. static bool canBeHoisted(Operation *op, - function_ref<bool(Value)> definedOutside) { + function_ref<bool(OpOperand &)> condition) { // Do not move terminators. if (op->hasTrait<OpTrait::IsTerminator>()) return false; @@ -35,11 +38,11 @@ static bool canBeHoisted(Operation *op, // defined outside of the loop or in a nested region, but not at the level of // the loop body. auto walkFn = [&](Operation *child) { - for (Value operand : child->getOperands()) { + for (OpOperand &operand : child->getOpOperands()) { // Ignore values defined in a nested region. - if (op->isAncestor(operand.getParentRegion()->getParentOp())) + if (op->isAncestor(operand.get().getParentRegion()->getParentOp())) continue; - if (!definedOutside(operand)) + if (!condition(operand)) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -47,6 +50,12 @@ static bool canBeHoisted(Operation *op, return !op->walk(walkFn).wasInterrupted(); } +static bool canBeHoisted(Operation *op, + function_ref<bool(Value)> definedOutside) { + return canBeHoisted( + op, [&](OpOperand &operand) { return definedOutside(operand.get()); }); +} + size_t mlir::moveLoopInvariantCode( ArrayRef<Region *> regions, function_ref<bool(Value, Region *)> isDefinedOutsideRegion, @@ -105,3 +114,240 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); } + +namespace { +/// Helper data structure that keeps track of equivalent/disjoint subset ops. +class MatchingSubsets { +public: + /// Insert a subset op. + void insert(SubsetOpInterface op) { + allSubsetOps.push_back(op); + if (auto extractionOp = + dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) + insertExtractionOp(extractionOp); + if (auto insertionOp = + dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) + insertInsertionOp(insertionOp); + } + + /// Return a range of matching extraction-insertion subset ops. If there is no + /// matching extraction/insertion op, the respective value is empty. Ops are + /// skipped if there are other subset ops that are not guaranteed to operate + /// on disjoint subsets. + auto getHoistableSubsetOps() { + return llvm::make_filter_range( + llvm::zip(extractions, insertions), [&](auto pair) { + auto [extractionOp, insertionOp] = pair; + // Hoist only if the extracted and inserted values have the same type. + if (extractionOp && insertionOp && + extractionOp->getResult(0).getType() != + insertionOp.getSourceOperand().get().getType()) + return false; + // Hoist only if there are no conflicting subset ops. + return allDisjoint(extractionOp, insertionOp); + }); + } + +private: + /// Helper function for equivalence of tensor values. Since only insertion + /// subset ops (that are also destination style ops) are followed when + /// traversing the SSA use-def chain, all tensor values are equivalent. + static bool isEquivalent(Value v1, Value v2) { return true; } + + /// Return "true" if the subsets of the given extraction and insertion ops + /// are operating disjoint from the subsets that all other known subset ops + /// are operating on. + bool allDisjoint(SubsetExtractionOpInterface extractionOp, + SubsetInsertionOpInterface insertionOp) const { + for (SubsetOpInterface other : allSubsetOps) { + if (other == extractionOp || other == insertionOp) + continue; + if (extractionOp && + !other.operatesOnDisjointSubset(extractionOp, isEquivalent)) + return false; + if (insertionOp && + !other.operatesOnDisjointSubset(insertionOp, isEquivalent)) + return false; + } + return true; + } + + /// Insert a subset extraction op. If the subset is equivalent to an existing + /// subset insertion op, pair them up. (If there is already a paired up subset + /// extraction op, overwrite the subset extraction op.) + void insertExtractionOp(SubsetExtractionOpInterface extractionOp) { + for (auto it : llvm::enumerate(insertions)) { + if (!it.value()) + continue; + auto other = cast<SubsetOpInterface>(it.value().getOperation()); + if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) { + extractions[it.index()] = extractionOp; + return; + } + } + // There is no known equivalent insertion op. Create a new entry. + extractions.push_back(extractionOp); + insertions.push_back({}); + } + + /// Insert a subset insertion op. If the subset is equivalent to an existing + /// subset extraction op, pair them up. (If there is already a paired up + /// subset insertion op, overwrite the subset insertion op.) + void insertInsertionOp(SubsetInsertionOpInterface insertionOp) { + for (auto it : llvm::enumerate(extractions)) { + if (!it.value()) + continue; + auto other = cast<SubsetOpInterface>(it.value().getOperation()); + if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) { + insertions[it.index()] = insertionOp; + return; + } + } + // There is no known equivalent extraction op. Create a new entry. + extractions.push_back({}); + insertions.push_back(insertionOp); + } + + SmallVector<SubsetExtractionOpInterface> extractions; + SmallVector<SubsetInsertionOpInterface> insertions; + SmallVector<SubsetOpInterface> allSubsetOps; +}; +} // namespace + +/// If the given value has a single use by an op that is a terminator, return +/// that use. Otherwise, return nullptr. +static OpOperand *getSingleTerminatorUse(Value value) { + if (!value.hasOneUse()) + return nullptr; + OpOperand &use = *value.getUses().begin(); + if (use.getOwner()->hasTrait<OpTrait::IsTerminator>()) + return &use; + return nullptr; +} + +/// Hoist all subset ops that operate on the idx-th region iter_arg of the given +/// loop-like op and index into loop-invariant subset locations. Return the +/// newly created loop op (that has extra iter_args) or the original loop op if +/// nothing was hoisted. +static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike, + BlockArgument iterArg) { + IRRewriter rewriter(loopLike.getContext()); + assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); + auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); + int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); + Value value = iterArg; + MatchingSubsets subsets; + + // Traverse use-def chain. Subset ops can be hoisted only if all ops along the + // use-def chain starting from the region iter_arg are subset extraction or + // subset insertion ops. The chain must terminate at the corresponding yield + // operand (e.g., no swapping of iter_args). + OpOperand *yieldedOperand = nullptr; + // Iterate until the single use of the current SSA value is a terminator, + // which is expected to be the yielding operation of the loop. + while (!(yieldedOperand = getSingleTerminatorUse(value))) { + Value nextValue = {}; + + for (OpOperand &use : value.getUses()) { + auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner()); + if (!subsetOp) + return loopLike; + subsets.insert(subsetOp); + + if (auto insertionOp = + dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) { + // The value must be used as a destination. (In case of a source, the + // entire tensor would be read, which would prevent any hoisting.) + if (&use != &insertionOp.getDestinationOperand()) + return loopLike; + // There must be a single use-def chain from the region iter_arg to the + // terminator. I.e., only one insertion op. Branches are not supported. + if (nextValue) + return loopLike; + nextValue = insertionOp.getUpdatedDestination(); + } + } + + // Nothing can be hoisted if the chain does not continue with loop yielding + // op or a subset insertion op. + if (!nextValue) + return loopLike; + value = nextValue; + } + + // Hoist only if the SSA use-def chain ends in the yielding terminator of the + // loop and the yielded value is the `idx`-th operand. (I.e., there is no + // swapping yield.) + if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand) + return loopLike; + + // Hoist all matching extraction-insertion pairs one-by-one. + for (auto it : subsets.getHoistableSubsetOps()) { + auto extractionOp = std::get<0>(it); + auto insertionOp = std::get<1>(it); + + // Ops cannot be hoisted if they depend on loop-variant values. + if (extractionOp) { + if (!canBeHoisted(extractionOp, [&](OpOperand &operand) { + return loopLike.isDefinedOutsideOfLoop(operand.get()) || + &operand == &extractionOp.getSourceOperand(); + })) + extractionOp = {}; + } + if (insertionOp) { + if (!canBeHoisted(insertionOp, [&](OpOperand &operand) { + return loopLike.isDefinedOutsideOfLoop(operand.get()) || + &operand == &insertionOp.getSourceOperand() || + &operand == &insertionOp.getDestinationOperand(); + })) + insertionOp = {}; + } + + // Only hoist extraction-insertion pairs for now. Standalone extractions/ + // insertions that are loop-invariant could be hoisted, but there may be + // easier ways to canonicalize the IR. + if (extractionOp && insertionOp) { + // Create a new loop with an additional iter_arg. + NewYieldValuesFn newYieldValuesFn = + [&](OpBuilder &b, Location loc, + ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { + return {insertionOp.getSourceOperand().get()}; + }; + FailureOr<LoopLikeOpInterface> newLoop = + loopLike.replaceWithAdditionalYields( + rewriter, extractionOp.getResult(), + /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn); + if (failed(newLoop)) + return loopLike; + loopLike = *newLoop; + + // Hoist the extraction/insertion ops. + iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; + OpResult loopResult = loopLike.getTiedLoopResult(iterArg); + OpResult newLoopResult = loopLike.getLoopResults()->back(); + extractionOp->moveBefore(loopLike); + insertionOp->moveAfter(loopLike); + insertionOp.getUpdatedDestination().replaceAllUsesWith( + insertionOp.getDestinationOperand().get()); + extractionOp.getSourceOperand().set( + loopLike.getTiedLoopInit(iterArg)->get()); + loopResult.replaceAllUsesWith(insertionOp.getUpdatedDestination()); + insertionOp.getSourceOperand().set(newLoopResult); + insertionOp.getDestinationOperand().set(loopResult); + } + } + + return loopLike; +} + +LoopLikeOpInterface +mlir::hoistLoopInvariantSubsets(LoopLikeOpInterface loopLike) { + // Note: As subset ops are getting hoisted, the number of region iter_args + // increases. This can enable further hoisting opportunities on the new + // iter_args. + for (int64_t i = 0; + i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) { + loopLike = hoistSubsetAtIterArg(loopLike, loopLike.getRegionIterArgs()[i]); + } + return loopLike; +} diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir new file mode 100644 index 0000000..5cded4c --- /dev/null +++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir @@ -0,0 +1,237 @@ +// RUN: mlir-opt %s -split-input-file -loop-invariant-subset-hoisting | FileCheck %s + +// CHECK-LABEL: func @hoist_matching_extract_insert( +// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> +func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]] + // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]]) + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + // CHECK: tensor.extract_slice %[[t]][9] [5] [1] + %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32> + "test.foo"(%standalone) : (tensor<5xf32>) -> () + + %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) + %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> + // CHECK: scf.yield %[[t]], %[[foo]] + scf.yield %3 : tensor<?xf32> + } + // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0 + + // CHECK: return %[[insert]] + return %0 : tensor<?xf32> +} + +// ----- + +func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]] + // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]] + // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]]) + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> + %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32> + + // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]]) + %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>) + + %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32> + %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> + + // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]] + scf.yield %insert2 : tensor<?xf32> + } + // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1] + // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1] + + // CHECK: return %[[insert1]] + return %0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @hoist_matching_chain( +// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> +func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + %sz = "test.foo"() : () -> (index) + + // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1] + // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1] + // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]]) + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32> + %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]]) + // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]]) + %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>) + %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) + %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32> + %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32> + // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]] + scf.yield %6 : tensor<?xf32> + } + // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1] + // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1] + + // CHECK: return %[[insert1]] + return %0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @do_not_hoist_overlapping_subsets( +func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + %sz1 = "test.foo"() : () -> (index) + %sz2 = "test.foo"() : () -> (index) + + // CHECK: scf.for + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + // These two slices are potentially overlapping. Do not hoist. + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32> + %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32> + // CHECK: "test.foo" + // CHECK: "test.foo" + %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>) + %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>) + // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice + %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32> + %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32> + // CHECK: scf.yield + scf.yield %6 : tensor<?xf32> + } + + return %0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @multiple_yields( +// CHECK-SAME: %[[arg:.*]]: tensor<?xf32> +func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: %[[extract1:.*]] = tensor.extract_slice + // CHECK: %[[extract2:.*]] = tensor.extract_slice + // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]]) + %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg) + -> (tensor<?xf32>, tensor<?xf32>) { + %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32> + %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK: "test.foo" + // CHECK: "test.foo" + %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) + %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32> + %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32> + // CHECK: scf.yield + scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32> + } + // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice + + return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @do_not_hoist_swapping_yields( +func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: scf.for + %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg) + -> (tensor<?xf32>, tensor<?xf32>) { + // CHECK: tensor.extract_slice + // CHECK: tensor.extract_slice + %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32> + %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK: "test.foo" + // CHECK: "test.foo" + %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>) + // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice + %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32> + %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32> + // Swapping yields: do not hoist. + // CHECK: scf.yield + scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32> + } + + return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @non_subset_op( +func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: scf.for + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + // If any value along the use-def chain from the region iter_arg to the + // terminator is used by a non-subset op, no subset op along that chain can + // be hoisted. That is because it is unknown which parts of the value are + // accessed by the non-subset op. + // CHECK: "test.non_subset_op" + "test.non_subset_op"(%t) : (tensor<?xf32>) -> () + // CHECK: tensor.extract_slice + %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK: "test.foo" + %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + // CHECK: tensor.insert_slice + %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> + // CHECK: scf.yield + scf.yield %3 : tensor<?xf32> + } + + return %0 : tensor<?xf32> +} + +// ----- + +// CHECK-LABEL: func @non_loop_invariant_subset_op( +func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> { + %lb = "test.foo"() : () -> (index) + %ub = "test.foo"() : () -> (index) + %step = "test.foo"() : () -> (index) + + // CHECK: scf.for + %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { + // Subset ops that are not loop-invariant cannot be hoisted. + // CHECK: tensor.extract_slice + %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32> + // CHECK: "test.foo" + %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) + // CHECK: tensor.insert_slice + %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32> + // CHECK: scf.yield + scf.yield %3 : tensor<?xf32> + } + + return %0 : tensor<?xf32> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 0448ecbe..0a2ae42 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7035,6 +7035,7 @@ cc_library( ":MemorySlotInterfaces", ":Rewrite", ":SideEffectInterfaces", + ":SubsetOpInterface", ":Support", ":TransformsPassIncGen", ":config", |