aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp
blob: f8bab8289cbc66e6287ac202dc6a02e0d19412f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//===- ResolveStridedMetadata.cpp - AMDGPU expand_strided_metadata ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPURESOLVESTRIDEDMETADATAPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace mlir::amdgpu

using namespace mlir;
using namespace mlir::amdgpu;

namespace {
struct AmdgpuResolveStridedMetadataPass
    : public amdgpu::impl::AmdgpuResolveStridedMetadataPassBase<
          AmdgpuResolveStridedMetadataPass> {
  void runOnOperation() override;
};

struct ExtractStridedMetadataOnFatRawBufferCastFolder final
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
                                PatternRewriter &rewriter) const override {
    auto castOp = metadataOp.getSource().getDefiningOp<FatRawBufferCastOp>();
    if (!castOp)
      return rewriter.notifyMatchFailure(metadataOp,
                                         "not a fat raw buffer cast");
    Location loc = castOp.getLoc();
    auto sourceMetadata = memref::ExtractStridedMetadataOp::create(
        rewriter, loc, castOp.getSource());
    SmallVector<Value> results;
    if (metadataOp.getBaseBuffer().use_empty()) {
      results.push_back(nullptr);
    } else {
      auto baseBufferType =
          cast<MemRefType>(metadataOp.getBaseBuffer().getType());
      if (baseBufferType == castOp.getResult().getType()) {
        results.push_back(castOp.getResult());
      } else {
        results.push_back(memref::ReinterpretCastOp::create(
            rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0,
            /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
      }
    }
    if (castOp.getResetOffset())
      results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0));
    else
      results.push_back(sourceMetadata.getOffset());
    llvm::append_range(results, sourceMetadata.getSizes());
    llvm::append_range(results, sourceMetadata.getStrides());
    rewriter.replaceOp(metadataOp, results);
    return success();
  }
};
} // namespace

void mlir::amdgpu::populateAmdgpuResolveStridedMetadataPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {
  patterns.add<ExtractStridedMetadataOnFatRawBufferCastFolder>(
      patterns.getContext(), benefit);
}

void AmdgpuResolveStridedMetadataPass::runOnOperation() {
  RewritePatternSet patterns(&getContext());
  populateAmdgpuResolveStridedMetadataPatterns(patterns);
  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
    signalPassFailure();
}