aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
blob: d54751098410bc5c0556826fec3e245b8b706da7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

struct AmdgpuFoldMemRefOpsPass final
    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
  void runOnOperation() override {
    RewritePatternSet patterns(&getContext());
    populateAmdgpuFoldMemRefOpsPatterns(patterns);
    walkAndApplyPatterns(getOperation(), std::move(patterns));
  }
};

static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
                                      Value view, mlir::OperandRange indices,
                                      SmallVectorImpl<Value> &resolvedIndices,
                                      Value &memrefBase, StringRef role) {
  Operation *defOp = view.getDefiningOp();
  if (!defOp) {
    return failure();
  }
  return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
      .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
        mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
            rewriter, loc, subviewOp.getMixedOffsets(),
            subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
            resolvedIndices);
        memrefBase = subviewOp.getSource();
        return success();
      })
      .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
        if (failed(mlir::memref::resolveSourceIndicesExpandShape(
                loc, rewriter, expandShapeOp, indices, resolvedIndices,
                false))) {
          return failure();
        }
        memrefBase = expandShapeOp.getViewSource();
        return success();
      })
      .Case<memref::CollapseShapeOp>(
          [&](memref::CollapseShapeOp collapseShapeOp) {
            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
                    loc, rewriter, collapseShapeOp, indices,
                    resolvedIndices))) {
              return failure();
            }
            memrefBase = collapseShapeOp.getViewSource();
            return success();
          })
      .Default([&](Operation *op) {
        return rewriter.notifyMatchFailure(
            op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
                        "CollapseShapeOp")
                    .str());
      });
}

struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(GatherToLDSOp op,
                                PatternRewriter &rewriter) const override {
    Location loc = op.getLoc();

    SmallVector<Value> sourceIndices, destIndices;
    Value memrefSource, memrefDest;

    auto foldSrcResult =
        foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
                         sourceIndices, memrefSource, "source");

    if (failed(foldSrcResult)) {
      memrefSource = op.getSrc();
      sourceIndices = op.getSrcIndices();
    }

    auto foldDstResult =
        foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
                         destIndices, memrefDest, "destination");

    if (failed(foldDstResult)) {
      memrefDest = op.getDst();
      destIndices = op.getDstIndices();
    }

    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
                                               memrefDest, destIndices,
                                               op.getTransferType());

    return success();
  }
};

void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
                                         PatternBenefit benefit) {
  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
}
} // namespace mlir::amdgpu