aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg')
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp37
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
6 files changed, 38 insertions, 28 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b..3dc45ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8b89244..3a43382 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1958,7 +1958,7 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
/// Return true if either `op` or `permutation` are empty to allow a simpler
/// polymorphic implementation.
template <typename RelayoutOpTy>
-bool isValidPackingPermutation(
+static bool isValidPackingPermutation(
RelayoutOpTy op, ArrayRef<int64_t> permutation,
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
static_assert(
@@ -4322,9 +4322,10 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
template <typename OpTy>
-DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
+static DiagnosedSilenceableFailure
+doit(RewriterBase &rewriter, OpTy target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
tensor::ParallelInsertSliceOp>() &&
"wrong op type");
@@ -4499,7 +4500,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
return true;
})
- .Default([&](Operation *op) { return false; });
+ .Default(false);
if (!supported) {
DiagnosedSilenceableFailure diag =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393..75bb175 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f05ffa8..6519c4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
return complex::CreateOp::create(b, t, tmp, tmp);
})
- .Default([](auto) { return Value(); });
+ .Default(nullptr);
if (!fillVal)
return failure();
linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index 27ccf3c..6becc1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
- .Default([](Operation *op) { return nullptr; });
+ .Default(nullptr);
if (!newConv)
return failure();
for (auto attr : preservedAttrs)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9d62491..cb6199f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
- .Default([&](auto op) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Check whether `outputOperand` is a reduction with a single combiner
@@ -3911,21 +3911,21 @@ struct Conv1DGenerator
Value lhs = vector::TransferReadOp::create(
rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
- auto maybeMaskedLhs = maybeMaskXferOp(
+ auto *maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = vector::TransferReadOp::create(
rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
- auto maybeMaskedRhs = maybeMaskXferOp(
+ auto *maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = vector::TransferReadOp::create(
rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
- auto maybeMaskedRes = maybeMaskXferOp(
+ auto *maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
//===------------------------------------------------------------------===//