diff options
author | Diego Caballero <diego.caballero@intel.com> | 2020-11-18 13:24:39 -0800 |
---|---|---|
committer | Diego Caballero <diego.caballero@intel.com> | 2020-11-18 13:50:32 -0800 |
commit | c1ba9c43adb7ee101048e88ab33c94a1ceda398e (patch) | |
tree | 9708d35e3710835de10b8b35c9e7cea865970715 /mlir/lib/Transforms/LoopFusion.cpp | |
parent | 5f2c5541f78750c21004e0172f13db4632966fd3 (diff) | |
download | llvm-c1ba9c43adb7ee101048e88ab33c94a1ceda398e.zip llvm-c1ba9c43adb7ee101048e88ab33c94a1ceda398e.tar.gz llvm-c1ba9c43adb7ee101048e88ab33c94a1ceda398e.tar.bz2 |
[mlir][Affine] Refactor affine fusion code in pass to utilities
Refactoring/clean-up step needed to add support for producer-consumer fusion
with multi-store producer loops and, in general, to implement more general
loop fusion strategies in Affine. It introduces the following changes:
- AffineLoopFusion pass now uses loop fusion utilities more broadly to compute
fusion legality (canFuseLoops utility) and perform the fusion transformation
(fuseLoops utility).
- Loop fusion utilities have been extended to deal with AffineLoopFusion
requirements and assumptions while preserving both loop fusion utilities and
AffineLoopFusion current functionality within a unified implementation.
'FusionStrategy' has been introduced for this purpose and, in the future, it
will allow us to have a single loop fusion core implementation that will produce
different fusion outputs depending on the strategy used.
- Improve separation of concerns for legality and profitability analysis:
'isFusionProfitable' no longer filters out illegal scenarios that 'canFuse'
didn't detect, or the other way around. 'canFuse' now takes loop dependences
into account to determine the fusion loop depth (producer-consumer fusion only).
- As a result, maximal fusion now doesn't require any profitability analysis.
- Slices are now computed only once and reused across the legality, profitability
and fusion transformation steps (producer-consumer).
- Refactor some utilities and remove redundant copies of them.
This patch is NFCI and should preserve the existing functionality of both the
AffineLoopFusion pass and the affine fusion utilities.
Reviewed By: andydavis1, bondhugula
Differential Revision: https://reviews.llvm.org/D90798
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 519 |
1 files changed, 216 insertions, 303 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index ed79be0..6716260 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref, srcLoads->swap(srcLoadsToKeep); } -// Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) { - unsigned numOps = ops.size(); - assert(numOps > 0); - - std::vector<SmallVector<AffineForOp, 4>> loops(numOps); - unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); - for (unsigned i = 0; i < numOps; ++i) { - getLoopIVs(*ops[i], &loops[i]); - loopDepthLimit = - std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); - } - - unsigned loopDepth = 0; - for (unsigned d = 0; d < loopDepthLimit; ++d) { - unsigned i; - for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) - break; - } - if (i != numOps) - break; - ++loopDepth; - } - return loopDepth; -} - -// Returns the maximum loop depth at which no dependences between 'loadOpInsts' -// and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts, - ArrayRef<Operation *> storeOpInsts) { - // Merge loads and stores into the same array. - SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); - ops.append(storeOpInsts.begin(), storeOpInsts.end()); - - // Compute the innermost common loop depth for loads and stores. - unsigned loopDepth = getInnermostCommonLoopDepth(ops); - - // Return common loop depth for loads if there are no store ops. - if (storeOpInsts.empty()) - return loopDepth; - - // Check dependences on all pairs of ops in 'ops' and store the minimum - // loop depth at which a dependence is satisfied. - for (unsigned i = 0, e = ops.size(); i < e; ++i) { - auto *srcOpInst = ops[i]; - MemRefAccess srcAccess(srcOpInst); - for (unsigned j = 0; j < e; ++j) { - auto *dstOpInst = ops[j]; - MemRefAccess dstAccess(dstOpInst); - - unsigned numCommonLoops = - getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); - for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - FlatAffineConstraints dependenceConstraints; - // TODO: Cache dependence analysis results, check cache here. - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); - if (hasDependence(result)) { - // Store minimum loop depth and break because we want the min 'd' at - // which there is a dependence. - loopDepth = std::min(loopDepth, d - 1); - break; - } - } - } - } - return loopDepth; -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -1077,14 +1006,16 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // The argument 'srcStoreOpInst' is used to calculate the storage reduction on // the memref being produced and consumed, which is an input to the cost model. // For producer-consumer fusion, 'srcStoreOpInst' will be the same as -// 'srcOpInst', as we are slicing w.r.t to that producer. -// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which -// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' -// will be the unique store op in the src node, which will be used to check -// that the write region is the same after input-reuse fusion. -// Returns true if it is profitable to fuse the candidate loop nests. Returns -// false otherwise. `dstLoopDepth` is set to the most profitable depth at which -// to materialize the source loop nest slice. +// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse +// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the +// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the +// unique store op in the src node, which will be used to check that the write +// region is the same after input-reuse fusion. Computation slices are provided +// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which +// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is +// profitable to fuse the candidate loop nests. Returns false otherwise. +// `dstLoopDepth` is set to the most profitable depth at which to materialize +// the source loop nest slice. // The profitability model executes the following steps: // *) Computes the backward computation slice at 'srcOpInst'. This // computation slice of the loop nest surrounding 'srcOpInst' is @@ -1112,9 +1043,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // is lower. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, ArrayRef<Operation *> dstLoadOpInsts, - ArrayRef<Operation *> dstStoreOpInsts, - ComputationSliceState *sliceState, - unsigned *dstLoopDepth, bool maximalFusion, + ArrayRef<ComputationSliceState> depthSliceUnions, + unsigned maxLegalFusionDepth, + unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; @@ -1124,10 +1055,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, }; }); + if (maxLegalFusionDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); + return false; + } + // Compute cost of sliced and unsliced src loop nest. SmallVector<AffineForOp, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); - unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; @@ -1142,19 +1077,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) return false; - // Compute the maximum loop depth at which we can can insert the src slice - // and still satisfy dest loop nest dependences, for producer-consumer fusion. - unsigned maxDstLoopDepth = - (srcOpInst == srcStoreOpInst) - ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) - : dstLoopIVs.size(); - if (maxDstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); - return false; - } - // Search for min cost value for 'dstLoopDepth'. At each value of - // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst @@ -1163,8 +1087,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, double maxStorageReduction = 0.0; Optional<uint64_t> sliceMemEstimate = None; - SmallVector<ComputationSliceState, 4> sliceStates; - sliceStates.resize(maxDstLoopDepth); // The best loop depth at which to materialize the slice. Optional<unsigned> bestDstLoopDepth = None; @@ -1190,21 +1112,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // Evaluate all depth choices for materializing the slice in the destination // loop nest. - for (unsigned i = maxDstLoopDepth; i >= 1; --i) { - // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. - if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, - /*loopDepth=*/i, - /*numCommonLoops=*/0, - /*isBackwardSlice=*/true, - &sliceStates[i - 1]))) { - LLVM_DEBUG(llvm::dbgs() - << "computeSliceUnion failed for loopDepth: " << i << "\n"); + for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { + // Skip slice union if it wasn't computed for this depth. + if (depthSliceUnions[i - 1].isEmpty()) continue; - } int64_t fusedLoopNestComputeCost; if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], - dstLoopNestStats, &sliceStates[i - 1], + dstLoopNestStats, depthSliceUnions[i - 1], &fusedLoopNestComputeCost)) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; @@ -1216,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1; // Determine what the slice write MemRefRegion would be, if the src loop - // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop - // nest at loop depth 'i' + // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst + // loop nest at loop depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1]))) { + &depthSliceUnions[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); @@ -1269,8 +1184,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // (as per computeToleranceThreshold), we will simply pick the one that // reduces the intermediary size the most. if ((storageReduction > maxStorageReduction) && - (maximalFusion || - (additionalComputeFraction < computeToleranceThreshold))) { + (additionalComputeFraction < computeToleranceThreshold)) { maxStorageReduction = storageReduction; bestDstLoopDepth = i; minFusedLoopNestComputeCost = fusedLoopNestComputeCost; @@ -1278,10 +1192,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, } } - // A simple cost model: fuse if it reduces the memory footprint. If - // -maximal-fusion is set, fuse nevertheless. + // A simple cost model: fuse if it reduces the memory footprint. - if (!maximalFusion && !bestDstLoopDepth.hasValue()) { + if (!bestDstLoopDepth.hasValue()) { LLVM_DEBUG( llvm::dbgs() << "All fusion choices involve more than the threshold amount of " @@ -1310,33 +1223,30 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, Optional<double> storageReduction = None; - if (!maximalFusion) { - if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { - LLVM_DEBUG( - llvm::dbgs() - << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); - return false; - } + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } - auto srcMemSizeVal = srcMemSize.getValue(); - auto dstMemSizeVal = dstMemSize.getValue(); + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); - assert(sliceMemEstimate.hasValue() && "expected value"); - auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + assert(sliceMemEstimate.hasValue() && "expected value"); + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); - LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" - << " dst mem: " << dstMemSizeVal << "\n" - << " fused mem: " << fusedMem << "\n" - << " slice mem: " << sliceMemEstimate << "\n"); + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); - if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { - LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); - return false; - } - storageReduction = - 100.0 * - (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); + if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); double additionalComputeFraction = 100.0 * (minFusedLoopNestComputeCost / @@ -1355,24 +1265,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, llvm::dbgs() << msg.str(); }); - // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; - sliceState->lbs = bestSliceState->lbs; - sliceState->ubs = bestSliceState->ubs; - sliceState->lbOperands = bestSliceState->lbOperands; - sliceState->ubOperands = bestSliceState->ubOperands; - - // Canonicalize slice bound affine maps. - for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (sliceState->lbs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->lbs[i], - &sliceState->lbOperands[i]); - } - if (sliceState->ubs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->ubs[i], - &sliceState->ubOperands[i]); - } - } return true; } @@ -1592,138 +1484,142 @@ public: if (insertPointInst == nullptr) continue; + auto srcAffineForOp = cast<AffineForOp>(srcNode->op); + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + // Compute the innermost common loop depth for dstNode loads/stores. - SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(), - dstNode->loads.end()); - dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); - unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + SmallVector<Operation *, 2> dstMemrefOps; + for (Operation *op : dstNode->loads) + if (cast<AffineReadOpInterface>(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + for (Operation *op : dstNode->stores) + if (cast<AffineWriteOpInterface>(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); + // Check the feasibility of fusing src loop nest into dst loop nest // at loop depths in range [1, dstLoopDepthTest]. - // TODO: Use slice union computation and union of memref - // read/write regions to cost model and fusion. - bool canFuse = false; + unsigned maxLegalFusionDepth = 0; + SmallVector<ComputationSliceState, 8> depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { - ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops( - cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op), - /*dstLoopDepth=*/i, &sliceUnion); + srcAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + if (result.value == FusionResult::Success) - canFuse = true; + maxLegalFusionDepth = i; } - // Skip if fusion is not feasible at all loop depths. - if (!canFuse) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; - // Gather 'dstNode' store ops to 'memref'. - SmallVector<Operation *, 2> dstStoreOpInsts; - for (auto *storeOpInst : dstNode->stores) - if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref) - dstStoreOpInsts.push_back(storeOpInst); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; - // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, - &bestDstLoopDepth, maximalFusion, - computeToleranceThreshold)) + // Check if fusion would be profitable. We skip profitability analysis + // for maximal fusion since we already know the maximal legal depth to + // fuse. + unsigned bestDstLoopDepth = maxLegalFusionDepth; + if (!maximalFusion && + !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) continue; + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Missing slice union for depth"); + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest) { - LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" - << *sliceLoopNest.getOperation() << "\n"); - // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = cast<AffineForOp>(dstNode->op); - if (insertPointInst != dstAffineForOp.getOperation()) { - dstAffineForOp.getOperation()->moveBefore(insertPointInst); - } - // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref, - createPrivateMemref); - - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - if (createPrivateMemref) { - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector<Operation *, 4> storesForMemref; - for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == - memref) - storesForMemref.push_back(storeOpInst); - } - // TODO: Use union of memref write regions to compute - // private memref footprint. - auto newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef.getDefiningOp()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + fuseLoops(srcAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + LLVM_DEBUG(llvm::dbgs() + << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":\n" + << dstAffineForOp << "\n"); + + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + if (insertPointInst != dstAffineForOp.getOperation()) + dstAffineForOp.getOperation()->moveBefore(insertPointInst); + + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); + + // Collect slice loop stats. + LoopNestStateCollector dstForCollector; + dstForCollector.collect(dstAffineForOp); + if (createPrivateMemref) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector<Operation *, 4> storesForMemref; + for (auto *storeOpInst : dstForCollector.storeOpInsts) { + if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == + memref) + storesForMemref.push_back(storeOpInst); } + // TODO: Use union of memref write regions to compute + // private memref footprint. + auto newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + } - // Collect dst loop stats after memref privatization transformation. - LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getOperation()); - - // Add new load ops to current Node load op list 'loads' to - // continue fusing based on new operands. - for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - // NOTE: Change 'loads' to a hash set in case efficiency is an - // issue. We still use a vector since it's expected to be small. - if (!llvm::is_contained(loads, loadOpInst)) - loads.push_back(loadOpInst); - } - // Clear visited memrefs after fusion so that previously visited src - // nodes are considered for fusion again in the context of the new - // fused node. - // TODO: This shouldn't be necessary if we visited candidates in the - // dependence graph in post-order or once we fully support - // multi-store producers. Currently, in a multi-store producer - // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the - // multiple outgoing edges. However, after fusing B+C, A has a - // single outgoing edge and can be fused if we revisit it in the - // context of the new fused B+C node. - visitedMemrefs.clear(); - - // Clear and add back loads and stores. - mdg->clearNodeLoadAndStores(dstNode->id); - mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, - dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has outgoing dependence - // edges, and if it does not write to a memref which escapes the - // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has - // been fused into 'dstNode' and write region of 'dstNode' covers - // the write region of 'srcNode', and 'srcNode' has no other users - // so it is safe to remove. - if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { - mdg->removeNode(srcNode->id); - srcNode->op->erase(); - } else { - // Add remaining users of 'oldMemRef' back on the worklist (if not - // already there), as its replacement with a local/private memref - // has reduced dependences on 'oldMemRef' which may have created - // new fusion opportunities. - if (mdg->outEdges.count(srcNode->id) > 0) { - SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = - mdg->outEdges[srcNode->id]; - for (auto &outEdge : oldOutEdges) { - if (outEdge.value == memref && - worklistSet.count(outEdge.id) == 0) { - worklist.push_back(outEdge.id); - worklistSet.insert(outEdge.id); - } + // Collect dst loop stats after memref privatization transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstAffineForOp.getOperation()); + + // Add new load ops to current Node load op list 'loads' to continue + // fusing based on new operands. + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + // NOTE: Change 'loads' to a hash set in case efficiency is an + // issue. We still use a vector since it's expected to be small. + if (!llvm::is_contained(loads, loadOpInst)) + loads.push_back(loadOpInst); + } + // Clear visited memrefs after fusion so that previously visited src + // nodes are considered for fusion again in the context of the new + // fused node. + // TODO: This shouldn't be necessary if we visited candidates in the + // dependence graph in post-order or once we fully support multi-store + // producers. Currently, in a multi-store producer scenario such as + // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing + // edges. However, after fusing B+C, A has a single outgoing edge and + // can be fused if we revisit it in the context of the new fused B+C + // node. + visitedMemrefs.clear(); + + // Clear and add back loads and stores. + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and if it does not write to a memref which escapes the + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been + // fused into 'dstNode' and write region of 'dstNode' covers the write + // region of 'srcNode', and 'srcNode' has no other users so it is safe + // to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { + mdg->removeNode(srcNode->id); + srcNode->op->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created new + // fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); } } } @@ -1759,6 +1655,8 @@ public: void fuseWithSiblingNodes(Node *dstNode) { DenseSet<unsigned> visitedSibNodeIds; std::pair<unsigned, Value> idAndMemref; + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; Value memref = idAndMemref.second; @@ -1791,31 +1689,53 @@ public: SmallVector<Operation *, 2> dstLoadOpInsts; dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); - // Gather 'dstNode' store ops to 'memref'. - SmallVector<Operation *, 2> dstStoreOpInsts; - dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); + unsigned dstLoopDepthTest = dstLoopIVs.size(); + auto sibAffineForOp = cast<AffineForOp>(sibNode->op); + + // Compute loop depth and slice union for fusion. + SmallVector<ComputationSliceState, 8> depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + unsigned maxLegalFusionDepth = 0; + FusionStrategy strategy(FusionStrategy::Sibling, memref); + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + FusionResult result = mlir::canFuseLoops( + sibAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + + if (result.value == FusionResult::Success) + maxLegalFusionDepth = i; + } - // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, &bestDstLoopDepth, - maximalFusion, computeToleranceThreshold)) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; + unsigned bestDstLoopDepth = dstLoopDepthTest; + if (!maximalFusion) { + // Check if fusion would be profitable. + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) + continue; + } + + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Fusion depth has no computed slice union"); + // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest != nullptr) { - auto dstForInst = cast<AffineForOp>(dstNode->op); - // Update operation position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getOperation()) { - dstForInst.getOperation()->moveBefore(insertPointInst); - } - // Update data dependence graph state post fusion. - updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); + mlir::fuseLoops(sibAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + auto dstForInst = cast<AffineForOp>(dstNode->op); + // Update operation position of fused loop nest (if needed). + if (insertPointInst != dstForInst.getOperation()) { + dstForInst.getOperation()->moveBefore(insertPointInst); } + // Update data dependence graph state post fusion. + updateStateAfterSiblingFusion(sibNode, dstNode); } } @@ -1943,19 +1863,12 @@ public: return false; } - void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, - Node *dstNode) { + /// Update data dependence graph state to reflect sibling fusion of 'sibNode' + /// into 'dstNode'. + void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. mdg->updateEdges(sibNode->id, dstNode->id); - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast<AffineForOp>(dstNode->op); LoopNestStateCollector dstLoopCollector; |