//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains patterns for flattening an multi-rank memref-related // ops into 1-d memref ops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace memref { #define GEN_PASS_DEF_FLATTENMEMREFSPASS #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir using namespace mlir; static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in) { if (Attribute offsetAttr = dyn_cast(in)) { return arith::ConstantIndexOp::create( rewriter, loc, cast(offsetAttr).getInt()); } return cast(in); } /// Returns a collapsed memref and the linearized index to access the element /// at the specified indices. static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices) { int64_t sourceOffset; SmallVector sourceStrides; auto sourceType = cast(source.getType()); if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { assert(false); } memref::ExtractStridedMetadataOp stridedMetadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, source); auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( rewriter, loc, typeBit, typeBit, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(indices)); return std::make_pair( memref::ReinterpretCastOp::create( rewriter, loc, source, /* offset = */ linearizedInfo.linearizedOffset, /* shapes = */ ArrayRef{linearizedInfo.linearizedSize}, /* strides = */ ArrayRef{rewriter.getIndexAttr(1)}), getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); } static bool needFlattening(Value val) { auto type = cast(val.getType()); return type.getRank() > 1; } static bool checkLayout(Value val) { auto type = cast(val.getType()); return type.getLayout().isIdentity() || isa(type.getLayout()); } namespace { static Value getTargetMemref(Operation *op) { return llvm::TypeSwitch(op) .template Case([](auto op) { return op.getMemref(); }) .template Case( [](auto op) { return op.getBase(); }) .Default([](auto) { return Value{}; }); } template static void castAllocResult(T oper, T newOper, Location loc, PatternRewriter &rewriter) { memref::ExtractStridedMetadataOp stridedMetadata = memref::ExtractStridedMetadataOp::create(rewriter, loc, oper); rewriter.replaceOpWithNewOp( oper, cast(oper.getType()), newOper, /*offset=*/rewriter.getIndexAttr(0), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides()); } template static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Value offset) { Location loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) .template Case([&](auto oper) { auto newAlloc = memref::AllocOp::create( rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloc, loc, rewriter); }) .template Case([&](auto oper) { auto newAlloca = memref::AllocaOp::create( rewriter, loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloca, loc, rewriter); }) .template Case([&](auto op) { auto newLoad = memref::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { auto newStore = memref::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { auto newLoad = vector::LoadOp::create(rewriter, loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case([&](auto op) { auto newStore = vector::StoreOp::create(rewriter, loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case([&](auto op) { auto newMaskedLoad = vector::MaskedLoadOp::create( rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) .template Case([&](auto op) { auto newMaskedStore = vector::MaskedStoreOp::create( rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) .template Case([&](auto op) { auto newTransferRead = vector::TransferReadOp::create( rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) .template Case([&](auto op) { auto newTransferWrite = vector::TransferWriteOp::create( rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); }) .Default([&](auto op) { op->emitOpError("unimplemented: do not know how to replace op."); }); } template static ValueRange getIndices(T op) { if constexpr (std::is_same_v || std::is_same_v) { return ValueRange{}; } else { return op.getIndices(); } } template static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { return llvm::TypeSwitch(op.getOperation()) .template Case( [&](auto oper) { // For vector.transfer_read/write, must make sure: // 1. all accesses are inbound, and // 2. has an identity or minor identity permutation map. auto permutationMap = oper.getPermutationMap(); if (!permutationMap.isIdentity() && !permutationMap.isMinorIdentity()) { return rewriter.notifyMatchFailure( oper, "only identity permutation map is supported"); } mlir::ArrayAttr inbounds = oper.getInBounds(); if (llvm::any_of(inbounds, [](Attribute attr) { return !cast(attr).getValue(); })) { return rewriter.notifyMatchFailure(oper, "only inbounds are supported"); } return success(); }) .Default([&](auto op) { return success(); }); } template struct MemRefRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { LogicalResult canFlatten = canBeFlattened(op, rewriter); if (failed(canFlatten)) { return canFlatten; } Value memref = getTargetMemref(op); if (!needFlattening(memref) || !checkLayout(memref)) return failure(); auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( rewriter, op->getLoc(), memref, getIndices(op)); replaceOp(op, rewriter, flatMemref, offset); return success(); } }; struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); memref::populateFlattenMemrefsPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace void memref::populateFlattenVectorOpsOnMemrefPatterns( RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern>( patterns.getContext()); } void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern>( patterns.getContext()); } void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { populateFlattenMemrefOpsPatterns(patterns); populateFlattenVectorOpsOnMemrefPatterns(patterns); }