1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
//===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct LinalgCopyOpSubsetOpInterface
: public SubsetOpInterface::ExternalModel<LinalgCopyOpSubsetOpInterface,
linalg::CopyOp> {
bool operatesOnEquivalentSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
// linalg.copy operates on the entire destination tensor.
if (auto otherCopyOp = dyn_cast<linalg::CopyOp>(candidate.getOperation()))
return equivalenceFn(cast<linalg::CopyOp>(op).getOutputs()[0],
otherCopyOp.getOutputs()[0]);
// In the absence of an analysis, "false" is a conservative way to implement
// this interface.
return false;
}
bool operatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
// In the absence of an analysis, "false" is a conservative way to implement
// this interface.
return false;
}
};
struct LinalgCopyOpInterface
: public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
return llvm::getSingleElement(copyOp.getInputsMutable());
}
bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
return equivalenceFn(candidate,
llvm::getSingleElement(copyOp.getOutputs()));
}
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
return llvm::getSingleElement(copyOp.getOutputs());
}
SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
void mlir::linalg::registerSubsetOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
linalg::CopyOp::attachInterface<LinalgCopyOpSubsetOpInterface>(*ctx);
linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
});
}
|