aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp
blob: 00dac19e37171cebfa4cfd38cb9f4400ea0683bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

namespace mlir {
#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir

namespace {
//===----------------------------------------------------------------------===//
// BubbleDownCastsPattern pattern
//===----------------------------------------------------------------------===//
/// Pattern to bubble down casts into consumer operations.
struct BubbleDownCastsPattern
    : public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;

  LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
                                PatternRewriter &rewriter) const override {
    FailureOr<std::optional<SmallVector<Value>>> results =
        op.bubbleDownCasts(rewriter);
    if (failed(results))
      return failure();
    if (!results->has_value()) {
      rewriter.modifyOpInPlace(op, []() {});
      return success();
    }
    rewriter.replaceOp(op, **results);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// BubbleDownMemorySpaceCasts pass
//===----------------------------------------------------------------------===//

struct BubbleDownMemorySpaceCasts
    : public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
  using impl::BubbleDownMemorySpaceCastsBase<
      BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;

  void runOnOperation() override {
    RewritePatternSet patterns(&getContext());
    populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
      signalPassFailure();
  }
};
} // namespace

void mlir::populateBubbleDownMemorySpaceCastPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {
  patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
}