diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 37 | 
1 files changed, 24 insertions, 13 deletions
| diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,                                       OpAsmSetValueNameFn setNameFn) {    for (Value v : getRegionInputArgs())      setNameFn(v, "in"); +  for (Value v : getRegionOutputArgs()) +    setNameFn(v, "init");  }  void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build(    if (bodyBuild)      buildGenericRegion(builder, result.location, *result.regions.front(), -                       inputs, /*outputs=*/{}, bodyBuild); +                       inputs, /*outputs=*/{init}, bodyBuild);  }  static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,                                   const OperationName &payloadOpName,                                   const NamedAttrList &payloadOpAttrs,                                   ArrayRef<Value> operands, -                                 bool initFirst = false) { +                                 bool initFirst = false, bool mapInit = true) {    OpBuilder b(parser.getContext());    Region *body = result.addRegion();    Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,    // If initFirst flag is enabled, we consider init as the first position of    // payload operands.    if (initFirst) { -    payloadOpOperands.push_back(block.getArguments().back()); +    if (mapInit) +      payloadOpOperands.push_back(block.getArguments().back());      for (const auto &arg : block.getArguments().drop_back())        payloadOpOperands.push_back(arg);    } else {      payloadOpOperands = {block.getArguments().begin(), -                         block.getArguments().end()}; +                         block.getArguments().end() - int(!mapInit)};    }    Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    if (payloadOpName.has_value()) {      if (!result.operands.empty())        addBodyWithPayloadOp(parser, result, payloadOpName.value(), -                           payloadOpAttrs, -                           ArrayRef(result.operands).drop_back()); +                           payloadOpAttrs, ArrayRef(result.operands), false, +                           false);      else        result.addRegion();    } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    return success();  } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, +                            bool mapInit = true) { +  // `intFirst == true` implies that we want to map init arg +  if (initFirst && !mapInit) +    return false;    // Check if the body can be printed in short form. The following 4 conditions    // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {    // 2) The payload op must have the same number of operands as the number of    //    block arguments.    if (payload.getNumOperands() == 0 || -      payload.getNumOperands() != body->getNumArguments()) +      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))      return false;    // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {      }    } else {      for (const auto &[operand, bbArg] : -         llvm::zip(payload.getOperands(), body->getArguments())) { +         llvm::zip(payload.getOperands(), +                   body->getArguments().drop_back(int(!mapInit)))) {        if (bbArg != operand)          return false;      } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {  void MapOp::print(OpAsmPrinter &p) {    Block *mapper = getBody(); -  bool useShortForm = canUseShortForm(mapper); +  bool useShortForm = +      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);    if (useShortForm) {      printShortForm(p, &mapper->getOperations().front());    } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {    auto *bodyBlock = getBody();    auto blockArgs = bodyBlock->getArguments(); -  // Checks if the number of `inputs` match the arity of the `mapper` region. -  if (getInputs().size() != blockArgs.size()) +  // Checks if the number of `inputs` + `init` match the arity of the `mapper` +  // region. +  if (getInputs().size() + 1 != blockArgs.size())      return emitOpError() << "expects number of operands to match the arity of "                              "mapper, but got: " -                         << getInputs().size() << " and " << blockArgs.size(); +                         << getInputs().size() + 1 << " and " +                         << blockArgs.size();    // The parameters of mapper should all match the element type of inputs.    for (const auto &[bbArgType, inputArg] : | 
