aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorValentin Clement <clementval@gmail.com>2022-04-07 10:06:50 +0200
committerValentin Clement <clementval@gmail.com>2022-04-07 10:08:55 +0200
commit02da9643506dee4a82353e0f911513279634d846 (patch)
tree51889f0a2150283dee322514730d868e1a538dd6 /mlir
parent842d0bf93176b9cb1e0d6894a2bbfb32ad33ebb8 (diff)
downloadllvm-02da9643506dee4a82353e0f911513279634d846.zip
llvm-02da9643506dee4a82353e0f911513279634d846.tar.gz
llvm-02da9643506dee4a82353e0f911513279634d846.tar.bz2
[mlir][CSE] Remove duplicated operations with MemRead side-effect
This patch enhances the CSE pass to deal with simple cases of duplicated operations with MemoryEffects. It allows the CSE pass to remove safely duplicate operations with the MemoryEffects::Read that have no other side-effecting operations in between. Other MemoryEffects::Read operation are allowed. The use case is pretty simple so far so we can build on top of it to add more features. This patch is also meant to avoid a dedicated CSE pass in FIR and was brought together afetr discussion on https://reviews.llvm.org/D112711. It does not currently cover the full range of use cases described in https://reviews.llvm.org/D112711 but the idea is to gradually enhance the MLIR CSE pass to handle common use cases that can be used by other dialects. This patch takes advantage of the new CSE capabilities in Fir. Reviewed By: mehdi_amini, rriddle, schweitz Differential Revision: https://reviews.llvm.org/D122801
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Transforms/CSE.cpp151
-rw-r--r--mlir/test/Examples/Toy/Ch5/affine-lowering.mlir3
-rw-r--r--mlir/test/Examples/Toy/Ch6/affine-lowering.mlir3
-rw-r--r--mlir/test/Examples/Toy/Ch7/affine-lowering.mlir3
-rw-r--r--mlir/test/Transforms/cse.mlir45
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td8
6 files changed, 172 insertions, 41 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 0570c91..080e393 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -60,6 +60,14 @@ struct CSE : public CSEBase<CSE> {
using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
SimpleOperationInfo, AllocatorTy>;
+ /// Cache holding MemoryEffects information between two operations. The first
+ /// operation is stored has the key. The second operation is stored inside a
+ /// pair in the value. The pair also hold the MemoryEffects between those
+ /// two operations. If the MemoryEffects is nullptr then we assume there is
+ /// no operation with MemoryEffects::Write between the two operations.
+ using MemEffectsCache =
+ DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
/// Represents a single entry in the depth first traversal of a CFG.
struct CFGStackNode {
CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
@@ -85,12 +93,94 @@ struct CSE : public CSEBase<CSE> {
void runOnOperation() override;
private:
+ void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance);
+
+ /// Check if there is side-effecting operations other than the given effect
+ /// between the two operations.
+ bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
/// Operations marked as dead and to be erased.
std::vector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
+ MemEffectsCache memEffectsCache;
};
} // namespace
+void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion. We can only replace an operand in
+ // an operation if it has not been visited yet.
+ if (hasSSADominance) {
+ // If the region has SSA dominance, then we are guaranteed to have not
+ // visited any use of the current operation.
+ op->replaceAllUsesWith(existing);
+ opsToErase.push_back(op);
+ } else {
+ // When the region does not have SSA dominance, we need to check if we
+ // have visited a use before replacing any use.
+ for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
+ std::get<0>(it).replaceUsesWithIf(
+ std::get<1>(it), [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ });
+ }
+
+ // There may be some remaining uses of the operation.
+ if (op->use_empty())
+ opsToErase.push_back(op);
+ }
+
+ // If the existing operation has an unknown location and the current
+ // operation doesn't, then set the existing op's location to that of the
+ // current op.
+ if (existing->getLoc().isa<UnknownLoc>() && !op->getLoc().isa<UnknownLoc>())
+ existing->setLoc(op->getLoc());
+
+ ++numCSE;
+}
+
+bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) {
+ assert(fromOp->getBlock() == toOp->getBlock());
+ assert(
+ isa<MemoryEffectOpInterface>(fromOp) &&
+ cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
+ isa<MemoryEffectOpInterface>(toOp) &&
+ cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+ Operation *nextOp = fromOp->getNextNode();
+ auto result =
+ memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
+ if (result.second) {
+ auto memEffectsCachePair = result.first->second;
+ if (memEffectsCachePair.second == nullptr) {
+ // No MemoryEffects::Write has been detected until the cached operation.
+ // Continue looking from the cached operation to toOp.
+ nextOp = memEffectsCachePair.first;
+ } else {
+ // MemoryEffects::Write has been detected before so there is no need to
+ // check further.
+ return true;
+ }
+ }
+ while (nextOp && nextOp != toOp) {
+ auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
+ // TODO: Do we need to handle other effects generically?
+ // If the operation does not implement the MemoryEffectOpInterface we
+ // conservatively assumes it writes.
+ if ((nextOpMemEffects &&
+ nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
+ !nextOpMemEffects) {
+ result.first->second =
+ std::make_pair(nextOp, MemoryEffects::Write::get());
+ return true;
+ }
+ nextOp = nextOp->getNextNode();
+ }
+ result.first->second = std::make_pair(toOp, nullptr);
+ return false;
+}
+
/// Attempt to eliminate a redundant operation.
LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
bool hasSSADominance) {
@@ -111,45 +201,34 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
if (op->getNumRegions() != 0)
return failure();
- // TODO: We currently only eliminate non side-effecting
- // operations.
- if (!MemoryEffectOpInterface::hasNoEffect(op))
+ // Some simple use case of operation with memory side-effect are dealt with
+ // here. Operations with no side-effect are done after.
+ if (!MemoryEffectOpInterface::hasNoEffect(op)) {
+ auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
+ // TODO: Only basic use case for operations with MemoryEffects::Read can be
+ // eleminated now. More work needs to be done for more complicated patterns
+ // and other side-effects.
+ if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+ return failure();
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+ knownValues.insert(op, op);
return failure();
+ }
// Look for an existing definition for the operation.
if (auto *existing = knownValues.lookup(op)) {
-
- // If we find one then replace all uses of the current operation with the
- // existing one and mark it for deletion. We can only replace an operand in
- // an operation if it has not been visited yet.
- if (hasSSADominance) {
- // If the region has SSA dominance, then we are guaranteed to have not
- // visited any use of the current operation.
- op->replaceAllUsesWith(existing);
- opsToErase.push_back(op);
- } else {
- // When the region does not have SSA dominance, we need to check if we
- // have visited a use before replacing any use.
- for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
- std::get<0>(it).replaceUsesWithIf(
- std::get<1>(it), [&](OpOperand &operand) {
- return !knownValues.count(operand.getOwner());
- });
- }
-
- // There may be some remaining uses of the operation.
- if (op->use_empty())
- opsToErase.push_back(op);
- }
-
- // If the existing operation has an unknown location and the current
- // operation doesn't, then set the existing op's location to that of the
- // current op.
- if (existing->getLoc().isa<UnknownLoc>() &&
- !op->getLoc().isa<UnknownLoc>()) {
- existing->setLoc(op->getLoc());
- }
-
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
++numCSE;
return success();
}
@@ -184,6 +263,8 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
for (auto &region : op.getRegions())
simplifyRegion(knownValues, region);
}
+ // Clear the MemoryEffects cache since its usage is by block only.
+ memEffectsCache.clear();
}
void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
diff --git a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
index ad99faa..034474d 100644
--- a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
index ca056b4..51dedaf 100644
--- a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
index 60d466e..3cefd0e 100644
--- a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 982511f..189cdde 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -265,3 +265,48 @@ func @use_before_def() {
}
return
}
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_direct_duplicated_read_op
+func @remove_direct_duplicated_read_op() -> i32 {
+ // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+ %1 = "test.op_with_memread"() : () -> (i32)
+ // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32
+ %2 = arith.addi %0, %1 : i32
+ return %2 : i32
+}
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_multiple_duplicated_read_op
+func @remove_multiple_duplicated_read_op() -> i64 {
+ // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64
+ %0 = "test.op_with_memread"() : () -> (i64)
+ %1 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64
+ %2 = arith.addi %0, %1 : i64
+ %3 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+ %4 = arith.addi %2, %3 : i64
+ %5 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+ %6 = arith.addi %4, %5 : i64
+ // CHECK-NEXT: return %{{.*}} : i64
+ return %6 : i64
+}
+
+/// This test is checking that CSE is not removing duplicated read op that
+/// have write op in between.
+// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
+func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
+ // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+ "test.op_with_memwrite"() : () -> ()
+ // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32
+ %1 = "test.op_with_memread"() : () -> (i32)
+ // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32
+ %2 = arith.addi %0, %1 : i32
+ return %2 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b157d5d..36e31d1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2761,4 +2761,12 @@ def TestEffectsOpA : TEST_Op<"op_with_effects_a"> {
def TestEffectsOpB : TEST_Op<"op_with_effects_b",
[MemoryEffects<[MemWrite<TestResource>]>]>;
+def TestEffectsRead : TEST_Op<"op_with_memread",
+ [MemoryEffects<[MemRead]>]> {
+ let results = (outs AnyInteger);
+}
+
+def TestEffectsWrite : TEST_Op<"op_with_memwrite",
+ [MemoryEffects<[MemWrite]>]>;
+
#endif // TEST_OPS