diff options
5 files changed, 297 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8edaa7d..157dc67 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2207,6 +2207,42 @@ def HoistRedundantVectorTransfersOp : } //===----------------------------------------------------------------------===// +// HoistRedundantVectorBroadcastsOp +//===----------------------------------------------------------------------===// + +def HoistRedundantVectorBroadcastsOp : + Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Hoist vector.extract / vector.broadcasts pairs out of immediately + enclosing scf::ForOp iteratively. + + #### Return modes: + + The operation always succeeds and returns a handle to the transformed + function op. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) "; + + let builders = [ + OpBuilder<(ins "Value":$target)>, + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +//===----------------------------------------------------------------------===// // ConvertConv2DToImg2ColOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h index 186e83a..236c2ce 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -43,6 +43,17 @@ namespace linalg { /// when used on distributed loops with memref semantics! void hoistRedundantVectorTransfers(Operation *root); +/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing +/// scf::ForOp iteratively, if the following conditions are met: +/// 1. The vector.extract operation is applied on an iter_argument, and no +/// other operator is using this argument in the body of the loop. +/// 2. The position of the vector.extract is either a static value, or defined +/// outside of the loop. +/// 3. The vector.broadcast operation is yielded by the loop. +/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper +/// function on the candidate loop above which to hoist. +void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3c3d968..82020c0 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3307,6 +3307,21 @@ transform::HoistRedundantVectorTransfersOp::applyToOne( } //===----------------------------------------------------------------------===// +// HoistRedundantVectorBroadcastsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::HoistRedundantVectorBroadcastsOp::applyToOne( + transform::TransformRewriter &rewriter, mlir::Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + linalg::hoistRedundantVectorBroadcasts(rewriter, target); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// // ConvertConv2DToImg2ColOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 34c9b2c..94f6b60 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -43,6 +43,132 @@ using llvm::dbgs; using namespace mlir; using namespace mlir::linalg; +/// Replace `loop` with a new loop that has a different init operand at +/// position `index`. The body of this loop is moved over to the new loop. +/// +/// `newInitOperands` specifies the replacement "init" operands. +/// `newYieldValue` is the replacement yield value of the loop at position +/// `index`. +static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, + scf::ForOp loop, + Value newInitOperand, + unsigned index, + Value newYieldValue) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop.getOperation()); + auto inits = llvm::to_vector(loop.getInits()); + + // Replace the init value with the new operand. + assert(index < inits.size()); + inits[index] = newInitOperand; + + scf::ForOp newLoop = rewriter.create<scf::ForOp>( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + inits, [](OpBuilder &, Location, Value, ValueRange) {}); + + // Generate the new yield with the replaced operand. + auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); + yieldOp.setOperand(index, newYieldValue); + + // Move the loop body to the new op. + rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments()); + + // Replace the old loop. + rewriter.replaceOp(loop.getOperation(), newLoop->getResults()); + return newLoop; +} + +// Hoist out a pair of corresponding vector.extract+vector.broadcast +// operations. This function transforms a loop like this: +// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) { +// %e = vector.extract %iarg : t1 to t2 +// %u = "some_use"(%e) : (t2) -> t2 +// %b = vector.broadcast %u : t2 to t1 +// scf.yield %b : t1 +// } +// into the following: +// %e = vector.extract %v: t1 to t2 +// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) { +// %u' = "some_use"(%iarg) : (t2) -> t2 +// scf.yield %u' : t2 +// } +// %res = vector.broadcast %res' : t2 to t1 +void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter, + Operation *root) { + bool changed = true; + while (changed) { + changed = false; + // First move loop invariant ops outside of their loop. This needs to be + // done before as we cannot move ops without interrupting the function walk. + root->walk( + [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); + + root->walk([&](vector::ExtractOp extractOp) { + LLVM_DEBUG(DBGS() << "Candidate for hoisting: " + << *extractOp.getOperation() << "\n"); + + auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp()); + if (!loop) + return WalkResult::advance(); + + // Check that the vector to extract from is a BlockArgument. + auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector()); + if (!blockArg) + return WalkResult::advance(); + + // Check that the blockArg is an iter_arg of the loop. + OpOperand *initArg = loop.getTiedLoopInit(blockArg); + if (!initArg) + return WalkResult::advance(); + + // If the iter_arg does not have only one use, it won't be possible to + // hoist the extractOp out. + if (!blockArg.hasOneUse()) + return WalkResult::advance(); + + unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars(); + + // Check that the loop yields a broadcast that has just one use. + Operation *yieldedVal = + loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp(); + auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal); + if (!broadcast || !broadcast.getResult().hasOneUse()) + return WalkResult::advance(); + + LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n"); + + Type broadcastInputType = broadcast.getSourceType(); + if (broadcastInputType != extractOp.getType()) + return WalkResult::advance(); + + // The position of the extract must be defined outside of the loop if + // it is dynamic. + for (auto operand : extractOp.getDynamicPosition()) + if (!loop.isDefinedOutsideOfLoop(operand)) + return WalkResult::advance(); + + rewriter.modifyOpInPlace(broadcast, [&] { + extractOp.getVectorMutable().assign(initArg->get()); + }); + loop.moveOutOfLoop(extractOp); + rewriter.moveOpAfter(broadcast, loop); + + scf::ForOp newLoop = replaceWithDifferentYield( + rewriter, loop, extractOp.getResult(), index, broadcast.getSource()); + + LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n"); + + rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast); + rewriter.modifyOpInPlace( + broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); }); + + changed = true; + return WalkResult::interrupt(); + }); + } +} + static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop) { Value source = transferRead.getSource(); diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 550ffbc..241b8a4 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// Test hoisting of vector.extract/vector.broadcast pairs + +// CHECK-LABEL: func.func @hoist_vector_broadcasts +// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> { +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32> +// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { +// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> +// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32> +// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32> + +func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> { + %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { + %extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32> + %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> + %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> + scf.yield %broadcast : vector<3x4xf32> + } + return %bcast_vec : vector<3x4xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_broadcasts %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position + +// CHECK-LABEL: func.func @hoist_vector_broadcasts +// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> { +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32> +// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { +// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> +// CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32> +// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32> + +func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> { + %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { + %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32> + %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> + %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> + scf.yield %broadcast : vector<3x4xf32> + } + return %bcast_vec : vector<3x4xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_broadcasts %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args + +// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple +// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>, +// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) { +// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32> +// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32> +// CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} { +// CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32> +// CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32> +// CHECK-NEXT: } +// CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32> +// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32> +// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32> + +func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) { + %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) { + %extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32> + %extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32> + %use1 = "some_use1"(%extract1) : (vector<4xf32>) -> vector<4xf32> + %use2 = "some_use2"(%extract2) : (vector<5xf32>) -> vector<5xf32> + %broadcast1 = vector.broadcast %use1 : vector<4xf32> to vector<3x4xf32> + %broadcast2 = vector.broadcast %use2 : vector<5xf32> to vector<3x5xf32> + scf.yield %broadcast1, %broadcast2 : vector<3x4xf32>,vector<3x5xf32> + } + return %bcast_vec#0, %bcast_vec#1 : vector<3x4xf32>, vector<3x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_broadcasts %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} |
