diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 23 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 48 |
2 files changed, 46 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index dcc1ef9..b4b1347 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { // FillOpInterface implementation //===----------------------------------------------------------------------===// +namespace { enum class MatchFillResult { Success = 0, NotLinalgOp, WrongNumOperands, - NotScalarInput + NotScalarInput, + TypeMismatch }; +} // namespace static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast<linalg::LinalgOp>(op); @@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; + // Check that the scalar input type matches the output element type. + OpOperand *output = linalgOp.getDpsInitOperand(0); + Type scalarType = value->get().getType(); + Type outputElementType = getElementTypeOrSelf(output->get().getType()); + if (scalarType != outputElementType) + return MatchFillResult::TypeMismatch; + return MatchFillResult::Success; } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { - auto res = isFillInterfaceImpl(op); + MatchFillResult res = isFillInterfaceImpl(op); if (res == MatchFillResult::NotLinalgOp) return op->emitError("expected a LinalgOp"); if (res == MatchFillResult::WrongNumOperands) return op->emitError("expected op with 1 input and 1 output"); if (res == MatchFillResult::NotScalarInput) return op->emitError("expected op with scalar input"); + if (res == MatchFillResult::TypeMismatch) { + auto linalgOp = cast<linalg::LinalgOp>(op); + Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType(); + Type outputElementType = + getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType()); + return op->emitOpError("expected fill value type (") + << scalarType << ") to match output element type (" + << outputElementType << ")"; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..33ec79b 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1338,8 +1338,6 @@ Speculation::Speculatability GenericOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } -LogicalResult GenericOp::verify() { return success(); } - namespace { /// Remove linalg operations that are just copying the values from inputs to @@ -1474,6 +1472,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); } void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1495,14 @@ void MapOp::build( if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), - inputs, /*outputs=*/{}, bodyBuild); + inputs, /*outputs=*/{init}, bodyBuild); } static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef<Value> operands, - bool initFirst = false) { + bool initFirst = false, bool mapInit = true) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); @@ -1516,12 +1516,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, // If initFirst flag is enabled, we consider init as the first position of // payload operands. if (initFirst) { - payloadOpOperands.push_back(block.getArguments().back()); + if (mapInit) + payloadOpOperands.push_back(block.getArguments().back()); for (const auto &arg : block.getArguments().drop_back()) payloadOpOperands.push_back(arg); } else { payloadOpOperands = {block.getArguments().begin(), - block.getArguments().end()}; + block.getArguments().end() - int(!mapInit)}; } Operation *payloadOp = b.create( @@ -1553,8 +1554,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (payloadOpName.has_value()) { if (!result.operands.empty()) addBodyWithPayloadOp(parser, result, payloadOpName.value(), - payloadOpAttrs, - ArrayRef(result.operands).drop_back()); + payloadOpAttrs, ArrayRef(result.operands), false, + false); else result.addRegion(); } else { @@ -1570,7 +1571,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, + bool mapInit = true) { + // `intFirst == true` implies that we want to map init arg + if (initFirst && !mapInit) + return false; // Check if the body can be printed in short form. The following 4 conditions // must be satisfied: @@ -1582,7 +1587,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { // 2) The payload op must have the same number of operands as the number of // block arguments. if (payload.getNumOperands() == 0 || - payload.getNumOperands() != body->getNumArguments()) + payload.getNumOperands() != body->getNumArguments() - int(!mapInit)) return false; // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1605,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { } } else { for (const auto &[operand, bbArg] : - llvm::zip(payload.getOperands(), body->getArguments())) { + llvm::zip(payload.getOperands(), + body->getArguments().drop_back(int(!mapInit)))) { if (bbArg != operand) return false; } @@ -1632,7 +1638,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - bool useShortForm = canUseShortForm(mapper); + bool useShortForm = + canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false); if (useShortForm) { printShortForm(p, &mapper->getOperations().front()); } @@ -1658,11 +1665,13 @@ LogicalResult MapOp::verify() { auto *bodyBlock = getBody(); auto blockArgs = bodyBlock->getArguments(); - // Checks if the number of `inputs` match the arity of the `mapper` region. - if (getInputs().size() != blockArgs.size()) + // Checks if the number of `inputs` + `init` match the arity of the `mapper` + // region. + if (getInputs().size() + 1 != blockArgs.size()) return emitOpError() << "expects number of operands to match the arity of " "mapper, but got: " - << getInputs().size() << " and " << blockArgs.size(); + << getInputs().size() + 1 << " and " + << blockArgs.size(); // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : @@ -2080,7 +2089,7 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor, return failure(); // Single dimension transpose. - if (getPermutation().size() == 0) { + if (getPermutation().empty()) { result.push_back(getInput()); return success(); } @@ -4874,13 +4883,6 @@ void ElementwiseOp::print(OpAsmPrinter &p) { elidedAttrs); } -LogicalResult ElementwiseOp::verify() { - // All necessary checks are done either by - // - EnumAttr (e.g. unknown operation kind) - // - verifyStructuredOpInterface (incorrect map, sizes). - return success(); -} - /// Implements the block region builder for the ElementwiseOp. This is called by /// 'fillStructuredOpRegion'. void ElementwiseOp::regionBuilder( |
