diff options
18 files changed, 491 insertions, 306 deletions
diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index d9e7bd6..0c474f4 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -375,7 +375,7 @@ mlir::LogicalResult AllocationAnalysis::visitOperation( } } else if (mlir::isa<fir::ResultOp>(op)) { mlir::Operation *parent = op->getParentOp(); - LatticePoint *parentLattice = getLattice(parent); + LatticePoint *parentLattice = getLattice(getProgramPointAfter(parent)); assert(parentLattice); mlir::ChangeResult parentChanged = parentLattice->join(*after); propagateIfChanged(parentLattice, parentChanged); @@ -396,28 +396,29 @@ void AllocationAnalysis::setToEntryState(LatticePoint *lattice) { /// Mostly a copy of AbstractDenseLattice::processOperation - the difference /// being that call operations are passed through to the transfer function mlir::LogicalResult AllocationAnalysis::processOperation(mlir::Operation *op) { + mlir::ProgramPoint *point = getProgramPointAfter(op); // If the containing block is not executable, bail out. - if (!getOrCreateFor<mlir::dataflow::Executable>(op, op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreateFor<mlir::dataflow::Executable>( + point, getProgramPointBefore(op->getBlock())) + ->isLive()) return mlir::success(); // Get the dense lattice to update - mlir::dataflow::AbstractDenseLattice *after = getLattice(op); + mlir::dataflow::AbstractDenseLattice *after = getLattice(point); // If this op implements region control-flow, then control-flow dictates its // transfer function. if (auto branch = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) { - visitRegionBranchOperation(op, branch, after); + visitRegionBranchOperation(point, branch, after); return mlir::success(); } // pass call operations through to the transfer function // Get the dense state before the execution of the op. - const mlir::dataflow::AbstractDenseLattice *before; - if (mlir::Operation *prev = op->getPrevNode()) - before = getLatticeFor(op, prev); - else - before = getLatticeFor(op, op->getBlock()); + const mlir::dataflow::AbstractDenseLattice *before = + getLatticeFor(point, getProgramPointBefore(op)); /// Invoke the operation transfer function return visitOperationImpl(op, *before, after); @@ -452,9 +453,10 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { return mlir::failure(); } - LatticePoint point{func}; + LatticePoint point{solver.getProgramPointAfter(func)}; auto joinOperationLattice = [&](mlir::Operation *op) { - const LatticePoint *lattice = solver.lookupState<LatticePoint>(op); + const LatticePoint *lattice = + solver.lookupState<LatticePoint>(solver.getProgramPointAfter(op)); // there will be no lattice for an unreachable block if (lattice) (void)point.join(*lattice); diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index 80c8b86..2250db8 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -182,7 +182,7 @@ public: /// Visit an operation with control-flow semantics and deduce which of its /// successors are live. - LogicalResult visit(ProgramPoint point) override; + LogicalResult visit(ProgramPoint *point) override; private: /// Find and mark symbol callables with potentially unknown callsites as diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h index 7917f1e..2e32bd1 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -36,8 +36,7 @@ enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee }; //===----------------------------------------------------------------------===// /// This class represents a dense lattice. A dense lattice is attached to -/// operations to represent the program state after their execution or to blocks -/// to represent the program state at the beginning of the block. A dense +/// program point to represent the program state at the program point. /// lattice is propagated through the IR by dense data-flow analysis. class AbstractDenseLattice : public AnalysisState { public: @@ -59,15 +58,13 @@ public: //===----------------------------------------------------------------------===// /// Base class for dense forward data-flow analyses. Dense data-flow analysis -/// attaches a lattice between the execution of operations and implements a -/// transfer function from the lattice before each operation to the lattice -/// after. The lattice contains information about the state of the program at -/// that point. +/// attaches a lattice to program points and implements a transfer function from +/// the lattice before each operation to the lattice after. The lattice contains +/// information about the state of the program at that program point. /// -/// In this implementation, a lattice attached to an operation represents the -/// state of the program after its execution, and a lattice attached to block -/// represents the state of the program right before it starts executing its -/// body. +/// Visit a program point in forward dense data-flow analysis will invoke the +/// transfer function of the operation preceding the program point iterator. +/// Visit a program point at the begining of block will visit the block itself. class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis { public: using DataFlowAnalysis::DataFlowAnalysis; @@ -76,13 +73,14 @@ public: /// may modify the program state; that is, every operation and block. LogicalResult initialize(Operation *top) override; - /// Visit a program point that modifies the state of the program. If this is a - /// block, then the state is propagated from control-flow predecessors or - /// callsites. If this is a call operation or region control-flow operation, - /// then the state after the execution of the operation is set by control-flow - /// or the callgraph. Otherwise, this function invokes the operation transfer - /// function. - LogicalResult visit(ProgramPoint point) override; + /// Visit a program point that modifies the state of the program. If the + /// program point is at the beginning of a block, then the state is propagated + /// from control-flow predecessors or callsites. If the operation before + /// program point iterator is a call operation or region control-flow + /// operation, then the state after the execution of the operation is set by + /// control-flow or the callgraph. Otherwise, this function invokes the + /// operation transfer function before the program point iterator. + LogicalResult visit(ProgramPoint *point) override; protected: /// Propagate the dense lattice before the execution of an operation to the @@ -91,15 +89,14 @@ protected: const AbstractDenseLattice &before, AbstractDenseLattice *after) = 0; - /// Get the dense lattice after the execution of the given lattice anchor. + /// Get the dense lattice on the given lattice anchor. virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0; - /// Get the dense lattice after the execution of the given program point and - /// add it as a dependency to a lattice anchor. That is, every time the - /// lattice after anchor is updated, the dependent program point must be - /// visited, and the newly triggered visit might update the lattice after - /// dependent. - const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, + /// Get the dense lattice on the given lattice anchor and add dependent as its + /// dependency. That is, every time the lattice after anchor is updated, the + /// dependent program point must be visited, and the newly triggered visit + /// might update the lattice on dependent. + const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor); /// Set the dense lattice at control flow entry point and propagate an update @@ -153,7 +150,7 @@ protected: /// Visit a program point within a region branch operation with predecessors /// in it. This can either be an entry block of one of the regions of the /// parent operation itself. - void visitRegionBranchOperation(ProgramPoint point, + void visitRegionBranchOperation(ProgramPoint *point, RegionBranchOpInterface branch, AbstractDenseLattice *after); @@ -294,14 +291,12 @@ protected: //===----------------------------------------------------------------------===// /// Base class for dense backward dataflow analyses. Such analyses attach a -/// lattice between the execution of operations and implement a transfer -/// function from the lattice after the operation ot the lattice before it, thus -/// propagating backward. +/// lattice to program point and implement a transfer function from the lattice +/// after the operation to the lattice before it, thus propagating backward. /// -/// In this implementation, a lattice attached to an operation represents the -/// state of the program before its execution, and a lattice attached to a block -/// represents the state of the program before the end of the block, i.e., after -/// its terminator. +/// Visit a program point in dense backward data-flow analysis will invoke the +/// transfer function of the operation following the program point iterator. +/// Visit a program point at the end of block will visit the block itself. class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { public: /// Construct the analysis in the given solver. Takes a symbol table @@ -321,9 +316,9 @@ public: /// operations, the state is propagated using the transfer function /// (visitOperation). /// - /// Note: the transfer function is currently *not* invoked for operations with - /// region or call interface, but *is* invoked for block terminators. - LogicalResult visit(ProgramPoint point) override; + /// Note: the transfer function is currently *not* invoked before operations + /// with region or call interface, but *is* invoked before block terminators. + LogicalResult visit(ProgramPoint *point) override; protected: /// Propagate the dense lattice after the execution of an operation to the @@ -337,10 +332,11 @@ protected: /// block. virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0; - /// Get the dense lattice before the execution of the program point in - /// `anchor` and declare that the `dependent` program point must be updated - /// every time `point` is. - const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, + /// Get the dense lattice on the given lattice anchor and add dependent as its + /// dependency. That is, every time the lattice after anchor is updated, the + /// dependent program point must be visited, and the newly triggered visit + /// might update the lattice before dependent. + const AbstractDenseLattice *getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor); /// Set the dense lattice before at the control flow exit point and propagate @@ -400,7 +396,7 @@ private: /// (from which the state is propagated) in or after it. `regionNo` indicates /// the region that contains the successor, `nullopt` indicating the successor /// of the branch operation itself. - void visitRegionBranchOperation(ProgramPoint point, + void visitRegionBranchOperation(ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint branchPoint, AbstractDenseLattice *before); diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index 933790b..dce7ab3 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -179,18 +179,22 @@ private: /// operands to the lattices of the results. This analysis will propagate /// lattices across control-flow edges and the callgraph using liveness /// information. +/// +/// Visit a program point in sparse forward data-flow analysis will invoke the +/// transfer function of the operation preceding the program point iterator. +/// Visit a program point at the begining of block will visit the block itself. class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { public: /// Initialize the analysis by visiting every owner of an SSA value: all /// operations and blocks. LogicalResult initialize(Operation *top) override; - /// Visit a program point. If this is a block and all control-flow - /// predecessors or callsites are known, then the arguments lattices are - /// propagated from them. If this is a call operation or an operation with - /// region control-flow, then its result lattices are set accordingly. - /// Otherwise, the operation transfer function is invoked. - LogicalResult visit(ProgramPoint point) override; + /// Visit a program point. If this is at beginning of block and all + /// control-flow predecessors or callsites are known, then the arguments + /// lattices are propagated from them. If this is after call operation or an + /// operation with region control-flow, then its result lattices are set + /// accordingly. Otherwise, the operation transfer function is invoked. + LogicalResult visit(ProgramPoint *point) override; protected: explicit AbstractSparseForwardDataFlowAnalysis(DataFlowSolver &solver); @@ -221,7 +225,7 @@ protected: /// Get a read-only lattice element for a value and add it as a dependency to /// a program point. - const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, + const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point, Value value); /// Set the given lattice element(s) at control flow entry point(s). @@ -251,7 +255,8 @@ private: /// operation `branch`, which can either be the entry block of one of the /// regions or the parent operation itself, and set either the argument or /// parent result lattices. - void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch, + void visitRegionSuccessors(ProgramPoint *point, + RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices); }; @@ -312,7 +317,7 @@ protected: /// Get the lattice element for a value and create a dependency on the /// provided program point. - const StateT *getLatticeElementFor(ProgramPoint point, Value value) { + const StateT *getLatticeElementFor(ProgramPoint *point, Value value) { return static_cast<const StateT *>( AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(point, value)); @@ -377,10 +382,10 @@ public: /// under it. LogicalResult initialize(Operation *top) override; - /// Visit a program point. If this is a call operation or an operation with + /// Visit a program point. If it is after call operation or an operation with /// block or region control-flow, then operand lattices are set accordingly. /// Otherwise, invokes the operation transfer function (`visitOperationImpl`). - LogicalResult visit(ProgramPoint point) override; + LogicalResult visit(ProgramPoint *point) override; protected: explicit AbstractSparseBackwardDataFlowAnalysis( @@ -445,14 +450,14 @@ private: /// Get the lattice element for a value, and also set up /// dependencies so that the analysis on the given ProgramPoint is re-invoked /// if the value changes. - const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point, + const AbstractSparseLattice *getLatticeElementFor(ProgramPoint *point, Value value); /// Get the lattice elements for a range of values, and also set up /// dependencies so that the analysis on the given ProgramPoint is re-invoked /// if any of the values change. SmallVector<const AbstractSparseLattice *> - getLatticeElementsFor(ProgramPoint point, ValueRange values); + getLatticeElementsFor(ProgramPoint *point, ValueRange values); SymbolTableCollection &symbolTable; }; @@ -465,6 +470,10 @@ private: /// backwards across the IR by implementing transfer functions for operations. /// /// `StateT` is expected to be a subclass of `AbstractSparseLattice`. +/// +/// Visit a program point in sparse backward data-flow analysis will invoke the +/// transfer function of the operation preceding the program point iterator. +/// Visit a program point at the begining of block will visit the block itself. template <typename StateT> class SparseBackwardDataFlowAnalysis : public AbstractSparseBackwardDataFlowAnalysis { diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index b0450ec..969664d 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -18,10 +18,12 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/StorageUniquer.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/TypeName.h" #include <queue> +#include <tuple> namespace mlir { @@ -51,23 +53,104 @@ class AnalysisState; /// Program point represents a specific location in the execution of a program. /// A sequence of program points can be combined into a control flow graph. -struct ProgramPoint : public PointerUnion<Operation *, Block *> { - using ParentTy = PointerUnion<Operation *, Block *>; - /// Inherit constructors. - using ParentTy::PointerUnion; - /// Allow implicit conversion from the parent type. - ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {} - /// Allow implicit conversions from operation wrappers. - /// TODO: For Windows only. Find a better solution. - template <typename OpT, typename = std::enable_if_t< - std::is_convertible<OpT, Operation *>::value && - !std::is_same<OpT, Operation *>::value>> - ProgramPoint(OpT op) : ParentTy(op) {} +struct ProgramPoint : public StorageUniquer::BaseStorage { + /// Creates a new program point at the given location. + ProgramPoint(Block *parentBlock, Block::iterator pp) + : block(parentBlock), point(pp) {} + + /// Creates a new program point at the given operation. + ProgramPoint(Operation *op) : op(op) {} + + /// The concrete key type used by the storage uniquer. This class is uniqued + /// by its contents. + using KeyTy = std::tuple<Block *, Block::iterator, Operation *>; + + /// Create a empty program point. + ProgramPoint() {} + + /// Create a new program point from the given program point. + ProgramPoint(const ProgramPoint &point) { + this->block = point.getBlock(); + this->point = point.getPoint(); + this->op = point.getOperation(); + } + + static ProgramPoint *construct(StorageUniquer::StorageAllocator &alloc, + KeyTy &&key) { + if (std::get<0>(key)) { + return new (alloc.allocate<ProgramPoint>()) + ProgramPoint(std::get<0>(key), std::get<1>(key)); + } + return new (alloc.allocate<ProgramPoint>()) ProgramPoint(std::get<2>(key)); + } + + /// Returns true if this program point is set. + bool isNull() const { return block == nullptr && op == nullptr; } + + /// Two program points are equal if their block and iterator are equal. + bool operator==(const KeyTy &key) const { + return block == std::get<0>(key) && point == std::get<1>(key) && + op == std::get<2>(key); + } + + bool operator==(const ProgramPoint &pp) const { + return block == pp.block && point == pp.point && op == pp.op; + } + + /// Get the block contains this program point. + Block *getBlock() const { return block; } + + /// Get the the iterator this program point refers to. + Block::iterator getPoint() const { return point; } + + /// Get the the iterator this program point refers to. + Operation *getOperation() const { return op; } + + /// Get the next operation of this program point. + Operation *getNextOp() const { + assert(!isBlockEnd()); + // If the current program point has no parent block, both the next op and + // the previous op point to the op corresponding to the current program + // point. + if (block == nullptr) { + return op; + } + return &*point; + } + + /// Get the previous operation of this program point. + Operation *getPrevOp() const { + assert(!isBlockStart()); + // If the current program point has no parent block, both the next op and + // the previous op point to the op corresponding to the current program + // point. + if (block == nullptr) { + return op; + } + return &*(--Block::iterator(point)); + } + + bool isBlockStart() const { return block && block->begin() == point; } + + bool isBlockEnd() const { return block && block->end() == point; } /// Print the program point. void print(raw_ostream &os) const; + +private: + Block *block = nullptr; + Block::iterator point; + + /// For operations without a parent block, we record the operation itself as + /// its program point. + Operation *op = nullptr; }; +inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) { + point.print(os); + return os; +} + //===----------------------------------------------------------------------===// // GenericLatticeAnchor //===----------------------------------------------------------------------===// @@ -165,21 +248,12 @@ private: /// Fundamental IR components are supported as first-class lattice anchor. struct LatticeAnchor - : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> { - using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>; + : public PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value> { + using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint *, Value>; /// Inherit constructors. using ParentTy::PointerUnion; /// Allow implicit conversion from the parent type. LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {} - /// Allow implicit conversions from operation wrappers. - /// TODO: For Windows only. Find a better solution. - template <typename OpT, typename = std::enable_if_t< - std::is_convertible<OpT, Operation *>::value && - !std::is_same<OpT, Operation *>::value>> - LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {} - - LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {} - LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {} /// Print the lattice anchor. void print(raw_ostream &os) const; @@ -238,7 +312,9 @@ private: class DataFlowSolver { public: explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig()) - : config(config) {} + : config(config) { + uniquer.registerParametricStorageType<ProgramPoint>(); + } /// Load an analysis into the solver. Return the analysis instance. template <typename AnalysisT, typename... Args> @@ -277,10 +353,39 @@ public: return AnchorT::get(uniquer, std::forward<Args>(args)...); } + /// Get a uniqued program point instance. + ProgramPoint *getProgramPointBefore(Operation *op) { + if (op->getBlock()) + return uniquer.get<ProgramPoint>(/*initFn*/ {}, op->getBlock(), + Block::iterator(op), nullptr); + else + return uniquer.get<ProgramPoint>(/*initFn*/ {}, nullptr, + Block::iterator(), op); + } + + ProgramPoint *getProgramPointBefore(Block *block) { + return uniquer.get<ProgramPoint>(/*initFn*/ {}, block, block->begin(), + nullptr); + } + + ProgramPoint *getProgramPointAfter(Operation *op) { + if (op->getBlock()) + return uniquer.get<ProgramPoint>(/*initFn*/ {}, op->getBlock(), + ++Block::iterator(op), nullptr); + else + return uniquer.get<ProgramPoint>(/*initFn*/ {}, nullptr, + Block::iterator(), op); + } + + ProgramPoint *getProgramPointAfter(Block *block) { + return uniquer.get<ProgramPoint>(/*initFn*/ {}, block, block->end(), + nullptr); + } + /// A work item on the solver queue is a program point, child analysis pair. /// Each item is processed by invoking the child analysis at the program /// point. - using WorkItem = std::pair<ProgramPoint, DataFlowAnalysis *>; + using WorkItem = std::pair<ProgramPoint *, DataFlowAnalysis *>; /// Push a work item onto the worklist. void enqueue(WorkItem item) { worklist.push(std::move(item)); } @@ -343,7 +448,7 @@ class AnalysisState { public: virtual ~AnalysisState(); - /// Create the analysis state at the given lattice anchor. + /// Create the analysis state on the given lattice anchor. AnalysisState(LatticeAnchor anchor) : anchor(anchor) {} /// Returns the lattice anchor this state is located at. @@ -356,7 +461,7 @@ public: /// Add a dependency to this analysis state on a lattice anchor and an /// analysis. If this state is updated, the analysis will be invoked on the /// given lattice anchor again (in onUpdate()). - void addDependency(ProgramPoint point, DataFlowAnalysis *analysis); + void addDependency(ProgramPoint *point, DataFlowAnalysis *analysis); protected: /// This function is called by the solver when the analysis state is updated @@ -446,12 +551,12 @@ public: /// `visit` can add new dependencies, but these dependencies will not be /// dynamically added to the worklist because the solver doesn't know what /// will provide a value for then. - virtual LogicalResult visit(ProgramPoint point) = 0; + virtual LogicalResult visit(ProgramPoint *point) = 0; protected: /// Create a dependency between the given analysis state and lattice anchor /// on this analysis. - void addDependency(AnalysisState *state, ProgramPoint point); + void addDependency(AnalysisState *state, ProgramPoint *point); /// Propagate an update to a state if it changed. void propagateIfChanged(AnalysisState *state, ChangeResult changed); @@ -480,12 +585,29 @@ protected: /// on `dependent`. If the return state is updated elsewhere, this analysis is /// re-invoked on the dependent. template <typename StateT, typename AnchorT> - const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) { + const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) { StateT *state = getOrCreate<StateT>(anchor); addDependency(state, dependent); return state; } + /// Get a uniqued program point instance. + ProgramPoint *getProgramPointBefore(Operation *op) { + return solver.getProgramPointBefore(op); + } + + ProgramPoint *getProgramPointBefore(Block *block) { + return solver.getProgramPointBefore(block); + } + + ProgramPoint *getProgramPointAfter(Operation *op) { + return solver.getProgramPointAfter(op); + } + + ProgramPoint *getProgramPointAfter(Block *block) { + return solver.getProgramPointAfter(block); + } + /// Return the configuration of the solver used for this analysis. const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); } @@ -539,12 +661,30 @@ inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) { namespace llvm { /// Allow hashing of lattice anchors and program points. template <> -struct DenseMapInfo<mlir::LatticeAnchor> - : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {}; +struct DenseMapInfo<mlir::ProgramPoint> { + static mlir::ProgramPoint getEmptyKey() { + void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); + return mlir::ProgramPoint( + (mlir::Block *)pointer, + mlir::Block::iterator((mlir::Operation *)pointer)); + } + static mlir::ProgramPoint getTombstoneKey() { + void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); + return mlir::ProgramPoint( + (mlir::Block *)pointer, + mlir::Block::iterator((mlir::Operation *)pointer)); + } + static unsigned getHashValue(mlir::ProgramPoint pp) { + return hash_combine(pp.getBlock(), pp.getPoint().getNodePtr()); + } + static bool isEqual(mlir::ProgramPoint lhs, mlir::ProgramPoint rhs) { + return lhs == rhs; + } +}; template <> -struct DenseMapInfo<mlir::ProgramPoint> - : public DenseMapInfo<mlir::ProgramPoint::ParentTy> {}; +struct DenseMapInfo<mlir::LatticeAnchor> + : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {}; // Allow llvm::cast style functions. template <typename To> @@ -555,19 +695,6 @@ template <typename To> struct CastInfo<To, const mlir::LatticeAnchor> : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {}; -template <typename To> -struct CastInfo<To, mlir::ProgramPoint> - : public CastInfo<To, mlir::ProgramPoint::PointerUnion> {}; - -template <typename To> -struct CastInfo<To, const mlir::ProgramPoint> - : public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {}; - -/// Allow stealing the low bits of a ProgramPoint. -template <> -struct PointerLikeTypeTraits<mlir::ProgramPoint> - : public PointerLikeTypeTraits<mlir::ProgramPoint::ParentTy> {}; - } // end namespace llvm #endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 67825eb..536cbf9 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -406,4 +406,25 @@ private: raw_ostream &operator<<(raw_ostream &, Block &); } // namespace mlir +namespace llvm { +template <> +struct DenseMapInfo<mlir::Block::iterator> { + static mlir::Block::iterator getEmptyKey() { + void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); + return mlir::Block::iterator((mlir::Operation *)pointer); + } + static mlir::Block::iterator getTombstoneKey() { + void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); + return mlir::Block::iterator((mlir::Operation *)pointer); + } + static unsigned getHashValue(mlir::Block::iterator iter) { + return hash_value(iter.getNodePtr()); + } + static bool isEqual(mlir::Block::iterator lhs, mlir::Block::iterator rhs) { + return lhs == rhs; + } +}; + +} // end namespace llvm + #endif // MLIR_IR_BLOCK_H diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index beb6801..3c190d4 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -46,22 +46,23 @@ void Executable::print(raw_ostream &os) const { void Executable::onUpdate(DataFlowSolver *solver) const { AnalysisState::onUpdate(solver); - if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) { - if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) { + if (ProgramPoint *pp = llvm::dyn_cast_if_present<ProgramPoint *>(anchor)) { + if (pp->isBlockStart()) { // Re-invoke the analyses on the block itself. for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({block, analysis}); + solver->enqueue({pp, analysis}); // Re-invoke the analyses on all operations in the block. for (DataFlowAnalysis *analysis : subscribers) - for (Operation &op : *block) - solver->enqueue({&op, analysis}); + for (Operation &op : *pp->getBlock()) + solver->enqueue({solver->getProgramPointAfter(&op), analysis}); } } else if (auto *latticeAnchor = llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) { // Re-invoke the analysis on the successor block. if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) { for (DataFlowAnalysis *analysis : subscribers) - solver->enqueue({edge->getTo(), analysis}); + solver->enqueue( + {solver->getProgramPointBefore(edge->getTo()), analysis}); } } } @@ -125,7 +126,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { for (Region ®ion : top->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate<Executable>(®ion.front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); } @@ -154,7 +156,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // Public symbol callables or those for which we can't see all uses have // potentially unknown callsites. if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { - auto *state = getOrCreate<PredecessorState>(callable); + auto *state = + getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); } foundSymbolCallable = true; @@ -171,7 +174,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // If we couldn't gather the symbol uses, conservatively assume that // we can't track information for any nested symbols. return top->walk([&](CallableOpInterface callable) { - auto *state = getOrCreate<PredecessorState>(callable); + auto *state = + getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); }); } @@ -182,7 +186,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { // If a callable symbol has a non-call use, then we can't be guaranteed to // know all callsites. Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef()); - auto *state = getOrCreate<PredecessorState>(symbol); + auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol)); propagateIfChanged(state, state->setHasUnknownPredecessors()); } }; @@ -193,7 +197,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { /// Returns true if the operation is a returning terminator in region /// control-flow or the terminator of a callable region. static bool isRegionOrCallableReturn(Operation *op) { - return !op->getNumSuccessors() && + return op->getBlock() != nullptr && !op->getNumSuccessors() && isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) && op->getBlock()->getTerminator() == op; } @@ -205,9 +209,10 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { // When the liveness of the parent block changes, make sure to re-invoke the // analysis on the op. if (op->getBlock()) - getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) + ->blockContentSubscribe(this); // Visit the op. - if (failed(visit(op))) + if (failed(visit(getProgramPointAfter(op)))) return failure(); } // Recurse on nested operations. @@ -219,7 +224,7 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { } void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { - auto *state = getOrCreate<Executable>(to); + auto *state = getOrCreate<Executable>(getProgramPointBefore(to)); propagateIfChanged(state, state->setToLive()); auto *edgeState = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to)); @@ -230,18 +235,20 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { for (Region ®ion : op->getRegions()) { if (region.empty()) continue; - auto *state = getOrCreate<Executable>(®ion.front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); } } -LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { - if (point.is<Block *>()) +LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { + if (point->isBlockStart()) return success(); - auto *op = point.get<Operation *>(); + Operation *op = point->getPrevOp(); // If the parent block is not executable, there is nothing to do. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); // We have a live call op. Add this as a live predecessor of the callee. @@ -256,7 +263,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { // Check if this is a callable operation. } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { - const auto *callsites = getOrCreateFor<PredecessorState>(op, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); // If the callsites could not be resolved or are known to be non-empty, // mark the callable as executable. @@ -316,11 +324,13 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { if (isa_and_nonnull<SymbolOpInterface>(callableOp) && !isExternalCallable(callableOp)) { // Add the live callsite. - auto *callsites = getOrCreate<PredecessorState>(callableOp); + auto *callsites = + getOrCreate<PredecessorState>(getProgramPointAfter(callableOp)); propagateIfChanged(callsites, callsites->join(call)); } else { // Mark this call op's predecessors as overdefined. - auto *predecessors = getOrCreate<PredecessorState>(call); + auto *predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(call)); propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); } } @@ -378,9 +388,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation( branch.getEntrySuccessorRegions(*operands, successors); for (const RegionSuccessor &successor : successors) { // The successor can be either an entry block or the parent operation. - ProgramPoint point = successor.getSuccessor() - ? &successor.getSuccessor()->front() - : ProgramPoint(branch); + ProgramPoint *point = + successor.getSuccessor() + ? getProgramPointBefore(&successor.getSuccessor()->front()) + : getProgramPointAfter(branch); // Mark the entry block as executable. auto *state = getOrCreate<Executable>(point); propagateIfChanged(state, state->setToLive()); @@ -409,12 +420,15 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, for (const RegionSuccessor &successor : successors) { PredecessorState *predecessors; if (Region *region = successor.getSuccessor()) { - auto *state = getOrCreate<Executable>(®ion->front()); + auto *state = + getOrCreate<Executable>(getProgramPointBefore(®ion->front())); propagateIfChanged(state, state->setToLive()); - predecessors = getOrCreate<PredecessorState>(®ion->front()); + predecessors = getOrCreate<PredecessorState>( + getProgramPointBefore(®ion->front())); } else { // Add this terminator as a predecessor to the parent op. - predecessors = getOrCreate<PredecessorState>(branch); + predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(branch)); } propagateIfChanged(predecessors, predecessors->join(op, successor.getSuccessorInputs())); @@ -424,11 +438,13 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, void DeadCodeAnalysis::visitCallableTerminator(Operation *op, CallableOpInterface callable) { // Add as predecessors to all callsites this return op. - auto *callsites = getOrCreateFor<PredecessorState>(op, callable); + auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); bool canResolve = op->hasTrait<OpTrait::ReturnLike>(); for (Operation *predecessor : callsites->getKnownPredecessors()) { assert(isa<CallOpInterface>(predecessor)); - auto *predecessors = getOrCreate<PredecessorState>(predecessor); + auto *predecessors = + getOrCreate<PredecessorState>(getProgramPointAfter(predecessor)); if (canResolve) { propagateIfChanged(predecessors, predecessors->join(op)); } else { diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 300c6e5..340aa39 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -44,10 +44,10 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) { return success(); } -LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) { - if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) - return processOperation(op); - visitBlock(point.get<Block *>()); +LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) + return processOperation(point->getPrevOp()); + visitBlock(point->getBlock()); return success(); } @@ -64,8 +64,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( call, CallControlFlowAction::ExternalCallee, before, after); } - const auto *predecessors = - getOrCreateFor<PredecessorState>(call.getOperation(), call); + const auto *predecessors = getOrCreateFor<PredecessorState>( + getProgramPointAfter(call.getOperation()), getProgramPointAfter(call)); // Otherwise, if not all return sites are known, then conservatively assume we // can't reason about the data-flow. if (!predecessors->allPredecessorsKnown()) @@ -87,7 +87,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( // } AbstractDenseLattice *latticeAfterCall = after; const AbstractDenseLattice *latticeAtCalleeReturn = - getLatticeFor(call.getOperation(), predecessor); + getLatticeFor(getProgramPointAfter(call.getOperation()), + getProgramPointAfter(predecessor)); visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee, *latticeAtCalleeReturn, latticeAfterCall); } @@ -95,24 +96,24 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( LogicalResult AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { + ProgramPoint *point = getProgramPointAfter(op); // If the containing block is not executable, bail out. - if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock())) + ->isLive()) return success(); // Get the dense lattice to update. - AbstractDenseLattice *after = getLattice(op); + AbstractDenseLattice *after = getLattice(point); // Get the dense state before the execution of the op. - const AbstractDenseLattice *before; - if (Operation *prev = op->getPrevNode()) - before = getLatticeFor(op, prev); - else - before = getLatticeFor(op, op->getBlock()); + const AbstractDenseLattice *before = + getLatticeFor(point, getProgramPointBefore(op)); // If this op implements region control-flow, then control-flow dictates its // transfer function. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - visitRegionBranchOperation(op, branch, after); + visitRegionBranchOperation(point, branch, after); return success(); } @@ -129,11 +130,12 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { // If the block is not executable, bail out. - if (!getOrCreateFor<Executable>(block, block)->isLive()) + ProgramPoint *point = getProgramPointBefore(block); + if (!getOrCreateFor<Executable>(point, point)->isLive()) return; // Get the dense lattice to update. - AbstractDenseLattice *after = getLattice(block); + AbstractDenseLattice *after = getLattice(point); // The dense lattices of entry blocks are set by region control-flow or the // callgraph. @@ -141,7 +143,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + point, getProgramPointAfter(callable)); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. Do the same if // interprocedural analysis is not enabled. @@ -151,10 +154,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { for (Operation *callsite : callsites->getKnownPredecessors()) { // Get the dense lattice before the callsite. const AbstractDenseLattice *before; - if (Operation *prev = callsite->getPrevNode()) - before = getLatticeFor(block, prev); - else - before = getLatticeFor(block, callsite->getBlock()); + before = getLatticeFor(point, getProgramPointBefore(callsite)); visitCallControlFlowTransfer(cast<CallOpInterface>(callsite), CallControlFlowAction::EnterCallee, @@ -165,7 +165,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { // Check if we can reason about the control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) - return visitRegionBranchOperation(block, branch, after); + return visitRegionBranchOperation(point, branch, after); // Otherwise, we can't reason about the data-flow. return setToEntryState(after); @@ -177,17 +177,18 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { // Skip control edges that aren't executable. Block *predecessor = *it; if (!getOrCreateFor<Executable>( - block, getLatticeAnchor<CFGEdge>(predecessor, block)) + point, getLatticeAnchor<CFGEdge>(predecessor, block)) ->isLive()) continue; // Merge in the state from the predecessor's terminator. - join(after, *getLatticeFor(block, predecessor->getTerminator())); + join(after, *getLatticeFor( + point, getProgramPointAfter(predecessor->getTerminator()))); } } void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( - ProgramPoint point, RegionBranchOpInterface branch, + ProgramPoint *point, RegionBranchOpInterface branch, AbstractDenseLattice *after) { // Get the terminator predecessors. const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); @@ -198,19 +199,15 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( const AbstractDenseLattice *before; // If the predecessor is the parent, get the state before the parent. if (op == branch) { - if (Operation *prev = op->getPrevNode()) - before = getLatticeFor(point, prev); - else - before = getLatticeFor(point, op->getBlock()); - + before = getLatticeFor(point, getProgramPointBefore(op)); // Otherwise, get the state after the terminator. } else { - before = getLatticeFor(point, op); + before = getLatticeFor(point, getProgramPointAfter(op)); } // This function is called in two cases: - // 1. when visiting the block (point = block); - // 2. when visiting the parent operation (point = parent op). + // 1. when visiting the block (point = block start); + // 2. when visiting the parent operation (point = iter after parent op). // In both cases, we are looking for predecessor operations of the point, // 1. predecessor may be the terminator of another block from another // region (assuming that the block does belong to another region via an @@ -224,12 +221,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( std::optional<unsigned> regionFrom = op == branch ? std::optional<unsigned>() : op->getBlock()->getParent()->getRegionNumber(); - if (auto *toBlock = point.dyn_cast<Block *>()) { - unsigned regionTo = toBlock->getParent()->getRegionNumber(); + if (point->isBlockStart()) { + unsigned regionTo = point->getBlock()->getParent()->getRegionNumber(); visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo, *before, after); } else { - assert(point.get<Operation *>() == branch && + assert(point->getPrevOp() == branch && "expected to be visiting the branch itself"); // Only need to call the arc transfer when the predecessor is the region // or the op itself, not the previous op. @@ -244,7 +241,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( } const AbstractDenseLattice * -AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, +AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor) { AbstractDenseLattice *state = getLattice(anchor); addDependency(state, dependent); @@ -273,10 +270,11 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) { return success(); } -LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) { - if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) - return processOperation(op); - visitBlock(point.get<Block *>()); +LogicalResult +AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockEnd()) + return processOperation(point->getNextOp()); + visitBlock(point->getBlock()); return success(); } @@ -316,11 +314,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( // ... // } Block *calleeEntryBlock = ®ion->front(); - ProgramPoint calleeEntry = calleeEntryBlock->empty() - ? ProgramPoint(calleeEntryBlock) - : &calleeEntryBlock->front(); + ProgramPoint *calleeEntry = getProgramPointBefore(calleeEntryBlock); const AbstractDenseLattice &latticeAtCalleeEntry = - *getLatticeFor(call.getOperation(), calleeEntry); + *getLatticeFor(getProgramPointBefore(call.getOperation()), calleeEntry); AbstractDenseLattice *latticeBeforeCall = before; visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee, latticeAtCalleeEntry, latticeBeforeCall); @@ -328,23 +324,24 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( LogicalResult AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { + ProgramPoint *point = getProgramPointBefore(op); // If the containing block is not executable, bail out. - if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock())) + ->isLive()) return success(); // Get the dense lattice to update. - AbstractDenseLattice *before = getLattice(op); + AbstractDenseLattice *before = getLattice(point); // Get the dense state after execution of this op. - const AbstractDenseLattice *after; - if (Operation *next = op->getNextNode()) - after = getLatticeFor(op, next); - else - after = getLatticeFor(op, op->getBlock()); + const AbstractDenseLattice *after = + getLatticeFor(point, getProgramPointAfter(op)); // Special cases where control flow may dictate data flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(), before); + visitRegionBranchOperation(point, branch, RegionBranchPoint::parent(), + before); return success(); } if (auto call = dyn_cast<CallOpInterface>(op)) { @@ -357,11 +354,13 @@ AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { } void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { + ProgramPoint *point = getProgramPointAfter(block); // If the block is not executable, bail out. - if (!getOrCreateFor<Executable>(block, block)->isLive()) + if (!getOrCreateFor<Executable>(point, getProgramPointBefore(block)) + ->isLive()) return; - AbstractDenseLattice *before = getLattice(block); + AbstractDenseLattice *before = getLattice(point); // We need "exit" blocks, i.e. the blocks that may return control to the // parent operation. @@ -382,7 +381,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // themselves are predecessors of the callable. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + point, getProgramPointAfter(callable)); // If not all call sites are known, conservative mark all lattices as // having reached their pessimistic fix points. if (!callsites->allPredecessorsKnown() || @@ -391,11 +391,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { } for (Operation *callsite : callsites->getKnownPredecessors()) { - const AbstractDenseLattice *after; - if (Operation *next = callsite->getNextNode()) - after = getLatticeFor(block, next); - else - after = getLatticeFor(block, callsite->getBlock()); + const AbstractDenseLattice *after = + getLatticeFor(point, getProgramPointAfter(callsite)); visitCallControlFlowTransfer(cast<CallOpInterface>(callsite), CallControlFlowAction::ExitCallee, *after, before); @@ -406,7 +403,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // If this block is exiting from an operation with region-based control // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { - visitRegionBranchOperation(block, branch, block->getParent(), before); + visitRegionBranchOperation(point, branch, block->getParent(), before); return; } @@ -417,22 +414,19 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // Meet the state with the state before block's successors. for (Block *successor : block->getSuccessors()) { - if (!getOrCreateFor<Executable>(block, + if (!getOrCreateFor<Executable>(point, getLatticeAnchor<CFGEdge>(block, successor)) ->isLive()) continue; // Merge in the state from the successor: either the first operation, or the // block itself when empty. - if (successor->empty()) - meet(before, *getLatticeFor(block, successor)); - else - meet(before, *getLatticeFor(block, &successor->front())); + meet(before, *getLatticeFor(point, getProgramPointBefore(successor))); } } void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( - ProgramPoint point, RegionBranchOpInterface branch, + ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint branchPoint, AbstractDenseLattice *before) { // The successors of the operation may be either the first operation of the @@ -443,22 +437,18 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( for (const RegionSuccessor &successor : successors) { const AbstractDenseLattice *after; if (successor.isParent() || successor.getSuccessor()->empty()) { - if (Operation *next = branch->getNextNode()) - after = getLatticeFor(point, next); - else - after = getLatticeFor(point, branch->getBlock()); + after = getLatticeFor(point, getProgramPointAfter(branch)); } else { Region *successorRegion = successor.getSuccessor(); assert(!successorRegion->empty() && "unexpected empty successor region"); Block *successorBlock = &successorRegion->front(); - if (!getOrCreateFor<Executable>(point, successorBlock)->isLive()) + if (!getOrCreateFor<Executable>(point, + getProgramPointBefore(successorBlock)) + ->isLive()) continue; - if (successorBlock->empty()) - after = getLatticeFor(point, successorBlock); - else - after = getLatticeFor(point, &successorBlock->front()); + after = getLatticeFor(point, getProgramPointBefore(successorBlock)); } visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after, @@ -467,7 +457,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( } const AbstractDenseLattice * -AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, +AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent, LatticeAnchor anchor) { AbstractDenseLattice *state = getLattice(anchor); addDependency(state, dependent); diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 9a95f17..bf9eabb 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -111,7 +111,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { - return getLatticeElementFor(op, value)->getValue(); + return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); }); auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { @@ -159,7 +159,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return bound.getValue(); } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { const IntegerValueRangeLattice *lattice = - getLatticeElementFor(op, value); + getLatticeElementFor(getProgramPointAfter(op), value); if (lattice != nullptr && !lattice->getValue().isUninitialized()) return getUpper ? lattice->getValue().getValue().smax() : lattice->getValue().getValue().smin(); diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 57a4d4a..9fb4d9d 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -87,7 +87,7 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands, meet(operand, *r); foundLiveResult = true; } - addDependency(const_cast<Liveness *>(r), op); + addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op)); } return success(); } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 1bd6def..67cf8c9 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -36,7 +36,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { // Push all users of the value to the queue. for (Operation *user : anchor.get<Value>().getUsers()) for (DataFlowAnalysis *analysis : useDefSubscribers) - solver->enqueue({user, analysis}); + solver->enqueue({solver->getProgramPointAfter(user), analysis}); } //===----------------------------------------------------------------------===// @@ -72,7 +72,8 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region) { - getOrCreate<Executable>(&block)->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(&block)) + ->blockContentSubscribe(this); visitBlock(&block); for (Operation &op : block) if (failed(initializeRecursively(&op))) @@ -83,10 +84,11 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { return success(); } -LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) { - if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) - return visitOperation(op); - visitBlock(point.get<Block *>()); +LogicalResult +AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) + return visitOperation(point->getPrevOp()); + visitBlock(point->getBlock()); return success(); } @@ -97,7 +99,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { return success(); // If the containing block is not executable, bail out. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); // Get the result lattices. @@ -110,7 +113,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - visitRegionSuccessors({branch}, branch, + visitRegionSuccessors(getProgramPointAfter(branch), branch, /*successor=*/RegionBranchPoint::parent(), resultLattices); return success(); @@ -138,7 +141,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // Otherwise, the results of a call operation are determined by the // callgraph. - const auto *predecessors = getOrCreateFor<PredecessorState>(op, call); + const auto *predecessors = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(call)); // If not all return sites are known, then conservatively assume we can't // reason about the data-flow. if (!predecessors->allPredecessorsKnown()) { @@ -148,7 +152,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { for (Operation *predecessor : predecessors->getKnownPredecessors()) for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) - join(resLattice, *getLatticeElementFor(op, operand)); + join(resLattice, + *getLatticeElementFor(getProgramPointAfter(op), operand)); return success(); } @@ -162,7 +167,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { return; // If the block is not executable, bail out. - if (!getOrCreate<Executable>(block)->isLive()) + if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) return; // Get the argument lattices. @@ -179,7 +184,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // Check if this block is the entry block of a callable region. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { - const auto *callsites = getOrCreateFor<PredecessorState>(block, callable); + const auto *callsites = getOrCreateFor<PredecessorState>( + getProgramPointBefore(block), getProgramPointAfter(callable)); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. if (!callsites->allPredecessorsKnown() || @@ -189,15 +195,17 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { for (Operation *callsite : callsites->getKnownPredecessors()) { auto call = cast<CallOpInterface>(callsite); for (auto it : llvm::zip(call.getArgOperands(), argLattices)) - join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it))); + join(std::get<1>(it), + *getLatticeElementFor(getProgramPointBefore(block), + std::get<0>(it))); } return; } // Check if the lattices can be determined from region control flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { - return visitRegionSuccessors(block, branch, block->getParent(), - argLattices); + return visitRegionSuccessors(getProgramPointBefore(block), branch, + block->getParent(), argLattices); } // Otherwise, we can't reason about the data-flow. @@ -226,7 +234,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { branch.getSuccessorOperands(it.getSuccessorIndex()); for (auto [idx, lattice] : llvm::enumerate(argLattices)) { if (Value operand = operands[idx]) { - join(lattice, *getLatticeElementFor(block, operand)); + join(lattice, + *getLatticeElementFor(getProgramPointBefore(block), operand)); } else { // Conservatively consider internally produced arguments as entry // points. @@ -240,7 +249,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { } void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( - ProgramPoint point, RegionBranchOpInterface branch, + ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); assert(predecessors->allPredecessorsKnown() && @@ -270,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( unsigned firstIndex = 0; if (inputs.size() != lattices.size()) { - if (llvm::dyn_cast_if_present<Operation *>(point)) { + if (!point->isBlockStart()) { if (!inputs.empty()) firstIndex = cast<OpResult>(inputs.front()).getResultNumber(); visitNonControlFlowArgumentsImpl( @@ -281,7 +290,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( } else { if (!inputs.empty()) firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber(); - Region *region = point.get<Block *>()->getParent(); + Region *region = point->getBlock()->getParent(); visitNonControlFlowArgumentsImpl( branch, RegionSuccessor(region, region->getArguments().slice( @@ -296,7 +305,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( } const AbstractSparseLattice * -AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, +AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); @@ -336,7 +345,8 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region) { - getOrCreate<Executable>(&block)->blockContentSubscribe(this); + getOrCreate<Executable>(getProgramPointBefore(&block)) + ->blockContentSubscribe(this); // Initialize ops in reverse order, so we can do as much initial // propagation as possible without having to go through the // solver queue. @@ -349,14 +359,14 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) { } LogicalResult -AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) { - if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) - return visitOperation(op); +AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { // For backward dataflow, we don't have to do any work for the blocks // themselves. CFG edges between blocks are processed by the BranchOp // logic in `visitOperation`, and entry blocks for functions are tied // to the CallOp arguments by visitOperation. - return success(); + if (point->isBlockStart()) + return success(); + return visitOperation(point->getPrevOp()); } SmallVector<AbstractSparseLattice *> @@ -372,7 +382,7 @@ AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) { SmallVector<const AbstractSparseLattice *> AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor( - ProgramPoint point, ValueRange values) { + ProgramPoint *point, ValueRange values) { SmallVector<const AbstractSparseLattice *> resultLattices; resultLattices.reserve(values.size()); for (Value result : values) { @@ -390,13 +400,14 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // If we're in a dead block, bail out. - if (!getOrCreate<Executable>(op->getBlock())->isLive()) + if (op->getBlock() != nullptr && + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) return success(); SmallVector<AbstractSparseLattice *> operandLattices = getLatticeElements(op->getOperands()); SmallVector<const AbstractSparseLattice *> resultLattices = - getLatticeElementsFor(op, op->getResults()); + getLatticeElementsFor(getProgramPointAfter(op), op->getResults()); // Block arguments of region branch operations flow back into the operands // of the parent op @@ -425,7 +436,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { detail::getBranchSuccessorArgument( successorOperands, operand.getOperandNumber(), block)) { meet(getLatticeElement(operand.get()), - *getLatticeElementFor(op, *blockArg)); + *getLatticeElementFor(getProgramPointAfter(op), *blockArg)); } } } @@ -467,7 +478,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { for (auto [blockArg, argOpOperand] : llvm::zip(block.getArguments(), argOpOperands)) { meet(getLatticeElement(argOpOperand.get()), - *getLatticeElementFor(op, blockArg)); + *getLatticeElementFor(getProgramPointAfter(op), blockArg)); unaccounted.reset(argOpOperand.getOperandNumber()); } @@ -502,12 +513,13 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - const PredecessorState *callsites = - getOrCreateFor<PredecessorState>(op, callable); + const PredecessorState *callsites = getOrCreateFor<PredecessorState>( + getProgramPointAfter(op), getProgramPointAfter(callable)); if (callsites->allPredecessorsKnown()) { for (Operation *call : callsites->getKnownPredecessors()) { SmallVector<const AbstractSparseLattice *> callResultLattices = - getLatticeElementsFor(op, call->getResults()); + getLatticeElementsFor(getProgramPointAfter(op), + call->getResults()); for (auto [op, result] : llvm::zip(operandLattices, callResultLattices)) meet(op, *result); @@ -542,7 +554,8 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors( MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands); ValueRange inputs = successor.getSuccessorInputs(); for (auto [operand, input] : llvm::zip(opoperands, inputs)) { - meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input)); + meet(getLatticeElement(operand.get()), + *getLatticeElementFor(getProgramPointAfter(op), input)); unaccounted.reset(operand.getOperandNumber()); } } @@ -576,7 +589,7 @@ void AbstractSparseBackwardDataFlowAnalysis:: MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands); for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { meet(getLatticeElement(opOperand.get()), - *getLatticeElementFor(terminator, input)); + *getLatticeElementFor(getProgramPointAfter(terminator), input)); unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber()); } } @@ -588,8 +601,8 @@ void AbstractSparseBackwardDataFlowAnalysis:: } const AbstractSparseLattice * -AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, - Value value) { +AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor( + ProgramPoint *point, Value value) { AbstractSparseLattice *state = getLatticeElement(value); addDependency(state, point); return state; diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index a65ddc1..7e83668 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -37,7 +37,7 @@ GenericLatticeAnchor::~GenericLatticeAnchor() = default; AnalysisState::~AnalysisState() = default; -void AnalysisState::addDependency(ProgramPoint dependent, +void AnalysisState::addDependency(ProgramPoint *dependent, DataFlowAnalysis *analysis) { auto inserted = dependents.insert({dependent, analysis}); (void)inserted; @@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent, void AnalysisState::dump() const { print(llvm::errs()); } //===----------------------------------------------------------------------===// -// LatticeAnchor +// ProgramPoint //===----------------------------------------------------------------------===// void ProgramPoint::print(raw_ostream &os) const { @@ -61,12 +61,18 @@ void ProgramPoint::print(raw_ostream &os) const { os << "<NULL POINT>"; return; } - if (Operation *op = llvm::dyn_cast<Operation *>(*this)) { - return op->print(os, OpPrintingFlags().skipRegions()); + if (!isBlockStart()) { + os << "<after operation>:"; + return getPrevOp()->print(os, OpPrintingFlags().skipRegions()); } - return get<Block *>()->print(os); + os << "<before operation>:"; + return getNextOp()->print(os, OpPrintingFlags().skipRegions()); } +//===----------------------------------------------------------------------===// +// LatticeAnchor +//===----------------------------------------------------------------------===// + void LatticeAnchor::print(raw_ostream &os) const { if (isNull()) { os << "<NULL POINT>"; @@ -78,7 +84,7 @@ void LatticeAnchor::print(raw_ostream &os) const { return value.print(os, OpPrintingFlags().skipRegions()); } - return get<ProgramPoint>().print(os); + return get<ProgramPoint *>()->print(os); } Location LatticeAnchor::getLoc() const { @@ -87,10 +93,10 @@ Location LatticeAnchor::getLoc() const { if (auto value = llvm::dyn_cast<Value>(*this)) return value.getLoc(); - ProgramPoint pp = get<ProgramPoint>(); - if (auto *op = llvm::dyn_cast<Operation *>(pp)) - return op->getLoc(); - return pp.get<Block *>()->getParent()->getLoc(); + ProgramPoint *pp = get<ProgramPoint *>(); + if (!pp->isBlockStart()) + return pp->getPrevOp()->getLoc(); + return pp->getBlock()->getParent()->getLoc(); } //===----------------------------------------------------------------------===// @@ -144,7 +150,8 @@ DataFlowAnalysis::~DataFlowAnalysis() = default; DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {} -void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) { +void DataFlowAnalysis::addDependency(AnalysisState *state, + ProgramPoint *point) { state->addDependency(point, this); } diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 8005f91..521138c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -80,7 +80,7 @@ public: protected: void notifyOperationErased(Operation *op) override { - s.eraseState(op); + s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) s.eraseState(res); } diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp index d02efaa..2dc77c9 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -29,7 +29,8 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op, os << " "; block.printAsOperand(os); os << " = "; - auto *live = solver.lookupState<Executable>(&block); + auto *live = solver.lookupState<Executable>( + solver.getProgramPointBefore(&block)); if (live) os << *live; else @@ -49,12 +50,14 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op, } } if (!region.empty()) { - auto *preds = solver.lookupState<PredecessorState>(®ion.front()); + auto *preds = solver.lookupState<PredecessorState>( + solver.getProgramPointBefore(®ion.front())); if (preds) os << "region_preds: " << *preds << "\n"; } } - auto *preds = solver.lookupState<PredecessorState>(op); + auto *preds = + solver.lookupState<PredecessorState>(solver.getProgramPointAfter(op)); if (preds) os << "op_preds: " << *preds << "\n"; }); @@ -68,15 +71,15 @@ struct ConstantAnalysis : public DataFlowAnalysis { LogicalResult initialize(Operation *top) override { WalkResult result = top->walk([&](Operation *op) { - if (failed(visit(op))) + if (failed(visit(getProgramPointAfter(op)))) return WalkResult::interrupt(); return WalkResult::advance(); }); return success(!result.wasInterrupted()); } - LogicalResult visit(ProgramPoint point) override { - Operation *op = point.get<Operation *>(); + LogicalResult visit(ProgramPoint *point) override { + Operation *op = point->getPrevOp(); Attribute value; if (matchPattern(op, m_Constant(&value))) { auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0)); diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index 6794cbb..fa6223a 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -115,7 +115,8 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op, std::optional<Value> underlyingValue = UnderlyingValueAnalysis::getMostUnderlyingValue( value, [&](Value value) { - return getOrCreateFor<UnderlyingValueLattice>(op, value); + return getOrCreateFor<UnderlyingValueLattice>( + getProgramPointBefore(op), value); }); // If the underlying value is not known yet, don't propagate. @@ -151,7 +152,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer( UnderlyingValueAnalysis::getMostUnderlyingValue( operand, [&](Value value) { return getOrCreateFor<UnderlyingValueLattice>( - call.getOperation(), value); + getProgramPointBefore(call.getOperation()), value); }); if (!underlyingValue) return; @@ -283,9 +284,8 @@ struct TestNextAccessPass if (!tag) return; - const NextAccess *nextAccess = solver.lookupState<NextAccess>( - op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) - : op->getNextNode()); + const NextAccess *nextAccess = + solver.lookupState<NextAccess>(solver.getProgramPointAfter(op)); op->setAttr(kNextAccessAttrName, makeNextAccessAttribute(op, solver, nextAccess)); @@ -300,9 +300,8 @@ struct TestNextAccessPass if (!successor.getSuccessor() || successor.getSuccessor()->empty()) continue; Block &successorBlock = successor.getSuccessor()->front(); - ProgramPoint successorPoint = successorBlock.empty() - ? ProgramPoint(&successorBlock) - : &successorBlock.front(); + ProgramPoint *successorPoint = + solver.getProgramPointBefore(&successorBlock); entryPointNextAccess.push_back(makeNextAccessAttribute( op, solver, solver.lookupState<NextAccess>(successorPoint))); } diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp index 301d2a20..89b5c83 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp @@ -117,7 +117,8 @@ LogicalResult LastModifiedAnalysis::visitOperation( std::optional<Value> underlyingValue = UnderlyingValueAnalysis::getMostUnderlyingValue( value, [&](Value value) { - return getOrCreateFor<UnderlyingValueLattice>(op, value); + return getOrCreateFor<UnderlyingValueLattice>( + getProgramPointAfter(op), value); }); // If the underlying value is not yet known, don't propagate yet. @@ -157,7 +158,7 @@ void LastModifiedAnalysis::visitCallControlFlowTransfer( UnderlyingValueAnalysis::getMostUnderlyingValue( operand, [&](Value value) { return getOrCreateFor<UnderlyingValueLattice>( - call.getOperation(), value); + getProgramPointAfter(call.getOperation()), value); }); if (!underlyingValue) return; @@ -243,7 +244,7 @@ struct TestLastModifiedPass return; os << "test_tag: " << tag.getValue() << ":\n"; const LastModification *lastMods = - solver.lookupState<LastModification>(op); + solver.lookupState<LastModification>(solver.getProgramPointAfter(op)); assert(lastMods && "expected a dense lattice"); for (auto [index, operand] : llvm::enumerate(op->getOperands())) { os << " operand #" << index << "\n"; diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp index 2445b58..1f3cab1 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp @@ -108,7 +108,7 @@ WrittenToAnalysis::visitOperation(Operation *op, ArrayRef<WrittenTo *> operands, for (WrittenTo *operand : operands) { meet(operand, *r); } - addDependency(const_cast<WrittenTo *>(r), op); + addDependency(const_cast<WrittenTo *>(r), getProgramPointAfter(op)); } return success(); } diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp index 9573ec1..3eb39fc 100644 --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -75,7 +75,7 @@ public: using DataFlowAnalysis::DataFlowAnalysis; LogicalResult initialize(Operation *top) override; - LogicalResult visit(ProgramPoint point) override; + LogicalResult visit(ProgramPoint *point) override; private: void visitBlock(Block *block); @@ -100,7 +100,8 @@ LogicalResult FooAnalysis::initialize(Operation *top) { return top->emitError("expected at least one block in the region"); // Initialize the top-level state. - (void)getOrCreate<FooState>(&top->getRegion(0).front())->join(0); + (void)getOrCreate<FooState>(getProgramPointBefore(&top->getRegion(0).front())) + ->join(0); // Visit all nested blocks and operations. for (Block &block : top->getRegion(0)) { @@ -114,11 +115,11 @@ LogicalResult FooAnalysis::initialize(Operation *top) { return success(); } -LogicalResult FooAnalysis::visit(ProgramPoint point) { - if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) - visitOperation(op); +LogicalResult FooAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) + visitOperation(point->getPrevOp()); else - visitBlock(point.get<Block *>()); + visitBlock(point->getBlock()); return success(); } @@ -127,27 +128,26 @@ void FooAnalysis::visitBlock(Block *block) { // This is the initial state. Let the framework default-initialize it. return; } - FooState *state = getOrCreate<FooState>(block); + ProgramPoint *point = getProgramPointBefore(block); + FooState *state = getOrCreate<FooState>(point); ChangeResult result = ChangeResult::NoChange; for (Block *pred : block->getPredecessors()) { // Join the state at the terminators of all predecessors. - const FooState *predState = - getOrCreateFor<FooState>(block, pred->getTerminator()); + const FooState *predState = getOrCreateFor<FooState>( + point, getProgramPointAfter(pred->getTerminator())); result |= state->join(*predState); } propagateIfChanged(state, result); } void FooAnalysis::visitOperation(Operation *op) { - FooState *state = getOrCreate<FooState>(op); + ProgramPoint *point = getProgramPointAfter(op); + FooState *state = getOrCreate<FooState>(point); ChangeResult result = ChangeResult::NoChange; // Copy the state across the operation. const FooState *prevState; - if (Operation *prev = op->getPrevNode()) - prevState = getOrCreateFor<FooState>(op, prev); - else - prevState = getOrCreateFor<FooState>(op, op->getBlock()); + prevState = getOrCreateFor<FooState>(point, getProgramPointBefore(op)); result |= state->set(*prevState); // Modify the state with the attribute, if specified. @@ -172,7 +172,8 @@ void TestFooAnalysisPass::runOnOperation() { auto tag = op->getAttrOfType<StringAttr>("tag"); if (!tag) return; - const FooState *state = solver.lookupState<FooState>(op); + const FooState *state = + solver.lookupState<FooState>(solver.getProgramPointAfter(op)); assert(state && !state->isUninitialized()); os << tag.getValue() << " -> " << state->getValue() << "\n"; }); |