blob: 496238fcaa3ff164aed4df9b4ac589a3340565ed (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
} // namespace mlir
namespace {
size_t getNumTensorResults(Operation *op) {
size_t numTensorResults = 0;
for (auto t : op->getResultTypes()) {
if (isa<TensorType>(t)) {
++numTensorResults;
}
}
return numTensorResults;
}
} // namespace
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputTensorOperands;
for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
Type type = operand.get().getType();
if (isa<TensorType>(type)) {
outputTensorOperands.push_back(&operand);
} else if (!isa<BaseMemRefType>(type)) {
return op->emitOpError("expected that operand #")
<< operand.getOperandNumber() << " is a tensor or a memref";
}
}
// Verify the number of tensor results matches the number of output tensors.
if (getNumTensorResults(op) != outputTensorOperands.size())
return op->emitOpError("expected the number of tensor results (")
<< getNumTensorResults(op)
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
return success();
}
|