diff options
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r-- | mlir/lib/Transforms/LoopFusion.cpp | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 49bd52d..955230d 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -70,19 +70,20 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace, namespace { // LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not an IfInst was encountered in the loop nest. +// operations, and whether or not a region holding op other than ForOp and IfOp +// was encountered in the loop nest. struct LoopNestStateCollector { SmallVector<AffineForOp, 4> forOps; SmallVector<Operation *, 4> loadOpInsts; SmallVector<Operation *, 4> storeOpInsts; - bool hasNonForRegion = false; + bool hasNonAffineRegionOp = false; void collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (isa<AffineForOp>(op)) forOps.push_back(cast<AffineForOp>(op)); - else if (op->getNumRegions() != 0) - hasNonForRegion = true; + else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op)) + hasNonAffineRegionOp = true; else if (isa<AffineReadOpInterface>(op)) loadOpInsts.push_back(op); else if (isa<AffineWriteOpInterface>(op)) @@ -744,9 +745,9 @@ bool MemRefDependenceGraph::init(FuncOp f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&op); - // Return false if a non 'affine.for' region was found (not currently - // supported). - if (collector.hasNonForRegion) + // Return false if a region holding op other than 'affine.for' and + // 'affine.if' was found (not currently supported). + if (collector.hasNonAffineRegionOp) return false; Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { |