diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 519 |
1 files changed, 216 insertions, 303 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index ed79be0..6716260 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref, srcLoads->swap(srcLoadsToKeep); } -// Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) { - unsigned numOps = ops.size(); - assert(numOps > 0); - - std::vector<SmallVector<AffineForOp, 4>> loops(numOps); - unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); - for (unsigned i = 0; i < numOps; ++i) { - getLoopIVs(*ops[i], &loops[i]); - loopDepthLimit = - std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); - } - - unsigned loopDepth = 0; - for (unsigned d = 0; d < loopDepthLimit; ++d) { - unsigned i; - for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) - break; - } - if (i != numOps) - break; - ++loopDepth; - } - return loopDepth; -} - -// Returns the maximum loop depth at which no dependences between 'loadOpInsts' -// and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts, - ArrayRef<Operation *> storeOpInsts) { - // Merge loads and stores into the same array. - SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); - ops.append(storeOpInsts.begin(), storeOpInsts.end()); - - // Compute the innermost common loop depth for loads and stores. - unsigned loopDepth = getInnermostCommonLoopDepth(ops); - - // Return common loop depth for loads if there are no store ops. - if (storeOpInsts.empty()) - return loopDepth; - - // Check dependences on all pairs of ops in 'ops' and store the minimum - // loop depth at which a dependence is satisfied. - for (unsigned i = 0, e = ops.size(); i < e; ++i) { - auto *srcOpInst = ops[i]; - MemRefAccess srcAccess(srcOpInst); - for (unsigned j = 0; j < e; ++j) { - auto *dstOpInst = ops[j]; - MemRefAccess dstAccess(dstOpInst); - - unsigned numCommonLoops = - getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); - for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - FlatAffineConstraints dependenceConstraints; - // TODO: Cache dependence analysis results, check cache here. - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); - if (hasDependence(result)) { - // Store minimum loop depth and break because we want the min 'd' at - // which there is a dependence. - loopDepth = std::min(loopDepth, d - 1); - break; - } - } - } - } - return loopDepth; -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -1077,14 +1006,16 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // The argument 'srcStoreOpInst' is used to calculate the storage reduction on // the memref being produced and consumed, which is an input to the cost model. // For producer-consumer fusion, 'srcStoreOpInst' will be the same as -// 'srcOpInst', as we are slicing w.r.t to that producer. -// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which -// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' -// will be the unique store op in the src node, which will be used to check -// that the write region is the same after input-reuse fusion. -// Returns true if it is profitable to fuse the candidate loop nests. Returns -// false otherwise. `dstLoopDepth` is set to the most profitable depth at which -// to materialize the source loop nest slice. +// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse +// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the +// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the +// unique store op in the src node, which will be used to check that the write +// region is the same after input-reuse fusion. Computation slices are provided +// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which +// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is +// profitable to fuse the candidate loop nests. Returns false otherwise. +// `dstLoopDepth` is set to the most profitable depth at which to materialize +// the source loop nest slice. // The profitability model executes the following steps: // *) Computes the backward computation slice at 'srcOpInst'. This // computation slice of the loop nest surrounding 'srcOpInst' is @@ -1112,9 +1043,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // is lower. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, ArrayRef<Operation *> dstLoadOpInsts, - ArrayRef<Operation *> dstStoreOpInsts, - ComputationSliceState *sliceState, - unsigned *dstLoopDepth, bool maximalFusion, + ArrayRef<ComputationSliceState> depthSliceUnions, + unsigned maxLegalFusionDepth, + unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; @@ -1124,10 +1055,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, }; }); + if (maxLegalFusionDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); + return false; + } + // Compute cost of sliced and unsliced src loop nest. SmallVector<AffineForOp, 4> srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); - unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; @@ -1142,19 +1077,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) return false; - // Compute the maximum loop depth at which we can can insert the src slice - // and still satisfy dest loop nest dependences, for producer-consumer fusion. - unsigned maxDstLoopDepth = - (srcOpInst == srcStoreOpInst) - ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) - : dstLoopIVs.size(); - if (maxDstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); - return false; - } - // Search for min cost value for 'dstLoopDepth'. At each value of - // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst @@ -1163,8 +1087,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, double maxStorageReduction = 0.0; Optional<uint64_t> sliceMemEstimate = None; - SmallVector<ComputationSliceState, 4> sliceStates; - sliceStates.resize(maxDstLoopDepth); // The best loop depth at which to materialize the slice. Optional<unsigned> bestDstLoopDepth = None; @@ -1190,21 +1112,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // Evaluate all depth choices for materializing the slice in the destination // loop nest. - for (unsigned i = maxDstLoopDepth; i >= 1; --i) { - // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. - if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, - /*loopDepth=*/i, - /*numCommonLoops=*/0, - /*isBackwardSlice=*/true, - &sliceStates[i - 1]))) { - LLVM_DEBUG(llvm::dbgs() - << "computeSliceUnion failed for loopDepth: " << i << "\n"); + for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { + // Skip slice union if it wasn't computed for this depth. + if (depthSliceUnions[i - 1].isEmpty()) continue; - } int64_t fusedLoopNestComputeCost; if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], - dstLoopNestStats, &sliceStates[i - 1], + dstLoopNestStats, depthSliceUnions[i - 1], &fusedLoopNestComputeCost)) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; @@ -1216,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1; // Determine what the slice write MemRefRegion would be, if the src loop - // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop - // nest at loop depth 'i' + // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst + // loop nest at loop depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1]))) { + &depthSliceUnions[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); @@ -1269,8 +1184,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // (as per computeToleranceThreshold), we will simply pick the one that // reduces the intermediary size the most. if ((storageReduction > maxStorageReduction) && - (maximalFusion || - (additionalComputeFraction < computeToleranceThreshold))) { + (additionalComputeFraction < computeToleranceThreshold)) { maxStorageReduction = storageReduction; bestDstLoopDepth = i; minFusedLoopNestComputeCost = fusedLoopNestComputeCost; @@ -1278,10 +1192,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, } } - // A simple cost model: fuse if it reduces the memory footprint. If - // -maximal-fusion is set, fuse nevertheless. + // A simple cost model: fuse if it reduces the memory footprint. - if (!maximalFusion && !bestDstLoopDepth.hasValue()) { + if (!bestDstLoopDepth.hasValue()) { LLVM_DEBUG( llvm::dbgs() << "All fusion choices involve more than the threshold amount of " @@ -1310,33 +1223,30 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, Optional<double> storageReduction = None; - if (!maximalFusion) { - if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { - LLVM_DEBUG( - llvm::dbgs() - << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); - return false; - } + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } - auto srcMemSizeVal = srcMemSize.getValue(); - auto dstMemSizeVal = dstMemSize.getValue(); + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); - assert(sliceMemEstimate.hasValue() && "expected value"); - auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + assert(sliceMemEstimate.hasValue() && "expected value"); + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); - LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" - << " dst mem: " << dstMemSizeVal << "\n" - << " fused mem: " << fusedMem << "\n" - << " slice mem: " << sliceMemEstimate << "\n"); + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); - if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { - LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); - return false; - } - storageReduction = - 100.0 * - (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); + if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); double additionalComputeFraction = 100.0 * (minFusedLoopNestComputeCost / @@ -1355,24 +1265,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, llvm::dbgs() << msg.str(); }); - // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; - sliceState->lbs = bestSliceState->lbs; - sliceState->ubs = bestSliceState->ubs; - sliceState->lbOperands = bestSliceState->lbOperands; - sliceState->ubOperands = bestSliceState->ubOperands; - - // Canonicalize slice bound affine maps. - for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (sliceState->lbs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->lbs[i], - &sliceState->lbOperands[i]); - } - if (sliceState->ubs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->ubs[i], - &sliceState->ubOperands[i]); - } - } return true; } @@ -1592,138 +1484,142 @@ public: if (insertPointInst == nullptr) continue; + auto srcAffineForOp = cast<AffineForOp>(srcNode->op); + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + // Compute the innermost common loop depth for dstNode loads/stores. - SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(), - dstNode->loads.end()); - dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); - unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + SmallVector<Operation *, 2> dstMemrefOps; + for (Operation *op : dstNode->loads) + if (cast<AffineReadOpInterface>(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + for (Operation *op : dstNode->stores) + if (cast<AffineWriteOpInterface>(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); + // Check the feasibility of fusing src loop nest into dst loop nest // at loop depths in range [1, dstLoopDepthTest]. - // TODO: Use slice union computation and union of memref - // read/write regions to cost model and fusion. - bool canFuse = false; + unsigned maxLegalFusionDepth = 0; + SmallVector<ComputationSliceState, 8> depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { - ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops( - cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op), - /*dstLoopDepth=*/i, &sliceUnion); + srcAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + if (result.value == FusionResult::Success) - canFuse = true; + maxLegalFusionDepth = i; } - // Skip if fusion is not feasible at all loop depths. - if (!canFuse) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; - // Gather 'dstNode' store ops to 'memref'. - SmallVector<Operation *, 2> dstStoreOpInsts; - for (auto *storeOpInst : dstNode->stores) - if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref) - dstStoreOpInsts.push_back(storeOpInst); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; - // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, - &bestDstLoopDepth, maximalFusion, - computeToleranceThreshold)) + // Check if fusion would be profitable. We skip profitability analysis + // for maximal fusion since we already know the maximal legal depth to + // fuse. + unsigned bestDstLoopDepth = maxLegalFusionDepth; + if (!maximalFusion && + !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) continue; + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Missing slice union for depth"); + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest) { - LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" - << *sliceLoopNest.getOperation() << "\n"); - // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = cast<AffineForOp>(dstNode->op); - if (insertPointInst != dstAffineForOp.getOperation()) { - dstAffineForOp.getOperation()->moveBefore(insertPointInst); - } - // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref, - createPrivateMemref); - - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - if (createPrivateMemref) { - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector<Operation *, 4> storesForMemref; - for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == - memref) - storesForMemref.push_back(storeOpInst); - } - // TODO: Use union of memref write regions to compute - // private memref footprint. - auto newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef.getDefiningOp()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + fuseLoops(srcAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + LLVM_DEBUG(llvm::dbgs() + << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":\n" + << dstAffineForOp << "\n"); + + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + if (insertPointInst != dstAffineForOp.getOperation()) + dstAffineForOp.getOperation()->moveBefore(insertPointInst); + + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); + + // Collect slice loop stats. + LoopNestStateCollector dstForCollector; + dstForCollector.collect(dstAffineForOp); + if (createPrivateMemref) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector<Operation *, 4> storesForMemref; + for (auto *storeOpInst : dstForCollector.storeOpInsts) { + if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == + memref) + storesForMemref.push_back(storeOpInst); } + // TODO: Use union of memref write regions to compute + // private memref footprint. + auto newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + } - // Collect dst loop stats after memref privatization transformation. - LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getOperation()); - - // Add new load ops to current Node load op list 'loads' to - // continue fusing based on new operands. - for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - // NOTE: Change 'loads' to a hash set in case efficiency is an - // issue. We still use a vector since it's expected to be small. - if (!llvm::is_contained(loads, loadOpInst)) - loads.push_back(loadOpInst); - } - // Clear visited memrefs after fusion so that previously visited src - // nodes are considered for fusion again in the context of the new - // fused node. - // TODO: This shouldn't be necessary if we visited candidates in the - // dependence graph in post-order or once we fully support - // multi-store producers. Currently, in a multi-store producer - // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the - // multiple outgoing edges. However, after fusing B+C, A has a - // single outgoing edge and can be fused if we revisit it in the - // context of the new fused B+C node. - visitedMemrefs.clear(); - - // Clear and add back loads and stores. - mdg->clearNodeLoadAndStores(dstNode->id); - mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, - dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has outgoing dependence - // edges, and if it does not write to a memref which escapes the - // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has - // been fused into 'dstNode' and write region of 'dstNode' covers - // the write region of 'srcNode', and 'srcNode' has no other users - // so it is safe to remove. - if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { - mdg->removeNode(srcNode->id); - srcNode->op->erase(); - } else { - // Add remaining users of 'oldMemRef' back on the worklist (if not - // already there), as its replacement with a local/private memref - // has reduced dependences on 'oldMemRef' which may have created - // new fusion opportunities. - if (mdg->outEdges.count(srcNode->id) > 0) { - SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = - mdg->outEdges[srcNode->id]; - for (auto &outEdge : oldOutEdges) { - if (outEdge.value == memref && - worklistSet.count(outEdge.id) == 0) { - worklist.push_back(outEdge.id); - worklistSet.insert(outEdge.id); - } + // Collect dst loop stats after memref privatization transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstAffineForOp.getOperation()); + + // Add new load ops to current Node load op list 'loads' to continue + // fusing based on new operands. + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + // NOTE: Change 'loads' to a hash set in case efficiency is an + // issue. We still use a vector since it's expected to be small. + if (!llvm::is_contained(loads, loadOpInst)) + loads.push_back(loadOpInst); + } + // Clear visited memrefs after fusion so that previously visited src + // nodes are considered for fusion again in the context of the new + // fused node. + // TODO: This shouldn't be necessary if we visited candidates in the + // dependence graph in post-order or once we fully support multi-store + // producers. Currently, in a multi-store producer scenario such as + // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing + // edges. However, after fusing B+C, A has a single outgoing edge and + // can be fused if we revisit it in the context of the new fused B+C + // node. + visitedMemrefs.clear(); + + // Clear and add back loads and stores. + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old src loop nest if it no longer has outgoing dependence + // edges, and if it does not write to a memref which escapes the + // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been + // fused into 'dstNode' and write region of 'dstNode' covers the write + // region of 'srcNode', and 'srcNode' has no other users so it is safe + // to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { + mdg->removeNode(srcNode->id); + srcNode->op->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if not + // already there), as its replacement with a local/private memref + // has reduced dependences on 'oldMemRef' which may have created new + // fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); } } } @@ -1759,6 +1655,8 @@ public: void fuseWithSiblingNodes(Node *dstNode) { DenseSet<unsigned> visitedSibNodeIds; std::pair<unsigned, Value> idAndMemref; + auto dstAffineForOp = cast<AffineForOp>(dstNode->op); + while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; Value memref = idAndMemref.second; @@ -1791,31 +1689,53 @@ public: SmallVector<Operation *, 2> dstLoadOpInsts; dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); - // Gather 'dstNode' store ops to 'memref'. - SmallVector<Operation *, 2> dstStoreOpInsts; - dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); + unsigned dstLoopDepthTest = dstLoopIVs.size(); + auto sibAffineForOp = cast<AffineForOp>(sibNode->op); + + // Compute loop depth and slice union for fusion. + SmallVector<ComputationSliceState, 8> depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + unsigned maxLegalFusionDepth = 0; + FusionStrategy strategy(FusionStrategy::Sibling, memref); + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + FusionResult result = mlir::canFuseLoops( + sibAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + + if (result.value == FusionResult::Success) + maxLegalFusionDepth = i; + } - // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, &bestDstLoopDepth, - maximalFusion, computeToleranceThreshold)) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; + unsigned bestDstLoopDepth = dstLoopDepthTest; + if (!maximalFusion) { + // Check if fusion would be profitable. + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) + continue; + } + + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Fusion depth has no computed slice union"); + // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest != nullptr) { - auto dstForInst = cast<AffineForOp>(dstNode->op); - // Update operation position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getOperation()) { - dstForInst.getOperation()->moveBefore(insertPointInst); - } - // Update data dependence graph state post fusion. - updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); + mlir::fuseLoops(sibAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + auto dstForInst = cast<AffineForOp>(dstNode->op); + // Update operation position of fused loop nest (if needed). + if (insertPointInst != dstForInst.getOperation()) { + dstForInst.getOperation()->moveBefore(insertPointInst); } + // Update data dependence graph state post fusion. + updateStateAfterSiblingFusion(sibNode, dstNode); } } @@ -1943,19 +1863,12 @@ public: return false; } - void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, - Node *dstNode) { + /// Update data dependence graph state to reflect sibling fusion of 'sibNode' + /// into 'dstNode'. + void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. mdg->updateEdges(sibNode->id, dstNode->id); - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast<AffineForOp>(dstNode->op); LoopNestStateCollector dstLoopCollector; |