//===- IndexingUtils.cpp - Helpers related to index computations ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/STLExtras.h" #include #include using namespace mlir; template SmallVector computeSuffixProductImpl(ArrayRef sizes, ExprType unit) { if (sizes.empty()) return {}; SmallVector strides(sizes.size(), unit); for (int64_t r = strides.size() - 2; r >= 0; --r) strides[r] = strides[r + 1] * sizes[r + 1]; return strides; } template SmallVector computeElementwiseMulImpl(ArrayRef v1, ArrayRef v2) { // Early exit if both are empty, let zip_equal fail if only 1 is empty. if (v1.empty() && v2.empty()) return {}; SmallVector result; for (auto it : llvm::zip_equal(v1, v2)) result.push_back(std::get<0>(it) * std::get<1>(it)); return result; } template ExprType linearizeImpl(ArrayRef offsets, ArrayRef basis, ExprType zero) { assert(offsets.size() == basis.size()); ExprType linearIndex = zero; for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) linearIndex = linearIndex + offsets[idx] * basis[idx]; return linearIndex; } template SmallVector delinearizeImpl(ExprType linearIndex, ArrayRef strides, DivOpTy divOp) { int64_t rank = strides.size(); SmallVector offsets(rank); for (int64_t r = 0; r < rank; ++r) { offsets[r] = divOp(linearIndex, strides[r]); linearIndex = linearIndex % strides[r]; } return offsets; } //===----------------------------------------------------------------------===// // Utils that operate on static integer values. //===----------------------------------------------------------------------===// SmallVector mlir::computeSuffixProduct(ArrayRef sizes) { assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && "sizes must be nonnegative"); int64_t unit = 1; return ::computeSuffixProductImpl(sizes, unit); } SmallVector mlir::computeElementwiseMul(ArrayRef v1, ArrayRef v2) { return computeElementwiseMulImpl(v1, v2); } int64_t mlir::computeSum(ArrayRef basis) { assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && "basis must be nonnegative"); if (basis.empty()) return 0; return std::accumulate(basis.begin(), basis.end(), 1, std::plus()); } int64_t mlir::computeProduct(ArrayRef basis) { assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && "basis must be nonnegative"); if (basis.empty()) return 0; return std::accumulate(basis.begin(), basis.end(), 1, std::multiplies()); } int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) && "basis must be nonnegative"); int64_t zero = 0; return linearizeImpl(offsets, basis, zero); } SmallVector mlir::delinearize(int64_t linearIndex, ArrayRef strides) { assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) && "strides must be nonnegative"); return delinearizeImpl(linearIndex, strides, [](int64_t e1, int64_t e2) { return e1 / e2; }); } std::optional> mlir::computeShapeRatio(ArrayRef shape, ArrayRef subShape) { if (shape.size() < subShape.size()) return std::nullopt; assert(llvm::all_of(shape, [](int64_t s) { return s > 0; }) && "shape must be nonnegative"); assert(llvm::all_of(subShape, [](int64_t s) { return s > 0; }) && "subShape must be nonnegative"); // Starting from the end, compute the integer divisors. std::vector result; result.reserve(shape.size()); for (auto [size, subSize] : llvm::zip(llvm::reverse(shape), llvm::reverse(subShape))) { // If integral division does not occur, return and let the caller decide. if (size % subSize != 0) return std::nullopt; result.push_back(size / subSize); } // At this point we computed the ratio (in reverse) for the common size. // Fill with the remaining entries from the shape (still in reverse). int commonSize = subShape.size(); std::copy(shape.rbegin() + commonSize, shape.rend(), std::back_inserter(result)); // Reverse again to get it back in the proper order and return. return SmallVector{result.rbegin(), result.rend()}; } //===----------------------------------------------------------------------===// // Utils that operate on AffineExpr. //===----------------------------------------------------------------------===// SmallVector mlir::computeSuffixProduct(ArrayRef sizes) { if (sizes.empty()) return {}; AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext()); return ::computeSuffixProductImpl(sizes, unit); } SmallVector mlir::computeElementwiseMul(ArrayRef v1, ArrayRef v2) { return computeElementwiseMulImpl(v1, v2); } AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef basis) { if (basis.empty()) return getAffineConstantExpr(0, ctx); return std::accumulate(basis.begin(), basis.end(), getAffineConstantExpr(0, ctx), std::plus()); } AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef basis) { if (basis.empty()) return getAffineConstantExpr(1, ctx); return std::accumulate(basis.begin(), basis.end(), getAffineConstantExpr(1, ctx), std::multiplies()); } AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, ArrayRef basis) { AffineExpr zero = getAffineConstantExpr(0, ctx); return linearizeImpl(offsets, basis, zero); } AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef offsets, ArrayRef basis) { return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx)); } SmallVector mlir::delinearize(AffineExpr linearIndex, ArrayRef strides) { return delinearizeImpl( linearIndex, strides, [](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); }); } SmallVector mlir::delinearize(AffineExpr linearIndex, ArrayRef strides) { MLIRContext *ctx = linearIndex.getContext(); return delinearize(linearIndex, getAffineConstantExprs(strides, ctx)); } //===----------------------------------------------------------------------===// // Permutation utils. //===----------------------------------------------------------------------===// SmallVector mlir::invertPermutationVector(ArrayRef permutation) { assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) && "permutation must be non-negative"); SmallVector inversion(permutation.size()); for (const auto &pos : llvm::enumerate(permutation)) { inversion[pos.value()] = pos.index(); } return inversion; } bool mlir::isPermutationVector(ArrayRef interchange) { assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) && "permutation must be non-negative"); llvm::SmallDenseSet seenVals; for (auto val : interchange) { if (seenVals.count(val)) return false; seenVals.insert(val); } return seenVals.size() == interchange.size(); } SmallVector mlir::computePermutationVector(int64_t permSize, ArrayRef positions, ArrayRef desiredPositions) { SmallVector res(permSize, -1); DenseSet seen; for (auto [pos, desiredPos] : llvm::zip_equal(positions, desiredPositions)) { res[desiredPos] = pos; seen.insert(pos); } int64_t nextPos = 0; for (int64_t &entry : res) { if (entry != -1) continue; while (seen.contains(nextPos)) ++nextPos; entry = nextPos; ++nextPos; } return res; } SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront, unsigned dropBack) { assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); auto range = arrayAttr.getAsRange(); SmallVector res; res.reserve(arrayAttr.size() - dropFront - dropBack); for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; it != eit; ++it) res.push_back((*it).getValue().getSExtValue()); return res; } // TODO: do we have any common utily for this? static MLIRContext *getContext(OpFoldResult val) { assert(val && "Invalid value"); if (auto attr = dyn_cast(val)) { return attr.getContext(); } else { return cast(val).getContext(); } } std::pair> mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef strides, ArrayRef indices) { assert(strides.size() == indices.size()); auto sourceRank = static_cast(strides.size()); // Hold the affine symbols and values for the computation of the offset. SmallVector values(2 * sourceRank + 1); SmallVector symbols(2 * sourceRank + 1); bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols}); AffineExpr expr = symbols.front(); values[0] = sourceOffset; for (unsigned i = 0; i < sourceRank; ++i) { // Compute the stride. OpFoldResult origStride = strides[i]; // Build up the computation of the offset. unsigned baseIdxForDim = 1 + 2 * i; unsigned subOffsetForDim = baseIdxForDim; unsigned origStrideForDim = baseIdxForDim + 1; expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; values[subOffsetForDim] = indices[i]; values[origStrideForDim] = origStride; } return {expr, values}; } //===----------------------------------------------------------------------===// // TileOffsetRange //===----------------------------------------------------------------------===// /// Apply left-padding by 1 to the tile shape if required. static SmallVector padTileShapeToSize(ArrayRef tileShape, unsigned paddedSize) { assert(tileShape.size() <= paddedSize && "expected tileShape to <= paddedSize"); if (tileShape.size() == paddedSize) return to_vector(tileShape); SmallVector result(paddedSize - tileShape.size(), 1); llvm::append_range(result, tileShape); return result; } mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl( ArrayRef shape, ArrayRef tileShape, ArrayRef loopOrder) : tileShape(padTileShapeToSize(tileShape, shape.size())), inverseLoopOrder(invertPermutationVector(loopOrder)), sliceStrides(shape.size()) { // Divide the shape by the tile shape. std::optional> shapeRatio = mlir::computeShapeRatio(shape, tileShape); assert(shapeRatio && shapeRatio->size() == shape.size() && "target shape does not evenly divide the original shape"); assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() && "expected loop order to be a permutation of rank equal to outer " "shape"); maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio); mlir::applyPermutationToVector(*shapeRatio, loopOrder); sliceStrides = mlir::computeStrides(*shapeRatio); } SmallVector mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets( int64_t linearIndex) const { SmallVector tileCoords = applyPermutation( delinearize(linearIndex, sliceStrides), inverseLoopOrder); return computeElementwiseMul(tileCoords, tileShape); } SmallVector mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets( AffineExpr linearIndex) const { MLIRContext *ctx = linearIndex.getContext(); SmallVector tileCoords = applyPermutation( delinearize(linearIndex, sliceStrides), inverseLoopOrder); return mlir::computeElementwiseMul(tileCoords, getAffineConstantExprs(tileShape, ctx)); }