aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR')
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp23
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp48
2 files changed, 46 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dcc1ef9..b4b1347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) {
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
+namespace {
enum class MatchFillResult {
Success = 0,
NotLinalgOp,
WrongNumOperands,
- NotScalarInput
+ NotScalarInput,
+ TypeMismatch
};
+} // namespace
static MatchFillResult isFillInterfaceImpl(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
@@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) {
if (!linalgOp.isScalar(value))
return MatchFillResult::NotScalarInput;
+ // Check that the scalar input type matches the output element type.
+ OpOperand *output = linalgOp.getDpsInitOperand(0);
+ Type scalarType = value->get().getType();
+ Type outputElementType = getElementTypeOrSelf(output->get().getType());
+ if (scalarType != outputElementType)
+ return MatchFillResult::TypeMismatch;
+
return MatchFillResult::Success;
}
LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
- auto res = isFillInterfaceImpl(op);
+ MatchFillResult res = isFillInterfaceImpl(op);
if (res == MatchFillResult::NotLinalgOp)
return op->emitError("expected a LinalgOp");
if (res == MatchFillResult::WrongNumOperands)
return op->emitError("expected op with 1 input and 1 output");
if (res == MatchFillResult::NotScalarInput)
return op->emitError("expected op with scalar input");
+ if (res == MatchFillResult::TypeMismatch) {
+ auto linalgOp = cast<linalg::LinalgOp>(op);
+ Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType();
+ Type outputElementType =
+ getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType());
+ return op->emitOpError("expected fill value type (")
+ << scalarType << ") to match output element type ("
+ << outputElementType << ")";
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b..33ec79b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
-LogicalResult GenericOp::verify() { return success(); }
-
namespace {
/// Remove linalg operations that are just copying the values from inputs to
@@ -1474,6 +1472,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1495,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1516,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1571,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1587,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1605,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1638,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1665,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
@@ -2080,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return failure();
// Single dimension transpose.
- if (getPermutation().size() == 0) {
+ if (getPermutation().empty()) {
result.push_back(getInput());
return success();
}
@@ -4874,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) {
elidedAttrs);
}
-LogicalResult ElementwiseOp::verify() {
- // All necessary checks are done either by
- // - EnumAttr (e.g. unknown operation kind)
- // - verifyStructuredOpInterface (incorrect map, sizes).
- return success();
-}
-
/// Implements the block region builder for the ElementwiseOp. This is called by
/// 'fillStructuredOpRegion'.
void ElementwiseOp::regionBuilder(