aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/LoopFusion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/LoopFusion.cpp')
-rw-r--r--mlir/lib/Transforms/LoopFusion.cpp15
1 files changed, 8 insertions, 7 deletions
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 49bd52d..955230d 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -70,19 +70,20 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace,
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not an IfInst was encountered in the loop nest.
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
struct LoopNestStateCollector {
SmallVector<AffineForOp, 4> forOps;
SmallVector<Operation *, 4> loadOpInsts;
SmallVector<Operation *, 4> storeOpInsts;
- bool hasNonForRegion = false;
+ bool hasNonAffineRegionOp = false;
void collect(Operation *opToWalk) {
opToWalk->walk([&](Operation *op) {
if (isa<AffineForOp>(op))
forOps.push_back(cast<AffineForOp>(op));
- else if (op->getNumRegions() != 0)
- hasNonForRegion = true;
+ else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
+ hasNonAffineRegionOp = true;
else if (isa<AffineReadOpInterface>(op))
loadOpInsts.push_back(op);
else if (isa<AffineWriteOpInterface>(op))
@@ -744,9 +745,9 @@ bool MemRefDependenceGraph::init(FuncOp f) {
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.collect(&op);
- // Return false if a non 'affine.for' region was found (not currently
- // supported).
- if (collector.hasNonForRegion)
+ // Return false if a region holding op other than 'affine.for' and
+ // 'affine.if' was found (not currently supported).
+ if (collector.hasNonAffineRegionOp)
return false;
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {