diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); } void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build( if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), - inputs, /*outputs=*/{}, bodyBuild); + inputs, /*outputs=*/{init}, bodyBuild); } static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef<Value> operands, - bool initFirst = false) { + bool initFirst = false, bool mapInit = true) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, // If initFirst flag is enabled, we consider init as the first position of // payload operands. if (initFirst) { - payloadOpOperands.push_back(block.getArguments().back()); + if (mapInit) + payloadOpOperands.push_back(block.getArguments().back()); for (const auto &arg : block.getArguments().drop_back()) payloadOpOperands.push_back(arg); } else { payloadOpOperands = {block.getArguments().begin(), - block.getArguments().end()}; + block.getArguments().end() - int(!mapInit)}; } Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (payloadOpName.has_value()) { if (!result.operands.empty()) addBodyWithPayloadOp(parser, result, payloadOpName.value(), - payloadOpAttrs, - ArrayRef(result.operands).drop_back()); + payloadOpAttrs, ArrayRef(result.operands), false, + false); else result.addRegion(); } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, + bool mapInit = true) { + // `intFirst == true` implies that we want to map init arg + if (initFirst && !mapInit) + return false; // Check if the body can be printed in short form. The following 4 conditions // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { // 2) The payload op must have the same number of operands as the number of // block arguments. if (payload.getNumOperands() == 0 || - payload.getNumOperands() != body->getNumArguments()) + payload.getNumOperands() != body->getNumArguments() - int(!mapInit)) return false; // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { } } else { for (const auto &[operand, bbArg] : - llvm::zip(payload.getOperands(), body->getArguments())) { + llvm::zip(payload.getOperands(), + body->getArguments().drop_back(int(!mapInit)))) { if (bbArg != operand) return false; } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - bool useShortForm = canUseShortForm(mapper); + bool useShortForm = + canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false); if (useShortForm) { printShortForm(p, &mapper->getOperations().front()); } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() { auto *bodyBlock = getBody(); auto blockArgs = bodyBlock->getArguments(); - // Checks if the number of `inputs` match the arity of the `mapper` region. - if (getInputs().size() != blockArgs.size()) + // Checks if the number of `inputs` + `init` match the arity of the `mapper` + // region. + if (getInputs().size() + 1 != blockArgs.size()) return emitOpError() << "expects number of operands to match the arity of " "mapper, but got: " - << getInputs().size() << " and " << blockArgs.size(); + << getInputs().size() + 1 << " and " + << blockArgs.size(); // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : |
