diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg')
6 files changed, 38 insertions, 28 deletions
| diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,                                       OpAsmSetValueNameFn setNameFn) {    for (Value v : getRegionInputArgs())      setNameFn(v, "in"); +  for (Value v : getRegionOutputArgs()) +    setNameFn(v, "init");  }  void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build(    if (bodyBuild)      buildGenericRegion(builder, result.location, *result.regions.front(), -                       inputs, /*outputs=*/{}, bodyBuild); +                       inputs, /*outputs=*/{init}, bodyBuild);  }  static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,                                   const OperationName &payloadOpName,                                   const NamedAttrList &payloadOpAttrs,                                   ArrayRef<Value> operands, -                                 bool initFirst = false) { +                                 bool initFirst = false, bool mapInit = true) {    OpBuilder b(parser.getContext());    Region *body = result.addRegion();    Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,    // If initFirst flag is enabled, we consider init as the first position of    // payload operands.    if (initFirst) { -    payloadOpOperands.push_back(block.getArguments().back()); +    if (mapInit) +      payloadOpOperands.push_back(block.getArguments().back());      for (const auto &arg : block.getArguments().drop_back())        payloadOpOperands.push_back(arg);    } else {      payloadOpOperands = {block.getArguments().begin(), -                         block.getArguments().end()}; +                         block.getArguments().end() - int(!mapInit)};    }    Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    if (payloadOpName.has_value()) {      if (!result.operands.empty())        addBodyWithPayloadOp(parser, result, payloadOpName.value(), -                           payloadOpAttrs, -                           ArrayRef(result.operands).drop_back()); +                           payloadOpAttrs, ArrayRef(result.operands), false, +                           false);      else        result.addRegion();    } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    return success();  } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, +                            bool mapInit = true) { +  // `intFirst == true` implies that we want to map init arg +  if (initFirst && !mapInit) +    return false;    // Check if the body can be printed in short form. The following 4 conditions    // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {    // 2) The payload op must have the same number of operands as the number of    //    block arguments.    if (payload.getNumOperands() == 0 || -      payload.getNumOperands() != body->getNumArguments()) +      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))      return false;    // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {      }    } else {      for (const auto &[operand, bbArg] : -         llvm::zip(payload.getOperands(), body->getArguments())) { +         llvm::zip(payload.getOperands(), +                   body->getArguments().drop_back(int(!mapInit)))) {        if (bbArg != operand)          return false;      } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {  void MapOp::print(OpAsmPrinter &p) {    Block *mapper = getBody(); -  bool useShortForm = canUseShortForm(mapper); +  bool useShortForm = +      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);    if (useShortForm) {      printShortForm(p, &mapper->getOperations().front());    } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {    auto *bodyBlock = getBody();    auto blockArgs = bodyBlock->getArguments(); -  // Checks if the number of `inputs` match the arity of the `mapper` region. -  if (getInputs().size() != blockArgs.size()) +  // Checks if the number of `inputs` + `init` match the arity of the `mapper` +  // region. +  if (getInputs().size() + 1 != blockArgs.size())      return emitOpError() << "expects number of operands to match the arity of "                              "mapper, but got: " -                         << getInputs().size() << " and " << blockArgs.size(); +                         << getInputs().size() + 1 << " and " +                         << blockArgs.size();    // The parameters of mapper should all match the element type of inputs.    for (const auto &[bbArgType, inputArg] : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8b89244..3a43382 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1958,7 +1958,7 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };  /// Return true if either `op` or `permutation` are empty to allow a simpler  /// polymorphic implementation.  template <typename RelayoutOpTy> -bool isValidPackingPermutation( +static bool isValidPackingPermutation(      RelayoutOpTy op, ArrayRef<int64_t> permutation,      OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {    static_assert( @@ -4322,9 +4322,10 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(  // InsertSliceToCopyOp  //===----------------------------------------------------------------------===//  template <typename OpTy> -DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, -                                 transform::ApplyToEachResultList &results, -                                 transform::TransformState &state) { +static DiagnosedSilenceableFailure +doit(RewriterBase &rewriter, OpTy target, +     transform::ApplyToEachResultList &results, +     transform::TransformState &state) {    static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,                                  tensor::ParallelInsertSliceOp>() &&                  "wrong op type"); @@ -4499,7 +4500,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(              maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);              return true;            }) -          .Default([&](Operation *op) { return false; }); +          .Default(false);    if (!supported) {      DiagnosedSilenceableFailure diag = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 3e31393..75bb175 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -31,10 +31,8 @@ using namespace mlir;  using namespace mlir::linalg;  static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { -  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot -  // trivially generalize a `linalg.map`, as it does not use the output as -  // region arguments in the block. -  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) +  // Bailout if `linalgOp` is already a generic. +  if (isa<GenericOp>(linalgOp))      return failure();    // Check if the operation has exactly one region.    if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index f05ffa8..6519c4f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,                  tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));                return complex::CreateOp::create(b, t, tmp, tmp);              }) -            .Default([](auto) { return Value(); }); +            .Default(nullptr);      if (!fillVal)        return failure();      linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 27ccf3c..6becc1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,                  ValueRange{input, collapsedKernel, iZp, kZp},                  ValueRange{collapsedInit}, stride, dilation);            }) -          .Default([](Operation *op) { return nullptr; }); +          .Default(nullptr);    if (!newConv)      return failure();    for (auto attr : preservedAttrs) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 9d62491..cb6199f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {            [&](auto op) { return CombiningKind::MUL; })        .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })        .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) -      .Default([&](auto op) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// Check whether `outputOperand` is a reduction with a single combiner @@ -3911,21 +3911,21 @@ struct Conv1DGenerator      Value lhs = vector::TransferReadOp::create(          rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); -    auto maybeMaskedLhs = maybeMaskXferOp( +    auto *maybeMaskedLhs = maybeMaskXferOp(          lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());      // Read rhs slice of size {kw, c} @ [0, 0].      Value rhs = vector::TransferReadOp::create(          rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); -    auto maybeMaskedRhs = maybeMaskXferOp( +    auto *maybeMaskedRhs = maybeMaskXferOp(          rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());      // Read res slice of size {n, w, c} @ [0, 0, 0].      Value res = vector::TransferReadOp::create(          rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); -    auto maybeMaskedRes = maybeMaskXferOp( +    auto *maybeMaskedRes = maybeMaskXferOp(          resType.getShape(), resType.getScalableDims(), res.getDefiningOp());      //===------------------------------------------------------------------===// | 
