1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
|
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace cir;
//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//
namespace {
/// Simplify suitable ternary operations into select operations.
///
/// For now we only simplify those ternary operations whose true and false
/// branches directly yield a value or a constant. That is, both of the true and
/// the false branch must either contain a cir.yield operation as the only
/// operation in the branch, or contain a cir.const operation followed by a
/// cir.yield operation that yields the constant value.
///
/// For example, we will simplify the following ternary operation:
///
/// %0 = ...
/// %1 = cir.ternary (%condition, true {
/// %2 = cir.const ...
/// cir.yield %2
/// } false {
/// cir.yield %0
///
/// into the following sequence of operations:
///
/// %1 = cir.const ...
/// %0 = cir.select if %condition then %1 else %2
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
using OpRewritePattern<TernaryOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TernaryOp op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return mlir::failure();
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
!isSimpleTernaryBranch(op.getFalseRegion()))
return mlir::failure();
cir::YieldOp trueBranchYieldOp =
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
cir::YieldOp falseBranchYieldOp =
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
rewriter.eraseOp(trueBranchYieldOp);
rewriter.eraseOp(falseBranchYieldOp);
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
falseValue);
return mlir::success();
}
private:
bool isSimpleTernaryBranch(mlir::Region ®ion) const {
if (!region.hasOneBlock())
return false;
mlir::Block &onlyBlock = region.front();
mlir::Block::OpListType &ops = onlyBlock.getOperations();
// The region/block could only contain at most 2 operations.
if (ops.size() > 2)
return false;
if (ops.size() == 1) {
// The region/block only contain a cir.yield operation.
return true;
}
// Check whether the region/block contains a cir.const followed by a
// cir.yield that yields the value.
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
yieldOp.getArgs()[0].getDefiningOp());
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
}
};
/// Simplify select operations with boolean constants into simpler forms.
///
/// This pattern simplifies select operations where both true and false values
/// are boolean constants. Two specific cases are handled:
///
/// 1. When selecting between true and false based on a condition,
/// the operation simplifies to just the condition itself:
///
/// %0 = cir.select if %condition then true else false
/// ->
/// (replaced with %condition directly)
///
/// 2. When selecting between false and true based on a condition,
/// the operation simplifies to the logical negation of the condition:
///
/// %0 = cir.select if %condition then false else true
/// ->
/// %0 = cir.unary not %condition
struct SimplifySelect : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter &rewriter) const final {
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
auto trueValueConstOp =
mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
auto falseValueConstOp =
mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
if (!trueValueConstOp || !falseValueConstOp)
return mlir::failure();
auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
auto falseValue =
mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
if (!trueValue || !falseValue)
return mlir::failure();
// cir.select if %0 then #true else #false -> %0
if (trueValue.getValue() && !falseValue.getValue()) {
rewriter.replaceAllUsesWith(op, op.getCondition());
rewriter.eraseOp(op);
return mlir::success();
}
// cir.select if %0 then #false else #true -> cir.unary not %0
if (!trueValue.getValue() && falseValue.getValue()) {
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
op.getCondition());
return mlir::success();
}
return mlir::failure();
}
};
/// Simplify `cir.switch` operations by folding cascading cases
/// into a single `cir.case` with the `anyof` kind.
///
/// This pattern identifies cascading cases within a `cir.switch` operation.
/// Cascading cases are defined as consecutive `cir.case` operations of kind
/// `equal`, each containing a single `cir.yield` operation in their body.
///
/// The pattern merges these cascading cases into a single `cir.case` operation
/// with kind `anyof`, aggregating all the case values.
///
/// The merging process continues until a `cir.case` with a different body
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
/// breaks the chain.
///
/// Example:
///
/// Before:
/// cir.case equal, [#cir.int<0> : !s32i] {
/// cir.yield
/// }
/// cir.case equal, [#cir.int<1> : !s32i] {
/// cir.yield
/// }
/// cir.case equal, [#cir.int<2> : !s32i] {
/// cir.break
/// }
///
/// After applying SimplifySwitch:
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
/// !s32i] {
/// cir.break
/// }
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
using OpRewritePattern<SwitchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SwitchOp op,
PatternRewriter &rewriter) const override {
LogicalResult changed = mlir::failure();
SmallVector<CaseOp, 8> cases;
SmallVector<CaseOp, 4> cascadingCases;
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
op.collectCases(cases);
if (cases.empty())
return mlir::failure();
auto flushMergedOps = [&]() {
for (CaseOp &c : cascadingCases)
rewriter.eraseOp(c);
cascadingCases.clear();
cascadingCaseValues.clear();
};
auto mergeCascadingInto = [&](CaseOp &target) {
rewriter.modifyOpInPlace(target, [&]() {
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
target.setKind(CaseOpKind::Anyof);
});
changed = mlir::success();
};
for (CaseOp c : cases) {
cir::CaseOpKind kind = c.getKind();
if (kind == cir::CaseOpKind::Equal &&
isa<YieldOp>(c.getCaseRegion().front().front())) {
// If the case contains only a YieldOp, collect it for cascading merge
cascadingCases.push_back(c);
cascadingCaseValues.push_back(c.getValue()[0]);
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
// merge previously collected cascading cases
cascadingCaseValues.push_back(c.getValue()[0]);
mergeCascadingInto(c);
flushMergedOps();
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
// If a Default, Anyof or Range case is found and there are previous
// cascading cases, merge all of them into the last cascading case.
// We don't currently fold case range statements with other case
// statements.
assert(!cir::MissingFeatures::foldRangeCase());
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
cascadingCases.pop_back();
flushMergedOps();
} else {
cascadingCases.clear();
cascadingCaseValues.clear();
}
}
// Edge case: all cases are simple cascading cases
if (cascadingCases.size() == cases.size()) {
CaseOp lastCascadingCase = cascadingCases.back();
mergeCascadingInto(lastCascadingCase);
cascadingCases.pop_back();
flushMergedOps();
}
return changed;
}
};
struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
using OpRewritePattern<VecSplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(VecSplatOp op,
PatternRewriter &rewriter) const override {
mlir::Value splatValue = op.getValue();
auto constant =
mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
if (!constant)
return mlir::failure();
auto value = constant.getValue();
if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
!mlir::isa_and_nonnull<cir::FPAttr>(value))
return mlir::failure();
cir::VectorType resultType = op.getResult().getType();
SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
auto constVecAttr = cir::ConstVectorAttr::get(
resultType, mlir::ArrayAttr::get(getContext(), elements));
rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
return mlir::success();
}
};
//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
using CIRSimplifyBase::CIRSimplifyBase;
void runOnOperation() override;
};
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
SimplifySelect,
SimplifySwitch,
SimplifyVecSplat
>(patterns.getContext());
// clang-format on
}
void CIRSimplifyPass::runOnOperation() {
// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateMergeCleanupPatterns(patterns);
// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
ops.push_back(op);
});
// Apply patterns.
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
signalPassFailure();
}
} // namespace
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
return std::make_unique<CIRSimplifyPass>();
}
|