diff options
Diffstat (limited to 'mlir/lib/Analysis')
-rw-r--r-- | mlir/lib/Analysis/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp | 127 | ||||
-rw-r--r-- | mlir/lib/Analysis/FlatLinearValueConstraints.cpp | 9 | ||||
-rw-r--r-- | mlir/lib/Analysis/Presburger/Simplex.cpp | 2 |
4 files changed, 135 insertions, 5 deletions
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34..db10ebc 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis DataFlow/IntegerRangeAnalysis.cpp DataFlow/LivenessAnalysis.cpp DataFlow/SparseAnalysis.cpp + DataFlow/StridedMetadataRangeAnalysis.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis @@ -53,6 +54,7 @@ add_mlir_library(MLIRAnalysis MLIRDataLayoutInterfaces MLIRFunctionInterfaces MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRPresburger diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp new file mode 100644 index 0000000..01c9daf --- /dev/null +++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp @@ -0,0 +1,127 @@ +//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dataflow analysis class for integer range inference +// which is used in transformations over the `arith` dialect such as +// branch elimination or signed->unsigned rewriting +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/DebugStringHelper.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + +#define DEBUG_TYPE "strided-metadata-range-analysis" + +using namespace mlir; +using namespace mlir::dataflow; + +/// Get the entry state for a value. For any value that is not a ranked memref, +/// this function sets the metadata to a top state with no offsets, sizes, or +/// strides. For `memref` types, this function will use the metadata in the type +/// to try to deduce as much informaiton as possible. +static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { + // TODO: generalize this method with a type interface. + auto mTy = dyn_cast<BaseMemRefType>(v.getType()); + + // If not a memref or it's un-ranked, don't infer any metadata. + if (!mTy || !mTy.hasRank()) + return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); + + // Get the top state. + auto metadata = + StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); + + // Compute the offset and strides. + int64_t offset; + SmallVector<int64_t> strides; + if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset))) + return metadata; + + // Refine the metadata if we know it from the type. + if (!ShapedType::isDynamic(offset)) { + metadata.getOffsets()[0] = + ConstantIntRanges::constant(APInt(indexBitwidth, offset)); + } + for (auto &&[size, range] : + llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { + if (ShapedType::isDynamic(size)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); + } + for (auto &&[stride, range] : + llvm::zip_equal(strides, metadata.getStrides())) { + if (ShapedType::isDynamic(stride)) + continue; + range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); + } + + return metadata; +} + +StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( + DataFlowSolver &solver, int32_t indexBitwidth) + : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { + assert(indexBitwidth > 0 && "invalid bitwidth"); +} + +void StridedMetadataRangeAnalysis::setToEntryState( + StridedMetadataRangeLattice *lattice) { + propagateIfChanged(lattice, lattice->join(getEntryStateImpl( + lattice->getAnchor(), indexBitwidth))); +} + +LogicalResult StridedMetadataRangeAnalysis::visitOperation( + Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands, + ArrayRef<StridedMetadataRangeLattice *> results) { + auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op); + + // Bail if we cannot reason about the op. + if (!inferrable) { + setAllToEntryStates(results); + return success(); + } + + LDBG() << "Inferring metadata for: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + // Helper function to retrieve int range values. + auto getIntRange = [&](Value value) -> IntegerValueRange { + auto lattice = getOrCreateFor<IntegerValueRangeLattice>( + getProgramPointAfter(op), value); + return lattice ? lattice->getValue() : IntegerValueRange(); + }; + + // Convert the arguments lattices to a vector. + SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector( + operands, [](const StridedMetadataRangeLattice *lattice) { + return lattice->getValue(); + }); + + // Callback to set metadata on a result. + auto joinCallback = [&](Value v, const StridedMetadataRange &md) { + auto result = cast<OpResult>(v); + assert(llvm::is_contained(op->getResults(), result)); + LDBG() << "- Inferred metadata: " << md; + StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(md); + LDBG() << "- Joined metadata: " << lattice->getValue(); + propagateIfChanged(lattice, changed); + }; + + // Infer the metadata. + inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, + indexBitwidth); + return success(); +} diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 30ce1fb..6588b53 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1244,8 +1244,9 @@ bool FlatLinearValueConstraints::areVarsAlignedWithOther( /// Checks if the SSA values associated with `cst`'s variables in range /// [start, end) are unique. -static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( - const FlatLinearValueConstraints &cst, unsigned start, unsigned end) { +[[maybe_unused]] static bool +areVarsUnique(const FlatLinearValueConstraints &cst, unsigned start, + unsigned end) { assert(start <= cst.getNumDimAndSymbolVars() && "Start position out of bounds"); @@ -1267,14 +1268,14 @@ static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique( } /// Checks if the SSA values associated with `cst`'s variables are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst) { return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars()); } /// Checks if the SSA values associated with `cst`'s variables of kind `kind` /// are unique. -static bool LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static bool areVarsUnique(const FlatLinearValueConstraints &cst, VarKind kind) { if (kind == VarKind::SetDim) diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index a1cbe29..547a4c2 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -34,7 +34,7 @@ using Direction = Simplex::Direction; const int nullIndex = std::numeric_limits<int>::max(); // Return a + scale*b; -LLVM_ATTRIBUTE_UNUSED +[[maybe_unused]] static SmallVector<DynamicAPInt, 8> scaleAndAddForAssert(ArrayRef<DynamicAPInt> a, const DynamicAPInt &scale, ArrayRef<DynamicAPInt> b) { |