aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp7
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp6
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp169
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp37
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp2
23 files changed, 221 insertions, 53 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 898d76c..980442e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
.Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
.Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
- .Default([](Operation *op) { return std::nullopt; });
+ .Default(std::nullopt);
if (!maybeKind) {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index d9d6934..8655ed3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
SymbolTableCollection &symbolTables) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<FuncOp>(
- symbolTables.lookupNearestSymbolFrom(callOp, sym));
+ return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));
}
/// Return the FuncOp called by `callOp`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index aa53f94..c233e24 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
static func::FuncOp
getCalledFunction(func::CallOp callOp,
mlir::SymbolTableCollection &symbolTable) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- symbolTable.lookupNearestSymbolFrom(callOp, sym));
+ callOp.resolveCallableInTable(&symbolTable));
}
/// Return "true" if the given function signature has tensor semantics.
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d3..e9da135 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRControlFlowTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
+ StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000..5e2a742
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+ const TypeConverter *converter,
+ Operation *branchOp, Block *block,
+ TypeRange expectedTypes) {
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(
+ branchOp,
+ "mismatch between adaptor operand types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+ TypeRange(ValueRange(flattenedAdaptor)));
+ if (failed(convertedBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+ *convertedBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
+ FailureOr<Block *> convertedTrueBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
+ if (failed(convertedTrueBlock))
+ return failure();
+ FailureOr<Block *> convertedFalseBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
+ if (failed(convertedFalseBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+ *convertedTrueBlock, *convertedFalseBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+ using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Get or convert default block.
+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+ rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+ TypeRange(adaptor.getDefaultOperands()));
+ if (failed(convertedDefaultBlock))
+ return failure();
+
+ // Get or convert all case blocks.
+ SmallVector<Block *> caseDestinations;
+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+ for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+ Block *b = it.value();
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, b,
+ TypeRange(caseOperands[it.index()]));
+ if (failed(convertedBlock))
+ return failure();
+ caseDestinations.push_back(*convertedBlock);
+ }
+
+ rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+ op, adaptor.getFlag(), *convertedDefaultBlock,
+ adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+ caseDestinations, caseOperands);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+ typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+ const TypeConverter &typeConverter, ConversionTarget &target) {
+ target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, PatternBenefit benefit) {
+ populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+ populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index d2c2138..025d1ac 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -330,7 +330,7 @@ static Value getBase(Value v) {
v = op.getSrc();
return true;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!shouldContinue)
break;
}
@@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {
.Case([](memref::TransposeOp transpose) { return transpose.getIn(); })
.Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
[](auto op) { return op.getSrc(); })
- .Default([](Operation *) { return Value(); });
+ .Default(nullptr);
}
/// Returns `true` if the given operation is known to capture the given value,
@@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {
// These operations are known not to capture.
.Case([](memref::DeallocOp) { return false; })
// By default, we don't know anything.
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Returns `true` if the value may be captured by any of its users, i.e., if
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3eae67f..2731069 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
return structType.getBody()[memberIndex];
return nullptr;
})
- .Default(Type(nullptr));
+ .Default(nullptr);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index cee943d..7d9058c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
.Case<IntegerType, FloatType>([](auto type) {
return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
- .Default([](Type) { return false; });
+ .Default(false);
if (!canConvertType)
return false;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ac35eea..ce93d18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
// clang-format on
.Case<PtrLikeTypeInterface>(
[](Type type) { return isCompatiblePtrType(type); })
- .Default([](Type) { return false; });
+ .Default(false);
if (!result)
compatibleTypes.erase(type);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b..3dc45ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8b89244..b09112b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4499,7 +4499,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
return true;
})
- .Default([&](Operation *op) { return false; });
+ .Default(false);
if (!supported) {
DiagnosedSilenceableFailure diag =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393..75bb175 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f05ffa8..6519c4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
return complex::CreateOp::create(b, t, tmp, tmp);
})
- .Default([](auto) { return Value(); });
+ .Default(nullptr);
if (!fillVal)
return failure();
linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index 27ccf3c..6becc1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
- .Default([](Operation *op) { return nullptr; });
+ .Default(nullptr);
if (!newConv)
return failure();
for (auto attr : preservedAttrs)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9d62491..cb6199f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
- .Default([&](auto op) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Check whether `outputOperand` is a reduction with a single combiner
@@ -3911,21 +3911,21 @@ struct Conv1DGenerator
Value lhs = vector::TransferReadOp::create(
rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
- auto maybeMaskedLhs = maybeMaskXferOp(
+ auto *maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = vector::TransferReadOp::create(
rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
- auto maybeMaskedRhs = maybeMaskXferOp(
+ auto *maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = vector::TransferReadOp::create(
rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
- auto maybeMaskedRes = maybeMaskXferOp(
+ auto *maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
//===------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fdd..e685089 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
[](auto op) { return op.getBase(); })
- .Default([](auto) { return Value{}; });
+ .Default(nullptr);
}
template <typename T>
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 4ebd90d..d380c46 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
? forOp.getInitArgs()[opResult.getResultNumber()]
: Value();
})
- .Default([&](auto op) { return Value(); });
+ .Default(nullptr);
}
return false;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0c8114d..938952e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType>(
[](auto coopType) { return coopType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
// Case 1. -- matrices.
if (coopElementType) {
@@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
.Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
assert(elementType && "Unhandled type");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f895807..d1e275d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {
return *elementSize * type.getNumElements();
return std::nullopt;
})
- .Default(std::optional<int64_t>());
+ .Default(std::nullopt);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 88e1ab6..cb9b7f6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ac72002..110bfdc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -41,10 +41,6 @@
using namespace mlir;
using namespace mlir::tensor;
-using llvm::divideCeilSigned;
-using llvm::divideFloorSigned;
-using llvm::mod;
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e..c607ece 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 69e649d..bc4f5a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
- .Default(Value());
+ .Default(nullptr);
if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,