blob: 00dac19e37171cebfa4cfd38cb9f4400ea0683bc (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
namespace mlir {
#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
namespace {
//===----------------------------------------------------------------------===//
// BubbleDownCastsPattern pattern
//===----------------------------------------------------------------------===//
/// Pattern to bubble down casts into consumer operations.
struct BubbleDownCastsPattern
: public OpInterfaceRewritePattern<MemorySpaceCastConsumerOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op,
PatternRewriter &rewriter) const override {
FailureOr<std::optional<SmallVector<Value>>> results =
op.bubbleDownCasts(rewriter);
if (failed(results))
return failure();
if (!results->has_value()) {
rewriter.modifyOpInPlace(op, []() {});
return success();
}
rewriter.replaceOp(op, **results);
return success();
}
};
//===----------------------------------------------------------------------===//
// BubbleDownMemorySpaceCasts pass
//===----------------------------------------------------------------------===//
struct BubbleDownMemorySpaceCasts
: public impl::BubbleDownMemorySpaceCastsBase<BubbleDownMemorySpaceCasts> {
using impl::BubbleDownMemorySpaceCastsBase<
BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1));
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateBubbleDownMemorySpaceCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownCastsPattern>(patterns.getContext(), benefit);
}
|