//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++ //-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the dataflow analysis class for integer range inference // which is used in transformations over the `arith` dialect such as // branch elimination or signed->unsigned rewriting // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Support/DebugStringHelper.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "strided-metadata-range-analysis" using namespace mlir; using namespace mlir::dataflow; /// Get the entry state for a value. For any value that is not a ranked memref, /// this function sets the metadata to a top state with no offsets, sizes, or /// strides. For `memref` types, this function will use the metadata in the type /// to try to deduce as much informaiton as possible. static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) { // TODO: generalize this method with a type interface. auto mTy = dyn_cast(v.getType()); // If not a memref or it's un-ranked, don't infer any metadata. if (!mTy || !mTy.hasRank()) return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0); // Get the top state. auto metadata = StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank()); // Compute the offset and strides. int64_t offset; SmallVector strides; if (failed(cast(mTy).getStridesAndOffset(strides, offset))) return metadata; // Refine the metadata if we know it from the type. if (!ShapedType::isDynamic(offset)) { metadata.getOffsets()[0] = ConstantIntRanges::constant(APInt(indexBitwidth, offset)); } for (auto &&[size, range] : llvm::zip_equal(mTy.getShape(), metadata.getSizes())) { if (ShapedType::isDynamic(size)) continue; range = ConstantIntRanges::constant(APInt(indexBitwidth, size)); } for (auto &&[stride, range] : llvm::zip_equal(strides, metadata.getStrides())) { if (ShapedType::isDynamic(stride)) continue; range = ConstantIntRanges::constant(APInt(indexBitwidth, stride)); } return metadata; } StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis( DataFlowSolver &solver, int32_t indexBitwidth) : SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) { assert(indexBitwidth > 0 && "invalid bitwidth"); } void StridedMetadataRangeAnalysis::setToEntryState( StridedMetadataRangeLattice *lattice) { propagateIfChanged(lattice, lattice->join(getEntryStateImpl( lattice->getAnchor(), indexBitwidth))); } LogicalResult StridedMetadataRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { auto inferrable = dyn_cast(op); // Bail if we cannot reason about the op. if (!inferrable) { setAllToEntryStates(results); return success(); } LDBG() << "Inferring metadata for: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); // Helper function to retrieve int range values. auto getIntRange = [&](Value value) -> IntegerValueRange { auto lattice = getOrCreateFor( getProgramPointAfter(op), value); return lattice ? lattice->getValue() : IntegerValueRange(); }; // Convert the arguments lattices to a vector. SmallVector argRanges = llvm::map_to_vector( operands, [](const StridedMetadataRangeLattice *lattice) { return lattice->getValue(); }); // Callback to set metadata on a result. auto joinCallback = [&](Value v, const StridedMetadataRange &md) { auto result = cast(v); assert(llvm::is_contained(op->getResults(), result)); LDBG() << "- Inferred metadata: " << md; StridedMetadataRangeLattice *lattice = results[result.getResultNumber()]; ChangeResult changed = lattice->join(md); LDBG() << "- Joined metadata: " << lattice->getValue(); propagateIfChanged(lattice, changed); }; // Infer the metadata. inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback, indexBitwidth); return success(); }