aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flang/lib/Optimizer/Transforms/StackArrays.cpp24
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h2
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h76
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h35
-rw-r--r--mlir/include/mlir/Analysis/DataFlowFramework.h221
-rw-r--r--mlir/include/mlir/IR/Block.h21
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp76
-rw-r--r--mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp148
-rw-r--r--mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp4
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp87
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp29
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp2
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp15
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp15
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp7
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp2
-rw-r--r--mlir/test/lib/Analysis/TestDataFlowFramework.cpp31
18 files changed, 491 insertions, 306 deletions
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index d9e7bd6..0c474f4 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -375,7 +375,7 @@ mlir::LogicalResult AllocationAnalysis::visitOperation(
}
} else if (mlir::isa<fir::ResultOp>(op)) {
mlir::Operation *parent = op->getParentOp();
- LatticePoint *parentLattice = getLattice(parent);
+ LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent));
assert(parentLattice);
mlir::ChangeResult parentChanged = parentLattice->join(*after);
propagateIfChanged(parentLattice, parentChanged);
@@ -396,28 +396,29 @@ void AllocationAnalysis::setToEntryState(LatticePoint *lattice) {
/// Mostly a copy of AbstractDenseLattice::processOperation - the difference
/// being that call operations are passed through to the transfer function
mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) {
+ mlir::ProgramPoint *point = getProgramPointAfter(op);
// If the containing block is not executable, bail out.
- if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreateFor<mlir::dataflow::Executable>(
+ point, getProgramPointBefore(op->getBlock()))
+ ->isLive())
return mlir::success();
// Get the dense lattice to update
- mlir::dataflow::AbstractDenseLattice *after = getLattice(op);
+ mlir::dataflow::AbstractDenseLattice *after = getLattice(point);
// If this op implements region control-flow, then control-flow dictates its
// transfer function.
if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
- visitRegionBranchOperation(op, branch, after);
+ visitRegionBranchOperation(point, branch, after);
return mlir::success();
}
// pass call operations through to the transfer function
// Get the dense state before the execution of the op.
- const mlir::dataflow::AbstractDenseLattice *before;
- if (mlir::Operation *prev = op->getPrevNode())
- before = getLatticeFor(op, prev);
- else
- before = getLatticeFor(op, op->getBlock());
+ const mlir::dataflow::AbstractDenseLattice *before =
+ getLatticeFor(point, getProgramPointBefore(op));
/// Invoke the operation transfer function
return visitOperationImpl(op, *before, after);
@@ -452,9 +453,10 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
return mlir::failure();
}
- LatticePoint point{func};
+ LatticePoint point{solver.getProgramPointAfter(func)};
auto joinOperationLattice = [&](mlir::Operation *op) {
- const LatticePoint *lattice = solver.lookupState<LatticePoint>(op);
+ const LatticePoint *lattice =
+ solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op));
// there will be no lattice for an unreachable block
if (lattice)
(void)point.join(*lattice);
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 80c8b86..2250db8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -182,7 +182,7 @@ public:
/// Visit an operation with control-flow semantics and deduce which of its
/// successors are live.
- LogicalResult visit(ProgramPoint point) override;
+ LogicalResult visit(ProgramPoint *point) override;
private:
/// Find and mark symbol callables with potentially unknown callsites as
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 7917f1e..2e32bd1 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -36,8 +36,7 @@ enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };
//===----------------------------------------------------------------------===//
/// This class represents a dense lattice. A dense lattice is attached to
-/// operations to represent the program state after their execution or to blocks
-/// to represent the program state at the beginning of the block. A dense
+/// program point to represent the program state at the program point.
/// lattice is propagated through the IR by dense data-flow analysis.
class AbstractDenseLattice : public AnalysisState {
public:
@@ -59,15 +58,13 @@ public:
//===----------------------------------------------------------------------===//
/// Base class for dense forward data-flow analyses. Dense data-flow analysis
-/// attaches a lattice between the execution of operations and implements a
-/// transfer function from the lattice before each operation to the lattice
-/// after. The lattice contains information about the state of the program at
-/// that point.
+/// attaches a lattice to program points and implements a transfer function from
+/// the lattice before each operation to the lattice after. The lattice contains
+/// information about the state of the program at that program point.
///
-/// In this implementation, a lattice attached to an operation represents the
-/// state of the program after its execution, and a lattice attached to block
-/// represents the state of the program right before it starts executing its
-/// body.
+/// Visit a program point in forward dense data-flow analysis will invoke the
+/// transfer function of the operation preceding the program point iterator.
+/// Visit a program point at the begining of block will visit the block itself.
class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
public:
using DataFlowAnalysis::DataFlowAnalysis;
@@ -76,13 +73,14 @@ public:
/// may modify the program state; that is, every operation and block.
LogicalResult initialize(Operation *top) override;
- /// Visit a program point that modifies the state of the program. If this is a
- /// block, then the state is propagated from control-flow predecessors or
- /// callsites. If this is a call operation or region control-flow operation,
- /// then the state after the execution of the operation is set by control-flow
- /// or the callgraph. Otherwise, this function invokes the operation transfer
- /// function.
- LogicalResult visit(ProgramPoint point) override;
+ /// Visit a program point that modifies the state of the program. If the
+ /// program point is at the beginning of a block, then the state is propagated
+ /// from control-flow predecessors or callsites. If the operation before
+ /// program point iterator is a call operation or region control-flow
+ /// operation, then the state after the execution of the operation is set by
+ /// control-flow or the callgraph. Otherwise, this function invokes the
+ /// operation transfer function before the program point iterator.
+ LogicalResult visit(ProgramPoint *point) override;
protected:
/// Propagate the dense lattice before the execution of an operation to the
@@ -91,15 +89,14 @@ protected:
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;
- /// Get the dense lattice after the execution of the given lattice anchor.
+ /// Get the dense lattice on the given lattice anchor.
virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
- /// Get the dense lattice after the execution of the given program point and
- /// add it as a dependency to a lattice anchor. That is, every time the
- /// lattice after anchor is updated, the dependent program point must be
- /// visited, and the newly triggered visit might update the lattice after
- /// dependent.
- const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
+ /// Get the dense lattice on the given lattice anchor and add dependent as its
+ /// dependency. That is, every time the lattice after anchor is updated, the
+ /// dependent program point must be visited, and the newly triggered visit
+ /// might update the lattice on dependent.
+ const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor);
/// Set the dense lattice at control flow entry point and propagate an update
@@ -153,7 +150,7 @@ protected:
/// Visit a program point within a region branch operation with predecessors
/// in it. This can either be an entry block of one of the regions of the
/// parent operation itself.
- void visitRegionBranchOperation(ProgramPoint point,
+ void visitRegionBranchOperation(ProgramPoint *point,
RegionBranchOpInterface branch,
AbstractDenseLattice *after);
@@ -294,14 +291,12 @@ protected:
//===----------------------------------------------------------------------===//
/// Base class for dense backward dataflow analyses. Such analyses attach a
-/// lattice between the execution of operations and implement a transfer
-/// function from the lattice after the operation ot the lattice before it, thus
-/// propagating backward.
+/// lattice to program point and implement a transfer function from the lattice
+/// after the operation to the lattice before it, thus propagating backward.
///
-/// In this implementation, a lattice attached to an operation represents the
-/// state of the program before its execution, and a lattice attached to a block
-/// represents the state of the program before the end of the block, i.e., after
-/// its terminator.
+/// Visit a program point in dense backward data-flow analysis will invoke the
+/// transfer function of the operation following the program point iterator.
+/// Visit a program point at the end of block will visit the block itself.
class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
public:
/// Construct the analysis in the given solver. Takes a symbol table
@@ -321,9 +316,9 @@ public:
/// operations, the state is propagated using the transfer function
/// (visitOperation).
///
- /// Note: the transfer function is currently *not* invoked for operations with
- /// region or call interface, but *is* invoked for block terminators.
- LogicalResult visit(ProgramPoint point) override;
+ /// Note: the transfer function is currently *not* invoked before operations
+ /// with region or call interface, but *is* invoked before block terminators.
+ LogicalResult visit(ProgramPoint *point) override;
protected:
/// Propagate the dense lattice after the execution of an operation to the
@@ -337,10 +332,11 @@ protected:
/// block.
virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
- /// Get the dense lattice before the execution of the program point in
- /// `anchor` and declare that the `dependent` program point must be updated
- /// every time `point` is.
- const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
+ /// Get the dense lattice on the given lattice anchor and add dependent as its
+ /// dependency. That is, every time the lattice after anchor is updated, the
+ /// dependent program point must be visited, and the newly triggered visit
+ /// might update the lattice before dependent.
+ const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor);
/// Set the dense lattice before at the control flow exit point and propagate
@@ -400,7 +396,7 @@ private:
/// (from which the state is propagated) in or after it. `regionNo` indicates
/// the region that contains the successor, `nullopt` indicating the successor
/// of the branch operation itself.
- void visitRegionBranchOperation(ProgramPoint point,
+ void visitRegionBranchOperation(ProgramPoint *point,
RegionBranchOpInterface branch,
RegionBranchPoint branchPoint,
AbstractDenseLattice *before);
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 933790b..dce7ab3 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -179,18 +179,22 @@ private:
/// operands to the lattices of the results. This analysis will propagate
/// lattices across control-flow edges and the callgraph using liveness
/// information.
+///
+/// Visit a program point in sparse forward data-flow analysis will invoke the
+/// transfer function of the operation preceding the program point iterator.
+/// Visit a program point at the begining of block will visit the block itself.
class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
public:
/// Initialize the analysis by visiting every owner of an SSA value: all
/// operations and blocks.
LogicalResult initialize(Operation *top) override;
- /// Visit a program point. If this is a block and all control-flow
- /// predecessors or callsites are known, then the arguments lattices are
- /// propagated from them. If this is a call operation or an operation with
- /// region control-flow, then its result lattices are set accordingly.
- /// Otherwise, the operation transfer function is invoked.
- LogicalResult visit(ProgramPoint point) override;
+ /// Visit a program point. If this is at beginning of block and all
+ /// control-flow predecessors or callsites are known, then the arguments
+ /// lattices are propagated from them. If this is after call operation or an
+ /// operation with region control-flow, then its result lattices are set
+ /// accordingly. Otherwise, the operation transfer function is invoked.
+ LogicalResult visit(ProgramPoint *point) override;
protected:
explicit AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver);
@@ -221,7 +225,7 @@ protected:
/// Get a read-only lattice element for a value and add it as a dependency to
/// a program point.
- const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+ const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point,
Value value);
/// Set the given lattice element(s) at control flow entry point(s).
@@ -251,7 +255,8 @@ private:
/// operation `branch`, which can either be the entry block of one of the
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
- void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
+ void visitRegionSuccessors(ProgramPoint *point,
+ RegionBranchOpInterface branch,
RegionBranchPoint successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
@@ -312,7 +317,7 @@ protected:
/// Get the lattice element for a value and create a dependency on the
/// provided program point.
- const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
+ const StateT *getLatticeElementFor(ProgramPoint *point, Value value) {
return static_cast<const StateT *>(
AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point,
value));
@@ -377,10 +382,10 @@ public:
/// under it.
LogicalResult initialize(Operation *top) override;
- /// Visit a program point. If this is a call operation or an operation with
+ /// Visit a program point. If it is after call operation or an operation with
/// block or region control-flow, then operand lattices are set accordingly.
/// Otherwise, invokes the operation transfer function (`visitOperationImpl`).
- LogicalResult visit(ProgramPoint point) override;
+ LogicalResult visit(ProgramPoint *point) override;
protected:
explicit AbstractSparseBackwardDataFlowAnalysis(
@@ -445,14 +450,14 @@ private:
/// Get the lattice element for a value, and also set up
/// dependencies so that the analysis on the given ProgramPoint is re-invoked
/// if the value changes.
- const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+ const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point,
Value value);
/// Get the lattice elements for a range of values, and also set up
/// dependencies so that the analysis on the given ProgramPoint is re-invoked
/// if any of the values change.
SmallVector<const AbstractSparseLattice *>
- getLatticeElementsFor(ProgramPoint point, ValueRange values);
+ getLatticeElementsFor(ProgramPoint *point, ValueRange values);
SymbolTableCollection &symbolTable;
};
@@ -465,6 +470,10 @@ private:
/// backwards across the IR by implementing transfer functions for operations.
///
/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+///
+/// Visit a program point in sparse backward data-flow analysis will invoke the
+/// transfer function of the operation preceding the program point iterator.
+/// Visit a program point at the begining of block will visit the block itself.
template <typename StateT>
class SparseBackwardDataFlowAnalysis
: public AbstractSparseBackwardDataFlowAnalysis {
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index b0450ec..969664d 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -18,10 +18,12 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
+#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/TypeName.h"
#include <queue>
+#include <tuple>
namespace mlir {
@@ -51,23 +53,104 @@ class AnalysisState;
/// Program point represents a specific location in the execution of a program.
/// A sequence of program points can be combined into a control flow graph.
-struct ProgramPoint : public PointerUnion<Operation *, Block *> {
- using ParentTy = PointerUnion<Operation *, Block *>;
- /// Inherit constructors.
- using ParentTy::PointerUnion;
- /// Allow implicit conversion from the parent type.
- ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
- /// Allow implicit conversions from operation wrappers.
- /// TODO: For Windows only. Find a better solution.
- template <typename OpT, typename = std::enable_if_t<
- std::is_convertible<OpT, Operation *>::value &&
- !std::is_same<OpT, Operation *>::value>>
- ProgramPoint(OpT op) : ParentTy(op) {}
+struct ProgramPoint : public StorageUniquer::BaseStorage {
+ /// Creates a new program point at the given location.
+ ProgramPoint(Block *parentBlock, Block::iterator pp)
+ : block(parentBlock), point(pp) {}
+
+ /// Creates a new program point at the given operation.
+ ProgramPoint(Operation *op) : op(op) {}
+
+ /// The concrete key type used by the storage uniquer. This class is uniqued
+ /// by its contents.
+ using KeyTy = std::tuple<Block *, Block::iterator, Operation *>;
+
+ /// Create a empty program point.
+ ProgramPoint() {}
+
+ /// Create a new program point from the given program point.
+ ProgramPoint(const ProgramPoint &point) {
+ this->block = point.getBlock();
+ this->point = point.getPoint();
+ this->op = point.getOperation();
+ }
+
+ static ProgramPoint *construct(StorageUniquer::StorageAllocator &alloc,
+ KeyTy &&key) {
+ if (std::get<0>(key)) {
+ return new (alloc.allocate<ProgramPoint>())
+ ProgramPoint(std::get<0>(key), std::get<1>(key));
+ }
+ return new (alloc.allocate<ProgramPoint>()) ProgramPoint(std::get<2>(key));
+ }
+
+ /// Returns true if this program point is set.
+ bool isNull() const { return block == nullptr && op == nullptr; }
+
+ /// Two program points are equal if their block and iterator are equal.
+ bool operator==(const KeyTy &key) const {
+ return block == std::get<0>(key) && point == std::get<1>(key) &&
+ op == std::get<2>(key);
+ }
+
+ bool operator==(const ProgramPoint &pp) const {
+ return block == pp.block && point == pp.point && op == pp.op;
+ }
+
+ /// Get the block contains this program point.
+ Block *getBlock() const { return block; }
+
+ /// Get the the iterator this program point refers to.
+ Block::iterator getPoint() const { return point; }
+
+ /// Get the the iterator this program point refers to.
+ Operation *getOperation() const { return op; }
+
+ /// Get the next operation of this program point.
+ Operation *getNextOp() const {
+ assert(!isBlockEnd());
+ // If the current program point has no parent block, both the next op and
+ // the previous op point to the op corresponding to the current program
+ // point.
+ if (block == nullptr) {
+ return op;
+ }
+ return &*point;
+ }
+
+ /// Get the previous operation of this program point.
+ Operation *getPrevOp() const {
+ assert(!isBlockStart());
+ // If the current program point has no parent block, both the next op and
+ // the previous op point to the op corresponding to the current program
+ // point.
+ if (block == nullptr) {
+ return op;
+ }
+ return &*(--Block::iterator(point));
+ }
+
+ bool isBlockStart() const { return block && block->begin() == point; }
+
+ bool isBlockEnd() const { return block && block->end() == point; }
/// Print the program point.
void print(raw_ostream &os) const;
+
+private:
+ Block *block = nullptr;
+ Block::iterator point;
+
+ /// For operations without a parent block, we record the operation itself as
+ /// its program point.
+ Operation *op = nullptr;
};
+inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
+ point.print(os);
+ return os;
+}
+
//===----------------------------------------------------------------------===//
// GenericLatticeAnchor
//===----------------------------------------------------------------------===//
@@ -165,21 +248,12 @@ private:
/// Fundamental IR components are supported as first-class lattice anchor.
struct LatticeAnchor
- : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
- using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
+ : public PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value> {
+ using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value>;
/// Inherit constructors.
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
- /// Allow implicit conversions from operation wrappers.
- /// TODO: For Windows only. Find a better solution.
- template <typename OpT, typename = std::enable_if_t<
- std::is_convertible<OpT, Operation *>::value &&
- !std::is_same<OpT, Operation *>::value>>
- LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
-
- LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
- LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
/// Print the lattice anchor.
void print(raw_ostream &os) const;
@@ -238,7 +312,9 @@ private:
class DataFlowSolver {
public:
explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
- : config(config) {}
+ : config(config) {
+ uniquer.registerParametricStorageType<ProgramPoint>();
+ }
/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
@@ -277,10 +353,39 @@ public:
return AnchorT::get(uniquer, std::forward<Args>(args)...);
}
+ /// Get a uniqued program point instance.
+ ProgramPoint *getProgramPointBefore(Operation *op) {
+ if (op->getBlock())
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, op->getBlock(),
+ Block::iterator(op), nullptr);
+ else
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, nullptr,
+ Block::iterator(), op);
+ }
+
+ ProgramPoint *getProgramPointBefore(Block *block) {
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, block, block->begin(),
+ nullptr);
+ }
+
+ ProgramPoint *getProgramPointAfter(Operation *op) {
+ if (op->getBlock())
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, op->getBlock(),
+ ++Block::iterator(op), nullptr);
+ else
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, nullptr,
+ Block::iterator(), op);
+ }
+
+ ProgramPoint *getProgramPointAfter(Block *block) {
+ return uniquer.get<ProgramPoint>(/*initFn*/ {}, block, block->end(),
+ nullptr);
+ }
+
/// A work item on the solver queue is a program point, child analysis pair.
/// Each item is processed by invoking the child analysis at the program
/// point.
- using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>;
+ using WorkItem = std::pair<ProgramPoint *, DataFlowAnalysis *>;
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
@@ -343,7 +448,7 @@ class AnalysisState {
public:
virtual ~AnalysisState();
- /// Create the analysis state at the given lattice anchor.
+ /// Create the analysis state on the given lattice anchor.
AnalysisState(LatticeAnchor anchor) : anchor(anchor) {}
/// Returns the lattice anchor this state is located at.
@@ -356,7 +461,7 @@ public:
/// Add a dependency to this analysis state on a lattice anchor and an
/// analysis. If this state is updated, the analysis will be invoked on the
/// given lattice anchor again (in onUpdate()).
- void addDependency(ProgramPoint point, DataFlowAnalysis *analysis);
+ void addDependency(ProgramPoint *point, DataFlowAnalysis *analysis);
protected:
/// This function is called by the solver when the analysis state is updated
@@ -446,12 +551,12 @@ public:
/// `visit` can add new dependencies, but these dependencies will not be
/// dynamically added to the worklist because the solver doesn't know what
/// will provide a value for then.
- virtual LogicalResult visit(ProgramPoint point) = 0;
+ virtual LogicalResult visit(ProgramPoint *point) = 0;
protected:
/// Create a dependency between the given analysis state and lattice anchor
/// on this analysis.
- void addDependency(AnalysisState *state, ProgramPoint point);
+ void addDependency(AnalysisState *state, ProgramPoint *point);
/// Propagate an update to a state if it changed.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
@@ -480,12 +585,29 @@ protected:
/// on `dependent`. If the return state is updated elsewhere, this analysis is
/// re-invoked on the dependent.
template <typename StateT, typename AnchorT>
- const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) {
+ const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) {
StateT *state = getOrCreate<StateT>(anchor);
addDependency(state, dependent);
return state;
}
+ /// Get a uniqued program point instance.
+ ProgramPoint *getProgramPointBefore(Operation *op) {
+ return solver.getProgramPointBefore(op);
+ }
+
+ ProgramPoint *getProgramPointBefore(Block *block) {
+ return solver.getProgramPointBefore(block);
+ }
+
+ ProgramPoint *getProgramPointAfter(Operation *op) {
+ return solver.getProgramPointAfter(op);
+ }
+
+ ProgramPoint *getProgramPointAfter(Block *block) {
+ return solver.getProgramPointAfter(block);
+ }
+
/// Return the configuration of the solver used for this analysis.
const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
@@ -539,12 +661,30 @@ inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) {
namespace llvm {
/// Allow hashing of lattice anchors and program points.
template <>
-struct DenseMapInfo<mlir::LatticeAnchor>
- : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
+struct DenseMapInfo<mlir::ProgramPoint> {
+ static mlir::ProgramPoint getEmptyKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::ProgramPoint(
+ (mlir::Block *)pointer,
+ mlir::Block::iterator((mlir::Operation *)pointer));
+ }
+ static mlir::ProgramPoint getTombstoneKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::ProgramPoint(
+ (mlir::Block *)pointer,
+ mlir::Block::iterator((mlir::Operation *)pointer));
+ }
+ static unsigned getHashValue(mlir::ProgramPoint pp) {
+ return hash_combine(pp.getBlock(), pp.getPoint().getNodePtr());
+ }
+ static bool isEqual(mlir::ProgramPoint lhs, mlir::ProgramPoint rhs) {
+ return lhs == rhs;
+ }
+};
template <>
-struct DenseMapInfo<mlir::ProgramPoint>
- : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
+struct DenseMapInfo<mlir::LatticeAnchor>
+ : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
// Allow llvm::cast style functions.
template <typename To>
@@ -555,19 +695,6 @@ template <typename To>
struct CastInfo<To, const mlir::LatticeAnchor>
: public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};
-template <typename To>
-struct CastInfo<To, mlir::ProgramPoint>
- : public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
-
-template <typename To>
-struct CastInfo<To, const mlir::ProgramPoint>
- : public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
-
-/// Allow stealing the low bits of a ProgramPoint.
-template <>
-struct PointerLikeTypeTraits<mlir::ProgramPoint>
- : public PointerLikeTypeTraits<mlir::ProgramPoint::ParentTy> {};
-
} // end namespace llvm
#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 67825eb..536cbf9 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -406,4 +406,25 @@ private:
raw_ostream &operator<<(raw_ostream &, Block &);
} // namespace mlir
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::Block::iterator> {
+ static mlir::Block::iterator getEmptyKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::Block::iterator((mlir::Operation *)pointer);
+ }
+ static mlir::Block::iterator getTombstoneKey() {
+ void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::Block::iterator((mlir::Operation *)pointer);
+ }
+ static unsigned getHashValue(mlir::Block::iterator iter) {
+ return hash_value(iter.getNodePtr());
+ }
+ static bool isEqual(mlir::Block::iterator lhs, mlir::Block::iterator rhs) {
+ return lhs == rhs;
+ }
+};
+
+} // end namespace llvm
+
#endif // MLIR_IR_BLOCK_H
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index beb6801..3c190d4 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -46,22 +46,23 @@ void Executable::print(raw_ostream &os) const {
void Executable::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
- if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) {
- if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) {
+ if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(anchor)) {
+ if (pp->isBlockStart()) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({block, analysis});
+ solver->enqueue({pp, analysis});
// Re-invoke the analyses on all operations in the block.
for (DataFlowAnalysis *analysis : subscribers)
- for (Operation &op : *block)
- solver->enqueue({&op, analysis});
+ for (Operation &op : *pp->getBlock())
+ solver->enqueue({solver->getProgramPointAfter(&op), analysis});
}
} else if (auto *latticeAnchor =
llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({edge->getTo(), analysis});
+ solver->enqueue(
+ {solver->getProgramPointBefore(edge->getTo()), analysis});
}
}
}
@@ -125,7 +126,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
for (Region &region : top->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(&region.front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
}
@@ -154,7 +156,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// Public symbol callables or those for which we can't see all uses have
// potentially unknown callsites.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
- auto *state = getOrCreate<PredecessorState>(callable);
+ auto *state =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
foundSymbolCallable = true;
@@ -171,7 +174,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
return top->walk([&](CallableOpInterface callable) {
- auto *state = getOrCreate<PredecessorState>(callable);
+ auto *state =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
});
}
@@ -182,7 +186,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If a callable symbol has a non-call use, then we can't be guaranteed to
// know all callsites.
Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
- auto *state = getOrCreate<PredecessorState>(symbol);
+ auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
propagateIfChanged(state, state->setHasUnknownPredecessors());
}
};
@@ -193,7 +197,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
/// Returns true if the operation is a returning terminator in region
/// control-flow or the terminator of a callable region.
static bool isRegionOrCallableReturn(Operation *op) {
- return !op->getNumSuccessors() &&
+ return op->getBlock() != nullptr && !op->getNumSuccessors() &&
isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
op->getBlock()->getTerminator() == op;
}
@@ -205,9 +209,10 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
- getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->blockContentSubscribe(this);
// Visit the op.
- if (failed(visit(op)))
+ if (failed(visit(getProgramPointAfter(op))))
return failure();
}
// Recurse on nested operations.
@@ -219,7 +224,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
- auto *state = getOrCreate<Executable>(to);
+ auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
propagateIfChanged(state, state->setToLive());
auto *edgeState =
getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
@@ -230,18 +235,20 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(&region.front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
}
}
-LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
- if (point.is<Block *>())
+LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
+ if (point->isBlockStart())
return success();
- auto *op = point.get<Operation *>();
+ Operation *op = point->getPrevOp();
// If the parent block is not executable, there is nothing to do.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
// We have a live call op. Add this as a live predecessor of the callee.
@@ -256,7 +263,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
- const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
// If the callsites could not be resolved or are known to be non-empty,
// mark the callable as executable.
@@ -316,11 +324,13 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
!isExternalCallable(callableOp)) {
// Add the live callsite.
- auto *callsites = getOrCreate<PredecessorState>(callableOp);
+ auto *callsites =
+ getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
} else {
// Mark this call op's predecessors as overdefined.
- auto *predecessors = getOrCreate<PredecessorState>(call);
+ auto *predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
}
}
@@ -378,9 +388,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
branch.getEntrySuccessorRegions(*operands, successors);
for (const RegionSuccessor &successor : successors) {
// The successor can be either an entry block or the parent operation.
- ProgramPoint point = successor.getSuccessor()
- ? &successor.getSuccessor()->front()
- : ProgramPoint(branch);
+ ProgramPoint *point =
+ successor.getSuccessor()
+ ? getProgramPointBefore(&successor.getSuccessor()->front())
+ : getProgramPointAfter(branch);
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
@@ -409,12 +420,15 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
for (const RegionSuccessor &successor : successors) {
PredecessorState *predecessors;
if (Region *region = successor.getSuccessor()) {
- auto *state = getOrCreate<Executable>(&region->front());
+ auto *state =
+ getOrCreate<Executable>(getProgramPointBefore(&region->front()));
propagateIfChanged(state, state->setToLive());
- predecessors = getOrCreate<PredecessorState>(&region->front());
+ predecessors = getOrCreate<PredecessorState>(
+ getProgramPointBefore(&region->front()));
} else {
// Add this terminator as a predecessor to the parent op.
- predecessors = getOrCreate<PredecessorState>(branch);
+ predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(branch));
}
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
@@ -424,11 +438,13 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
CallableOpInterface callable) {
// Add as predecessors to all callsites this return op.
- auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+ auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
for (Operation *predecessor : callsites->getKnownPredecessors()) {
assert(isa<CallOpInterface>(predecessor));
- auto *predecessors = getOrCreate<PredecessorState>(predecessor);
+ auto *predecessors =
+ getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
} else {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 300c6e5..340aa39 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -44,10 +44,10 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
return success();
}
-LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- return processOperation(op);
- visitBlock(point.get<Block *>());
+LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockStart())
+ return processOperation(point->getPrevOp());
+ visitBlock(point->getBlock());
return success();
}
@@ -64,8 +64,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
call, CallControlFlowAction::ExternalCallee, before, after);
}
- const auto *predecessors =
- getOrCreateFor<PredecessorState>(call.getOperation(), call);
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(call.getOperation()), getProgramPointAfter(call));
// Otherwise, if not all return sites are known, then conservatively assume we
// can't reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
@@ -87,7 +87,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
// }
AbstractDenseLattice *latticeAfterCall = after;
const AbstractDenseLattice *latticeAtCalleeReturn =
- getLatticeFor(call.getOperation(), predecessor);
+ getLatticeFor(getProgramPointAfter(call.getOperation()),
+ getProgramPointAfter(predecessor));
visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee,
*latticeAtCalleeReturn, latticeAfterCall);
}
@@ -95,24 +96,24 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
LogicalResult
AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
+ ProgramPoint *point = getProgramPointAfter(op);
// If the containing block is not executable, bail out.
- if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock()))
+ ->isLive())
return success();
// Get the dense lattice to update.
- AbstractDenseLattice *after = getLattice(op);
+ AbstractDenseLattice *after = getLattice(point);
// Get the dense state before the execution of the op.
- const AbstractDenseLattice *before;
- if (Operation *prev = op->getPrevNode())
- before = getLatticeFor(op, prev);
- else
- before = getLatticeFor(op, op->getBlock());
+ const AbstractDenseLattice *before =
+ getLatticeFor(point, getProgramPointBefore(op));
// If this op implements region control-flow, then control-flow dictates its
// transfer function.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- visitRegionBranchOperation(op, branch, after);
+ visitRegionBranchOperation(point, branch, after);
return success();
}
@@ -129,11 +130,12 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// If the block is not executable, bail out.
- if (!getOrCreateFor<Executable>(block, block)->isLive())
+ ProgramPoint *point = getProgramPointBefore(block);
+ if (!getOrCreateFor<Executable>(point, point)->isLive())
return;
// Get the dense lattice to update.
- AbstractDenseLattice *after = getLattice(block);
+ AbstractDenseLattice *after = getLattice(point);
// The dense lattices of entry blocks are set by region control-flow or the
// callgraph.
@@ -141,7 +143,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ point, getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints. Do the same if
// interprocedural analysis is not enabled.
@@ -151,10 +154,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
for (Operation *callsite : callsites->getKnownPredecessors()) {
// Get the dense lattice before the callsite.
const AbstractDenseLattice *before;
- if (Operation *prev = callsite->getPrevNode())
- before = getLatticeFor(block, prev);
- else
- before = getLatticeFor(block, callsite->getBlock());
+ before = getLatticeFor(point, getProgramPointBefore(callsite));
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
CallControlFlowAction::EnterCallee,
@@ -165,7 +165,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp()))
- return visitRegionBranchOperation(block, branch, after);
+ return visitRegionBranchOperation(point, branch, after);
// Otherwise, we can't reason about the data-flow.
return setToEntryState(after);
@@ -177,17 +177,18 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Skip control edges that aren't executable.
Block *predecessor = *it;
if (!getOrCreateFor<Executable>(
- block, getLatticeAnchor<CFGEdge>(predecessor, block))
+ point, getLatticeAnchor<CFGEdge>(predecessor, block))
->isLive())
continue;
// Merge in the state from the predecessor's terminator.
- join(after, *getLatticeFor(block, predecessor->getTerminator()));
+ join(after, *getLatticeFor(
+ point, getProgramPointAfter(predecessor->getTerminator())));
}
}
void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
- ProgramPoint point, RegionBranchOpInterface branch,
+ ProgramPoint *point, RegionBranchOpInterface branch,
AbstractDenseLattice *after) {
// Get the terminator predecessors.
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
@@ -198,19 +199,15 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
const AbstractDenseLattice *before;
// If the predecessor is the parent, get the state before the parent.
if (op == branch) {
- if (Operation *prev = op->getPrevNode())
- before = getLatticeFor(point, prev);
- else
- before = getLatticeFor(point, op->getBlock());
-
+ before = getLatticeFor(point, getProgramPointBefore(op));
// Otherwise, get the state after the terminator.
} else {
- before = getLatticeFor(point, op);
+ before = getLatticeFor(point, getProgramPointAfter(op));
}
// This function is called in two cases:
- // 1. when visiting the block (point = block);
- // 2. when visiting the parent operation (point = parent op).
+ // 1. when visiting the block (point = block start);
+ // 2. when visiting the parent operation (point = iter after parent op).
// In both cases, we are looking for predecessor operations of the point,
// 1. predecessor may be the terminator of another block from another
// region (assuming that the block does belong to another region via an
@@ -224,12 +221,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
std::optional<unsigned> regionFrom =
op == branch ? std::optional<unsigned>()
: op->getBlock()->getParent()->getRegionNumber();
- if (auto *toBlock = point.dyn_cast<Block *>()) {
- unsigned regionTo = toBlock->getParent()->getRegionNumber();
+ if (point->isBlockStart()) {
+ unsigned regionTo = point->getBlock()->getParent()->getRegionNumber();
visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
*before, after);
} else {
- assert(point.get<Operation *>() == branch &&
+ assert(point->getPrevOp() == branch &&
"expected to be visiting the branch itself");
// Only need to call the arc transfer when the predecessor is the region
// or the op itself, not the previous op.
@@ -244,7 +241,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
}
const AbstractDenseLattice *
-AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
+AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) {
AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
@@ -273,10 +270,11 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
return success();
}
-LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- return processOperation(op);
- visitBlock(point.get<Block *>());
+LogicalResult
+AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockEnd())
+ return processOperation(point->getNextOp());
+ visitBlock(point->getBlock());
return success();
}
@@ -316,11 +314,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
// ...
// }
Block *calleeEntryBlock = &region->front();
- ProgramPoint calleeEntry = calleeEntryBlock->empty()
- ? ProgramPoint(calleeEntryBlock)
- : &calleeEntryBlock->front();
+ ProgramPoint *calleeEntry = getProgramPointBefore(calleeEntryBlock);
const AbstractDenseLattice &latticeAtCalleeEntry =
- *getLatticeFor(call.getOperation(), calleeEntry);
+ *getLatticeFor(getProgramPointBefore(call.getOperation()), calleeEntry);
AbstractDenseLattice *latticeBeforeCall = before;
visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee,
latticeAtCalleeEntry, latticeBeforeCall);
@@ -328,23 +324,24 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
+ ProgramPoint *point = getProgramPointBefore(op);
// If the containing block is not executable, bail out.
- if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock()))
+ ->isLive())
return success();
// Get the dense lattice to update.
- AbstractDenseLattice *before = getLattice(op);
+ AbstractDenseLattice *before = getLattice(point);
// Get the dense state after execution of this op.
- const AbstractDenseLattice *after;
- if (Operation *next = op->getNextNode())
- after = getLatticeFor(op, next);
- else
- after = getLatticeFor(op, op->getBlock());
+ const AbstractDenseLattice *after =
+ getLatticeFor(point, getProgramPointAfter(op));
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), before);
+ visitRegionBranchOperation(point, branch, RegionBranchPoint::parent(),
+ before);
return success();
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
@@ -357,11 +354,13 @@ AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
}
void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
+ ProgramPoint *point = getProgramPointAfter(block);
// If the block is not executable, bail out.
- if (!getOrCreateFor<Executable>(block, block)->isLive())
+ if (!getOrCreateFor<Executable>(point, getProgramPointBefore(block))
+ ->isLive())
return;
- AbstractDenseLattice *before = getLattice(block);
+ AbstractDenseLattice *before = getLattice(point);
// We need "exit" blocks, i.e. the blocks that may return control to the
// parent operation.
@@ -382,7 +381,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// themselves are predecessors of the callable.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ point, getProgramPointAfter(callable));
// If not all call sites are known, conservative mark all lattices as
// having reached their pessimistic fix points.
if (!callsites->allPredecessorsKnown() ||
@@ -391,11 +391,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
}
for (Operation *callsite : callsites->getKnownPredecessors()) {
- const AbstractDenseLattice *after;
- if (Operation *next = callsite->getNextNode())
- after = getLatticeFor(block, next);
- else
- after = getLatticeFor(block, callsite->getBlock());
+ const AbstractDenseLattice *after =
+ getLatticeFor(point, getProgramPointAfter(callsite));
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
CallControlFlowAction::ExitCallee, *after,
before);
@@ -406,7 +403,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- visitRegionBranchOperation(block, branch, block->getParent(), before);
+ visitRegionBranchOperation(point, branch, block->getParent(), before);
return;
}
@@ -417,22 +414,19 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// Meet the state with the state before block's successors.
for (Block *successor : block->getSuccessors()) {
- if (!getOrCreateFor<Executable>(block,
+ if (!getOrCreateFor<Executable>(point,
getLatticeAnchor<CFGEdge>(block, successor))
->isLive())
continue;
// Merge in the state from the successor: either the first operation, or the
// block itself when empty.
- if (successor->empty())
- meet(before, *getLatticeFor(block, successor));
- else
- meet(before, *getLatticeFor(block, &successor->front()));
+ meet(before, *getLatticeFor(point, getProgramPointBefore(successor)));
}
}
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
- ProgramPoint point, RegionBranchOpInterface branch,
+ ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
// The successors of the operation may be either the first operation of the
@@ -443,22 +437,18 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
- if (Operation *next = branch->getNextNode())
- after = getLatticeFor(point, next);
- else
- after = getLatticeFor(point, branch->getBlock());
+ after = getLatticeFor(point, getProgramPointAfter(branch));
} else {
Region *successorRegion = successor.getSuccessor();
assert(!successorRegion->empty() && "unexpected empty successor region");
Block *successorBlock = &successorRegion->front();
- if (!getOrCreateFor<Executable>(point, successorBlock)->isLive())
+ if (!getOrCreateFor<Executable>(point,
+ getProgramPointBefore(successorBlock))
+ ->isLive())
continue;
- if (successorBlock->empty())
- after = getLatticeFor(point, successorBlock);
- else
- after = getLatticeFor(point, &successorBlock->front());
+ after = getLatticeFor(point, getProgramPointBefore(successorBlock));
}
visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
@@ -467,7 +457,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
}
const AbstractDenseLattice *
-AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
+AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
LatticeAnchor anchor) {
AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9a95f17..bf9eabb 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -111,7 +111,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(op, value)->getValue();
+ return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
});
auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
@@ -159,7 +159,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return bound.getValue();
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(op, value);
+ getLatticeElementFor(getProgramPointAfter(op), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 57a4d4a..9fb4d9d 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -87,7 +87,7 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
meet(operand, *r);
foundLiveResult = true;
}
- addDependency(const_cast<Liveness *>(r), op);
+ addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op));
}
return success();
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 1bd6def..67cf8c9 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -36,7 +36,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
// Push all users of the value to the queue.
for (Operation *user : anchor.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
- solver->enqueue({user, analysis});
+ solver->enqueue({solver->getProgramPointAfter(user), analysis});
}
//===----------------------------------------------------------------------===//
@@ -72,7 +72,8 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
- getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(&block))
+ ->blockContentSubscribe(this);
visitBlock(&block);
for (Operation &op : block)
if (failed(initializeRecursively(&op)))
@@ -83,10 +84,11 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
return success();
}
-LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
- visitBlock(point.get<Block *>());
+LogicalResult
+AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockStart())
+ return visitOperation(point->getPrevOp());
+ visitBlock(point->getBlock());
return success();
}
@@ -97,7 +99,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
return success();
// If the containing block is not executable, bail out.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
// Get the result lattices.
@@ -110,7 +113,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- visitRegionSuccessors({branch}, branch,
+ visitRegionSuccessors(getProgramPointAfter(branch), branch,
/*successor=*/RegionBranchPoint::parent(),
resultLattices);
return success();
@@ -138,7 +141,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// Otherwise, the results of a call operation are determined by the
// callgraph.
- const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(call));
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown()) {
@@ -148,7 +152,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto &&[operand, resLattice] :
llvm::zip(predecessor->getOperands(), resultLattices))
- join(resLattice, *getLatticeElementFor(op, operand));
+ join(resLattice,
+ *getLatticeElementFor(getProgramPointAfter(op), operand));
return success();
}
@@ -162,7 +167,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
return;
// If the block is not executable, bail out.
- if (!getOrCreate<Executable>(block)->isLive())
+ if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive())
return;
// Get the argument lattices.
@@ -179,7 +184,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointBefore(block), getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown() ||
@@ -189,15 +195,17 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
- join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
+ join(std::get<1>(it),
+ *getLatticeElementFor(getProgramPointBefore(block),
+ std::get<0>(it)));
}
return;
}
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
- return visitRegionSuccessors(block, branch, block->getParent(),
- argLattices);
+ return visitRegionSuccessors(getProgramPointBefore(block), branch,
+ block->getParent(), argLattices);
}
// Otherwise, we can't reason about the data-flow.
@@ -226,7 +234,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
branch.getSuccessorOperands(it.getSuccessorIndex());
for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
if (Value operand = operands[idx]) {
- join(lattice, *getLatticeElementFor(block, operand));
+ join(lattice,
+ *getLatticeElementFor(getProgramPointBefore(block), operand));
} else {
// Conservatively consider internally produced arguments as entry
// points.
@@ -240,7 +249,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
- ProgramPoint point, RegionBranchOpInterface branch,
+ ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
@@ -270,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (llvm::dyn_cast_if_present<Operation *>(point)) {
+ if (!point->isBlockStart()) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
@@ -281,7 +290,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
} else {
if (!inputs.empty())
firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
- Region *region = point.get<Block *>()->getParent();
+ Region *region = point->getBlock()->getParent();
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(region, region->getArguments().slice(
@@ -296,7 +305,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
}
const AbstractSparseLattice *
-AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point,
Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
@@ -336,7 +345,8 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region) {
- getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ getOrCreate<Executable>(getProgramPointBefore(&block))
+ ->blockContentSubscribe(this);
// Initialize ops in reverse order, so we can do as much initial
// propagation as possible without having to go through the
// solver queue.
@@ -349,14 +359,14 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
}
LogicalResult
-AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
+AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
// For backward dataflow, we don't have to do any work for the blocks
// themselves. CFG edges between blocks are processed by the BranchOp
// logic in `visitOperation`, and entry blocks for functions are tied
// to the CallOp arguments by visitOperation.
- return success();
+ if (point->isBlockStart())
+ return success();
+ return visitOperation(point->getPrevOp());
}
SmallVector<AbstractSparseLattice *>
@@ -372,7 +382,7 @@ AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
SmallVector<const AbstractSparseLattice *>
AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
- ProgramPoint point, ValueRange values) {
+ ProgramPoint *point, ValueRange values) {
SmallVector<const AbstractSparseLattice *> resultLattices;
resultLattices.reserve(values.size());
for (Value result : values) {
@@ -390,13 +400,14 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// If we're in a dead block, bail out.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (op->getBlock() != nullptr &&
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
- getLatticeElementsFor(op, op->getResults());
+ getLatticeElementsFor(getProgramPointAfter(op), op->getResults());
// Block arguments of region branch operations flow back into the operands
// of the parent op
@@ -425,7 +436,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
detail::getBranchSuccessorArgument(
successorOperands, operand.getOperandNumber(), block)) {
meet(getLatticeElement(operand.get()),
- *getLatticeElementFor(op, *blockArg));
+ *getLatticeElementFor(getProgramPointAfter(op), *blockArg));
}
}
}
@@ -467,7 +478,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
for (auto [blockArg, argOpOperand] :
llvm::zip(block.getArguments(), argOpOperands)) {
meet(getLatticeElement(argOpOperand.get()),
- *getLatticeElementFor(op, blockArg));
+ *getLatticeElementFor(getProgramPointAfter(op), blockArg));
unaccounted.reset(argOpOperand.getOperandNumber());
}
@@ -502,12 +513,13 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
- const PredecessorState *callsites =
- getOrCreateFor<PredecessorState>(op, callable);
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
if (callsites->allPredecessorsKnown()) {
for (Operation *call : callsites->getKnownPredecessors()) {
SmallVector<const AbstractSparseLattice *> callResultLattices =
- getLatticeElementsFor(op, call->getResults());
+ getLatticeElementsFor(getProgramPointAfter(op),
+ call->getResults());
for (auto [op, result] :
llvm::zip(operandLattices, callResultLattices))
meet(op, *result);
@@ -542,7 +554,8 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
- meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
+ meet(getLatticeElement(operand.get()),
+ *getLatticeElementFor(getProgramPointAfter(op), input));
unaccounted.reset(operand.getOperandNumber());
}
}
@@ -576,7 +589,7 @@ void AbstractSparseBackwardDataFlowAnalysis::
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
- *getLatticeElementFor(terminator, input));
+ *getLatticeElementFor(getProgramPointAfter(terminator), input));
unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
}
}
@@ -588,8 +601,8 @@ void AbstractSparseBackwardDataFlowAnalysis::
}
const AbstractSparseLattice *
-AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
- Value value) {
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
+ ProgramPoint *point, Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
return state;
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index a65ddc1..7e83668 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -37,7 +37,7 @@ GenericLatticeAnchor::~GenericLatticeAnchor() = default;
AnalysisState::~AnalysisState() = default;
-void AnalysisState::addDependency(ProgramPoint dependent,
+void AnalysisState::addDependency(ProgramPoint *dependent,
DataFlowAnalysis *analysis) {
auto inserted = dependents.insert({dependent, analysis});
(void)inserted;
@@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
void AnalysisState::dump() const { print(llvm::errs()); }
//===----------------------------------------------------------------------===//
-// LatticeAnchor
+// ProgramPoint
//===----------------------------------------------------------------------===//
void ProgramPoint::print(raw_ostream &os) const {
@@ -61,12 +61,18 @@ void ProgramPoint::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (Operation *op = llvm::dyn_cast<Operation *>(*this)) {
- return op->print(os, OpPrintingFlags().skipRegions());
+ if (!isBlockStart()) {
+ os << "<after operation>:";
+ return getPrevOp()->print(os, OpPrintingFlags().skipRegions());
}
- return get<Block *>()->print(os);
+ os << "<before operation>:";
+ return getNextOp()->print(os, OpPrintingFlags().skipRegions());
}
+//===----------------------------------------------------------------------===//
+// LatticeAnchor
+//===----------------------------------------------------------------------===//
+
void LatticeAnchor::print(raw_ostream &os) const {
if (isNull()) {
os << "<NULL POINT>";
@@ -78,7 +84,7 @@ void LatticeAnchor::print(raw_ostream &os) const {
return value.print(os, OpPrintingFlags().skipRegions());
}
- return get<ProgramPoint>().print(os);
+ return get<ProgramPoint *>()->print(os);
}
Location LatticeAnchor::getLoc() const {
@@ -87,10 +93,10 @@ Location LatticeAnchor::getLoc() const {
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
- ProgramPoint pp = get<ProgramPoint>();
- if (auto *op = llvm::dyn_cast<Operation *>(pp))
- return op->getLoc();
- return pp.get<Block *>()->getParent()->getLoc();
+ ProgramPoint *pp = get<ProgramPoint *>();
+ if (!pp->isBlockStart())
+ return pp->getPrevOp()->getLoc();
+ return pp->getBlock()->getParent()->getLoc();
}
//===----------------------------------------------------------------------===//
@@ -144,7 +150,8 @@ DataFlowAnalysis::~DataFlowAnalysis() = default;
DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
-void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
+void DataFlowAnalysis::addDependency(AnalysisState *state,
+ ProgramPoint *point) {
state->addDependency(point, this);
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 8005f91..521138c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -80,7 +80,7 @@ public:
protected:
void notifyOperationErased(Operation *op) override {
- s.eraseState(op);
+ s.eraseState(s.getProgramPointAfter(op));
for (Value res : op->getResults())
s.eraseState(res);
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index d02efaa..2dc77c9 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -29,7 +29,8 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
os << " ";
block.printAsOperand(os);
os << " = ";
- auto *live = solver.lookupState<Executable>(&block);
+ auto *live = solver.lookupState<Executable>(
+ solver.getProgramPointBefore(&block));
if (live)
os << *live;
else
@@ -49,12 +50,14 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
}
}
if (!region.empty()) {
- auto *preds = solver.lookupState<PredecessorState>(&region.front());
+ auto *preds = solver.lookupState<PredecessorState>(
+ solver.getProgramPointBefore(&region.front()));
if (preds)
os << "region_preds: " << *preds << "\n";
}
}
- auto *preds = solver.lookupState<PredecessorState>(op);
+ auto *preds =
+ solver.lookupState<PredecessorState>(solver.getProgramPointAfter(op));
if (preds)
os << "op_preds: " << *preds << "\n";
});
@@ -68,15 +71,15 @@ struct ConstantAnalysis : public DataFlowAnalysis {
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
- if (failed(visit(op)))
+ if (failed(visit(getProgramPointAfter(op))))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}
- LogicalResult visit(ProgramPoint point) override {
- Operation *op = point.get<Operation *>();
+ LogicalResult visit(ProgramPoint *point) override {
+ Operation *op = point->getPrevOp();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 6794cbb..fa6223a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -115,7 +115,8 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
value, [&](Value value) {
- return getOrCreateFor<UnderlyingValueLattice>(op, value);
+ return getOrCreateFor<UnderlyingValueLattice>(
+ getProgramPointBefore(op), value);
});
// If the underlying value is not known yet, don't propagate.
@@ -151,7 +152,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
- call.getOperation(), value);
+ getProgramPointBefore(call.getOperation()), value);
});
if (!underlyingValue)
return;
@@ -283,9 +284,8 @@ struct TestNextAccessPass
if (!tag)
return;
- const NextAccess *nextAccess = solver.lookupState<NextAccess>(
- op->getNextNode() == nullptr ? ProgramPoint(op->getBlock())
- : op->getNextNode());
+ const NextAccess *nextAccess =
+ solver.lookupState<NextAccess>(solver.getProgramPointAfter(op));
op->setAttr(kNextAccessAttrName,
makeNextAccessAttribute(op, solver, nextAccess));
@@ -300,9 +300,8 @@ struct TestNextAccessPass
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
Block &successorBlock = successor.getSuccessor()->front();
- ProgramPoint successorPoint = successorBlock.empty()
- ? ProgramPoint(&successorBlock)
- : &successorBlock.front();
+ ProgramPoint *successorPoint =
+ solver.getProgramPointBefore(&successorBlock);
entryPointNextAccess.push_back(makeNextAccessAttribute(
op, solver, solver.lookupState<NextAccess>(successorPoint)));
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 301d2a20..89b5c83 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -117,7 +117,8 @@ LogicalResult LastModifiedAnalysis::visitOperation(
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
value, [&](Value value) {
- return getOrCreateFor<UnderlyingValueLattice>(op, value);
+ return getOrCreateFor<UnderlyingValueLattice>(
+ getProgramPointAfter(op), value);
});
// If the underlying value is not yet known, don't propagate yet.
@@ -157,7 +158,7 @@ void LastModifiedAnalysis::visitCallControlFlowTransfer(
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
- call.getOperation(), value);
+ getProgramPointAfter(call.getOperation()), value);
});
if (!underlyingValue)
return;
@@ -243,7 +244,7 @@ struct TestLastModifiedPass
return;
os << "test_tag: " << tag.getValue() << ":\n";
const LastModification *lastMods =
- solver.lookupState<LastModification>(op);
+ solver.lookupState<LastModification>(solver.getProgramPointAfter(op));
assert(lastMods && "expected a dense lattice");
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
os << " operand #" << index << "\n";
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index 2445b58..1f3cab1 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -108,7 +108,7 @@ WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
for (WrittenTo *operand : operands) {
meet(operand, *r);
}
- addDependency(const_cast<WrittenTo *>(r), op);
+ addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op));
}
return success();
}
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 9573ec1..3eb39fc 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -75,7 +75,7 @@ public:
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override;
- LogicalResult visit(ProgramPoint point) override;
+ LogicalResult visit(ProgramPoint *point) override;
private:
void visitBlock(Block *block);
@@ -100,7 +100,8 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
return top->emitError("expected at least one block in the region");
// Initialize the top-level state.
- (void)getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
+ (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front()))
+ ->join(0);
// Visit all nested blocks and operations.
for (Block &block : top->getRegion(0)) {
@@ -114,11 +115,11 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
return success();
}
-LogicalResult FooAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- visitOperation(op);
+LogicalResult FooAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockStart())
+ visitOperation(point->getPrevOp());
else
- visitBlock(point.get<Block *>());
+ visitBlock(point->getBlock());
return success();
}
@@ -127,27 +128,26 @@ void FooAnalysis::visitBlock(Block *block) {
// This is the initial state. Let the framework default-initialize it.
return;
}
- FooState *state = getOrCreate<FooState>(block);
+ ProgramPoint *point = getProgramPointBefore(block);
+ FooState *state = getOrCreate<FooState>(point);
ChangeResult result = ChangeResult::NoChange;
for (Block *pred : block->getPredecessors()) {
// Join the state at the terminators of all predecessors.
- const FooState *predState =
- getOrCreateFor<FooState>(block, pred->getTerminator());
+ const FooState *predState = getOrCreateFor<FooState>(
+ point, getProgramPointAfter(pred->getTerminator()));
result |= state->join(*predState);
}
propagateIfChanged(state, result);
}
void FooAnalysis::visitOperation(Operation *op) {
- FooState *state = getOrCreate<FooState>(op);
+ ProgramPoint *point = getProgramPointAfter(op);
+ FooState *state = getOrCreate<FooState>(point);
ChangeResult result = ChangeResult::NoChange;
// Copy the state across the operation.
const FooState *prevState;
- if (Operation *prev = op->getPrevNode())
- prevState = getOrCreateFor<FooState>(op, prev);
- else
- prevState = getOrCreateFor<FooState>(op, op->getBlock());
+ prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op));
result |= state->set(*prevState);
// Modify the state with the attribute, if specified.
@@ -172,7 +172,8 @@ void TestFooAnalysisPass::runOnOperation() {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
- const FooState *state = solver.lookupState<FooState>(op);
+ const FooState *state =
+ solver.lookupState<FooState>(solver.getProgramPointAfter(op));
assert(state && !state->isUninitialized());
os << tag.getValue() << " -> " << state->getValue() << "\n";
});