aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h48
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td4
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td12
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp2
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp64
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp7
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp6
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp169
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp37
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp2
-rw-r--r--mlir/lib/Query/Query.cpp5
-rw-r--r--mlir/lib/Support/Timing.cpp1
-rw-r--r--mlir/lib/TableGen/Type.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/DebugTranslation.cpp6
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp55
-rw-r--r--mlir/test/Conversion/SCFToGPU/parallel_loop.mlir32
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/generalize-named-ops.mlir22
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir10
-rw-r--r--mlir/test/Dialect/Linalg/one-shot-bufferize.mlir2
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir18
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir2
-rw-r--r--mlir/test/Dialect/Tensor/bufferize.mlir2
-rw-r--r--mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir6
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir3
-rw-r--r--mlir/test/Transforms/test-legalize-type-conversion.mlir22
-rw-r--r--mlir/test/lib/Dialect/Test/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp7
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp2
-rw-r--r--mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp4
51 files changed, 471 insertions, 143 deletions
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
new file mode 100644
index 0000000..a32d9e2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h
@@ -0,0 +1,48 @@
+//===- StructuralTypeConversions.h - CF Type Conversions --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+#define MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+
+class ConversionTarget;
+class TypeConverter;
+
+namespace cf {
+
+/// Populates patterns for CF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is cf.br -- the cf.br op needs to update
+/// its types accordingly to the TypeConverter, but otherwise does not care
+/// what type conversions are happening.
+void populateCFStructuralTypeConversionsAndLegality(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, PatternBenefit benefit = 1);
+
+/// Similar to `populateCFStructuralTypeConversionsAndLegality` but does not
+/// populate the conversion target.
+void populateCFStructuralTypeConversions(const TypeConverter &typeConverter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Updates the ConversionTarget with dynamic legality of CF operations based
+/// on the provided type converter.
+void populateCFStructuralTypeConversionTarget(
+ const TypeConverter &typeConverter, ConversionTarget &target);
+
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROL_FLOW_TRANSFORMS_STRUCTURAL_TYPE_CONVERSIONS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f3674c3..ecd036d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -293,10 +293,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
- SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
- return getDpsInputOperands();
- }
-
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 93e9e3d..d1bbc7f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -261,6 +261,18 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
>,
InterfaceMethod<
/*description=*/[{
+ Returns true if the dimensions of this type are not known. This occurs
+ when the MLIR type does not encode dimensional information and there is
+ no associated descriptor or metadata in the current entity that would
+ make this information extractable. For example, an opaque pointer type
+ pointing to an array without dimension information would have unknown
+ dimensions.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasUnknownDimensions"
+ >,
+ InterfaceMethod<
+ /*description=*/[{
Returns explicit `acc.bounds` operations that envelop the whole
data structure. These operations are inserted using the provided builder
at the location set before calling this API.
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 41e333c..3a307a0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
.Case([](Float6E2M3FNType) { return 2u; })
.Case([](Float6E3M2FNType) { return 3u; })
.Case([](Float4E2M1FNType) { return 4u; })
- .Default([](Type) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// If there is a scaled MFMA instruction for the input element types `aType`
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 247dba1..cfdcd9c 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!skipOp) {
break;
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 25f1e1b..425594b 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
}
return std::nullopt;
})
- .Default([](auto) { return std::nullopt; });
+ .Default(std::nullopt);
}
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 7d0a236..76a822b 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPU.h"
+#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -27,6 +28,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/DebugLog.h"
#include <optional>
@@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
bool seenSideeffects = false;
// Whether we have left a nesting scope (and hence are no longer innermost).
bool leftNestingScope = false;
+ LocalAliasAnalysis aliasAnalysis;
+ llvm::DenseSet<Value> writtenBuffer;
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
// Now walk over the body and clone it.
// TODO: This is only correct if there either is no further scf.parallel
- // nested or this code is side-effect free. Otherwise we might need
- // predication. We are overly conservative for now and only allow
- // side-effects in the innermost scope.
+ // nested or this code has side-effect but the memory buffer is not
+ // alias to inner loop access buffer. Otherwise we might need
+ // predication.
if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
// Before entering a nested scope, make sure there have been no
- // sideeffects until now.
- if (seenSideeffects)
- return failure();
+ // sideeffects until now or the nested operations do not access the
+ // buffer written by outer scope.
+ if (seenSideeffects) {
+ WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) {
+ if (isMemoryEffectFree(nestedOp))
+ return WalkResult::advance();
+
+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
+ if (!memEffectInterface)
+ return WalkResult::advance();
+
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Read>(effect.getEffect()) ||
+ isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value baseBuffer = effect.getValue();
+ if (!baseBuffer)
+ return WalkResult::interrupt();
+ for (Value val : writtenBuffer) {
+ if (aliasAnalysis.alias(baseBuffer, val) !=
+ AliasResult::NoAlias) {
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (walkRes.wasInterrupted())
+ return failure();
+ }
// A nested scf.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
@@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.setInsertionPointAfter(parent);
leftNestingScope = true;
seenSideeffects = false;
+ writtenBuffer.clear();
} else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
// Convert scf.reduction op
auto parentLoop = op->getParentOfType<ParallelOp>();
@@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
Operation *clone = rewriter.clone(*op, cloningMap);
cloningMap.map(op->getResults(), clone->getResults());
// Check for side effects.
+ if (!isMemoryEffectFree(clone)) {
+ // Record the buffer accessed by the operations with write effects.
+ if (auto memEffectInterface =
+ dyn_cast<MemoryEffectOpInterface>(clone)) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value writtenBase = effect.getValue();
+ // Conservatively return failure if we cannot find the written
+ // address.
+ if (!writtenBase)
+ return failure();
+ writtenBuffer.insert(writtenBase);
+ }
+ }
+ }
+ }
// TODO: Handle region side effects properly.
seenSideeffects |=
!isMemoryEffectFree(clone) || clone->getNumRegions() != 0;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d80..91c1aa5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) {
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
.Case<IntegerAttr>(
[](auto intAttr) { return intAttr.getValue().isZero(); })
- .Default([](auto) { return false; });
+ .Default(false);
}
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 898d76c..980442e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
.Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
.Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
- .Default([](Operation *op) { return std::nullopt; });
+ .Default(std::nullopt);
if (!maybeKind) {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index d9d6934..8655ed3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
SymbolTableCollection &symbolTables) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<FuncOp>(
- symbolTables.lookupNearestSymbolFrom(callOp, sym));
+ return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));
}
/// Return the FuncOp called by `callOp`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index aa53f94..c233e24 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
static func::FuncOp
getCalledFunction(func::CallOp callOp,
mlir::SymbolTableCollection &symbolTable) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- symbolTable.lookupNearestSymbolFrom(callOp, sym));
+ callOp.resolveCallableInTable(&symbolTable));
}
/// Return "true" if the given function signature has tensor semantics.
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
index 47740d3..e9da135 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRControlFlowTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
+ StructuralTypeConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
new file mode 100644
index 0000000..5e2a742
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp
@@ -0,0 +1,169 @@
+//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert MLIR standard and builtin dialects
+// into the LLVM IR dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Helper function for converting branch ops. This function converts the
+/// signature of the given block. If the new block signature is different from
+/// `expectedTypes`, returns "failure".
+static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
+ const TypeConverter *converter,
+ Operation *branchOp, Block *block,
+ TypeRange expectedTypes) {
+ assert(converter && "expected non-null type converter");
+ assert(!block->isEntryBlock() && "entry blocks have no predecessors");
+
+ // There is nothing to do if the types already match.
+ if (block->getArgumentTypes() == expectedTypes)
+ return block;
+
+ // Compute the new block argument types and convert the block.
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion)
+ return rewriter.notifyMatchFailure(branchOp,
+ "could not compute block signature");
+ if (expectedTypes != conversion->getConvertedTypes())
+ return rewriter.notifyMatchFailure(
+ branchOp,
+ "mismatch between adaptor operand types and computed block signature");
+ return rewriter.applySignatureConversion(block, *conversion, converter);
+}
+
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Convert the destination block signature (if necessary) and change the
+/// operands of the branch op.
+struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> {
+ using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
+ TypeRange(ValueRange(flattenedAdaptor)));
+ if (failed(convertedBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor,
+ *convertedBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the branch op.
+struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> {
+ using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
+ FailureOr<Block *> convertedTrueBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
+ if (failed(convertedTrueBlock))
+ return failure();
+ FailureOr<Block *> convertedFalseBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
+ if (failed(convertedFalseBlock))
+ return failure();
+ rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
+ *convertedTrueBlock, *convertedFalseBlock);
+ return success();
+ }
+};
+
+/// Convert the destination block signatures (if necessary) and change the
+/// operands of the switch op.
+struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> {
+ using OpConversionPattern<cf::SwitchOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Get or convert default block.
+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
+ rewriter, getTypeConverter(), op, op.getDefaultDestination(),
+ TypeRange(adaptor.getDefaultOperands()));
+ if (failed(convertedDefaultBlock))
+ return failure();
+
+ // Get or convert all case blocks.
+ SmallVector<Block *> caseDestinations;
+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
+ for (auto it : llvm::enumerate(op.getCaseDestinations())) {
+ Block *b = it.value();
+ FailureOr<Block *> convertedBlock =
+ getConvertedBlock(rewriter, getTypeConverter(), op, b,
+ TypeRange(caseOperands[it.index()]));
+ if (failed(convertedBlock))
+ return failure();
+ caseDestinations.push_back(*convertedBlock);
+ }
+
+ rewriter.replaceOpWithNewOp<cf::SwitchOp>(
+ op, adaptor.getFlag(), *convertedDefaultBlock,
+ adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
+ caseDestinations, caseOperands);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::cf::populateCFStructuralTypeConversions(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>(
+ typeConverter, patterns.getContext(), benefit);
+}
+
+void mlir::cf::populateCFStructuralTypeConversionTarget(
+ const TypeConverter &typeConverter, ConversionTarget &target) {
+ target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); });
+}
+
+void mlir::cf::populateCFStructuralTypeConversionsAndLegality(
+ const TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target, PatternBenefit benefit) {
+ populateCFStructuralTypeConversions(typeConverter, patterns, benefit);
+ populateCFStructuralTypeConversionTarget(typeConverter, target);
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index d2c2138..025d1ac 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -330,7 +330,7 @@ static Value getBase(Value v) {
v = op.getSrc();
return true;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!shouldContinue)
break;
}
@@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {
.Case([](memref::TransposeOp transpose) { return transpose.getIn(); })
.Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
[](auto op) { return op.getSrc(); })
- .Default([](Operation *) { return Value(); });
+ .Default(nullptr);
}
/// Returns `true` if the given operation is known to capture the given value,
@@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {
// These operations are known not to capture.
.Case([](memref::DeallocOp) { return false; })
// By default, we don't know anything.
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Returns `true` if the value may be captured by any of its users, i.e., if
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3eae67f..2731069 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
return structType.getBody()[memberIndex];
return nullptr;
})
- .Default(Type(nullptr));
+ .Default(nullptr);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index cee943d..7d9058c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
.Case<IntegerType, FloatType>([](auto type) {
return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
- .Default([](Type) { return false; });
+ .Default(false);
if (!canConvertType)
return false;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ac35eea..ce93d18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
// clang-format on
.Case<PtrLikeTypeInterface>(
[](Type type) { return isCompatiblePtrType(type); })
- .Default([](Type) { return false; });
+ .Default(false);
if (!result)
compatibleTypes.erase(type);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b..3dc45ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8b89244..b09112b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4499,7 +4499,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
return true;
})
- .Default([&](Operation *op) { return false; });
+ .Default(false);
if (!supported) {
DiagnosedSilenceableFailure diag =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393..75bb175 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f05ffa8..6519c4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
return complex::CreateOp::create(b, t, tmp, tmp);
})
- .Default([](auto) { return Value(); });
+ .Default(nullptr);
if (!fillVal)
return failure();
linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index 27ccf3c..6becc1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
- .Default([](Operation *op) { return nullptr; });
+ .Default(nullptr);
if (!newConv)
return failure();
for (auto attr : preservedAttrs)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9d62491..cb6199f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
- .Default([&](auto op) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Check whether `outputOperand` is a reduction with a single combiner
@@ -3911,21 +3911,21 @@ struct Conv1DGenerator
Value lhs = vector::TransferReadOp::create(
rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
- auto maybeMaskedLhs = maybeMaskXferOp(
+ auto *maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = vector::TransferReadOp::create(
rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
- auto maybeMaskedRhs = maybeMaskXferOp(
+ auto *maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = vector::TransferReadOp::create(
rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
- auto maybeMaskedRes = maybeMaskXferOp(
+ auto *maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
//===------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fdd..e685089 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
[](auto op) { return op.getBase(); })
- .Default([](auto) { return Value{}; });
+ .Default(nullptr);
}
template <typename T>
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 4ebd90d..d380c46 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
? forOp.getInitArgs()[opResult.getResultNumber()]
: Value();
})
- .Default([&](auto op) { return Value(); });
+ .Default(nullptr);
}
return false;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0c8114d..938952e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType>(
[](auto coopType) { return coopType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
// Case 1. -- matrices.
if (coopElementType) {
@@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
.Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
assert(elementType && "Unhandled type");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f895807..d1e275d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {
return *elementSize * type.getNumElements();
return std::nullopt;
})
- .Default(std::optional<int64_t>());
+ .Default(std::nullopt);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 88e1ab6..cb9b7f6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ac72002..110bfdc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -41,10 +41,6 @@
using namespace mlir;
using namespace mlir::tensor;
-using llvm::divideCeilSigned;
-using llvm::divideFloorSigned;
-using llvm::mod;
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e..c607ece 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 69e649d..bc4f5a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
- .Default(Value());
+ .Default(nullptr);
if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 375e820..cf8a4d2 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -121,12 +121,13 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
int matchCount = 0;
matcher::MatchFinder finder;
+
+ StringRef functionName = matcher.getFunctionName();
auto matches = finder.collectMatches(rootOp, std::move(matcher));
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
- if (matcher.hasFunctionName()) {
- auto functionName = matcher.getFunctionName();
+ if (!functionName.empty()) {
std::vector<Operation *> flattenedMatches =
finder.flattenMatchedOps(matches);
Operation *function =
diff --git a/mlir/lib/Support/Timing.cpp b/mlir/lib/Support/Timing.cpp
index fb6f82c..16306d7 100644
--- a/mlir/lib/Support/Timing.cpp
+++ b/mlir/lib/Support/Timing.cpp
@@ -319,7 +319,6 @@ public:
void mergeChildren(AsyncChildrenMap &&other) {
for (auto &thread : other) {
mergeChildren(std::move(thread.second));
- assert(thread.second.empty());
}
other.clear();
}
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index b31377e..0f1bf83 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -56,7 +56,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
StringRef value = init->getValue();
return value.empty() ? std::optional<StringRef>() : value;
})
- .Default([](auto *) { return std::nullopt; });
+ .Default(std::nullopt);
}
// Return the C++ type for this type (which may just be ::mlir::Type).
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index eeb8725..e3bcf27 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) {
.Case<>([&](LLVM::DIGlobalVariableAttr global) {
return translate(global);
})
- .Default([&](Attribute attr) { return nullptr; });
+ .Default(nullptr);
return metadata;
};
return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()),
@@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) {
.Case([&](LLVM::DILocalVariableAttr local) {
return translate(local);
})
- .Case<>([&](LLVM::DIGlobalVariableAttr global) {
+ .Case([&](LLVM::DIGlobalVariableAttr global) {
return translate(global);
})
- .Default([&](Attribute attr) { return nullptr; });
+ .Default(nullptr);
return metadata;
};
return llvm::DIGenericSubrange::get(llvmCtx,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f284540..8edec99 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4084,12 +4084,13 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
///
/// Fortran
/// map(tofrom: array(2:5, 3:2))
-/// or
-/// C++
-/// map(tofrom: array[1:4][2:3])
+///
/// We must calculate the initial pointer offset to pass across, this function
/// performs this using bounds.
///
+/// TODO/WARNING: This only supports Fortran's column major indexing currently
+/// as is noted in the note below and comments in the function, we must extend
+/// this function when we add a C++ frontend.
/// NOTE: which while specified in row-major order it currently needs to be
/// flipped for Fortran's column order array allocation and access (as
/// opposed to C++'s row-major, hence the backwards processing where order is
@@ -4125,46 +4126,28 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// with a pointer that's being treated like an array and we have the
// underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
// address (pointer pointing to the actual data) so we must caclulate the
- // offset using a single index which the following two loops attempts to
- // compute.
-
- // Calculates the size offset we need to make per row e.g. first row or
- // column only needs to be offset by one, but the next would have to be
- // the previous row/column offset multiplied by the extent of current row.
+ // offset using a single index which the following loop attempts to
+ // compute using the standard column-major algorithm e.g for a 3D array:
//
- // For example ([1][10][100]):
+ // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
//
- // - First row/column we move by 1 for each index increment
- // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
- // current) for 10 for each index increment
- // - Third row/column we would move by 10 (second row/column) *
- // (extent/size of current) 100 for 1000 for each index increment
- std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
- for (size_t i = 1; i < bounds.size(); ++i) {
- if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
- bounds[i].getDefiningOp())) {
- dimensionIndexSizeOffset.push_back(builder.CreateMul(
- moduleTranslation.lookupValue(boundOp.getExtent()),
- dimensionIndexSizeOffset[i - 1]));
- }
- }
-
- // Now that we have calculated how much we move by per index, we must
- // multiply each lower bound offset in indexes by the size offset we
- // have calculated in the previous and accumulate the results to get
- // our final resulting offset.
+ // It is of note that it's doing column-major rather than row-major at the
+ // moment, but having a way for the frontend to indicate which major format
+ // to use or standardizing/canonicalizing the order of the bounds to compute
+ // the offset may be useful in the future when there's other frontends with
+ // different formats.
+ std::vector<llvm::Value *> dimensionIndexSizeOffset;
for (int i = bounds.size() - 1; i >= 0; --i) {
if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
- if (idx.empty())
- idx.emplace_back(builder.CreateMul(
- moduleTranslation.lookupValue(boundOp.getLowerBound()),
- dimensionIndexSizeOffset[i]));
+ if (i == ((int)bounds.size() - 1))
+ idx.emplace_back(
+ moduleTranslation.lookupValue(boundOp.getLowerBound()));
else
idx.back() = builder.CreateAdd(
- idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
- boundOp.getLowerBound()),
- dimensionIndexSizeOffset[i]));
+ builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
+ boundOp.getExtent())),
+ moduleTranslation.lookupValue(boundOp.getLowerBound()));
}
}
}
diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
index 1dbce05..26f5a3e 100644
--- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
+++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir
@@ -641,3 +641,35 @@ func.func @parallel_reduction_1d_outside() {
// CHECK: scf.parallel
// CHECK-NEXT: scf.parallel
// CHECK: scf.reduce
+
+// -----
+
+// CHECK-LABEL: @nested_parallel_with_side_effect
+func.func @nested_parallel_with_side_effect() {
+ %c65536 = arith.constant 65536 : index
+ %c2 = arith.constant 2 : index
+ %c256 = arith.constant 256 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %alloc_0 = memref.alloc() : memref<2x256x256xf32>
+ %alloc_1 = memref.alloc() : memref<2x4x256x256xf32>
+ %alloc_2 = memref.alloc() : memref<4x4xf32>
+ %alloc_3 = memref.alloc() : memref<4x4xf32>
+ scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c2, %c4, %c65536) step (%c1, %c1, %c1) {
+ %1 = arith.remsi %arg4, %c256 : index
+ %2 = arith.divsi %arg4, %c256 : index
+ %4 = memref.load %alloc_0[%arg2, %2, %1] : memref<2x256x256xf32>
+ memref.store %4, %alloc_1[%arg2, %arg3, %2, %1] : memref<2x4x256x256xf32>
+ scf.parallel (%arg5) = (%c0) to (%c4) step (%c1) {
+ %5 = memref.load %alloc_2[%arg5, %c0] : memref<4x4xf32>
+ memref.store %5, %alloc_3[%arg5, %c0] : memref<4x4xf32>
+ scf.reduce
+ } {mapping = [#gpu.loop_dim_map<processor = thread_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
+ scf.reduce
+ } {mapping = [#gpu.loop_dim_map<processor = block_z, map = (d0) -> (d0), bound = (d0) -> (d0)>, #gpu.loop_dim_map<processor = block_y, map = (d0) -> (d0), bound = (d0) -> (d0)>, #gpu.loop_dim_map<processor = block_x, map = (d0) -> (d0), bound = (d0) -> (d0)>]}
+ return
+}
+
+// CHECK: gpu.launch
+// CHECK-NOT: scf.parallel
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 26d2d98..f4020ede 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
func.func @recursive_effect(%arg : tensor<1xf32>) {
%init = arith.constant dense<0.0> : tensor<1xf32>
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
- (%in : f32) {
+ (%in : f32, %out: f32) {
vector.print %in : f32
linalg.yield %in : f32
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index ae07b1b..dcdd6c8 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -386,18 +386,24 @@ func.func @generalize_batch_reduce_gemm_bf16(%lhs: memref<7x8x9xbf16>, %rhs: mem
// -----
-// CHECK-LABEL: generalize_linalg_map
-func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>) {
+func.func @generalize_linalg_map(%arg0: memref<1x8x8x8xf32>, %arg1: memref<1x8x8x8xf32>, %arg2: memref<1x8x8x8xf32>) {
%cst = arith.constant 0.000000e+00 : f32
- // CHECK: linalg.map
- // CHECK-NOT: linalg.generic
- linalg.map outs(%arg0 : memref<1x8x8x8xf32>)
- () {
- linalg.yield %cst : f32
- }
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%arg2 : memref<1x8x8x8xf32>)
return
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK: @generalize_linalg_map
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x8x8x8xf32>, memref<1x8x8x8xf32>) outs(%{{.+}} : memref<1x8x8x8xf32>
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
// -----
func.func @generalize_add(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 40bf4d1..fabc8e6 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
func.func @map_input_mapper_arity_mismatch(
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
-> tensor<64xf32> {
- // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+ // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f64, %rhs_elem: f64) {
+ (%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f64
linalg.yield %0: f64
}
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
%add = linalg.map
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%init:tensor<32xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 1df15e8..85cc1ff 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 563013d..7492892 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -341,7 +341,7 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
%add = linalg.map
outs(%init:tensor<64xf32>)
- () {
+ (%out: f32) {
%0 = arith.constant 0.0: f32
linalg.yield %0: f32
}
@@ -349,7 +349,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
}
// CHECK-LABEL: func @map_no_inputs
// CHECK: linalg.map outs
-// CHECK-NEXT: () {
+// CHECK-NEXT: (%[[OUT:.*]]: f32) {
// CHECK-NEXT: arith.constant
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
linalg.map
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
outs(%init:memref<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem: f32
linalg.yield %0: f32
}
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
%abs = linalg.map
ins(%input:tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
linalg.map
ins(%input:memref<64xf32>)
outs(%init:memref<64xf32>)
- (%input_elem: f32) {
+ (%input_elem: f32, %out: f32) {
%0 = math.absf %input_elem: f32
linalg.yield %0: f32
}
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
+ (%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
linalg.yield %0: f32
}
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
- (%in_1: f32, %in_2: f32) {
+ (%in_1: f32, %in_2: f32, %out: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
-// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
index 93a0336..aa2c1da 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
@@ -356,7 +356,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
outs(%arg2 : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 296ca02..5eb2360 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -728,7 +728,7 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc(%[[M]], %[[N]]) {{.*}} : memref<?x3x?xf32>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?x3x?xf32>)
-// CHECK: () {
+// CHECK: (%[[INIT:.*]]: f32) {
// CHECK: linalg.yield %[[F]] : f32
// CHECK: }
// CHECK: return %[[MAPPED]] : tensor<?x3x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 8cbee3c..aa8882d 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -257,10 +257,10 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @map(%lhs: memref<64xf32>,
- %rhs: memref<64xf32>, %out: memref<64xf32>) {
+ %rhs: memref<64xf32>, %init: memref<64xf32>) {
linalg.map ins(%lhs, %rhs : memref<64xf32>, memref<64xf32>)
- outs(%out : memref<64xf32>)
- (%in: f32, %in_0: f32) {
+ outs(%init : memref<64xf32>)
+ (%in: f32, %in_0: f32, %out: f32) {
%0 = arith.addf %in, %in_0 : f32
linalg.yield %0 : f32
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
index a1e415c..9640f03 100644
--- a/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-record-type-with-ptr-member-host.mlir
@@ -81,9 +81,8 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
// CHECK: %[[ARR_SECT_SIZE:.*]] = mul i64 %[[ARR_SECT_SIZE1]], 4
// CHECK: %[[LFULL_ARR:.*]] = load ptr, ptr @full_arr, align 8
// CHECK: %[[FULL_ARR_PTR:.*]] = getelementptr inbounds float, ptr %[[LFULL_ARR]], i64 0
-// CHECK: %[[ARR_SECT_OFFSET1:.*]] = mul i64 %[[ARR_SECT_OFFSET2]], 1
// CHECK: %[[LARR_SECT:.*]] = load ptr, ptr @sect_arr, align 8
-// CHECK: %[[ARR_SECT_PTR:.*]] = getelementptr inbounds i32, ptr %[[LARR_SECT]], i64 %[[ARR_SECT_OFFSET1]]
+// CHECK: %[[ARR_SECT_PTR:.*]] = getelementptr inbounds i32, ptr %[[LARR_SECT]], i64 %[[ARR_SECT_OFFSET2]]
// CHECK: %[[SCALAR_PTR_LOAD:.*]] = load ptr, ptr %[[SCALAR_BASE]], align 8
// CHECK: %[[FULL_ARR_DESC_SIZE:.*]] = sdiv exact i64 48, ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
// CHECK: %[[FULL_ARR_SIZE_CMP:.*]] = icmp eq ptr %[[FULL_ARR_PTR]], null
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index c003f8b..91f83a0 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -143,3 +143,25 @@ func.func @test_signature_conversion_no_converter() {
return
}
+// -----
+
+// CHECK-LABEL: func @test_unstructured_cf_conversion(
+// CHECK-SAME: %[[arg0:.*]]: f64, %[[c:.*]]: i1)
+// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (f64) -> f32
+// CHECK: "test.foo"(%[[cast1]])
+// CHECK: cf.br ^[[bb1:.*]](%[[arg0]] : f64)
+// CHECK: ^[[bb1]](%[[arg1:.*]]: f64):
+// CHECK: cf.cond_br %[[c]], ^[[bb1]](%[[arg1]] : f64), ^[[bb2:.*]](%[[arg1]] : f64)
+// CHECK: ^[[bb2]](%[[arg2:.*]]: f64):
+// CHECK: %[[cast2:.*]] = "test.cast"(%[[arg2]]) : (f64) -> f32
+// CHECK: "test.bar"(%[[cast2]])
+// CHECK: return
+func.func @test_unstructured_cf_conversion(%arg0: f32, %c: i1) {
+ "test.foo"(%arg0) : (f32) -> ()
+ cf.br ^bb1(%arg0: f32)
+^bb1(%arg1: f32):
+ cf.cond_br %c, ^bb1(%arg1 : f32), ^bb2(%arg1 : f32)
+^bb2(%arg2: f32):
+ "test.bar"(%arg2) : (f32) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01..9354a85 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
)
mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRControlFlowInterfaces
+ MLIRControlFlowTransforms
MLIRDataLayoutInterfaces
MLIRDerivedAttributeOpInterface
MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb..fd2b943 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
});
converter.addConversion([](IndexType type) { return type; });
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(1)) {
+ // i1 is legal.
+ types.push_back(type);
+ }
if (type.isInteger(38)) {
// i38 is legal.
types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
converter);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
+ mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+ patterns, target);
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 496f18b..61db9d2 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -797,7 +797,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
// Provide some IR that does not verify.
rewriter.setInsertionPointToStart(&target->getRegion(0).front());
TestDummyPayloadOp::create(rewriter, target->getLoc(), TypeRange(),
- ValueRange(), /*failToVerify=*/true);
+ ValueRange(), /*fail_to_verify=*/true);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index 6ac9a87..d6203b9 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -766,7 +766,9 @@ void testShortDataEntryOpBuildersMappableVar(OpBuilder &b, MLIRContext &context,
struct IntegerOpenACCMappableModel
: public mlir::acc::MappableType::ExternalModel<IntegerOpenACCMappableModel,
- IntegerType> {};
+ IntegerType> {
+ bool hasUnknownDimensions(mlir::Type type) const { return false; }
+};
TEST_F(OpenACCOpsTest, mappableTypeBuilderDataEntry) {
// First, set up the test by attaching MappableInterface to IntegerType.