//===- BuiltinAttributeInterfaces.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "llvm/ADT/Sequence.h" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// /// Tablegen Interface Definitions //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" //===----------------------------------------------------------------------===// // ElementsAttr //===----------------------------------------------------------------------===// Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { return elementsAttr.getShapedType().getElementType(); } int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { return elementsAttr.getShapedType().getNumElements(); } bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { // Verify that the rank of the indices matches the held type. int64_t rank = type.getRank(); if (rank == 0 && index.size() == 1 && index[0] == 0) return true; if (rank != static_cast(index.size())) return false; // Verify that all of the indices are within the shape dimensions. ArrayRef shape = type.getShape(); return llvm::all_of(llvm::seq(0, rank), [&](int i) { int64_t dim = static_cast(index[i]); return 0 <= dim && dim < shape[i]; }); } bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, ArrayRef index) { return isValidIndex(elementsAttr.getShapedType(), index); } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { ShapedType shapeType = llvm::cast(type); assert(isValidIndex(shapeType, index) && "expected valid multi-dimensional index"); // Reduce the provided multidimensional index into a flattended 1D row-major // index. auto rank = shapeType.getRank(); ArrayRef shape = shapeType.getShape(); uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; for (int i = rank - 1; i >= 0; --i) { valueIndex += index[i] * dimMultiplier; dimMultiplier *= shape[i]; } return valueIndex; } //===----------------------------------------------------------------------===// // MemRefLayoutAttrInterface //===----------------------------------------------------------------------===// LogicalResult mlir::detail::verifyAffineMapAsLayout( AffineMap m, ArrayRef shape, function_ref emitError) { if (m.getNumDims() != shape.size()) return emitError() << "memref layout mismatch between rank and affine map: " << shape.size() << " != " << m.getNumDims(); return success(); } // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( // i.e. single term). Accumulate the AffineExpr into the existing one. static void extractStridesFromTerm(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { if (auto dim = dyn_cast(e)) strides[dim.getPosition()] = strides[dim.getPosition()] + multiplicativeFactor; else offset = offset + e * multiplicativeFactor; } /// Takes a single AffineExpr `e` and populates the `strides` array with the /// strides expressions for each dim position. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef strides, AffineExpr &offset) { auto bin = dyn_cast(e); if (!bin) { extractStridesFromTerm(e, multiplicativeFactor, strides, offset); return success(); } if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || bin.getKind() == AffineExprKind::Mod) return failure(); if (bin.getKind() == AffineExprKind::Mul) { auto dim = dyn_cast(bin.getLHS()); if (dim) { strides[dim.getPosition()] = strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; return success(); } // LHS and RHS may both contain complex expressions of dims. Try one path // and if it fails try the other. This is guaranteed to succeed because // only one path may have a `dim`, otherwise this is not an AffineExpr in // the first place. if (bin.getLHS().isSymbolicOrConstant()) return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), strides, offset); return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), strides, offset); } if (bin.getKind() == AffineExprKind::Add) { auto res1 = extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); auto res2 = extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); return success(succeeded(res1) && succeeded(res2)); } llvm_unreachable("unexpected binary operation"); } /// A stride specification is a list of integer values that are either static /// or dynamic (encoded with ShapedType::kDynamic). Strides encode /// the distance in the number of elements between successive entries along a /// particular dimension. /// /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a /// non-contiguous memory region of `42` by `16` `f32` elements in which the /// distance between two consecutive elements along the outer dimension is `1` /// and the distance between two consecutive elements along the inner dimension /// is `64`. /// /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef shape, SmallVectorImpl &strides, AffineExpr &offset) { if (m.getNumResults() != 1 && !m.isIdentity()) return failure(); auto zero = getAffineConstantExpr(0, m.getContext()); auto one = getAffineConstantExpr(1, m.getContext()); offset = zero; strides.assign(shape.size(), zero); // Canonical case for empty map. if (m.isIdentity()) { // 0-D corner case, offset is already 0. if (shape.empty()) return success(); auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext()); if (succeeded(extractStrides(stridedExpr, one, strides, offset))) return success(); assert(false && "unexpected failure: extract strides in canonical layout"); } // Non-canonical case requires more work. auto stridedExpr = simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); if (failed(extractStrides(stridedExpr, one, strides, offset))) { offset = AffineExpr(); strides.clear(); return failure(); } // Simplify results to allow folding to constants and simple checks. unsigned numDims = m.getNumDims(); unsigned numSymbols = m.getNumSymbols(); offset = simplifyAffineExpr(offset, numDims, numSymbols); for (auto &stride : strides) stride = simplifyAffineExpr(stride, numDims, numSymbols); return success(); } LogicalResult mlir::detail::getAffineMapStridesAndOffset( AffineMap map, ArrayRef shape, SmallVectorImpl &strides, int64_t &offset) { AffineExpr offsetExpr; SmallVector strideExprs; if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr))) return failure(); if (auto cst = llvm::dyn_cast(offsetExpr)) offset = cst.getValue(); else offset = ShapedType::kDynamic; for (auto e : strideExprs) { if (auto c = llvm::dyn_cast(e)) strides.push_back(c.getValue()); else strides.push_back(ShapedType::kDynamic); } return success(); }