aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/LoopFusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp519
1 files changed, 216 insertions, 303 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index ed79be0..6716260 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref,
srcLoads->swap(srcLoadsToKeep);
}
-// Returns the innermost common loop depth for the set of operations in 'ops'.
-static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
- unsigned numOps = ops.size();
- assert(numOps > 0);
-
- std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
- unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
- for (unsigned i = 0; i < numOps; ++i) {
- getLoopIVs(*ops[i], &loops[i]);
- loopDepthLimit =
- std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
- }
-
- unsigned loopDepth = 0;
- for (unsigned d = 0; d < loopDepthLimit; ++d) {
- unsigned i;
- for (i = 1; i < numOps; ++i) {
- if (loops[i - 1][d] != loops[i][d])
- break;
- }
- if (i != numOps)
- break;
- ++loopDepth;
- }
- return loopDepth;
-}
-
-// Returns the maximum loop depth at which no dependences between 'loadOpInsts'
-// and 'storeOpInsts' are satisfied.
-static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
- ArrayRef<Operation *> storeOpInsts) {
- // Merge loads and stores into the same array.
- SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
- ops.append(storeOpInsts.begin(), storeOpInsts.end());
-
- // Compute the innermost common loop depth for loads and stores.
- unsigned loopDepth = getInnermostCommonLoopDepth(ops);
-
- // Return common loop depth for loads if there are no store ops.
- if (storeOpInsts.empty())
- return loopDepth;
-
- // Check dependences on all pairs of ops in 'ops' and store the minimum
- // loop depth at which a dependence is satisfied.
- for (unsigned i = 0, e = ops.size(); i < e; ++i) {
- auto *srcOpInst = ops[i];
- MemRefAccess srcAccess(srcOpInst);
- for (unsigned j = 0; j < e; ++j) {
- auto *dstOpInst = ops[j];
- MemRefAccess dstAccess(dstOpInst);
-
- unsigned numCommonLoops =
- getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
- for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
- FlatAffineConstraints dependenceConstraints;
- // TODO: Cache dependence analysis results, check cache here.
- DependenceResult result = checkMemrefAccessDependence(
- srcAccess, dstAccess, d, &dependenceConstraints,
- /*dependenceComponents=*/nullptr);
- if (hasDependence(result)) {
- // Store minimum loop depth and break because we want the min 'd' at
- // which there is a dependence.
- loopDepth = std::min(loopDepth, d - 1);
- break;
- }
- }
- }
- }
- return loopDepth;
-}
-
// Sinks all sequential loops to the innermost levels (while preserving
// relative order among them) and moves all parallel loops to the
// outermost (while again preserving relative order among them).
@@ -1077,14 +1006,16 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
// the memref being produced and consumed, which is an input to the cost model.
// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer.
-// For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
-// reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
-// will be the unique store op in the src node, which will be used to check
-// that the write region is the same after input-reuse fusion.
-// Returns true if it is profitable to fuse the candidate loop nests. Returns
-// false otherwise. `dstLoopDepth` is set to the most profitable depth at which
-// to materialize the source loop nest slice.
+// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
+// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
+// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
+// unique store op in the src node, which will be used to check that the write
+// region is the same after input-reuse fusion. Computation slices are provided
+// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
+// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
+// profitable to fuse the candidate loop nests. Returns false otherwise.
+// `dstLoopDepth` is set to the most profitable depth at which to materialize
+// the source loop nest slice.
// The profitability model executes the following steps:
// *) Computes the backward computation slice at 'srcOpInst'. This
// computation slice of the loop nest surrounding 'srcOpInst' is
@@ -1112,9 +1043,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
// is lower.
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
ArrayRef<Operation *> dstLoadOpInsts,
- ArrayRef<Operation *> dstStoreOpInsts,
- ComputationSliceState *sliceState,
- unsigned *dstLoopDepth, bool maximalFusion,
+ ArrayRef<ComputationSliceState> depthSliceUnions,
+ unsigned maxLegalFusionDepth,
+ unsigned *dstLoopDepth,
double computeToleranceThreshold) {
LLVM_DEBUG({
llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
@@ -1124,10 +1055,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
};
});
+ if (maxLegalFusionDepth == 0) {
+ LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n");
+ return false;
+ }
+
// Compute cost of sliced and unsliced src loop nest.
SmallVector<AffineForOp, 4> srcLoopIVs;
getLoopIVs(*srcOpInst, &srcLoopIVs);
- unsigned numSrcLoopIVs = srcLoopIVs.size();
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
@@ -1142,19 +1077,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
return false;
- // Compute the maximum loop depth at which we can can insert the src slice
- // and still satisfy dest loop nest dependences, for producer-consumer fusion.
- unsigned maxDstLoopDepth =
- (srcOpInst == srcStoreOpInst)
- ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
- : dstLoopIVs.size();
- if (maxDstLoopDepth == 0) {
- LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
- return false;
- }
-
// Search for min cost value for 'dstLoopDepth'. At each value of
- // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
+ // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
// of these bounds). Next the union slice bounds are used to calculate
// the cost of the slice and the cost of the slice inserted into the dst
@@ -1163,8 +1087,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
double maxStorageReduction = 0.0;
Optional<uint64_t> sliceMemEstimate = None;
- SmallVector<ComputationSliceState, 4> sliceStates;
- sliceStates.resize(maxDstLoopDepth);
// The best loop depth at which to materialize the slice.
Optional<unsigned> bestDstLoopDepth = None;
@@ -1190,21 +1112,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
// Evaluate all depth choices for materializing the slice in the destination
// loop nest.
- for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
- // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
- if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
- /*loopDepth=*/i,
- /*numCommonLoops=*/0,
- /*isBackwardSlice=*/true,
- &sliceStates[i - 1]))) {
- LLVM_DEBUG(llvm::dbgs()
- << "computeSliceUnion failed for loopDepth: " << i << "\n");
+ for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
+ // Skip slice union if it wasn't computed for this depth.
+ if (depthSliceUnions[i - 1].isEmpty())
continue;
- }
int64_t fusedLoopNestComputeCost;
if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
- dstLoopNestStats, &sliceStates[i - 1],
+ dstLoopNestStats, depthSliceUnions[i - 1],
&fusedLoopNestComputeCost)) {
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
continue;
@@ -1216,11 +1131,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
1;
// Determine what the slice write MemRefRegion would be, if the src loop
- // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
- // nest at loop depth 'i'
+ // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst
+ // loop nest at loop depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
- &sliceStates[i - 1]))) {
+ &depthSliceUnions[i - 1]))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to compute slice write region at loopDepth: " << i
<< "\n");
@@ -1269,8 +1184,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
// (as per computeToleranceThreshold), we will simply pick the one that
// reduces the intermediary size the most.
if ((storageReduction > maxStorageReduction) &&
- (maximalFusion ||
- (additionalComputeFraction < computeToleranceThreshold))) {
+ (additionalComputeFraction < computeToleranceThreshold)) {
maxStorageReduction = storageReduction;
bestDstLoopDepth = i;
minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
@@ -1278,10 +1192,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
}
}
- // A simple cost model: fuse if it reduces the memory footprint. If
- // -maximal-fusion is set, fuse nevertheless.
+ // A simple cost model: fuse if it reduces the memory footprint.
- if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
+ if (!bestDstLoopDepth.hasValue()) {
LLVM_DEBUG(
llvm::dbgs()
<< "All fusion choices involve more than the threshold amount of "
@@ -1310,33 +1223,30 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
Optional<double> storageReduction = None;
- if (!maximalFusion) {
- if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
- LLVM_DEBUG(
- llvm::dbgs()
- << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
- return false;
- }
+ if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " fusion memory benefit cannot be evaluated; NOT fusing.\n");
+ return false;
+ }
- auto srcMemSizeVal = srcMemSize.getValue();
- auto dstMemSizeVal = dstMemSize.getValue();
+ auto srcMemSizeVal = srcMemSize.getValue();
+ auto dstMemSizeVal = dstMemSize.getValue();
- assert(sliceMemEstimate.hasValue() && "expected value");
- auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
+ assert(sliceMemEstimate.hasValue() && "expected value");
+ auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
- LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
- << " dst mem: " << dstMemSizeVal << "\n"
- << " fused mem: " << fusedMem << "\n"
- << " slice mem: " << sliceMemEstimate << "\n");
+ LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n"
+ << " dst mem: " << dstMemSizeVal << "\n"
+ << " fused mem: " << fusedMem << "\n"
+ << " slice mem: " << sliceMemEstimate << "\n");
- if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
- LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
- return false;
- }
- storageReduction =
- 100.0 *
- (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
+ if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
+ LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
+ return false;
}
+ storageReduction =
+ 100.0 *
+ (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
double additionalComputeFraction =
100.0 * (minFusedLoopNestComputeCost /
@@ -1355,24 +1265,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
llvm::dbgs() << msg.str();
});
- // Update return parameter 'sliceState' with 'bestSliceState'.
- ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
- sliceState->lbs = bestSliceState->lbs;
- sliceState->ubs = bestSliceState->ubs;
- sliceState->lbOperands = bestSliceState->lbOperands;
- sliceState->ubOperands = bestSliceState->ubOperands;
-
- // Canonicalize slice bound affine maps.
- for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
- if (sliceState->lbs[i] != AffineMap()) {
- canonicalizeMapAndOperands(&sliceState->lbs[i],
- &sliceState->lbOperands[i]);
- }
- if (sliceState->ubs[i] != AffineMap()) {
- canonicalizeMapAndOperands(&sliceState->ubs[i],
- &sliceState->ubOperands[i]);
- }
- }
return true;
}
@@ -1592,138 +1484,142 @@ public:
if (insertPointInst == nullptr)
continue;
+ auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
// Compute the innermost common loop depth for dstNode loads/stores.
- SmallVector<Operation *, 2> dstOps(dstNode->loads.begin(),
- dstNode->loads.end());
- dstOps.append(dstNode->stores.begin(), dstNode->stores.end());
- unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps);
+ SmallVector<Operation *, 2> dstMemrefOps;
+ for (Operation *op : dstNode->loads)
+ if (cast<AffineReadOpInterface>(op).getMemRef() == memref)
+ dstMemrefOps.push_back(op);
+ for (Operation *op : dstNode->stores)
+ if (cast<AffineWriteOpInterface>(op).getMemRef() == memref)
+ dstMemrefOps.push_back(op);
+ unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);
+
// Check the feasibility of fusing src loop nest into dst loop nest
// at loop depths in range [1, dstLoopDepthTest].
- // TODO: Use slice union computation and union of memref
- // read/write regions to cost model and fusion.
- bool canFuse = false;
+ unsigned maxLegalFusionDepth = 0;
+ SmallVector<ComputationSliceState, 8> depthSliceUnions;
+ depthSliceUnions.resize(dstLoopDepthTest);
+ FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref);
for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
- ComputationSliceState sliceUnion;
FusionResult result = mlir::canFuseLoops(
- cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
- /*dstLoopDepth=*/i, &sliceUnion);
+ srcAffineForOp, dstAffineForOp,
+ /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
if (result.value == FusionResult::Success)
- canFuse = true;
+ maxLegalFusionDepth = i;
}
- // Skip if fusion is not feasible at all loop depths.
- if (!canFuse)
+ // Skip if fusion is not feasible at any loop depths.
+ if (maxLegalFusionDepth == 0)
continue;
- // Gather 'dstNode' store ops to 'memref'.
- SmallVector<Operation *, 2> dstStoreOpInsts;
- for (auto *storeOpInst : dstNode->stores)
- if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
- dstStoreOpInsts.push_back(storeOpInst);
-
- unsigned bestDstLoopDepth;
- mlir::ComputationSliceState sliceState;
- // Check if fusion would be profitable.
- if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
- dstStoreOpInsts, &sliceState,
- &bestDstLoopDepth, maximalFusion,
- computeToleranceThreshold))
+ // Check if fusion would be profitable. We skip profitability analysis
+ // for maximal fusion since we already know the maximal legal depth to
+ // fuse.
+ unsigned bestDstLoopDepth = maxLegalFusionDepth;
+ if (!maximalFusion &&
+ !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts,
+ depthSliceUnions, maxLegalFusionDepth,
+ &bestDstLoopDepth, computeToleranceThreshold))
continue;
+ assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+ assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+ "Missing slice union for depth");
+
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
- auto sliceLoopNest = mlir::insertBackwardComputationSlice(
- srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
- if (sliceLoopNest) {
- LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
- << *sliceLoopNest.getOperation() << "\n");
- // Move 'dstAffineForOp' before 'insertPointInst' if needed.
- auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
- if (insertPointInst != dstAffineForOp.getOperation()) {
- dstAffineForOp.getOperation()->moveBefore(insertPointInst);
- }
- // Update edges between 'srcNode' and 'dstNode'.
- mdg->updateEdges(srcNode->id, dstNode->id, memref,
- createPrivateMemref);
-
- // Collect slice loop stats.
- LoopNestStateCollector sliceCollector;
- sliceCollector.collect(sliceLoopNest.getOperation());
- // Promote single iteration slice loops to single IV value.
- for (auto forOp : sliceCollector.forOps) {
- promoteIfSingleIteration(forOp);
- }
- if (createPrivateMemref) {
- // Create private memref for 'memref' in 'dstAffineForOp'.
- SmallVector<Operation *, 4> storesForMemref;
- for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
- memref)
- storesForMemref.push_back(storeOpInst);
- }
- // TODO: Use union of memref write regions to compute
- // private memref footprint.
- auto newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
- visitedMemrefs.insert(newMemRef);
- // Create new node in dependence graph for 'newMemRef' alloc op.
- unsigned newMemRefNodeId =
- mdg->addNode(newMemRef.getDefiningOp());
- // Add edge from 'newMemRef' node to dstNode.
- mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+ fuseLoops(srcAffineForOp, dstAffineForOp,
+ depthSliceUnions[bestDstLoopDepth - 1]);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Fused src loop " << srcId << " into dst loop " << dstId
+ << " at depth " << bestDstLoopDepth << ":\n"
+ << dstAffineForOp << "\n");
+
+ // Move 'dstAffineForOp' before 'insertPointInst' if needed.
+ if (insertPointInst != dstAffineForOp.getOperation())
+ dstAffineForOp.getOperation()->moveBefore(insertPointInst);
+
+ // Update edges between 'srcNode' and 'dstNode'.
+ mdg->updateEdges(srcNode->id, dstNode->id, memref,
+ createPrivateMemref);
+
+ // Collect slice loop stats.
+ LoopNestStateCollector dstForCollector;
+ dstForCollector.collect(dstAffineForOp);
+ if (createPrivateMemref) {
+ // Create private memref for 'memref' in 'dstAffineForOp'.
+ SmallVector<Operation *, 4> storesForMemref;
+ for (auto *storeOpInst : dstForCollector.storeOpInsts) {
+ if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+ memref)
+ storesForMemref.push_back(storeOpInst);
}
+ // TODO: Use union of memref write regions to compute
+ // private memref footprint.
+ auto newMemRef = createPrivateMemRef(
+ dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
+ fastMemorySpace, localBufSizeThreshold);
+ visitedMemrefs.insert(newMemRef);
+ // Create new node in dependence graph for 'newMemRef' alloc op.
+ unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
+ // Add edge from 'newMemRef' node to dstNode.
+ mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
+ }
- // Collect dst loop stats after memref privatization transformation.
- LoopNestStateCollector dstLoopCollector;
- dstLoopCollector.collect(dstAffineForOp.getOperation());
-
- // Add new load ops to current Node load op list 'loads' to
- // continue fusing based on new operands.
- for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- // NOTE: Change 'loads' to a hash set in case efficiency is an
- // issue. We still use a vector since it's expected to be small.
- if (!llvm::is_contained(loads, loadOpInst))
- loads.push_back(loadOpInst);
- }
- // Clear visited memrefs after fusion so that previously visited src
- // nodes are considered for fusion again in the context of the new
- // fused node.
- // TODO: This shouldn't be necessary if we visited candidates in the
- // dependence graph in post-order or once we fully support
- // multi-store producers. Currently, in a multi-store producer
- // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the
- // multiple outgoing edges. However, after fusing B+C, A has a
- // single outgoing edge and can be fused if we revisit it in the
- // context of the new fused B+C node.
- visitedMemrefs.clear();
-
- // Clear and add back loads and stores.
- mdg->clearNodeLoadAndStores(dstNode->id);
- mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
- dstLoopCollector.storeOpInsts);
- // Remove old src loop nest if it no longer has outgoing dependence
- // edges, and if it does not write to a memref which escapes the
- // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
- // been fused into 'dstNode' and write region of 'dstNode' covers
- // the write region of 'srcNode', and 'srcNode' has no other users
- // so it is safe to remove.
- if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
- mdg->removeNode(srcNode->id);
- srcNode->op->erase();
- } else {
- // Add remaining users of 'oldMemRef' back on the worklist (if not
- // already there), as its replacement with a local/private memref
- // has reduced dependences on 'oldMemRef' which may have created
- // new fusion opportunities.
- if (mdg->outEdges.count(srcNode->id) > 0) {
- SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
- mdg->outEdges[srcNode->id];
- for (auto &outEdge : oldOutEdges) {
- if (outEdge.value == memref &&
- worklistSet.count(outEdge.id) == 0) {
- worklist.push_back(outEdge.id);
- worklistSet.insert(outEdge.id);
- }
+ // Collect dst loop stats after memref privatization transformation.
+ LoopNestStateCollector dstLoopCollector;
+ dstLoopCollector.collect(dstAffineForOp.getOperation());
+
+ // Add new load ops to current Node load op list 'loads' to continue
+ // fusing based on new operands.
+ for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
+ // NOTE: Change 'loads' to a hash set in case efficiency is an
+ // issue. We still use a vector since it's expected to be small.
+ if (!llvm::is_contained(loads, loadOpInst))
+ loads.push_back(loadOpInst);
+ }
+ // Clear visited memrefs after fusion so that previously visited src
+ // nodes are considered for fusion again in the context of the new
+ // fused node.
+ // TODO: This shouldn't be necessary if we visited candidates in the
+ // dependence graph in post-order or once we fully support multi-store
+ // producers. Currently, in a multi-store producer scenario such as
+ // A->B, A->C, B->C, we fail to fuse A+B due to the multiple outgoing
+ // edges. However, after fusing B+C, A has a single outgoing edge and
+ // can be fused if we revisit it in the context of the new fused B+C
+ // node.
+ visitedMemrefs.clear();
+
+ // Clear and add back loads and stores.
+ mdg->clearNodeLoadAndStores(dstNode->id);
+ mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
+ dstLoopCollector.storeOpInsts);
+ // Remove old src loop nest if it no longer has outgoing dependence
+ // edges, and if it does not write to a memref which escapes the
+ // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has been
+ // fused into 'dstNode' and write region of 'dstNode' covers the write
+ // region of 'srcNode', and 'srcNode' has no other users so it is safe
+ // to remove.
+ if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
+ mdg->removeNode(srcNode->id);
+ srcNode->op->erase();
+ } else {
+ // Add remaining users of 'oldMemRef' back on the worklist (if not
+ // already there), as its replacement with a local/private memref
+ // has reduced dependences on 'oldMemRef' which may have created new
+ // fusion opportunities.
+ if (mdg->outEdges.count(srcNode->id) > 0) {
+ SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+ mdg->outEdges[srcNode->id];
+ for (auto &outEdge : oldOutEdges) {
+ if (outEdge.value == memref &&
+ worklistSet.count(outEdge.id) == 0) {
+ worklist.push_back(outEdge.id);
+ worklistSet.insert(outEdge.id);
}
}
}
@@ -1759,6 +1655,8 @@ public:
void fuseWithSiblingNodes(Node *dstNode) {
DenseSet<unsigned> visitedSibNodeIds;
std::pair<unsigned, Value> idAndMemref;
+ auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
+
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
unsigned sibId = idAndMemref.first;
Value memref = idAndMemref.second;
@@ -1791,31 +1689,53 @@ public:
SmallVector<Operation *, 2> dstLoadOpInsts;
dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
- // Gather 'dstNode' store ops to 'memref'.
- SmallVector<Operation *, 2> dstStoreOpInsts;
- dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
-
- unsigned bestDstLoopDepth;
- mlir::ComputationSliceState sliceState;
+ SmallVector<AffineForOp, 4> dstLoopIVs;
+ getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
+ unsigned dstLoopDepthTest = dstLoopIVs.size();
+ auto sibAffineForOp = cast<AffineForOp>(sibNode->op);
+
+ // Compute loop depth and slice union for fusion.
+ SmallVector<ComputationSliceState, 8> depthSliceUnions;
+ depthSliceUnions.resize(dstLoopDepthTest);
+ unsigned maxLegalFusionDepth = 0;
+ FusionStrategy strategy(FusionStrategy::Sibling, memref);
+ for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
+ FusionResult result = mlir::canFuseLoops(
+ sibAffineForOp, dstAffineForOp,
+ /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);
+
+ if (result.value == FusionResult::Success)
+ maxLegalFusionDepth = i;
+ }
- // Check if fusion would be profitable.
- if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
- dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
- maximalFusion, computeToleranceThreshold))
+ // Skip if fusion is not feasible at any loop depths.
+ if (maxLegalFusionDepth == 0)
continue;
+ unsigned bestDstLoopDepth = dstLoopDepthTest;
+ if (!maximalFusion) {
+ // Check if fusion would be profitable.
+ if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
+ depthSliceUnions, maxLegalFusionDepth,
+ &bestDstLoopDepth, computeToleranceThreshold))
+ continue;
+ }
+
+ assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
+ assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+ "Fusion depth has no computed slice union");
+
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
- auto sliceLoopNest = mlir::insertBackwardComputationSlice(
- sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
- if (sliceLoopNest != nullptr) {
- auto dstForInst = cast<AffineForOp>(dstNode->op);
- // Update operation position of fused loop nest (if needed).
- if (insertPointInst != dstForInst.getOperation()) {
- dstForInst.getOperation()->moveBefore(insertPointInst);
- }
- // Update data dependence graph state post fusion.
- updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
+ mlir::fuseLoops(sibAffineForOp, dstAffineForOp,
+ depthSliceUnions[bestDstLoopDepth - 1]);
+
+ auto dstForInst = cast<AffineForOp>(dstNode->op);
+ // Update operation position of fused loop nest (if needed).
+ if (insertPointInst != dstForInst.getOperation()) {
+ dstForInst.getOperation()->moveBefore(insertPointInst);
}
+ // Update data dependence graph state post fusion.
+ updateStateAfterSiblingFusion(sibNode, dstNode);
}
}
@@ -1943,19 +1863,12 @@ public:
return false;
}
- void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
- Node *dstNode) {
+ /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
+ /// into 'dstNode'.
+ void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
// Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
mdg->updateEdges(sibNode->id, dstNode->id);
- // Collect slice loop stats.
- LoopNestStateCollector sliceCollector;
- sliceCollector.collect(sliceLoopNest.getOperation());
- // Promote single iteration slice loops to single IV value.
- for (auto forOp : sliceCollector.forOps) {
- promoteIfSingleIteration(forOp);
- }
-
// Collect dst loop stats after memref privatization transformation.
auto dstForInst = cast<AffineForOp>(dstNode->op);
LoopNestStateCollector dstLoopCollector;