aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/Utils.cpp5
-rw-r--r--mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp4
-rw-r--r--mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp8
-rw-r--r--mlir/lib/Interfaces/SideEffects.cpp67
-rw-r--r--mlir/lib/TableGen/SideEffects.cpp12
-rw-r--r--mlir/lib/Transforms/CSE.cpp16
-rw-r--r--mlir/lib/Transforms/LoopInvariantCodeMotion.cpp16
-rw-r--r--mlir/lib/Transforms/Utils/FoldUtils.cpp6
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp15
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp6
12 files changed, 119 insertions, 42 deletions
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 14635a1..7b3cd58 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -974,11 +974,12 @@ void mlir::getSequentialLoops(AffineForOp forOp,
bool mlir::isLoopParallel(AffineForOp forOp) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
- auto walkResult = forOp.walk([&](Operation *opInst) {
+ auto walkResult = forOp.walk([&](Operation *opInst) -> WalkResult {
if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
loadAndStoreOpInsts.push_back(opInst);
else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
- !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect())
+ !isa<AffineIfOp>(opInst) &&
+ !MemoryEffectOpInterface::hasNoEffect(opInst))
return WalkResult::interrupt();
return WalkResult::advance();
diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
index 293d935..9ba3f40 100644
--- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
+++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp
@@ -797,7 +797,9 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
Operation *clone = rewriter.clone(*op, cloningMap);
cloningMap.map(op->getResults(), clone->getResults());
// Check for side effects.
- seenSideeffects |= !clone->hasNoSideEffect();
+ // TODO: Handle region side effects properly.
+ seenSideeffects |= !MemoryEffectOpInterface::hasNoEffect(clone) ||
+ clone->getNumRegions() != 0;
// If we are no longer in the innermost scope, sideeffects are disallowed.
if (seenSideeffects && leftNestingScope)
return matchFailure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index de29848..3274abd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -491,9 +491,7 @@ static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
op.erase();
});
f.walk([](LinalgOp op) {
- if (!op.getOperation()->hasNoSideEffect())
- return;
- if (op.getOperation()->use_empty())
+ if (isOpTriviallyDead(op))
op.erase();
});
}
diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
index 98f8313..b84cfa5 100644
--- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp
@@ -132,13 +132,13 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
void mlir::loop::naivelyFuseParallelOps(Region &region) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
- for (auto &block : region.getBlocks()) {
+ for (auto &block : region) {
SmallVector<SmallVector<ParallelOp, 8>, 1> ploopChains{{}};
// Not using `walk()` to traverse only top-level parallel loops and also
// make sure that there are no side-effecting ops between the parallel
// loops.
bool noSideEffects = true;
- for (auto &op : block.getOperations()) {
+ for (auto &op : block) {
if (auto ploop = dyn_cast<ParallelOp>(op)) {
if (noSideEffects) {
ploopChains.back().push_back(ploop);
@@ -148,7 +148,9 @@ void mlir::loop::naivelyFuseParallelOps(Region &region) {
}
continue;
}
- noSideEffects &= op.hasNoSideEffect();
+ // TODO: Handle region side effects properly.
+ noSideEffects &=
+ MemoryEffectOpInterface::hasNoEffect(&op) && op.getNumRegions() == 0;
}
for (ArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
diff --git a/mlir/lib/Interfaces/SideEffects.cpp b/mlir/lib/Interfaces/SideEffects.cpp
index da43239..53406c6 100644
--- a/mlir/lib/Interfaces/SideEffects.cpp
+++ b/mlir/lib/Interfaces/SideEffects.cpp
@@ -25,3 +25,70 @@ bool MemoryEffects::Effect::classof(const SideEffects::Effect *effect) {
return isa<Allocate>(effect) || isa<Free>(effect) || isa<Read>(effect) ||
isa<Write>(effect);
}
+
+//===----------------------------------------------------------------------===//
+// SideEffect Utilities
+//===----------------------------------------------------------------------===//
+
+bool mlir::isOpTriviallyDead(Operation *op) {
+ return op->use_empty() && wouldOpBeTriviallyDead(op);
+}
+
+/// Internal implementation of `mlir::wouldOpBeTriviallyDead` that also
+/// considers terminator operations as dead if they have no side effects. This
+/// allows for marking region operations as trivially dead without always being
+/// conservative of terminators.
+static bool wouldOpBeTriviallyDeadImpl(Operation *rootOp) {
+ // The set of operations to consider when checking for side effects.
+ SmallVector<Operation *, 1> effectingOps(1, rootOp);
+ while (!effectingOps.empty()) {
+ Operation *op = effectingOps.pop_back_val();
+
+ // If the operation has recursive effects, push all of the nested operations
+ // on to the stack to consider.
+ bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveSideEffects>();
+ if (hasRecursiveEffects) {
+ for (Region &region : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto &nestedOp : block)
+ effectingOps.push_back(&nestedOp);
+ }
+ }
+ }
+
+ // If the op has memory effects, try to characterize them to see if the op
+ // is trivially dead here.
+ if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+ // Check to see if this op either has no effects, or only allocates/reads
+ // memory.
+ SmallVector<MemoryEffects::EffectInstance, 1> effects;
+ effectInterface.getEffects(effects);
+ if (!llvm::all_of(effects, [](const auto &it) {
+ return isa<MemoryEffects::Read>(it.getEffect()) ||
+ isa<MemoryEffects::Allocate>(it.getEffect());
+ })) {
+ return false;
+ }
+ continue;
+
+ // Otherwise, if the op has recursive side effects we can treat the
+ // operation itself as having no effects.
+ } else if (hasRecursiveEffects) {
+ continue;
+ }
+
+ // If there were no effect interfaces, we treat this op as conservatively
+ // having effects.
+ return false;
+ }
+
+ // If we get here, none of the operations had effects that prevented marking
+ // 'op' as dead.
+ return true;
+}
+
+bool mlir::wouldOpBeTriviallyDead(Operation *op) {
+ if (!op->isKnownNonTerminator())
+ return false;
+ return wouldOpBeTriviallyDeadImpl(op);
+}
diff --git a/mlir/lib/TableGen/SideEffects.cpp b/mlir/lib/TableGen/SideEffects.cpp
index 0b334b8..7fbeffa 100644
--- a/mlir/lib/TableGen/SideEffects.cpp
+++ b/mlir/lib/TableGen/SideEffects.cpp
@@ -20,12 +20,8 @@ StringRef SideEffect::getName() const {
return def->getValueAsString("effect");
}
-StringRef SideEffect::getBaseName() const {
- return def->getValueAsString("baseEffect");
-}
-
-StringRef SideEffect::getInterfaceTrait() const {
- return def->getValueAsString("interfaceTrait");
+StringRef SideEffect::getBaseEffectName() const {
+ return def->getValueAsString("baseEffectName");
}
StringRef SideEffect::getResource() const {
@@ -46,6 +42,10 @@ Operator::var_decorator_range SideEffectTrait::getEffects() const {
return {listInit->begin(), listInit->end()};
}
+StringRef SideEffectTrait::getBaseEffectName() const {
+ return def->getValueAsString("baseEffectName");
+}
+
bool SideEffectTrait::classof(const OpTrait *t) {
return t->getDef().isSubClassOf("SideEffectsTraitBase");
}
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 3a76594..42ba715 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -127,6 +127,13 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
if (op->isKnownTerminator())
return failure();
+ // If the operation is already trivially dead just add it to the erase list.
+ if (isOpTriviallyDead(op)) {
+ opsToErase.push_back(op);
+ ++numDCE;
+ return success();
+ }
+
// Don't simplify operations with nested blocks. We don't currently model
// equality comparisons correctly among other things. It is also unclear
// whether we would want to CSE such operations.
@@ -135,16 +142,9 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
// TODO(riverriddle) We currently only eliminate non side-effecting
// operations.
- if (!op->hasNoSideEffect())
+ if (!MemoryEffectOpInterface::hasNoEffect(op))
return failure();
- // If the operation is already trivially dead just add it to the erase list.
- if (op->use_empty()) {
- opsToErase.push_back(op);
- ++numDCE;
- return success();
- }
-
// Look for an existing definition for the operation.
if (auto *existing = knownValues.lookup(op)) {
// If we find one then replace all uses of the current operation with the
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index a452e33..7300948 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -49,16 +49,18 @@ static bool canBeHoisted(Operation *op,
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (!memInterface.hasNoEffect())
return false;
- } else if (!op->hasNoSideEffect() &&
- !op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
+ // If the operation doesn't have side effects and it doesn't recursively
+ // have side effects, it can always be hoisted.
+ if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>())
+ return true;
+
+ // Otherwise, if the operation doesn't provide the memory effect interface
+ // and it doesn't have recursive side effects we treat it conservatively as
+ // side-effecting.
+ } else if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
return false;
}
- // If the operation doesn't have side effects and it doesn't recursively
- // have side effects, it can always be hoisted.
- if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>())
- return true;
-
// Recurse into the regions for this op and check whether the contained ops
// can be hoisted.
for (auto &region : op->getRegions()) {
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index f374d38..66535ec 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -126,6 +126,12 @@ void OperationFolder::notifyRemoval(Operation *op) {
referencedDialects.erase(it);
}
+/// Clear out any constants cached inside of the folder.
+void OperationFolder::clear() {
+ foldScopes.clear();
+ referencedDialects.clear();
+}
+
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult OperationFolder::tryToFold(
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index f9a9be5..e40a4d9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -10,9 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffects.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
@@ -162,11 +161,8 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
if (op == nullptr)
continue;
- // If the operation has no side effects, and no users, then it is
- // trivially dead - remove it.
- if (op->isKnownNonTerminator() && op->hasNoSideEffect() &&
- op->use_empty()) {
- // Be careful to update bookkeeping.
+ // If the operation is trivially dead - remove it.
+ if (isOpTriviallyDead(op)) {
notifyOperationRemoved(op);
op->erase();
continue;
@@ -204,7 +200,10 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
// After applying patterns, make sure that the CFG of each of the regions is
// kept up to date.
- changed |= succeeded(simplifyRegions(regions));
+ if (succeeded(simplifyRegions(regions))) {
+ folder.clear();
+ changed = true;
+ }
} while (changed && ++i < maxIterations);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return !changed;
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index a02c71a..7095e55 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -887,7 +887,7 @@ static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) {
}
// Skip if op has side effects.
// TODO(ntv): loads to immutable memory regions are ok.
- if (!op.hasNoSideEffect()) {
+ if (!MemoryEffectOpInterface::hasNoEffect(&op)) {
status = failure();
continue;
}
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 78bb609b..162091c 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffects.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
@@ -196,9 +197,8 @@ static bool isOpIntrinsicallyLive(Operation *op) {
if (!op->isKnownNonTerminator())
return true;
// If the op has a side effect, we treat it as live.
- if (!op->hasNoSideEffect())
- return true;
- return false;
+ // TODO: Properly handle region side effects.
+ return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
}
static void propagateLiveness(Region &region, LiveMap &liveMap);