aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flang/include/flang/Optimizer/Dialect/FIROps.td4
-rw-r--r--flang/test/Fir/cse.fir57
-rw-r--r--mlir/lib/Transforms/CSE.cpp151
-rw-r--r--mlir/test/Examples/Toy/Ch5/affine-lowering.mlir3
-rw-r--r--mlir/test/Examples/Toy/Ch6/affine-lowering.mlir3
-rw-r--r--mlir/test/Examples/Toy/Ch7/affine-lowering.mlir3
-rw-r--r--mlir/test/Transforms/cse.mlir45
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td8
8 files changed, 231 insertions, 43 deletions
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 262d953..6eb0fdf 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -253,7 +253,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
}
-def fir_LoadOp : fir_OneResultOp<"load"> {
+def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a value from a memory reference";
let description = [{
Load a value from a memory reference into an ssa-value (virtual register).
@@ -320,7 +320,7 @@ def fir_CharConvertOp : fir_Op<"char_convert", []> {
let hasVerifier = 1;
}
-def fir_StoreOp : fir_Op<"store", []> {
+def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store an SSA-value to a memory location";
let description = [{
diff --git a/flang/test/Fir/cse.fir b/flang/test/Fir/cse.fir
new file mode 100644
index 0000000..148b689
--- /dev/null
+++ b/flang/test/Fir/cse.fir
@@ -0,0 +1,57 @@
+// RUN: fir-opt --cse -split-input-file %s | FileCheck %s
+
+// Check that the redundant fir.load is removed.
+func @fun(%arg0: !fir.ref<i64>) -> i64 {
+ %0 = fir.load %arg0 : !fir.ref<i64>
+ %1 = fir.load %arg0 : !fir.ref<i64>
+ %2 = arith.addi %0, %1 : i64
+ return %2 : i64
+}
+
+// CHECK-LABEL: func @fun
+// CHECK-NEXT: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+// CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
+
+// -----
+
+// CHECK-LABEL: func @fun(
+// CHECK-SAME: %[[A:.*]]: !fir.ref<i64>
+func @fun(%a : !fir.ref<i64>) -> i64 {
+ // CHECK: %[[LOAD:.*]] = fir.load %[[A]] : !fir.ref<i64>
+ %1 = fir.load %a : !fir.ref<i64>
+ %2 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
+ %3 = arith.addi %1, %2 : i64
+ %4 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi
+ %5 = arith.addi %3, %4 : i64
+ %6 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi
+ %7 = arith.addi %5, %6 : i64
+ %8 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi
+ %9 = arith.addi %7, %8 : i64
+ %10 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi
+ %11 = arith.addi %10, %9 : i64
+ %12 = fir.load %a : !fir.ref<i64>
+ // CHECK-NEXT: %{{.*}} = arith.addi
+ %13 = arith.addi %11, %12 : i64
+ // CHECK-NEXT: return %{{.*}} : i64
+ return %13 : i64
+}
+
+// -----
+
+func @fun(%a : !fir.ref<i64>) -> i64 {
+ cf.br ^bb1
+^bb1:
+ %1 = fir.load %a : !fir.ref<i64>
+ %2 = fir.load %a : !fir.ref<i64>
+ %3 = arith.addi %1, %2 : i64
+ cf.br ^bb2
+^bb2:
+ %4 = fir.load %a : !fir.ref<i64>
+ %5 = arith.subi %4, %4 : i64
+ return %5 : i64
+}
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 0570c91..080e393 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -60,6 +60,14 @@ struct CSE : public CSEBase<CSE> {
using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
SimpleOperationInfo, AllocatorTy>;
+ /// Cache holding MemoryEffects information between two operations. The first
+ /// operation is stored has the key. The second operation is stored inside a
+ /// pair in the value. The pair also hold the MemoryEffects between those
+ /// two operations. If the MemoryEffects is nullptr then we assume there is
+ /// no operation with MemoryEffects::Write between the two operations.
+ using MemEffectsCache =
+ DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
/// Represents a single entry in the depth first traversal of a CFG.
struct CFGStackNode {
CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
@@ -85,12 +93,94 @@ struct CSE : public CSEBase<CSE> {
void runOnOperation() override;
private:
+ void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance);
+
+ /// Check if there is side-effecting operations other than the given effect
+ /// between the two operations.
+ bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
/// Operations marked as dead and to be erased.
std::vector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
+ MemEffectsCache memEffectsCache;
};
} // namespace
+void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+ Operation *existing, bool hasSSADominance) {
+ // If we find one then replace all uses of the current operation with the
+ // existing one and mark it for deletion. We can only replace an operand in
+ // an operation if it has not been visited yet.
+ if (hasSSADominance) {
+ // If the region has SSA dominance, then we are guaranteed to have not
+ // visited any use of the current operation.
+ op->replaceAllUsesWith(existing);
+ opsToErase.push_back(op);
+ } else {
+ // When the region does not have SSA dominance, we need to check if we
+ // have visited a use before replacing any use.
+ for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
+ std::get<0>(it).replaceUsesWithIf(
+ std::get<1>(it), [&](OpOperand &operand) {
+ return !knownValues.count(operand.getOwner());
+ });
+ }
+
+ // There may be some remaining uses of the operation.
+ if (op->use_empty())
+ opsToErase.push_back(op);
+ }
+
+ // If the existing operation has an unknown location and the current
+ // operation doesn't, then set the existing op's location to that of the
+ // current op.
+ if (existing->getLoc().isa<UnknownLoc>() && !op->getLoc().isa<UnknownLoc>())
+ existing->setLoc(op->getLoc());
+
+ ++numCSE;
+}
+
+bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) {
+ assert(fromOp->getBlock() == toOp->getBlock());
+ assert(
+ isa<MemoryEffectOpInterface>(fromOp) &&
+ cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
+ isa<MemoryEffectOpInterface>(toOp) &&
+ cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+ Operation *nextOp = fromOp->getNextNode();
+ auto result =
+ memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
+ if (result.second) {
+ auto memEffectsCachePair = result.first->second;
+ if (memEffectsCachePair.second == nullptr) {
+ // No MemoryEffects::Write has been detected until the cached operation.
+ // Continue looking from the cached operation to toOp.
+ nextOp = memEffectsCachePair.first;
+ } else {
+ // MemoryEffects::Write has been detected before so there is no need to
+ // check further.
+ return true;
+ }
+ }
+ while (nextOp && nextOp != toOp) {
+ auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
+ // TODO: Do we need to handle other effects generically?
+ // If the operation does not implement the MemoryEffectOpInterface we
+ // conservatively assumes it writes.
+ if ((nextOpMemEffects &&
+ nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
+ !nextOpMemEffects) {
+ result.first->second =
+ std::make_pair(nextOp, MemoryEffects::Write::get());
+ return true;
+ }
+ nextOp = nextOp->getNextNode();
+ }
+ result.first->second = std::make_pair(toOp, nullptr);
+ return false;
+}
+
/// Attempt to eliminate a redundant operation.
LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
bool hasSSADominance) {
@@ -111,45 +201,34 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
if (op->getNumRegions() != 0)
return failure();
- // TODO: We currently only eliminate non side-effecting
- // operations.
- if (!MemoryEffectOpInterface::hasNoEffect(op))
+ // Some simple use case of operation with memory side-effect are dealt with
+ // here. Operations with no side-effect are done after.
+ if (!MemoryEffectOpInterface::hasNoEffect(op)) {
+ auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
+ // TODO: Only basic use case for operations with MemoryEffects::Read can be
+ // eleminated now. More work needs to be done for more complicated patterns
+ // and other side-effects.
+ if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+ return failure();
+
+ // Look for an existing definition for the operation.
+ if (auto *existing = knownValues.lookup(op)) {
+ if (existing->getBlock() == op->getBlock() &&
+ !hasOtherSideEffectingOpInBetween(existing, op)) {
+ // The operation that can be deleted has been reach with no
+ // side-effecting operations in between the existing operation and
+ // this one so we can remove the duplicate.
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
+ return success();
+ }
+ }
+ knownValues.insert(op, op);
return failure();
+ }
// Look for an existing definition for the operation.
if (auto *existing = knownValues.lookup(op)) {
-
- // If we find one then replace all uses of the current operation with the
- // existing one and mark it for deletion. We can only replace an operand in
- // an operation if it has not been visited yet.
- if (hasSSADominance) {
- // If the region has SSA dominance, then we are guaranteed to have not
- // visited any use of the current operation.
- op->replaceAllUsesWith(existing);
- opsToErase.push_back(op);
- } else {
- // When the region does not have SSA dominance, we need to check if we
- // have visited a use before replacing any use.
- for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
- std::get<0>(it).replaceUsesWithIf(
- std::get<1>(it), [&](OpOperand &operand) {
- return !knownValues.count(operand.getOwner());
- });
- }
-
- // There may be some remaining uses of the operation.
- if (op->use_empty())
- opsToErase.push_back(op);
- }
-
- // If the existing operation has an unknown location and the current
- // operation doesn't, then set the existing op's location to that of the
- // current op.
- if (existing->getLoc().isa<UnknownLoc>() &&
- !op->getLoc().isa<UnknownLoc>()) {
- existing->setLoc(op->getLoc());
- }
-
+ replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
++numCSE;
return success();
}
@@ -184,6 +263,8 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
for (auto &region : op.getRegions())
simplifyRegion(knownValues, region);
}
+ // Clear the MemoryEffects cache since its usage is by block only.
+ memEffectsCache.clear();
}
void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
diff --git a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
index ad99faa..034474d 100644
--- a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
index ca056b4..51dedaf 100644
--- a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
index 60d466e..3cefd0e 100644
--- a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
+++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
@@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 982511f..189cdde 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -265,3 +265,48 @@ func @use_before_def() {
}
return
}
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_direct_duplicated_read_op
+func @remove_direct_duplicated_read_op() -> i32 {
+ // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+ %1 = "test.op_with_memread"() : () -> (i32)
+ // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32
+ %2 = arith.addi %0, %1 : i32
+ return %2 : i32
+}
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_multiple_duplicated_read_op
+func @remove_multiple_duplicated_read_op() -> i64 {
+ // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64
+ %0 = "test.op_with_memread"() : () -> (i64)
+ %1 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64
+ %2 = arith.addi %0, %1 : i64
+ %3 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+ %4 = arith.addi %2, %3 : i64
+ %5 = "test.op_with_memread"() : () -> (i64)
+ // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+ %6 = arith.addi %4, %5 : i64
+ // CHECK-NEXT: return %{{.*}} : i64
+ return %6 : i64
+}
+
+/// This test is checking that CSE is not removing duplicated read op that
+/// have write op in between.
+// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
+func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
+ // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32
+ %0 = "test.op_with_memread"() : () -> (i32)
+ "test.op_with_memwrite"() : () -> ()
+ // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32
+ %1 = "test.op_with_memread"() : () -> (i32)
+ // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32
+ %2 = arith.addi %0, %1 : i32
+ return %2 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b157d5d..36e31d1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2761,4 +2761,12 @@ def TestEffectsOpA : TEST_Op<"op_with_effects_a"> {
def TestEffectsOpB : TEST_Op<"op_with_effects_b",
[MemoryEffects<[MemWrite<TestResource>]>]>;
+def TestEffectsRead : TEST_Op<"op_with_memread",
+ [MemoryEffects<[MemRead]>]> {
+ let results = (outs AnyInteger);
+}
+
+def TestEffectsWrite : TEST_Op<"op_with_memwrite",
+ [MemoryEffects<[MemWrite]>]>;
+
#endif // TEST_OPS