1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
//===- ComposeSubView.cpp - Combining composed subview ops ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains patterns for combining composed subview ops (i.e. subview
// of a subview becomes a single subview).
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/ComposeSubView.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
// Replaces a subview of a subview with a single subview(both static and dynamic
// offsets are supported).
struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp op,
PatternRewriter &rewriter) const override {
// 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
// produces the input of the op we're rewriting (for 'SubViewOp' the input
// is called the "source" value). We can only combine them if both 'op' and
// 'sourceOp' are 'SubViewOp'.
auto sourceOp = op.getSource().getDefiningOp<memref::SubViewOp>();
if (!sourceOp)
return failure();
// A 'SubViewOp' can be "rank-reducing" by eliminating dimensions of the
// output memref that are statically known to be equal to 1. We do not
// allow 'sourceOp' to be a rank-reducing subview because then our two
// 'SubViewOp's would have different numbers of offset/size/stride
// parameters (just difficult to deal with, not impossible if we end up
// needing it).
if (sourceOp.getSourceType().getRank() != sourceOp.getType().getRank()) {
return failure();
}
// Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
SmallVector<OpFoldResult> offsets, sizes, strides,
opStrides = op.getMixedStrides(),
sourceStrides = sourceOp.getMixedStrides();
// The output stride in each dimension is equal to the product of the
// dimensions corresponding to source and op.
int64_t sourceStrideValue;
for (auto &&[opStride, sourceStride] :
llvm::zip(opStrides, sourceStrides)) {
Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
if (!opStrideAttr || !sourceStrideAttr)
return failure();
sourceStrideValue = cast<IntegerAttr>(sourceStrideAttr).getInt();
strides.push_back(rewriter.getI64IntegerAttr(
cast<IntegerAttr>(opStrideAttr).getInt() * sourceStrideValue));
}
// The rules for calculating the new offsets and sizes are:
// * Multiple subview offsets for a given dimension compose additively.
// ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
// m + n * k")
// * Multiple sizes for a given dimension compose by taking the size of the
// final subview and ignoring the rest. ("Take m values" followed by "Take
// n values" == "Take n values") This size must also be the smallest one
// by definition (a subview needs to be the same size as or smaller than
// its source along each dimension; presumably subviews that are larger
// than their sources are disallowed by validation).
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
llvm::dyn_cast_if_present<Attribute>(sourceOffset),
sourceStrideAttr =
llvm::dyn_cast_if_present<Attribute>(sourceStride);
if (opOffsetAttr && sourceOffsetAttr) {
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
cast<IntegerAttr>(opOffsetAttr).getInt() *
cast<IntegerAttr>(sourceStrideAttr).getInt() +
cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
AffineExpr expr;
SmallVector<Value> affineApplyOperands;
// Make 'expr' add 'sourceOffset'.
if (auto attr = llvm::dyn_cast_if_present<Attribute>(sourceOffset)) {
expr =
rewriter.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt());
} else {
expr = rewriter.getAffineSymbolExpr(affineApplyOperands.size());
affineApplyOperands.push_back(cast<Value>(sourceOffset));
}
// Multiply 'opOffset' by 'sourceStride' and make the 'expr' add the
// result.
if (auto attr = llvm::dyn_cast_if_present<Attribute>(opOffset)) {
expr = expr + cast<IntegerAttr>(attr).getInt() *
cast<IntegerAttr>(sourceStrideAttr).getInt();
} else {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()) *
cast<IntegerAttr>(sourceStrideAttr).getInt();
affineApplyOperands.push_back(cast<Value>(opOffset));
}
AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map,
affineApplyOperands);
offsets.push_back(result);
}
}
// This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
// uses it can be removed by a (separate) dead code elimination pass.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
return success();
}
};
} // namespace
void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ComposeSubViewOpPattern>(context);
}
|