//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/AMDGPU/Transforms/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir::amdgpu { #define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" struct AmdgpuFoldMemRefOpsPass final : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateAmdgpuFoldMemRefOpsPatterns(patterns); walkAndApplyPatterns(getOperation(), std::move(patterns)); } }; static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl &resolvedIndices, Value &memrefBase, StringRef role) { Operation *defOp = view.getDefiningOp(); if (!defOp) { return failure(); } return llvm::TypeSwitch(defOp) .Case([&](memref::SubViewOp subviewOp) { mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices, resolvedIndices); memrefBase = subviewOp.getSource(); return success(); }) .Case([&](memref::ExpandShapeOp expandShapeOp) { if (failed(mlir::memref::resolveSourceIndicesExpandShape( loc, rewriter, expandShapeOp, indices, resolvedIndices, false))) { return failure(); } memrefBase = expandShapeOp.getViewSource(); return success(); }) .Case( [&](memref::CollapseShapeOp collapseShapeOp) { if (failed(mlir::memref::resolveSourceIndicesCollapseShape( loc, rewriter, collapseShapeOp, indices, resolvedIndices))) { return failure(); } memrefBase = collapseShapeOp.getViewSource(); return success(); }) .Default([&](Operation *op) { return rewriter.notifyMatchFailure( op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or " "CollapseShapeOp") .str()); }); } struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); SmallVector sourceIndices, destIndices; Value memrefSource, memrefDest; auto foldSrcResult = foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(), sourceIndices, memrefSource, "source"); if (failed(foldSrcResult)) { memrefSource = op.getSrc(); sourceIndices = op.getSrcIndices(); } auto foldDstResult = foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(), destIndices, memrefDest, "destination"); if (failed(foldDstResult)) { memrefDest = op.getDst(); destIndices = op.getDstIndices(); } rewriter.replaceOpWithNewOp(op, memrefSource, sourceIndices, memrefDest, destIndices, op.getTransferType()); return success(); } }; void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } } // namespace mlir::amdgpu