aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/.clang-format1
-rw-r--r--mlir/include/mlir-c/IR.h4
-rw-r--r--mlir/include/mlir/Bindings/Python/NanobindAdaptors.h38
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td22
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td49
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.td8
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td39
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td29
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td69
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td63
-rw-r--r--mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h1
-rw-r--r--mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td54
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp70
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h6
-rw-r--r--mlir/lib/Bindings/Python/IRTypes.cpp4
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp6
-rw-r--r--mlir/lib/Bindings/Python/Rewrite.cpp2
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp4
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp13
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp52
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp1
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp7
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp116
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp21
-rw-r--r--mlir/lib/Dialect/Math/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp80
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp23
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp451
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp7
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp184
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp82
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp12
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp16
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp40
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp24
-rw-r--r--mlir/lib/IR/Builders.cpp7
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp2
-rw-r--r--mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp162
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt53
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt14
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp42
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp46
-rw-r--r--mlir/python/mlir/dialects/transform/structured.py6
-rw-r--r--mlir/python/mlir/dialects/transform/tune.py66
-rw-r--r--mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir30
-rw-r--r--mlir/test/Dialect/Arith/canonicalize.mlir12
-rw-r--r--mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir11
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir32
-rw-r--r--mlir/test/Dialect/Math/sincos-fusion.mlir86
-rw-r--r--mlir/test/Dialect/MemRef/invalid.mlir16
-rw-r--r--mlir/test/Dialect/MemRef/ops.mlir9
-rw-r--r--mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir198
-rw-r--r--mlir/test/Dialect/OpenMP/cli-tile.mlir138
-rw-r--r--mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir28
-rw-r--r--mlir/test/Dialect/OpenMP/invalid-tile.mlir119
-rw-r--r--mlir/test/Dialect/Transform/test-promote-tensors.mlir104
-rw-r--r--mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir85
-rw-r--r--mlir/test/Dialect/Transform/test-tune-extension.mlir126
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir19
-rw-r--r--mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir94
-rw-r--r--mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir184
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir28
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt2
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp31
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir18
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir51
-rw-r--r--mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir27
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp2
-rw-r--r--mlir/test/mlir-tblgen/op-format-invalid.td2
-rw-r--r--mlir/test/mlir-tblgen/op-format-spec.td2
-rw-r--r--mlir/test/python/dialects/transform_tune_ext.py105
-rw-r--r--mlir/test/python/ir/operation.py8
-rw-r--r--mlir/tools/mlir-rewrite/mlir-rewrite.cpp12
-rw-r--r--mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp1
-rw-r--r--mlir/tools/mlir-tblgen/FormatGen.cpp2
-rw-r--r--mlir/tools/mlir-tblgen/OpFormatGen.cpp1
-rw-r--r--mlir/unittests/TableGen/PassGenTest.cpp3
96 files changed, 3161 insertions, 516 deletions
diff --git a/mlir/.clang-format b/mlir/.clang-format
index a74fda4..76cc928 100644
--- a/mlir/.clang-format
+++ b/mlir/.clang-format
@@ -1,2 +1,3 @@
BasedOnStyle: LLVM
AlwaysBreakTemplateDeclarations: Yes
+LineEnding: LF
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 061d762..c464e4d 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -634,6 +634,10 @@ MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
/// Gets the location of the operation.
MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op);
+/// Sets the location of the operation.
+MLIR_CAPI_EXPORTED void mlirOperationSetLocation(MlirOperation op,
+ MlirLocation loc);
+
/// Gets the type id of the operation.
/// Returns null if the operation does not have a registered operation
/// description.
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index b5f985f..847951a 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -116,7 +116,8 @@ mlirApiObjectToCapsule(nanobind::handle apiObject) {
/// Casts object <-> MlirAffineMap.
template <>
struct type_caster<MlirAffineMap> {
- NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap"))
+ NB_TYPE_CASTER(MlirAffineMap,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToAffineMap(capsule->ptr());
@@ -138,7 +139,8 @@ struct type_caster<MlirAffineMap> {
/// Casts object <-> MlirAttribute.
template <>
struct type_caster<MlirAttribute> {
- NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute"))
+ NB_TYPE_CASTER(MlirAttribute,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToAttribute(capsule->ptr());
@@ -161,7 +163,7 @@ struct type_caster<MlirAttribute> {
/// Casts object -> MlirBlock.
template <>
struct type_caster<MlirBlock> {
- NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock"))
+ NB_TYPE_CASTER(MlirBlock, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Block")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToBlock(capsule->ptr());
@@ -174,7 +176,8 @@ struct type_caster<MlirBlock> {
/// Casts object -> MlirContext.
template <>
struct type_caster<MlirContext> {
- NB_TYPE_CASTER(MlirContext, const_name("MlirContext"))
+ NB_TYPE_CASTER(MlirContext,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Context")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (src.is_none()) {
// Gets the current thread-bound context.
@@ -192,7 +195,8 @@ struct type_caster<MlirContext> {
/// Casts object <-> MlirDialectRegistry.
template <>
struct type_caster<MlirDialectRegistry> {
- NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry"))
+ NB_TYPE_CASTER(MlirDialectRegistry,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToDialectRegistry(capsule->ptr());
@@ -214,7 +218,8 @@ struct type_caster<MlirDialectRegistry> {
/// Casts object <-> MlirLocation.
template <>
struct type_caster<MlirLocation> {
- NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation"))
+ NB_TYPE_CASTER(MlirLocation,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Location")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (src.is_none()) {
// Gets the current thread-bound context.
@@ -240,7 +245,7 @@ struct type_caster<MlirLocation> {
/// Casts object <-> MlirModule.
template <>
struct type_caster<MlirModule> {
- NB_TYPE_CASTER(MlirModule, const_name("MlirModule"))
+ NB_TYPE_CASTER(MlirModule, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Module")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToModule(capsule->ptr());
@@ -262,8 +267,9 @@ struct type_caster<MlirModule> {
/// Casts object <-> MlirFrozenRewritePatternSet.
template <>
struct type_caster<MlirFrozenRewritePatternSet> {
- NB_TYPE_CASTER(MlirFrozenRewritePatternSet,
- const_name("MlirFrozenRewritePatternSet"))
+ NB_TYPE_CASTER(
+ MlirFrozenRewritePatternSet,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("rewrite.FrozenRewritePatternSet")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr());
@@ -285,7 +291,8 @@ struct type_caster<MlirFrozenRewritePatternSet> {
/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {
- NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation"))
+ NB_TYPE_CASTER(MlirOperation,
+ const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Operation")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToOperation(capsule->ptr());
@@ -309,7 +316,7 @@ struct type_caster<MlirOperation> {
/// Casts object <-> MlirValue.
template <>
struct type_caster<MlirValue> {
- NB_TYPE_CASTER(MlirValue, const_name("MlirValue"))
+ NB_TYPE_CASTER(MlirValue, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Value")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToValue(capsule->ptr());
@@ -334,7 +341,8 @@ struct type_caster<MlirValue> {
/// Casts object -> MlirPassManager.
template <>
struct type_caster<MlirPassManager> {
- NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager"))
+ NB_TYPE_CASTER(MlirPassManager, const_name(MAKE_MLIR_PYTHON_QUALNAME(
+ "passmanager.PassManager")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToPassManager(capsule->ptr());
@@ -347,7 +355,7 @@ struct type_caster<MlirPassManager> {
/// Casts object <-> MlirTypeID.
template <>
struct type_caster<MlirTypeID> {
- NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID"))
+ NB_TYPE_CASTER(MlirTypeID, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToTypeID(capsule->ptr());
@@ -371,7 +379,7 @@ struct type_caster<MlirTypeID> {
/// Casts object <-> MlirType.
template <>
struct type_caster<MlirType> {
- NB_TYPE_CASTER(MlirType, const_name("MlirType"))
+ NB_TYPE_CASTER(MlirType, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Type")))
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
if (auto capsule = mlirApiObjectToCapsule(src)) {
value = mlirPythonCapsuleToType(capsule->ptr());
@@ -394,7 +402,7 @@ struct type_caster<MlirType> {
/// Casts MlirStringRef -> object.
template <>
struct type_caster<MlirStringRef> {
- NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
+ NB_TYPE_CASTER(MlirStringRef, const_name("str"))
static handle from_cpp(MlirStringRef s, rv_policy,
cleanup_list *cleanup) noexcept {
return nanobind::str(s.data, s.length).release();
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 8b687a7..29001e2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -985,7 +985,6 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
//===---------------------------------------------------------------------===//
// Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//
-
foreach smallT = [
ScaleArgInfo<I32, "Fp4">,
ScaleArgInfo<ROCDL_V2I32Type, "Fp8">,
@@ -996,6 +995,8 @@ foreach smallT = [
ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V8F32Type, "F32">,
] in {
+
+ // Up-scaling
def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name,
[Pure], 1, [2], ["scaleSel"]>,
@@ -1010,13 +1011,30 @@ foreach smallT = [
attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res)
}];
}
+
+ // Down-scaling
+ def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name ;
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $scale `:` type($res)
+ }];
+ }
} // foreach largeT
} // foreach smallTOp
//===---------------------------------------------------------------------===//
// Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics
//===---------------------------------------------------------------------===//
-
foreach smallT = [
ScaleArgInfo<ROCDL_V3I32Type, "Fp6">,
ScaleArgInfo<ROCDL_V3I32Type, "Bf6">
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8f3232f..0d6ebc0 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
Transform_AnyOpType:$new_ops);
let assemblyFormat = "$target attr-dict `:` type($target)";
let hasVerifier = 1;
+}
- let builders = [
- OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
- OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
- ];
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SameOperandsAndResultType]> {
+ let summary = "Request a tensor value to live in a specific memory space "
+ "after bufferization";
+ let description = [{
+ Requests that a tensor value lives in a specific memory space for its
+ lifetime. This is achieved by allocating a new tensor in the desired
+ memory space with `bufferization.alloc_tensor` and optionally materializing
+ the source value into that allocation with
+ `bufferization.materialize_in_destination`. All uses of the original value
+ are then redirected to the promoted value.
+
+ The generated code for promoting tensor value %0 resembles the following:
+
+ %1 = bufferization.alloc_tensor(<dynamic dims of %0>)
+ { memory_space = memory_space }
+ // Note: the materialization is omitted if %0 is never read and is only
+ // written into (i.e., it behaves as a result tensor).
+ %2 = bufferization.materialize_in_destination %0 in %1
+ // ...
+ <all users of %0 now use %2 instead>
+
+ Deallocation is not handled by this transform.
+
+ Return modes:
+ - Produces a silenceable failure if the given handle does not point to
+ tensor-typed values.
+ - Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
+ the result of materialization if present and the allocation otherwise.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$tensor,
+ OptionalAttr<AnyAttr>:$memory_space);
+ let results = (outs TransformValueHandleTypeInterface:$promoted);
+
+ let assemblyFormat =
+ "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 4d415ae..48346abd 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> {
];
}
+def MathSincosFusionPass : Pass<"math-sincos-fusion"> {
+ let summary = "Fuse sin and cos operations.";
+ let description = [{
+ Fuse sin and cos operations into a sincos operation.
+ }];
+ let dependentDialects = ["math::MathDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2bf953e..d4d67bf 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -155,7 +155,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
- alignment.
+ alignment.
If the buffer isn't aligned to the given alignment, its result is poison.
This operation doesn't affect the semantics of a program where the
@@ -170,7 +170,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
-
+
Value getViewSource() { return getMemref(); }
}];
@@ -179,6 +179,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
}
//===----------------------------------------------------------------------===//
+// DistinctObjectsOp
+//===----------------------------------------------------------------------===//
+
+def DistinctObjectsOp : MemRef_Op<"distinct_objects", [
+ Pure,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>
+ // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument
+ ]> {
+ let summary = "assumption that acesses to specific memrefs will never alias";
+ let description = [{
+ The `distinct_objects` operation takes a list of memrefs and returns the same
+ memrefs, with the additional assumption that accesses to them will never
+ alias with each other. This means that loads and stores to different
+ memrefs in the list can be safely reordered.
+
+ If the memrefs do alias, the load/store behavior is undefined. This
+ operation doesn't affect the semantics of a valid program. It is
+ intended for optimization purposes, allowing the compiler to generate more
+ efficient code based on the non-aliasing assumption. The optimization is
+ best-effort.
+
+ Example:
+
+ ```mlir
+ %1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
+ ```
+ }];
+ let arguments = (ins Variadic<AnyMemRef>:$operands);
+ let results = (outs Variadic<AnyMemRef>:$results);
+
+ let assemblyFormat = "$operands attr-dict `:` type($operands)";
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1eda5e4..8e43c42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -996,6 +996,35 @@ class OpenMP_NumTeamsClauseSkip<
def OpenMP_NumTeamsClause : OpenMP_NumTeamsClauseSkip<>;
//===----------------------------------------------------------------------===//
+// V5.1: [10.1.2] `sizes` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_SizesClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<IntLikeType>:$sizes
+ );
+
+ let optAssemblyFormat = [{
+ `sizes` `(` $sizes `:` type($sizes) `)`
+ }];
+
+ let description = [{
+ The `sizes` clauses defines the size of a grid over a multi-dimensional
+ logical iteration space. This grid is used for loop transformations such as
+ `tile` and `strip`. The size per dimension can be a variable, but only
+ values that are not at least 2 make sense. It is not specified what happens
+ when smaller values are used, but should still result in a loop nest that
+ executes each logical iteration once.
+ }];
+}
+
+def OpenMP_SizesClause : OpenMP_SizesClauseSkip<>;
+
+//===----------------------------------------------------------------------===//
// V5.2: [10.1.2] `num_threads` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td
index bbcfb87f..5ad4e4b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td
@@ -38,6 +38,44 @@ def OpenMP_MapBoundsType : OpenMP_Type<"MapBounds", "map_bounds_ty"> {
let summary = "Type for representing omp map clause bounds information";
}
+//===---------------------------------------------------------------------===//
+// OpenMP Canonical Loop Info Type
+//===---------------------------------------------------------------------===//
+
+def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> {
+ let summary = "Type for representing a reference to a canonical loop";
+ let description = [{
+ A variable of type CanonicalLoopInfo refers to an OpenMP-compatible
+ canonical loop in the same function. Values of this type are not
+ available at runtime and therefore cannot be used by the program itself,
+ i.e. an opaque type. It is similar to the transform dialect's
+ `!transform.interface` type, but instead of implementing an interface
+ for each transformation, the OpenMP dialect itself defines possible
+ operations on this type.
+
+ A value of type CanonicalLoopInfoType (in the following: CLI) value can be
+
+ 1. created by omp.new_cli.
+ 2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI
+ can only be associated once.
+ 3. passed to an omp loop transformation operation that modifies the loop
+ associated with the CLI. The CLI is the "applyee" and the operation is
+ the consumer. A CLI can only be consumed once.
+ 4. passed to an omp loop transformation operation to associate the cli with
+ a result of that transformation. The CLI is the "generatee" and the
+ operation is the generator.
+
+ A CLI cannot
+
+ 1. be returned from a function.
+ 2. be passed to operations that are not specifically designed to take a
+ CanonicalLoopInfoType, including AnyType.
+
+ A CLI directly corresponds to an object of
+ OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Base classes for OpenMP dialect operations.
//===----------------------------------------------------------------------===//
@@ -211,8 +249,35 @@ class OpenMP_Op<string mnemonic, list<Trait> traits = [],
// Doesn't actually create a C++ base class (only defines default values for
// tablegen classes that derive from this). Use LoopTransformationInterface
// instead for common operations.
-class OpenMPTransform_Op<string mnemonic, list<Trait> traits = []> :
- OpenMP_Op<mnemonic, !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits) > {
+class OpenMPTransform_Op<string mnemonic,
+ list<Trait> traits = [],
+ list<OpenMP_Clause> clauses = []> :
+ OpenMP_Op<mnemonic,
+ traits = !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits),
+ clauses = clauses> {
+}
+
+// Base clause for loop transformations using the standard syntax.
+//
+// omp.opname ($generatees) <- ($applyees) clause(...) clause(...) ... <attr-dicr>
+// omp.opname ($applyees) clause(...) clause(...) ... <attr-dict>
+//
+// $generatees is optional and is assumed to be empty if omitted
+class OpenMPTransformBase_Op<string mnemonic,
+ list<Trait> traits = [],
+ list<OpenMP_Clause> clauses = []> :
+ OpenMPTransform_Op<mnemonic,
+ traits = !listconcat(traits, [AttrSizedOperandSegments]),
+ clauses = clauses> {
+
+ let arguments = !con(
+ (ins Variadic<CanonicalLoopInfoType>:$generatees,
+ Variadic<CanonicalLoopInfoType>:$applyees
+ ), clausesArgs);
+
+ let assemblyFormat = [{ custom<LoopTransformClis>($generatees, $applyees) }]
+ # clausesAssemblyFormat
+ # [{ attr-dict }];
}
#endif // OPENMP_OP_BASE
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5c77e21..b73091e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -358,44 +358,6 @@ def SingleOp : OpenMP_Op<"single", traits = [
}
//===---------------------------------------------------------------------===//
-// OpenMP Canonical Loop Info Type
-//===---------------------------------------------------------------------===//
-
-def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> {
- let summary = "Type for representing a reference to a canonical loop";
- let description = [{
- A variable of type CanonicalLoopInfo refers to an OpenMP-compatible
- canonical loop in the same function. Values of this type are not
- available at runtime and therefore cannot be used by the program itself,
- i.e. an opaque type. It is similar to the transform dialect's
- `!transform.interface` type, but instead of implementing an interface
- for each transformation, the OpenMP dialect itself defines possible
- operations on this type.
-
- A value of type CanonicalLoopInfoType (in the following: CLI) value can be
-
- 1. created by omp.new_cli.
- 2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI
- can only be associated once.
- 3. passed to an omp loop transformation operation that modifies the loop
- associated with the CLI. The CLI is the "applyee" and the operation is
- the consumer. A CLI can only be consumed once.
- 4. passed to an omp loop transformation operation to associate the cli with
- a result of that transformation. The CLI is the "generatee" and the
- operation is the generator.
-
- A CLI cannot
-
- 1. be returned from a function.
- 2. be passed to operations that are not specifically designed to take a
- CanonicalLoopInfoType, including AnyType.
-
- A CLI directly corresponds to an object of
- OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR.
- }];
-}
-
-//===---------------------------------------------------------------------===//
// OpenMP Canonical Loop Info Creation
//===---------------------------------------------------------------------===//
@@ -564,6 +526,31 @@ def UnrollHeuristicOp : OpenMPTransform_Op<"unroll_heuristic", []> {
}
//===----------------------------------------------------------------------===//
+// OpenMP tile operation
+//===----------------------------------------------------------------------===//
+
+def TileOp : OpenMPTransformBase_Op<"tile",
+ clauses = [OpenMP_SizesClause]> {
+ let summary = "OpenMP tile operation";
+ let description = [{
+ Represents the OpenMP tile directive introduced in OpenMP 5.1.
+
+ The construct partitions the logical iteration space of the affected loops
+ into equally-sized tiles, then creates two sets of nested loops. The outer
+ loops, called the grid loops, iterate over all tiles. The inner loops,
+ called the intratile loops, iterate over the logical iterations of a tile.
+ The sizes clause determines the size of a tile.
+
+ Currently, the affected loops must be rectangular (the tripcount of the
+ inner loop must not depend on any iv of an surrounding affected loop) and
+ perfectly nested (except for the innermost affected loop, no operations
+ other than the nested loop and the terminator in the loop body).
+ }] # clausesDescription;
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// 2.8.3 Workshare Construct
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
index 74e1d28..ba11259 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d68d451..d095659 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -11,10 +11,15 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"
+//===----------------------------------------------------------------------===//
+// KnobOp
+//===----------------------------------------------------------------------===//
+
def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
+ "getRegionInvocationBounds"]>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+ NoRegionArguments
+]> {
+ let summary = "Represents a choice among its regions, i.e. sub-schedules";
+
+ let description = [{
+ This op represents a choice over which of its regions is to be used.
+
+ When `selected_region` is provided, the semantics are that this op is to be
+ substituted for by the selected region, meaning the region's results become
+ the results of this op. Without a provided `selected_region`, the semantics
+ are that this non-deterministic choice is yet to be resolved -- which in
+ terms of the op's interpreted semantics is a failure.
+
+ The `selected_region` argument is either an `IntegerAttr` or a param holding
+ an `IntegerAttr`, which should provide a valid zero-based index with respect
+ to the number of alternatives, i.e. regions.
+ }];
+ let cppNamespace = [{ mlir::transform::tune }];
+
+ let arguments = (ins Builtin_StringAttr:$name,
+ OptionalAttr<APIntAttr>:$selected_region_attr,
+ Optional<TransformParamTypeInterface>:$selected_region_param);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+ let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
+
+ let assemblyFormat = [{
+ `<` $name `>`
+ (`selected_region` `=` custom<AlternativesOpSelectedRegion>(
+ $selected_region_attr, $selected_region_param)^)?
+ attr-dict-with-keyword
+ (`:` type($selected_region_param)^)?
+ (`->` type($results)^)?
+ regions
+ }];
+
+ let hasVerifier = 1;
+}
+
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 83a8757..32b2b0c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3219,13 +3219,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context") = nb::none(), kContextGetFileRangeDocstring)
.def("is_a_file", mlirLocationIsAFileLineColRange)
- .def_prop_ro(
- "filename",
- [](MlirLocation loc) {
- return mlirIdentifierStr(
- mlirLocationFileLineColRangeGetFilename(loc));
- },
- nb::sig("def filename(self) -> str"))
+ .def_prop_ro("filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ })
.def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
.def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
.def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
@@ -3274,12 +3272,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc") = nb::none(),
nb::arg("context") = nb::none(), kContextGetNameLocationDocString)
.def("is_a_name", mlirLocationIsAName)
- .def_prop_ro(
- "name_str",
- [](MlirLocation loc) {
- return mlirIdentifierStr(mlirLocationNameGetName(loc));
- },
- nb::sig("def name_str(self) -> str"))
+ .def_prop_ro("name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ })
.def_prop_ro("child_loc",
[](PyLocation &self) {
return PyLocation(self.getContext(),
@@ -3453,15 +3449,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return concreteOperation.getContext().getObject();
},
"Context that owns the Operation")
- .def_prop_ro(
- "name",
- [](PyOperationBase &self) {
- auto &concreteOperation = self.getOperation();
- concreteOperation.checkValid();
- MlirOperation operation = concreteOperation.get();
- return mlirIdentifierStr(mlirOperationGetName(operation));
- },
- nb::sig("def name(self) -> str"))
+ .def_prop_ro("name",
+ [](PyOperationBase &self) {
+ auto &concreteOperation = self.getOperation();
+ concreteOperation.checkValid();
+ MlirOperation operation = concreteOperation.get();
+ return mlirIdentifierStr(mlirOperationGetName(operation));
+ })
.def_prop_ro("operands",
[](PyOperationBase &self) {
return PyOpOperandList(self.getOperation().getRef());
@@ -3485,15 +3479,21 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
- .def_prop_ro(
+ .def_prop_rw(
"location",
[](PyOperationBase &self) {
PyOperation &operation = self.getOperation();
return PyLocation(operation.getContext(),
mlirOperationGetLocation(operation.get()));
},
- "Returns the source location the operation was defined or derived "
- "from.")
+ [](PyOperationBase &self, const PyLocation &location) {
+ PyOperation &operation = self.getOperation();
+ mlirOperationSetLocation(operation.get(), location.get());
+ },
+ nb::for_getter("Returns the source location the operation was "
+ "defined or derived from."),
+ nb::for_setter("Sets the source location the operation was defined "
+ "or derived from."))
.def_prop_ro("parent",
[](PyOperationBase &self)
-> std::optional<nb::typed<nb::object, PyOperation>> {
@@ -3597,12 +3597,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
"Reports if the operation is attached to its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
- .def(
- "walk", &PyOperationBase::walk, nb::arg("callback"),
- nb::arg("walk_order") = MlirWalkPostOrder,
- // clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = " MAKE_MLIR_PYTHON_QUALNAME("ir.WalkOrder.POST_ORDER") ") -> None")
- // clang-format on
+ .def("walk", &PyOperationBase::walk, nb::arg("callback"),
+ nb::arg("walk_order") = MlirWalkPostOrder,
+ // clang-format off
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None")
+ // clang-format on
);
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
@@ -4118,7 +4117,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyNamedAttribute &self) {
return mlirIdentifierStr(self.namedAttr.name);
},
- nb::sig("def name(self) -> str"),
"The name of the NamedAttribute binding")
.def_prop_ro(
"attr",
@@ -4336,17 +4334,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kValueReplaceAllUsesWithDocstring)
.def(
"replace_all_uses_except",
- [](MlirValue self, MlirValue with, PyOperation &exception) {
+ [](PyValue &self, PyValue &with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
nb::arg("with_"), nb::arg("exceptions"),
- nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
- "Operation) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
- [](MlirValue self, MlirValue with, nb::list exceptions) {
+ [](PyValue &self, PyValue &with, const nb::list &exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
for (nb::handle exception : exceptions) {
@@ -4358,8 +4354,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
exceptionOps.data());
},
nb::arg("with_"), nb::arg("exceptions"),
- nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: "
- "Sequence[Operation]) -> None"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 598ae01..edbd73e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -273,8 +273,7 @@ class DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] =
- MAKE_MLIR_PYTHON_QUALNAME("ir.Context");
+ static constexpr const char kTypeDescription[] = "Context";
static PyMlirContext &resolve();
};
@@ -500,8 +499,7 @@ class DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] =
- MAKE_MLIR_PYTHON_QUALNAME("ir.Location");
+ static constexpr const char kTypeDescription[] = "Location";
static PyLocation &resolve();
operator MlirLocation() const { return *get(); }
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 3488d92..34c5b8d 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -1010,7 +1010,7 @@ public:
},
nb::arg("elements"), nb::arg("context") = nb::none(),
// clang-format off
- nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"),
+ nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"),
// clang-format on
"Create a tuple type");
c.def(
@@ -1070,7 +1070,7 @@ public:
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
// clang-format off
- nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
+ nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"),
// clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 52656138..a14f09f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -115,9 +115,6 @@ NB_MODULE(_mlir, m) {
});
},
"typeid"_a, nb::kw_only(), "replace"_a = false,
- // clang-format off
- nb::sig("def register_type_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
- // clang-format on
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
@@ -130,9 +127,6 @@ NB_MODULE(_mlir, m) {
});
},
"typeid"_a, nb::kw_only(), "replace"_a = false,
- // clang-format off
- nb::sig("def register_value_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"),
- // clang-format on
"Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index f18298e..836f44fd 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -127,7 +127,7 @@ public:
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
- static nb::object createFromCapsule(nb::object capsule) {
+ static nb::object createFromCapsule(const nb::object &capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e9844a7..1881865 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -656,6 +656,10 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) {
return wrap(unwrap(op)->getLoc());
}
+void mlirOperationSetLocation(MlirOperation op, MlirLocation loc) {
+ unwrap(op)->setLoc(unwrap(loc));
+}
+
MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
if (auto info = unwrap(op)->getRegisteredInfo())
return wrap(info->getTypeID());
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308..0d56259 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}
-inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}
-inline mlir::FrozenRewritePatternSet *
+static inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
}
-inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+static inline MlirFrozenRewritePatternSet
+wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}
@@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::PDLPatternModule *>(module.ptr);
}
-inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
return {module};
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cc6314c..a6f816a 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering
}
};
+struct DistinctObjectsOpLowering
+ : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> {
+ using ConvertOpToLLVMPattern<
+ memref::DistinctObjectsOp>::ConvertOpToLLVMPattern;
+ explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter)
+ : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {}
+
+ LogicalResult
+ matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ValueRange operands = adaptor.getOperands();
+ if (operands.size() <= 1) {
+ // Fast path.
+ rewriter.replaceOp(op, operands);
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ SmallVector<Value> ptrs;
+ for (auto [origOperand, newOperand] :
+ llvm::zip_equal(op.getOperands(), operands)) {
+ auto memrefType = cast<MemRefType>(origOperand.getType());
+ MemRefDescriptor memRefDescriptor(newOperand);
+ Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType);
+ ptrs.push_back(ptr);
+ }
+
+ auto cond =
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1);
+ // Generate separate_storage assumptions for each pair of pointers.
+ for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) {
+ for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) {
+ Value ptr1 = ptrs[i];
+ Value ptr2 = ptrs[j];
+ LLVM::AssumeOp::create(rewriter, loc, cond,
+ LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2);
+ }
+ }
+
+ rewriter.replaceOp(op, operands);
+ return success();
+ }
+};
+
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
@@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
- AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
+ AtomicRMWOpLowering,
ConvertExtractAlignedPointerAsIndex,
DimOpLowering,
+ DistinctObjectsOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
- MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
+ MemorySpaceCastOpLowering,
PrefetchOpLowering,
RankOpLowering,
- ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
+ ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
StoreOpLowering,
SubViewOpLowering,
TransposeOpLowering,
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 035f197..399ccf3 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -267,9 +267,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
copyInfo.push_back(info);
}
// Create a call to the kernel and copy the data back.
- Operation *callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
- op, kernelFunc, ArrayRef<Value>());
- rewriter.setInsertionPointAfter(callOp);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
+ ArrayRef<Value>());
for (CopyInfo info : copyInfo)
copy(loc, info.src, info.dst, info.size, rewriter);
return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6f28849..0cb0bad 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -802,7 +802,6 @@ public:
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
dilationAttr);
- rewriter.setInsertionPointAfter(op);
NanPropagationMode nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index f3e065a..9821a75 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -246,6 +246,6 @@ void SimplifyAffineMinMaxPass::runOnOperation() {
patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
func.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- if (failed(applyPatternsGreedily(func, std::move(frozenPatterns))))
+ if (failed(applyPatternsGreedily(func, frozenPatterns)))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3..898d76c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
+ arith::FastMathFlags::nsz)) {
+ // mulf(x, 0) -> 0
+ if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
+ return getRhs();
+ }
+
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a * b; });
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 7626d35..c64e10f5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::SplatOp,
+ vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3f0b0ba..dd9b4c2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -42,6 +42,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
@@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- Attribute memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- int64_t memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memory_space=*/b.getI64IntegerAttr(memorySpace));
-}
-
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
@@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
}
//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+ // Be conservative about ops we cannot analyze deeper.
+ if (!linalgOp)
+ return true;
+
+ // Look inside linalg ops.
+ Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+ return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+ // If the value has a reference semantics, it
+ // may be read through any alias...
+ if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+ return true;
+ return llvm::any_of(value.getUses(),
+ static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> promoted;
+ for (Value tensor : state.getPayloadValues(getTensor())) {
+ auto type = dyn_cast<RankedTensorType>(tensor.getType());
+ if (!type) {
+ return emitSilenceableError() << "non-tensor type: " << tensor;
+ }
+
+ Operation *definingOp = tensor.getDefiningOp();
+ if (definingOp)
+ rewriter.setInsertionPointAfter(definingOp);
+ else
+ rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+ // Check this before we emit operations using this value.
+ bool needsMaterialization = mayBeRead(tensor);
+
+ SmallVector<Value> dynamicDims;
+ llvm::SmallPtrSet<Operation *, 4> preservedOps;
+ for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+ auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ preservedOps.insert(dimOp);
+ dynamicDims.push_back(dimOp);
+ }
+ auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+ tensor.getLoc(), type, dynamicDims);
+ // Set memory space if provided.
+ if (getMemorySpaceAttr())
+ allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+ Value allocated = allocation;
+
+ // Only insert a materialization (typically bufferizes to a copy) when the
+ // value may be read from.
+ if (needsMaterialization) {
+ auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+ tensor.getLoc(), tensor, allocated);
+ preservedOps.insert(copy);
+ promoted.push_back(copy.getResult());
+ } else {
+ promoted.push_back(allocated);
+ }
+ rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+ }
+ results.setValues(cast<OpResult>(getPromoted()), promoted);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTensorMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3bd763e..05fc7cb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion(
}
}
-void collapseOperandsAndResults(LinalgOp op,
- const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter,
- SmallVectorImpl<Value> &inputOperands,
- SmallVectorImpl<Value> &outputOperands,
- SmallVectorImpl<Type> &resultTypes) {
+static void collapseOperandsAndResults(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter,
+ SmallVectorImpl<Value> &inputOperands,
+ SmallVectorImpl<Value> &outputOperands,
+ SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
@@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op,
/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
-OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
- const CollapsingInfo &collapsingInfo) {
+static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
+ const CollapsingInfo &collapsingInfo) {
return nullptr;
}
@@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
return collapsedOp;
}
-LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter) {
+static LinalgOp createCollapsedOp(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter) {
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
} else {
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index ff62b51..8899c3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
+ SincosFusion.cpp
UpliftToFMA.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
new file mode 100644
index 0000000..69407df
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -0,0 +1,80 @@
+//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::math;
+
+namespace {
+
+/// Fuse a math.sin and math.cos in the same block that use the same operand and
+/// have identical fastmath flags into a single math.sincos.
+struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(math::SinOp sinOp,
+ PatternRewriter &rewriter) const override {
+ Value operand = sinOp.getOperand();
+ mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
+
+ math::CosOp cosOp = nullptr;
+ sinOp->getBlock()->walk([&](math::CosOp op) {
+ if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
+ cosOp = op;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ if (!cosOp)
+ return failure();
+
+ Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
+ : cosOp.getOperation();
+ rewriter.setInsertionPoint(firstOp);
+
+ Type elemType = sinOp.getType();
+ auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
+ TypeRange{elemType, elemType}, operand,
+ sinOp.getFastmathAttr());
+
+ rewriter.replaceOp(sinOp, sincos.getSin());
+ rewriter.replaceOp(cosOp, sincos.getCos());
+ return success();
+ }
+};
+
+} // namespace
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+namespace {
+
+struct MathSincosFusionPass final
+ : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
+ using MathSincosFusionPassBase::MathSincosFusionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<SincosFusionPattern>(&getContext());
+
+ GreedyRewriteConfig config;
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 349b4de..e9bdcda 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -607,6 +607,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
}
//===----------------------------------------------------------------------===//
+// DistinctObjectsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult DistinctObjectsOp::verify() {
+ if (getOperandTypes() != getResultTypes())
+ return emitOpError("operand types and result types must match");
+
+ if (getOperandTypes().empty())
+ return emitOpError("expected at least one operand");
+
+ return success();
+}
+
+LogicalResult DistinctObjectsOp::inferReturnTypes(
+ MLIRContext * /*context*/, std::optional<Location> /*location*/,
+ ValueRange operands, DictionaryAttr /*attributes*/,
+ OpaqueProperties /*properties*/, RegionRange /*regions*/,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f01ad05..5672942 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/InterleavedRange.h"
#include <cstddef>
#include <iterator>
#include <optional>
@@ -77,6 +78,232 @@ struct LLVMPointerPointerLikeModel
};
} // namespace
+/// Generate a name of a canonical loop nest of the format
+/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
+/// argument index of an operation that has multiple regions, if the operation
+/// has multiple regions.
+/// `_s<idx>` identifies the position of an operation within a region, where
+/// only operations that may potentially contain loops ("container operations"
+/// i.e. have region arguments) are counted. Again, it is omitted if there is
+/// only one such operation in a region. If there are canonical loops nested
+/// inside each other, also may also use the format `_d<num>` where <num> is the
+/// nesting depth of the loop.
+///
+/// The generated name is a best-effort to make canonical loop unique within an
+/// SSA namespace. This also means that regions with IsolatedFromAbove property
+/// do not consider any parents or siblings.
+static std::string generateLoopNestingName(StringRef prefix,
+ CanonicalLoopOp op) {
+ struct Component {
+ /// If true, this component describes a region operand of an operation (the
+ /// operand's owner) If false, this component describes an operation located
+ /// in a parent region
+ bool isRegionArgOfOp;
+ bool skip = false;
+ bool isUnique = false;
+
+ size_t idx;
+ Operation *op;
+ Region *parentRegion;
+ size_t loopDepth;
+
+ Operation *&getOwnerOp() {
+ assert(isRegionArgOfOp && "Must describe a region operand");
+ return op;
+ }
+ size_t &getArgIdx() {
+ assert(isRegionArgOfOp && "Must describe a region operand");
+ return idx;
+ }
+
+ Operation *&getContainerOp() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return op;
+ }
+ size_t &getOpPos() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return idx;
+ }
+ bool isLoopOp() const {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return isa<CanonicalLoopOp>(op);
+ }
+ Region *&getParentRegion() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return parentRegion;
+ }
+ size_t &getLoopDepth() {
+ assert(!isRegionArgOfOp && "Must describe a operation of a region");
+ return loopDepth;
+ }
+
+ void skipIf(bool v = true) { skip = skip || v; }
+ };
+
+ // List of ancestors, from inner to outer.
+ // Alternates between
+ // * region argument of an operation
+ // * operation within a region
+ SmallVector<Component> components;
+
+ // Gather a list of parent regions and operations, and the position within
+ // their parent
+ Operation *o = op.getOperation();
+ while (o) {
+ // Operation within a region
+ Region *r = o->getParentRegion();
+ if (!r)
+ break;
+
+ llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
+ size_t idx = 0;
+ bool found = false;
+ size_t sequentialIdx = -1;
+ bool isOnlyContainerOp = true;
+ for (Block *b : traversal) {
+ for (Operation &op : *b) {
+ if (&op == o && !found) {
+ sequentialIdx = idx;
+ found = true;
+ }
+ if (op.getNumRegions()) {
+ idx += 1;
+ if (idx > 1)
+ isOnlyContainerOp = false;
+ }
+ if (found && !isOnlyContainerOp)
+ break;
+ }
+ }
+
+ Component &containerOpInRegion = components.emplace_back();
+ containerOpInRegion.isRegionArgOfOp = false;
+ containerOpInRegion.isUnique = isOnlyContainerOp;
+ containerOpInRegion.getContainerOp() = o;
+ containerOpInRegion.getOpPos() = sequentialIdx;
+ containerOpInRegion.getParentRegion() = r;
+
+ Operation *parent = r->getParentOp();
+
+ // Region argument of an operation
+ Component &regionArgOfOperation = components.emplace_back();
+ regionArgOfOperation.isRegionArgOfOp = true;
+ regionArgOfOperation.isUnique = true;
+ regionArgOfOperation.getArgIdx() = 0;
+ regionArgOfOperation.getOwnerOp() = parent;
+
+ // The IsolatedFromAbove trait of the parent operation implies that each
+ // individual region argument has its own separate namespace, so no
+ // ambiguity.
+ if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
+ break;
+
+ // Component only needed if operation has multiple region operands. Region
+ // arguments may be optional, but we currently do not consider this.
+ if (parent->getRegions().size() > 1) {
+ auto getRegionIndex = [](Operation *o, Region *r) {
+ for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
+ if (&region == r)
+ return idx;
+ }
+ llvm_unreachable("Region not child of its parent operation");
+ };
+ regionArgOfOperation.isUnique = false;
+ regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
+ }
+
+ // next parent
+ o = parent;
+ }
+
+ // Determine whether a region-argument component is not needed
+ for (Component &c : components)
+ c.skipIf(c.isRegionArgOfOp && c.isUnique);
+
+ // Find runs of nested loops and determine each loop's depth in the loop nest
+ size_t numSurroundingLoops = 0;
+ for (Component &c : llvm::reverse(components)) {
+ if (c.skip)
+ continue;
+
+ // non-skipped multi-argument operands interrupt the loop nest
+ if (c.isRegionArgOfOp) {
+ numSurroundingLoops = 0;
+ continue;
+ }
+
+ // Multiple loops in a region means each of them is the outermost loop of a
+ // new loop nest
+ if (!c.isUnique)
+ numSurroundingLoops = 0;
+
+ c.getLoopDepth() = numSurroundingLoops;
+
+ // Next loop is surrounded by one more loop
+ if (isa<CanonicalLoopOp>(c.getContainerOp()))
+ numSurroundingLoops += 1;
+ }
+
+ // In loop nests, skip all but the innermost loop that contains the depth
+ // number
+ bool isLoopNest = false;
+ for (Component &c : components) {
+ if (c.skip || c.isRegionArgOfOp)
+ continue;
+
+ if (!isLoopNest && c.getLoopDepth() >= 1) {
+ // Innermost loop of a loop nest of at least two loops
+ isLoopNest = true;
+ } else if (isLoopNest) {
+ // Non-innermost loop of a loop nest
+ c.skipIf(c.isUnique);
+
+ // If there is no surrounding loop left, this must have been the outermost
+ // loop; leave loop-nest mode for the next iteration
+ if (c.getLoopDepth() == 0)
+ isLoopNest = false;
+ }
+ }
+
+ // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
+ // we mark them as skipped only after computing loop nests)
+ for (Component &c : components)
+ c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
+ !isa<CanonicalLoopOp>(c.getContainerOp()));
+
+ // Components can be skipped if they are already disambiguated by their parent
+ // (or does not have a parent)
+ bool newRegion = true;
+ for (Component &c : llvm::reverse(components)) {
+ c.skipIf(newRegion && c.isUnique);
+
+ // non-skipped components disambiguate unique children
+ if (!c.skip)
+ newRegion = true;
+
+ // ...except canonical loops that need a suffix for each nest
+ if (!c.isRegionArgOfOp && c.getContainerOp())
+ newRegion = false;
+ }
+
+ // Compile the nesting name string
+ SmallString<64> Name{prefix};
+ llvm::raw_svector_ostream NameOS(Name);
+ for (auto &c : llvm::reverse(components)) {
+ if (c.skip)
+ continue;
+
+ if (c.isRegionArgOfOp)
+ NameOS << "_r" << c.getArgIdx();
+ else if (c.getLoopDepth() >= 1)
+ NameOS << "_d" << c.getLoopDepth();
+ else
+ NameOS << "_s" << c.getOpPos();
+ }
+
+ return NameOS.str().str();
+}
+
void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
@@ -182,7 +409,7 @@ static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
}
template <typename ClauseAttr>
-void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
+static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
p << stringifyEnum(attr.getValue());
}
@@ -1511,8 +1738,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
//===----------------------------------------------------------------------===//
// Helper function to get bitwise AND of `value` and 'flag'
-uint64_t mapTypeToBitFlag(uint64_t value,
- llvm::omp::OpenMPOffloadMappingFlags flag) {
+static uint64_t mapTypeToBitFlag(uint64_t value,
+ llvm::omp::OpenMPOffloadMappingFlags flag) {
return value & llvm::to_underlying(flag);
}
@@ -3159,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
Value result = getResult();
auto [newCli, gen, cons] = decodeCli(result);
+ // Structured binding `gen` cannot be captured in lambdas before C++20
+ OpOperand *generator = gen;
+
// Derive the CLI variable name from its generator:
// * "canonloop" for omp.canonical_loop
// * custom name for loop transformation generatees
@@ -3172,71 +3402,29 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
cliName =
TypeSwitch<Operation *, std::string>(gen->getOwner())
.Case([&](CanonicalLoopOp op) {
- // Find the canonical loop nesting: For each ancestor add a
- // "+_r<idx>" suffix (in reverse order)
- SmallVector<std::string> components;
- Operation *o = op.getOperation();
- while (o) {
- if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
- break;
-
- Region *r = o->getParentRegion();
- if (!r)
- break;
-
- auto getSequentialIndex = [](Region *r, Operation *o) {
- llvm::ReversePostOrderTraversal<Block *> traversal(
- &r->getBlocks().front());
- size_t idx = 0;
- for (Block *b : traversal) {
- for (Operation &op : *b) {
- if (&op == o)
- return idx;
- // Only consider operations that are containers as
- // possible children
- if (!op.getRegions().empty())
- idx += 1;
- }
- }
- llvm_unreachable("Operation not part of the region");
- };
- size_t sequentialIdx = getSequentialIndex(r, o);
- components.push_back(("s" + Twine(sequentialIdx)).str());
-
- Operation *parent = r->getParentOp();
- if (!parent)
- break;
-
- // If the operation has more than one region, also count in
- // which of the regions
- if (parent->getRegions().size() > 1) {
- auto getRegionIndex = [](Operation *o, Region *r) {
- for (auto [idx, region] :
- llvm::enumerate(o->getRegions())) {
- if (&region == r)
- return idx;
- }
- llvm_unreachable("Region not child its parent operation");
- };
- size_t regionIdx = getRegionIndex(parent, r);
- components.push_back(("r" + Twine(regionIdx)).str());
- }
-
- // next parent
- o = parent;
- }
-
- SmallString<64> Name("canonloop");
- for (const std::string &s : reverse(components)) {
- Name += '_';
- Name += s;
- }
-
- return Name;
+ return generateLoopNestingName("canonloop", op);
})
.Case([&](UnrollHeuristicOp op) -> std::string {
llvm_unreachable("heuristic unrolling does not generate a loop");
})
+ .Case([&](TileOp op) -> std::string {
+ auto [generateesFirst, generateesCount] =
+ op.getGenerateesODSOperandIndexAndLength();
+ unsigned firstGrid = generateesFirst;
+ unsigned firstIntratile = generateesFirst + generateesCount / 2;
+ unsigned end = generateesFirst + generateesCount;
+ unsigned opnum = generator->getOperandNumber();
+ // In the OpenMP apply and looprange clauses, indices are 1-based
+ if (firstGrid <= opnum && opnum < firstIntratile) {
+ unsigned gridnum = opnum - firstGrid + 1;
+ return ("grid" + Twine(gridnum)).str();
+ }
+ if (firstIntratile <= opnum && opnum < end) {
+ unsigned intratilenum = opnum - firstIntratile + 1;
+ return ("intratile" + Twine(intratilenum)).str();
+ }
+ llvm_unreachable("Unexpected generatee argument");
+ })
.Default([&](Operation *op) {
assert(false && "TODO: Custom name for this operation");
return "transformed";
@@ -3323,7 +3511,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
- setNameFn(region.getArgument(0), "iv");
+ std::string ivName = generateLoopNestingName("iv", *this);
+ setNameFn(region.getArgument(0), ivName);
}
void CanonicalLoopOp::print(OpAsmPrinter &p) {
@@ -3465,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
}
//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
+static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
+ OperandRange generatees,
+ OperandRange applyees) {
+ if (!generatees.empty())
+ p << '(' << llvm::interleaved(generatees) << ')';
+
+ if (!applyees.empty())
+ p << " <- (" << llvm::interleaved(applyees) << ')';
+}
+
+static ParseResult parseLoopTransformClis(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) {
+ if (parser.parseOptionalLess()) {
+ // Syntax 1: generatees present
+
+ if (parser.parseOperandList(generateesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ if (parser.parseLess())
+ return failure();
+ } else {
+ // Syntax 2: generatees omitted
+ }
+
+ // Parse `<-` (`<` has already been parsed)
+ if (parser.parseMinus())
+ return failure();
+
+ if (parser.parseOperandList(applyeesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ return success();
+}
+
+LogicalResult TileOp::verify() {
+ if (getApplyees().empty())
+ return emitOpError() << "must apply to at least one loop";
+
+ if (getSizes().size() != getApplyees().size())
+ return emitOpError() << "there must be one tile size for each applyee";
+
+ if (!getGeneratees().empty() &&
+ 2 * getSizes().size() != getGeneratees().size())
+ return emitOpError()
+ << "expecting two times the number of generatees than applyees";
+
+ DenseSet<Value> parentIVs;
+
+ Value parent = getApplyees().front();
+ for (auto &&applyee : llvm::drop_begin(getApplyees())) {
+ auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
+ auto [create, gen, cons] = decodeCli(applyee);
+
+ if (!parentGen)
+ return emitOpError() << "applyee CLI has no generator";
+
+ auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
+ if (!parentGen)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ parentIVs.insert(parentLoop.getInductionVar());
+
+ if (!gen)
+ return emitOpError() << "applyee CLI has no generator";
+ auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
+ if (!loop)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ // Canonical loop must be perfectly nested, i.e. the body of the parent must
+ // only contain the omp.canonical_loop of the nested loops, and
+ // omp.terminator
+ bool isPerfectlyNested = [&]() {
+ auto &parentBody = parentLoop.getRegion();
+ if (!parentBody.hasOneBlock())
+ return false;
+ auto &parentBlock = parentBody.getBlocks().front();
+
+ auto nestedLoopIt = parentBlock.begin();
+ if (nestedLoopIt == parentBlock.end() ||
+ (&*nestedLoopIt != loop.getOperation()))
+ return false;
+
+ auto termIt = std::next(nestedLoopIt);
+ if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
+ return false;
+
+ if (std::next(termIt) != parentBlock.end())
+ return false;
+
+ return true;
+ }();
+ if (!isPerfectlyNested)
+ return emitOpError() << "tiled loop nest must be perfectly nested";
+
+ if (parentIVs.contains(loop.getTripCount()))
+ return emitOpError() << "tiled loop nest must be rectangular";
+
+ parent = applyee;
+ }
+
+ // TODO: The tile sizes must be computed before the loop, but checking this
+ // requires dominance analysis. For instance:
+ //
+ // %canonloop = omp.new_cli
+ // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // // write to %x
+ // omp.terminator
+ // }
+ // %ts = llvm.load %x
+ // omp.tile <- (%canonloop) sizes(%ts : i32)
+
+ return success();
+}
+
+std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_applyees);
+}
+
+std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_generatees);
+}
+
+//===----------------------------------------------------------------------===//
// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 132ed81..3385b2a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -616,11 +616,10 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
if (diag.succeeded()) {
// Tracking failure is the only failure.
return trackingFailure;
- } else {
- diag.attachNote() << "tracking listener also failed: "
- << trackingFailure.getMessage();
- (void)trackingFailure.silence();
}
+ diag.attachNote() << "tracking listener also failed: "
+ << trackingFailure.getMessage();
+ (void)trackingFailure.silence();
}
if (!diag.succeeded())
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 842e880..c627158 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,13 +6,24 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
using namespace mlir;
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
@@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
+ size_t selectedRegionIdx;
+ OptionalParseResult attrParseRes =
+ parser.parseOptionalInteger(selectedRegionIdx);
+ if (attrParseRes.has_value()) {
+ if (failed(*attrParseRes))
+ return failure();
+
+ selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand param;
+ auto paramParseRes = parser.parseOptionalOperand(param);
+ if (paramParseRes.has_value()) {
+ if (failed(*paramParseRes))
+ return failure();
+
+ selectedRegionParam = param;
+ return success();
+ }
+
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected either an integer attribute or a transform.param operand";
+}
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam) {
+ if (selectedRegionAttr)
+ printer << selectedRegionAttr.getValue();
+ if (selectedRegionParam)
+ printer << selectedRegionParam;
+}
+
+OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
+ RegionBranchPoint point) {
+ // No operands will be forwarded to the region(s).
+ return getOperands().slice(0, 0);
+}
+
+void transform::tune::AlternativesOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+ if (point.isParent())
+ if (auto selectedRegionIdx = getSelectedRegionAttr())
+ regions.emplace_back(
+ &getAlternatives()[selectedRegionIdx->getSExtValue()],
+ Block::BlockArgListType());
+ else
+ for (Region &alternative : getAlternatives())
+ regions.emplace_back(&alternative, Block::BlockArgListType());
+ else
+ regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::tune::AlternativesOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+ (void)operands;
+ bounds.reserve(getNumRegions());
+
+ if (auto selectedRegionIdx = getSelectedRegionAttr()) {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 0));
+ bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
+ } else {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+ }
+}
+
+void transform::tune::AlternativesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getSelectedRegionParamMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ // TODO: should effects from regions be forwarded?
+}
+
+DiagnosedSilenceableFailure
+transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ std::optional<size_t> selectedRegionIdx;
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr())
+ selectedRegionIdx = selectedRegionAttr->getSExtValue();
+
+ if (Value selectedRegionParam = getSelectedRegionParam()) {
+ ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
+ IntegerAttr selectedRegionAttr;
+ if (associatedAttrs.size() != 1 ||
+ !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
+ return emitDefiniteFailure()
+ << "param should hold exactly one integer attribute, got: "
+ << associatedAttrs[0];
+ selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
+ }
+
+ if (!selectedRegionIdx)
+ return emitDefiniteFailure() << "non-deterministic choice " << getName()
+ << " is only resolved through providing a "
+ "`selected_region` attr/param";
+
+ if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
+ return emitDefiniteFailure()
+ << "'selected_region' attribute/param specifies region at index "
+ << *selectedRegionIdx << " while op has only " << getNumRegions()
+ << " regions";
+
+ Region &selectedRegion = getRegion(*selectedRegionIdx);
+ auto scope = state.make_region_scope(selectedRegion);
+ Block &block = selectedRegion.front();
+ // Apply the region's ops one by one.
+ for (Operation &transform : block.without_terminator()) {
+ DiagnosedSilenceableFailure result =
+ state.applyTransform(cast<transform::TransformOpInterface>(transform));
+ if (result.isDefiniteFailure())
+ return result;
+
+ if (result.isSilenceableFailure()) {
+ for (const auto &res : getResults())
+ results.set(res, {});
+ return result;
+ }
+ }
+ // Forward the operation mapping for values yielded from the region to the
+ // values produced by the alternatives op.
+ transform::detail::forwardTerminatorOperands(&block, state, results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::tune::AlternativesOp::verify() {
+ for (auto *region : getRegions()) {
+ auto yieldTerminator =
+ llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << transform::YieldOp::getOperationName()
+ << "' as terminator";
+
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [i, operandType, resultType] : llvm::zip_equal(
+ llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ if (operandType == resultType)
+ continue;
+ return yieldTerminator.emitOpError()
+ << "the type of the terminator operand #" << i
+ << " must match the type of the corresponding parent op result ("
+ << operandType << " vs " << resultType << ")";
+ }
+ }
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr()) {
+ size_t regionIdx = selectedRegionAttr->getSExtValue();
+ if (regionIdx < 0 || regionIdx >= getNumRegions())
+ return emitOpError()
+ << "'selected_region' attribute specifies region at index "
+ << regionIdx << " while op has only " << getNumRegions()
+ << " regions";
+ }
+
+ return success();
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb46869..b0132e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -580,7 +580,7 @@ namespace {
// ElideSingleElementReduction for ReduceOp.
struct ElideUnitDimsInMultiDimReduction
: public OpRewritePattern<MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
@@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
namespace {
struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
@@ -2197,7 +2197,7 @@ namespace {
// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -2220,7 +2220,7 @@ public:
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {
@@ -2938,7 +2938,7 @@ namespace {
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
@@ -3109,7 +3109,7 @@ namespace {
// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
// to a broadcast.
struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
@@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) {
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
@@ -3182,7 +3182,7 @@ public:
/// vector.interleave.
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
@@ -3326,7 +3326,7 @@ namespace {
// broadcast.
class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
@@ -3344,7 +3344,7 @@ public:
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
@@ -3380,7 +3380,7 @@ public:
/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
@@ -3748,7 +3748,7 @@ namespace {
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -3768,7 +3768,7 @@ public:
class FoldInsertStridedSliceOfExtract final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -3798,7 +3798,7 @@ public:
class InsertStridedSliceConstantFolder final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
// unless the source vector constant has a single use.
@@ -4250,7 +4250,7 @@ namespace {
// %mask = vector.create_mask %new_ub : vector<8xi1>
class StridedSliceCreateMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
public:
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
@@ -4310,7 +4310,7 @@ public:
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -4365,7 +4365,7 @@ public:
class StridedSliceBroadcast final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -4416,7 +4416,7 @@ public:
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -4448,7 +4448,7 @@ public:
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -5023,7 +5023,7 @@ namespace {
/// ```
struct TransferReadAfterWriteToBroadcast
: public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -5458,7 +5458,7 @@ namespace {
/// any other uses.
class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
@@ -5514,7 +5514,7 @@ public:
struct SwapExtractSliceOfTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() {
namespace {
class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(load.getMask())) {
@@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() {
namespace {
class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(store.getMask())) {
@@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(gather.getMask())) {
@@ -5910,7 +5910,7 @@ public:
/// maskedload. Only 1D fixed vectors are supported for now.
class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
if (!isa<MemRefType>(op.getBase().getType()))
@@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() {
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(scatter.getMask())) {
@@ -5982,7 +5982,7 @@ public:
/// maskedstore. Only 1D fixed vectors are supported for now.
class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
@@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() {
namespace {
class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(expand.getMask())) {
@@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() {
namespace {
class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(compress.getMask())) {
@@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
class ShapeCastCreateMaskFolderTrailingOneDim final
: public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
PatternRewriter &rewriter) const override {
@@ -6330,7 +6330,7 @@ public:
/// If both (i) and (ii) are possible, (i) is chosen.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
@@ -6614,7 +6614,7 @@ namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6646,7 +6646,7 @@ public:
/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6663,7 +6663,7 @@ public:
/// Folds transpose(create_mask) into a new transposed create_mask.
class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transpOp,
PatternRewriter &rewriter) const override {
@@ -6700,7 +6700,7 @@ public:
/// Folds transpose(shape_cast) into a new shape_cast.
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6750,7 +6750,7 @@ public:
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
@@ -6971,7 +6971,7 @@ namespace {
/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
@@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
///
class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const override {
@@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
// vector.broadcast.
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
public:
- using OpRewritePattern<SplatOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(SplatOp splatOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index dedc3b3..61d9357 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -34,7 +34,7 @@ namespace {
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 65702ff..efe8d14 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
///
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 1f96a3a..6bc8347 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -50,7 +50,7 @@ namespace {
///
/// Supports vector types with a fixed leading dimension.
struct UnrollGather : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
@@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
/// but should be fairly straightforward to extend beyond that.
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
@@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 9d6a865..479fc0c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -163,7 +163,7 @@ private:
/// : vector<7xi16>, vector<7xi16>
/// ```
struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 5617b06..7730c4e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -48,7 +48,7 @@ namespace {
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
@@ -100,7 +100,7 @@ public:
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
@@ -184,7 +184,7 @@ namespace {
/// and actually match the traits of its the nested `MaskableOpInterface`.
template <class SourceOp>
struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
- using OpRewritePattern<MaskOp>::OpRewritePattern;
+ using Base::Base;
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4773732d..e86e2a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -39,7 +39,7 @@ namespace {
class InnerOuterDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
explicit InnerOuterDimReductionConversion(
MLIRContext *context, vector::VectorMultiReductionLowering options,
@@ -136,7 +136,7 @@ private:
class ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
explicit ReduceMultiDimReductionRank(
MLIRContext *context, vector::VectorMultiReductionLowering options,
@@ -304,7 +304,7 @@ private:
/// and combines results
struct TwoDimMultiReductionToElementWise
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise
/// a sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction
/// separately.
struct OneDimMultiReductionToTwoDim
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index af4851e..258f2cb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -99,7 +99,7 @@ namespace {
/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 603ea41..c5f22b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
@@ -356,7 +356,7 @@ public:
class ScalableShapeCastOpRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 78102f7..8f46ad6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -44,7 +44,7 @@ namespace {
///
struct MixedSizeInputShuffleOpRewrite final
: OpRewritePattern<vector::ShuffleOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
index ee5568a..08e7c89 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -24,7 +24,7 @@ using namespace mlir::vector;
namespace {
struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 6407a86..7521e24 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp,
struct ToFromElementsToShuffleTreeRewrite final
: OpRewritePattern<vector::FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ce..c3f7de0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -300,7 +300,7 @@ namespace {
/// %x = vector.insert .., .. [.., ..]
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
@@ -395,7 +395,7 @@ private:
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
@@ -433,7 +433,7 @@ public:
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
TransposeOp2DToShuffleLowering(
vector::VectorTransposeLowering vectorTransposeLowering,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index cab1289..963b2c8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -54,7 +54,7 @@ namespace {
// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
@@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
// Casts away leading one dimensions in vector.insert's vector inputs by
// inserting vector.broadcast.
struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
PatternRewriter &rewriter) const override {
@@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
// 1 dimensions.
struct CastAwayTransferReadLeadingOneDim
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
@@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim
// 1 dimensions.
struct CastAwayTransferWriteLeadingOneDim
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
@@ -541,7 +541,7 @@ public:
// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
: public OpRewritePattern<vector::ConstantMaskOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index bdbb792..7acc120 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -48,7 +48,7 @@ namespace {
///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
PatternRewriter &rewriter) const override {
@@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final
///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 264cbc1..3a6684f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -548,7 +548,7 @@ namespace {
// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<vector::StoreOp>(context),
@@ -827,7 +827,7 @@ private:
/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
@@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
@@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
@@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
// TODO: Document-me
struct ConvertVectorTransferRead final
: OpConversionPattern<vector::TransferReadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
@@ -1942,7 +1942,7 @@ namespace {
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
PatternRewriter &rewriter) const override {
@@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
- using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
@@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
///
struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ using Base::Base;
RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index f6d6555..9e49873 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -34,7 +34,7 @@ using namespace mlir::vector;
class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -84,7 +84,7 @@ public:
class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
@@ -183,7 +183,7 @@ public:
class Convert1DExtractStridedSliceIntoShuffle
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -271,7 +271,7 @@ private:
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 82bac8c..71fba71c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices(
/// vector.extract_strided_slice operation.
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
@@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final
///
struct LinearizeVectorInsertStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
@@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final
/// of the original shuffle operation.
struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorShuffle(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final
///
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorExtract(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -501,7 +501,7 @@ struct LinearizeVectorExtract final
///
struct LinearizeVectorInsert final
: public OpConversionPattern<vector::InsertOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorInsert(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -575,7 +575,7 @@ struct LinearizeVectorInsert final
/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorBitCast(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
@@ -629,7 +629,7 @@ struct LinearizeVectorSplat final
/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
@@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final
/// For generic cases, the vector unroll pass should be used to unroll the load
/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorStore final
: public OpConversionPattern<vector::StoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -778,7 +778,7 @@ struct LinearizeVectorStore final
///
struct LinearizeVectorFromElements final
: public OpConversionPattern<vector::FromElementsOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorFromElements(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final
///
struct LinearizeVectorToElements final
: public OpConversionPattern<vector::ToElementsOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorToElements(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c364a8b..1121d95 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -1081,7 +1081,7 @@ private:
/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
/// to memref.store.
class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 866f789..d6a6d7cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -78,7 +78,7 @@ namespace {
/// ```
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
@@ -138,7 +138,7 @@ struct MultiReduceToContract
/// ```
struct CombineContractABTranspose final
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -202,7 +202,7 @@ struct CombineContractABTranspose final
/// ```
struct CombineContractResultTranspose final
: public OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
PatternRewriter &rewriter) const override {
@@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
// %2 = vector.extract %1[1] : f16 from vector<2xf16>
struct BubbleDownVectorBitCastForExtract
: public OpRewritePattern<vector::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract
// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
struct BubbleDownBitCastForStridedSliceExtract
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
@@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract
// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
//
struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
@@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
@@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// %7 = vector.insert_strided_slice %6, %cst {
// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
public:
BreakDownVectorBitCast(MLIRContext *context,
@@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final
class ExtractOpFromElementwise final
: public OpRewritePattern<vector::ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
@@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) {
/// ```
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
@@ -1285,7 +1285,7 @@ public:
class StoreOpFromSplatOrBroadcast final
: public OpRewritePattern<vector::StoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::StoreOp op,
PatternRewriter &rewriter) const override {
@@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
/// InstCombine seems to handle vectors with multiple elements but not the
/// single element ones.
struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
- using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(arith::SelectOp selectOp,
PatternRewriter &rewriter) const override {
@@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
/// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDimsTransferRead
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead
/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
class DropInnerMostUnitDimsTransferWrite
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
@@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite
/// with the RHS transposed) lowering.
struct CanonicalizeContractMatmulToMMT final
: OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -1845,7 +1845,7 @@ private:
template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp
/// %b = vector.reduction <add> %a, %acc
/// ```
struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
@@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final
/// ```
struct DropUnitDimsFromTransposeOp final
: OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
@@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
/// ```
struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
@@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
/// %c = vector.reduction <add> %b, %acc
/// ```
struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a92..784e5d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
return failure();
xegpu::DistributeLayoutAttr layout =
- xegpu::getDistributeLayoutAttr(op.getValue());
+ xegpu::getDistributeLayoutAttr(op.getOperand(0));
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
for (auto [val, offs, mask] : llvm::zip(
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
- xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
- mask, chunkSizeAttr, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ auto store = xegpu::StoreScatterOp::create(
+ rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
// Update the layout attribute to drop sg_layout and sg_data.
- if (auto newLayout = layout.dropSgLayoutAndData())
- op->setAttr("layout", newLayout);
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ for (OpOperand &operand : store->getOpOperands()) {
+ // Skip for operand one (memref)
+ if (operand.getOperandNumber() == 1)
+ continue;
+ xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
+ }
+ }
}
rewriter.eraseOp(op);
return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
[=](xegpu::StoreScatterOp op) -> bool {
- // Check if the layout attribute is present on the result.
- auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
- if (!layout)
- return true;
+ auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
return isLegal(layout);
});
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c84e760..8f199b6 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -489,13 +489,6 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
SmallVector<OpFoldResult, 4> foldResults;
LDBG() << "Trying to fold: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
- if (op->getName().getStringRef() == "vector.extract") {
- Operation *parent = op->getParentOp();
- while (parent && parent->getName().getStringRef() != "spirv.func")
- parent = parent->getParentOp();
- if (parent)
- parent->dump();
- }
if (failed(op->fold(foldResults)))
return cleanupFailure();
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index af4ea5a..0f28cbc 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -304,7 +304,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
umin = lhsMin.udiv(rhsMax);
// X u/ Y u<= X.
- APInt umax = lhsMax;
+ const APInt &umax = lhsMax;
return ConstantIntRanges::fromUnsigned(umin, umax);
}
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a..e3f075f 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}
+// add traits to the dictionary, return true if any were added
+static SmallVector<std::string> generateTraits(irdl::OperationOp op,
+ const OpStrings &strings) {
+ SmallVector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+
+ // Requires verifyInvariantsImpl is implemented on the op
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
@@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
+ SmallVector<std::string> traitNames = generateTraits(op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);
@@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+static void generateRegionConstraintVerifiers(
+ irdl::detail::dictionary &dict, irdl::OperationOp op,
+ const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
+ SmallVectorImpl<std::string> &verifierCalls) {
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
+ if (strings.opRegionNames.empty() || !regionsOp)
+ return;
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+
+ // Extract the actual region constraint from the IRDL RegionOp
+ std::string condition = "true";
+ std::string textualConditionName = "any region";
+
+ if (auto regionDefOp =
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+ // Generate constraint condition based on RegionOp attributes
+ SmallVector<std::string> conditionParts;
+ SmallVector<std::string> descriptionParts;
+
+ // Check number of blocks constraint
+ if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+ conditionParts.push_back(
+ llvm::formatv("region.getBlocks().size() == {0}",
+ blockCount.value())
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+ }
+
+ // Check entry block arguments constraint
+ if (regionDefOp.getConstrainedArguments()) {
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+ conditionParts.push_back(
+ llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+ .str());
+ }
+
+ // Combine conditions
+ if (!conditionParts.empty()) {
+ condition = llvm::join(conditionParts, " && ");
+ }
+
+ // Generate descriptive error message
+ if (!descriptionParts.empty()) {
+ textualConditionName =
+ llvm::formatv("region with {0}",
+ llvm::join(descriptionParts, " and "))
+ .str();
+ }
+ }
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!({1})) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
+ << "failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+})",
+ helperFnName, condition, textualConditionName));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+}
+
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ SmallVector<std::string> verifierHelpers;
+ SmallVector<std::string> verifierCalls;
+
+ generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
+ verifierCalls);
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+ if(::mlir::failed(verify()))
+ return ::mlir::failure();
+
+ {1}
+
+ return ::mlir::success();
+})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9..93ce0be 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -112,6 +118,8 @@ public:
return {};
}
+ ::llvm::LogicalResult verifyInvariantsImpl();
+
static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
}
@@ -147,7 +155,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +170,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420..f4a1b7a 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) {
+ (void)odsState.addRegion();
+ }
}
__OP_CPP_NAME__
@@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 53209a4..9fcb02e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
return success();
}
+/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
+/// OpenMPIRBuilder.
+static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+
+ SmallVector<llvm::CanonicalLoopInfo *> translatedLoops;
+ SmallVector<llvm::Value *> translatedSizes;
+
+ for (Value size : op.getSizes()) {
+ llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
+ assert(translatedSize &&
+ "sizes clause arguments must already be translated");
+ translatedSizes.push_back(translatedSize);
+ }
+
+ for (Value applyee : op.getApplyees()) {
+ llvm::CanonicalLoopInfo *consBuilderCLI =
+ moduleTranslation.lookupOMPLoop(applyee);
+ assert(applyee && "Canonical loop must already been translated");
+ translatedLoops.push_back(consBuilderCLI);
+ }
+
+ auto generatedLoops =
+ ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
+ if (!op.getGeneratees().empty()) {
+ for (auto [mlirLoop, genLoop] :
+ zip_equal(op.getGeneratees(), generatedLoops))
+ moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
+ }
+
+ // CLIs can only be consumed once
+ for (Value applyee : op.getApplyees())
+ moduleTranslation.invalidateOmpLoop(applyee);
+
+ return success();
+}
+
/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static llvm::AtomicOrdering
convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
@@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TileOp op) {
+ return applyTile(op, builder, moduleTranslation);
+ })
.Case([&](omp::TargetAllocMemOp) {
return convertTargetAllocMemOp(*op, builder, moduleTranslation);
})
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bf0136b..3a23bbf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1856,6 +1856,44 @@ void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(newValues.size() == op->getNumResults() &&
"incorrect number of replacement values");
+ LLVM_DEBUG({
+ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
+ << ")\n";
+ if (currentTypeConverter) {
+ // If the user-provided replacement types are different from the
+ // legalized types, as per the current type converter, print a note.
+ // In most cases, the replacement types are expected to match the types
+ // produced by the type converter, so this could indicate a bug in the
+ // user code.
+ for (auto [result, repls] :
+ llvm::zip_equal(op->getResults(), newValues)) {
+ Type resultType = result.getType();
+ auto logProlog = [&, repls = repls]() {
+ logger.startLine() << " Note: Replacing op result of type "
+ << resultType << " with value(s) of type (";
+ llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
+ logger.getOStream() << v.getType();
+ });
+ logger.getOStream() << ")";
+ };
+ SmallVector<Type> convertedTypes;
+ if (failed(currentTypeConverter->convertTypes(resultType,
+ convertedTypes))) {
+ logProlog();
+ logger.getOStream() << ", but the type converter failed to legalize "
+ "the original type.\n";
+ continue;
+ }
+ if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
+ logProlog();
+ logger.getOStream() << ", but the legalized type(s) is/are (";
+ llvm::interleaveComma(convertedTypes, logger.getOStream(),
+ [&](Type t) { logger.getOStream() << t; });
+ logger.getOStream() << ")\n";
+ }
+ }
+ }
+ });
if (!config.allowPatternRollback) {
// Pattern rollback is not allowed: materialize all IR changes immediately.
@@ -2072,10 +2110,6 @@ void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Replace : '" << op->getName() << "'(" << op << ")\n";
- });
// If the current insertion point is before the erased operation, we adjust
// the insertion point to be after the operation.
@@ -2093,10 +2127,6 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Replace : '" << op->getName() << "'(" << op << ")\n";
- });
// If the current insertion point is before the erased operation, we adjust
// the insertion point to be after the operation.
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index bf40cc5..e3bacb5 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -44,18 +44,12 @@ class BufferizeToAllocationOp(BufferizeToAllocationOp):
loc=None,
ip=None,
):
- # No other types are allowed, so hard-code those here.
- allocated_buffer_type = transform.AnyValueType.get()
- new_ops_type = transform.AnyOpType.get()
-
if isinstance(memory_space, int):
memory_space = str(memory_space)
if isinstance(memory_space, str):
memory_space = Attribute.parse(memory_space)
super().__init__(
- allocated_buffer_type,
- new_ops_type,
target,
memory_space=memory_space,
memcpy_op=memcpy_op,
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index f63f88a..b3bfa80 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -6,6 +6,9 @@ from typing import Optional, Sequence
from ...ir import (
Type,
+ Value,
+ Operation,
+ OpView,
Attribute,
ArrayAttr,
StringAttr,
@@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import *
from .._transform_tune_extension_ops_gen import _Dialect
try:
- from .._ods_common import _cext as _ods_cext
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -36,7 +42,7 @@ class KnobOp(KnobOp):
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
@@ -75,8 +81,62 @@ def knob(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AlternativesOp(AlternativesOp):
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[
+ Union[int, IntegerAttr, Value, Operation, OpView]
+ ] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(name, str):
+ name = StringAttr.get(name)
+
+ selected_region_attr = selected_region_param = None
+ if isinstance(selected_region, IntegerAttr):
+ selected_region_attr = selected_region
+ elif isinstance(selected_region, int):
+ selected_region_attr = IntegerAttr.get(
+ IntegerType.get_signless(32), selected_region
+ )
+ elif isinstance(selected_region, (Value, Operation, OpView)):
+ selected_region_param = _get_op_result_or_value(selected_region)
+
+ super().__init__(
+ results,
+ name,
+ num_alternatives,
+ selected_region_attr=selected_region_attr,
+ selected_region_param=selected_region_param,
+ loc=loc,
+ ip=ip,
+ )
+ for region in self.regions:
+ region.blocks.append()
+
+
+def alternatives(
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
+ loc=None,
+ ip=None,
+):
+ return AlternativesOp(
+ results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
+ )
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 45b1a1f..0cbe064 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) {
// -----
+// ALL-LABEL: func @distinct_objects
+// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
+func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
+// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
+// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1
+// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
+// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1
+ %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+ return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+}
+
+// -----
+
+// ALL-LABEL: func @distinct_objects_noop
+// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>)
+func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> {
+// 1-operand version is noop
+// ALL-NEXT: return %[[ARG0]]
+ %1 = memref.distinct_objects %arg0 : memref<?xf16>
+ return %1 : memref<?xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @assume_alignment_w_offset
// CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset
func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ca3de3a..2fe0995 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}
+// CHECK-LABEL: @test_mulf2(
+func.func @test_mulf2(%arg0 : f32) -> (f32, f32) {
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32
+ // CHECK-NEXT: return %[[C0]], %[[C0n]]
+ %c0 = arith.constant 0.0 : f32
+ %c0n = arith.constant -0.0 : f32
+ %0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32
+ %1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32
+ return %0, %1 : f32, f32
+}
+
// -----
// CHECK-LABEL: @test_divf(
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 99790cc..fcd004a 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -85,3 +85,14 @@ func.func @no_expansion(%x: f32) -> f32 {
%y = arith.addf %x, %c : f32
func.return %y : f32
}
+
+// -----
+
+func.func @no_promote_select(%c: i1, %x: bf16, %y: bf16) -> bf16 {
+// CHECK-LABEL: @no_promote_select
+// CHECK-SAME: (%[[C:.+]]: i1, %[[X:.+]]: bf16, %[[Y:.+]]: bf16)
+// CHECK: %[[Z:.+]] = arith.select %[[C]], %[[X]], %[[Y]] : bf16
+// CHECK: return %[[Z]]
+ %z = arith.select %c, %x, %y : bf16
+ func.return %z : bf16
+}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 0bad151..6134695 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1068,6 +1068,38 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
// -----
+// CHECK-LABEL: rocdl.cvt.scalef32.pk8
+llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %scale: f32) {
+
+ // CHECK: rocdl.cvt.scalef32.pk8.fp8.f32
+ %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.bf8.f32
+ %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.fp4.f32
+ %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.pk8.fp8.f16
+ %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.bf8.f16
+ %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.fp4.f16
+ %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.pk8.fp8.bf16
+ %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.bf8.bf16
+ %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.pk8.fp4.bf16
+ %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: rocdl.cvt.scale.pk16
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
new file mode 100644
index 0000000..29fb9f1
--- /dev/null
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s
+
+// CHECK-LABEL: func.func @sincos_fusion(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK: }
+func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+ %0 = math.sin %arg0 : f32
+ %1 = math.cos %arg0 : f32
+
+ %2 = math.cos %arg1 : f32
+ %3 = math.sin %arg1 : f32
+
+ func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+func.func private @sink(%arg0 : f32)
+
+// CHECK: func.func private @sink(f32)
+// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK: call @sink(%[[VAL_0]]) : (f32) -> ()
+// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK: call @sink(%[[VAL_3]]) : (f32) -> ()
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK: }
+func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+ %0 = math.sin %arg0 : f32
+ func.call @sink(%0) : (f32) -> ()
+ %1 = math.cos %arg0 : f32
+ %2 = math.cos %arg1 : f32
+ func.call @sink(%2) : (f32) -> ()
+ %3 = math.sin %arg1 : f32
+ func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
+// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
+// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK: }
+func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) {
+ %0 = math.sin %arg0 fastmath<contract> : f32
+ %1 = math.cos %arg0 : f32
+ func.return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @sincos_no_fusion_different_block(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 {
+// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) {
+// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32
+// CHECK: scf.yield %[[VAL_1]] : f32
+// CHECK: } else {
+// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32
+// CHECK: scf.yield %[[VAL_2]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_0]] : f32
+// CHECK: }
+func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
+ %0 = scf.if %flag -> f32 {
+ %s = math.sin %arg0 : f32
+ scf.yield %s : f32
+ } else {
+ %c = math.cos %arg0 : f32
+ scf.yield %c : f32
+ }
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
+// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK: }
+func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) {
+ %0 = math.sin %arg0 fastmath<contract> : f32
+ %1 = math.cos %arg0 fastmath<contract> : f32
+ func.return %0, %1 : f32, f32
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 3f96d90..5ff2920 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape(
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
return
}
+
+// -----
+
+func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) {
+ // expected-error @+1 {{operand types and result types must match}}
+ %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>)
+ return %0, %1 : memref<?xi32>, memref<?xf32>
+}
+
+// -----
+
+func.func @distinct_objects_0_operands() {
+ // expected-error @+1 {{expected at least one operand}}
+ "memref.distinct_objects"() : () -> ()
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6c2298a..a90c950 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
return
}
+// CHECK-LABEL: func @distinct_objects
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>)
+func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) {
+ // CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64>
+ %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+ return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64>
+}
+
// CHECK-LABEL: func @expand_collapse_shape_static
func.func @expand_collapse_shape_static(
%arg0: memref<3x4x5xf32>,
diff --git a/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir b/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir
index adadb8b..0e9385e 100644
--- a/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir
+++ b/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s | FileCheck %s
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s --enable-var-scope
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope
// CHECK-LABEL: @omp_canonloop_raw(
@@ -24,10 +24,10 @@ func.func @omp_canonloop_raw(%tc : i32) -> () {
func.func @omp_canonloop_sequential_raw(%tc : i32) -> () {
// CHECK-NEXT: %canonloop_s0 = omp.new_cli
%canonloop_s0 = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) {
"omp.canonical_loop" (%tc, %canonloop_s0) ({
^bb_first(%iv_first: i32):
- // CHECK-NEXT: = llvm.add %iv, %iv : i32
+ // CHECK-NEXT: = llvm.add %iv_s0, %iv_s0 : i32
%newval = llvm.add %iv_first, %iv_first : i32
// CHECK-NEXT: omp.terminator
omp.terminator
@@ -36,7 +36,7 @@ func.func @omp_canonloop_sequential_raw(%tc : i32) -> () {
// CHECK-NEXT: %canonloop_s1 = omp.new_cli
%canonloop_s1 = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv : i32 in range(%[[tc]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) {
"omp.canonical_loop" (%tc, %canonloop_s1) ({
^bb_second(%iv_second: i32):
// CHECK: omp.terminator
@@ -52,17 +52,17 @@ func.func @omp_canonloop_sequential_raw(%tc : i32) -> () {
// CHECK-LABEL: @omp_nested_canonloop_raw(
// CHECK-SAME: %[[tc_outer:.+]]: i32, %[[tc_inner:.+]]: i32)
func.func @omp_nested_canonloop_raw(%tc_outer : i32, %tc_inner : i32) -> () {
- // CHECK-NEXT: %canonloop_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop = omp.new_cli
%outer = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop_d1 = omp.new_cli
%inner = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc_outer]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc_outer]]) {
"omp.canonical_loop" (%tc_outer, %outer) ({
^bb_outer(%iv_outer: i32):
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc_inner]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc_inner]]) {
"omp.canonical_loop" (%tc_inner, %inner) ({
^bb_inner(%iv_inner: i32):
- // CHECK-NEXT: = llvm.add %iv, %iv_0 : i32
+ // CHECK-NEXT: = llvm.add %iv, %iv_d1 : i32
%newval = llvm.add %iv_outer, %iv_inner: i32
// CHECK-NEXT: omp.terminator
omp.terminator
@@ -108,16 +108,24 @@ func.func @omp_canonloop_constant_pretty() -> () {
func.func @omp_canonloop_sequential_pretty(%tc : i32) -> () {
// CHECK-NEXT: %canonloop_s0 = omp.new_cli
%canonloop_s0 = omp.new_cli
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
- omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%tc) {
// CHECK-NEXT: omp.terminator
omp.terminator
}
// CHECK: %canonloop_s1 = omp.new_cli
%canonloop_s1 = omp.new_cli
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv : i32 in range(%[[tc]]) {
- omp.canonical_loop(%canonloop_s1) %iv_0 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ }
+
+ // CHECK: %canonloop_s2 = omp.new_cli
+ %canonloop_s2 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s2) %iv_s2 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s2) %iv_s2 : i32 in range(%tc) {
// CHECK-NEXT: omp.terminator
omp.terminator
}
@@ -126,17 +134,17 @@ func.func @omp_canonloop_sequential_pretty(%tc : i32) -> () {
}
-// CHECK-LABEL: @omp_canonloop_nested_pretty(
+// CHECK-LABEL: @omp_canonloop_2d_nested_pretty(
// CHECK-SAME: %[[tc:.+]]: i32)
-func.func @omp_canonloop_nested_pretty(%tc : i32) -> () {
- // CHECK-NEXT: %canonloop_s0 = omp.new_cli
- %canonloop_s0 = omp.new_cli
- // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli
- %canonloop_s0_s0 = omp.new_cli
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
- omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%tc) {
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc]]) {
- omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%tc) {
+func.func @omp_canonloop_2d_nested_pretty(%tc : i32) -> () {
+ // CHECK-NEXT: %canonloop = omp.new_cli
+ %canonloop = omp.new_cli
+ // CHECK-NEXT: %canonloop_d1 = omp.new_cli
+ %canonloop_d1 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%tc) {
// CHECK: omp.terminator
omp.terminator
}
@@ -147,6 +155,77 @@ func.func @omp_canonloop_nested_pretty(%tc : i32) -> () {
}
+// CHECK-LABEL: @omp_canonloop_3d_nested_pretty(
+// CHECK-SAME: %[[tc:.+]]: i32)
+func.func @omp_canonloop_3d_nested_pretty(%tc : i32) -> () {
+ // CHECK: %canonloop = omp.new_cli
+ %canonloop = omp.new_cli
+ // CHECK: %canonloop_d1 = omp.new_cli
+ %canonloop_d1 = omp.new_cli
+ // CHECK: %canonloop_d2 = omp.new_cli
+ %canonloop_d2 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_d1) %iv_1d : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ }
+
+ return
+}
+
+
+// CHECK-LABEL: @omp_canonloop_sequential_nested_pretty(
+// CHECK-SAME: %[[tc:.+]]: i32)
+func.func @omp_canonloop_sequential_nested_pretty(%tc : i32) -> () {
+ // CHECK-NEXT: %canonloop_s0 = omp.new_cli
+ %canonloop_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop_s0_d1 = omp.new_cli
+ %canonloop_s0_d1 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_d1) %iv_s0_d1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s0_d1) %iv_s0_d1 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ // CHECK-NEXT: }
+ }
+
+ // CHECK-NEXT: %canonloop_s1 = omp.new_cli
+ %canonloop_s1 = omp.new_cli
+ // CHECK-NEXT: %canonloop_s1_d1 = omp.new_cli
+ %canonloop_s1_d1 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_s1_d1) %iv_s1_d1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop_s1_d1) %iv_s1d1 : i32 in range(%tc) {
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: omp.terminator
+ omp.terminator
+ }
+
+ return
+}
+
+
// CHECK-LABEL: @omp_newcli_unused(
// CHECK-SAME: )
func.func @omp_newcli_unused() -> () {
@@ -155,3 +234,74 @@ func.func @omp_newcli_unused() -> () {
// CHECK-NEXT: return
return
}
+
+
+// CHECK-LABEL: @omp_canonloop_multiregion_isolatedfromabove(
+func.func @omp_canonloop_multiregion_isolatedfromabove() -> () {
+ omp.private {type = firstprivate} @x.privatizer : !llvm.ptr init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %c42_i32 = arith.constant 42: i32
+ // CHECK: omp.canonical_loop %iv : i32 in range(%c42_i32) {
+ omp.canonical_loop %iv1 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ // CHECK: omp.yield
+ omp.yield(%arg0 : !llvm.ptr)
+ } copy {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %c42_i32 = arith.constant 42: i32
+ // CHECK: omp.canonical_loop %iv : i32 in range(%c42_i32) {
+ omp.canonical_loop %iv : i32 in range(%c42_i32) {
+ // CHECK: omp.canonical_loop %iv_d1 : i32 in range(%c42_i32) {
+ omp.canonical_loop %iv_d1 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ // CHECK: omp.yield
+ omp.yield(%arg0 : !llvm.ptr)
+ } dealloc {
+ ^bb0(%arg0: !llvm.ptr):
+ %c42_i32 = arith.constant 42: i32
+ // CHECK: omp.canonical_loop %iv_s0 : i32 in range(%c42_i32) {
+ omp.canonical_loop %iv_s0 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ // CHECK: omp.canonical_loop %iv_s1 : i32 in range(%c42_i32) {
+ omp.canonical_loop %iv_s1 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ // CHECK: omp.yield
+ omp.yield
+ }
+
+ // CHECK: return
+ return
+}
+
+
+// CHECK-LABEL: @omp_canonloop_multiregion(
+func.func @omp_canonloop_multiregion(%c : i1) -> () {
+ %c42_i32 = arith.constant 42: i32
+ %canonloop1 = omp.new_cli
+ %canonloop2 = omp.new_cli
+ %canonloop3 = omp.new_cli
+ scf.if %c {
+ // CHECK: omp.canonical_loop(%canonloop_r0) %iv_r0 : i32 in range(%c42_i32) {
+ omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ } else {
+ // CHECK: omp.canonical_loop(%canonloop_r1_s0) %iv_r1_s0 : i32 in range(%c42_i32) {
+ omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ // CHECK: omp.canonical_loop(%canonloop_r1_s1) %iv_r1_s1 : i32 in range(%c42_i32) {
+ omp.canonical_loop(%canonloop3) %iv3 : i32 in range(%c42_i32) {
+ omp.terminator
+ }
+ }
+
+ // CHECK: return
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/cli-tile.mlir b/mlir/test/Dialect/OpenMP/cli-tile.mlir
new file mode 100644
index 0000000..73d5478
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/cli-tile.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s | FileCheck %s --enable-var-scope
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope
+
+
+// Raw syntax check (MLIR output is always pretty-printed)
+// CHECK-LABEL: @omp_tile_raw(
+// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) {
+func.func @omp_tile_raw(%tc : i32, %ts : i32) -> () {
+ // CHECK-NEXT: %canonloop = omp.new_cli
+ %canonloop = "omp.new_cli" () : () -> (!omp.cli)
+ // CHECK-NEXT: %grid1 = omp.new_cli
+ %grid = "omp.new_cli" () : () -> (!omp.cli)
+ // CHECK-NEXT: %intratile1 = omp.new_cli
+ %intratile = "omp.new_cli" () : () -> (!omp.cli)
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ "omp.canonical_loop" (%tc, %canonloop) ({
+ ^bb0(%iv: i32):
+ // CHECK: omp.terminator
+ omp.terminator
+ }) : (i32, !omp.cli) -> ()
+ // CHECK: omp.tile (%grid1, %intratile1) <- (%canonloop) sizes(%[[ts]] : i32)
+ "omp.tile"(%grid, %intratile, %canonloop, %ts) <{operandSegmentSizes = array<i32: 2, 1, 1>}> : (!omp.cli, !omp.cli, !omp.cli, i32) -> ()
+ //"omp.tile" (%canonloop) : (!omp.cli) -> ()
+ return
+}
+
+
+// Pretty syntax check
+// CHECK-LABEL: @omp_tile_pretty(
+// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) {
+func.func @omp_tile_pretty(%tc : i32, %ts : i32) -> () {
+ // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli
+ %canonloop = omp.new_cli
+ // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli
+ %grid = omp.new_cli
+ // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli
+ %intratile = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.tile (%grid1, %intratile1) <- (%canonloop) sizes(%[[ts]] : i32)
+ omp.tile(%grid, %intratile) <- (%canonloop) sizes(%ts : i32)
+ return
+}
+
+
+// Specifying the generatees for omp.tile is optional
+// CHECK-LABEL: @omp_tile_optionalgen_pretty(
+// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) {
+func.func @omp_tile_optionalgen_pretty(%tc : i32, %ts : i32) -> () {
+ // CHECK-NEXT: %canonloop = omp.new_cli
+ %canonloop = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.tile <- (%canonloop) sizes(%[[ts]] : i32)
+ omp.tile <- (%canonloop) sizes(%ts : i32)
+ return
+}
+
+
+// Two-dimensional tiling
+// CHECK-LABEL: @omp_tile_2d_pretty(
+// CHECK-SAME: %[[tc1:.+]]: i32, %[[tc2:.+]]: i32, %[[ts1:.+]]: i32, %[[ts2:.+]]: i32) {
+func.func @omp_tile_2d_pretty(%tc1 : i32, %tc2 : i32, %ts1 : i32, %ts2 : i32) -> () {
+ // CHECK-NEXT: %canonloop = omp.new_cli
+ %cli_outer = omp.new_cli
+ // CHECK-NEXT: %canonloop_d1 = omp.new_cli
+ %cli_inner = omp.new_cli
+ // CHECK-NEXT: %grid1 = omp.new_cli
+ %grid1 = omp.new_cli
+ // CHECK-NEXT: %grid2 = omp.new_cli
+ %grid2 = omp.new_cli
+ // CHECK-NEXT: %intratile1 = omp.new_cli
+ %intratile1 = omp.new_cli
+ // CHECK-NEXT: %intratile2 = omp.new_cli
+ %intratile2 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc1]]) {
+ omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc1) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc2]]) {
+ omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc2) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.tile (%grid1, %grid2, %intratile1, %intratile2) <- (%canonloop, %canonloop_d1) sizes(%[[ts1]], %[[ts2]] : i32, i32)
+ omp.tile (%grid1, %grid2, %intratile1, %intratile2) <- (%cli_outer, %cli_inner) sizes(%ts1, %ts2 : i32, i32)
+ return
+}
+
+
+// Three-dimensional tiling
+// CHECK-LABEL: @omp_tile_3d_pretty(
+// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) {
+func.func @omp_tile_3d_pretty(%tc : i32, %ts : i32) -> () {
+ // CHECK-NEXT: %canonloop = omp.new_cli
+ %cli_outer = omp.new_cli
+ // CHECK-NEXT: %canonloop_d1 = omp.new_cli
+ %cli_middle = omp.new_cli
+ // CHECK-NEXT: %canonloop_d2 = omp.new_cli
+ %cli_inner = omp.new_cli
+ // CHECK-NEXT: %grid1 = omp.new_cli
+ %grid1 = omp.new_cli
+ // CHECK-NEXT: %grid2 = omp.new_cli
+ %grid2 = omp.new_cli
+ // CHECK-NEXT: %grid3 = omp.new_cli
+ %grid3 = omp.new_cli
+ // CHECK-NEXT: %intratile1 = omp.new_cli
+ %intratile1 = omp.new_cli
+ // CHECK-NEXT: %intratile2 = omp.new_cli
+ %intratile2 = omp.new_cli
+ // CHECK-NEXT: %intratile3 = omp.new_cli
+ %intratile3 = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%cli_middle) %iv_middle : i32 in range(%tc) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%[[tc]]) {
+ omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.tile (%grid1, %grid2, %grid3, %intratile1, %intratile2, %intratile3) <- (%canonloop, %canonloop_d1, %canonloop_d2) sizes(%[[ts]], %[[ts]], %[[ts]] : i32, i32, i32)
+ omp.tile (%grid1, %grid2, %grid3, %intratile1, %intratile2, %intratile3) <- (%cli_outer, %cli_middle, %cli_inner) sizes(%ts, %ts, %ts: i32, i32, i32)
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir b/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir
index cda7d0b..16884f4 100644
--- a/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir
+++ b/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir
@@ -1,18 +1,18 @@
-// RUN: mlir-opt %s | FileCheck %s
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s --enable-var-scope
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope
// CHECK-LABEL: @omp_unroll_heuristic_raw(
// CHECK-SAME: %[[tc:.+]]: i32) {
func.func @omp_unroll_heuristic_raw(%tc : i32) -> () {
- // CHECK-NEXT: %canonloop_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop = omp.new_cli
%canonloop = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
"omp.canonical_loop" (%tc, %canonloop) ({
^bb0(%iv: i32):
omp.terminator
}) : (i32, !omp.cli) -> ()
- // CHECK: omp.unroll_heuristic(%canonloop_s0)
+ // CHECK: omp.unroll_heuristic(%canonloop)
"omp.unroll_heuristic" (%canonloop) : (!omp.cli) -> ()
return
}
@@ -22,12 +22,12 @@ func.func @omp_unroll_heuristic_raw(%tc : i32) -> () {
// CHECK-SAME: %[[tc:.+]]: i32) {
func.func @omp_unroll_heuristic_pretty(%tc : i32) -> () {
// CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli
- %canonloop = "omp.new_cli" () : () -> (!omp.cli)
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
+ %canonloop = omp.new_cli
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
omp.terminator
}
- // CHECK: omp.unroll_heuristic(%canonloop_s0)
+ // CHECK: omp.unroll_heuristic(%canonloop)
omp.unroll_heuristic(%canonloop)
return
}
@@ -36,13 +36,13 @@ func.func @omp_unroll_heuristic_pretty(%tc : i32) -> () {
// CHECK-LABEL: @omp_unroll_heuristic_nested_pretty(
// CHECK-SAME: %[[tc:.+]]: i32) {
func.func @omp_unroll_heuristic_nested_pretty(%tc : i32) -> () {
- // CHECK-NEXT: %canonloop_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop = omp.new_cli
%cli_outer = omp.new_cli
- // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli
+ // CHECK-NEXT: %canonloop_d1 = omp.new_cli
%cli_inner = omp.new_cli
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) {
omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc) {
- // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc]]) {
+ // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) {
omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc) {
// CHECK: omp.terminator
omp.terminator
@@ -51,9 +51,9 @@ func.func @omp_unroll_heuristic_nested_pretty(%tc : i32) -> () {
omp.terminator
}
- // CHECK: omp.unroll_heuristic(%canonloop_s0)
+ // CHECK: omp.unroll_heuristic(%canonloop)
omp.unroll_heuristic(%cli_outer)
- // CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0_s0)
+ // CHECK-NEXT: omp.unroll_heuristic(%canonloop_d1)
omp.unroll_heuristic(%cli_inner)
return
}
diff --git a/mlir/test/Dialect/OpenMP/invalid-tile.mlir b/mlir/test/Dialect/OpenMP/invalid-tile.mlir
new file mode 100644
index 0000000..e63a062
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/invalid-tile.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+
+func.func @missing_sizes(%tc : i32, %ts : i32) {
+ %canonloop = omp.new_cli
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ omp.terminator
+ }
+
+ // expected-error@+1 {{'omp.tile' op there must be one tile size for each applyee}}
+ omp.tile <-(%canonloop)
+
+ llvm.return
+}
+
+// -----
+
+func.func @no_loop(%tc : i32, %ts : i32) {
+ // expected-error@+1 {{'omp.tile' op must apply to at least one loop}}
+ omp.tile <-()
+
+ return
+}
+
+// -----
+
+func.func @missing_generator(%tc : i32, %ts : i32) {
+ // expected-error@+1 {{'omp.new_cli' op CLI has no generator}}
+ %canonloop = omp.new_cli
+
+ // expected-note@+1 {{see consumer here: "omp.tile"(%0, %arg1) <{operandSegmentSizes = array<i32: 0, 1, 1>}> : (!omp.cli, i32) -> ()}}
+ omp.tile <-(%canonloop) sizes(%ts : i32)
+
+ return
+}
+
+// -----
+
+func.func @insufficient_sizes(%tc : i32, %ts : i32) {
+ %canonloop1 = omp.new_cli
+ %canonloop2 = omp.new_cli
+ omp.canonical_loop(%canonloop1) %iv : i32 in range(%tc) {
+ omp.terminator
+ }
+ omp.canonical_loop(%canonloop2) %iv : i32 in range(%tc) {
+ omp.terminator
+ }
+
+ // expected-error@+1 {{'omp.tile' op there must be one tile size for each applyee}}
+ omp.tile <-(%canonloop1, %canonloop2) sizes(%ts : i32)
+
+ llvm.return
+}
+
+// -----
+
+func.func @insufficient_applyees(%tc : i32, %ts : i32) {
+ %canonloop = omp.new_cli
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ omp.terminator
+ }
+
+ // expected-error@+1 {{omp.tile' op there must be one tile size for each applyee}}
+ omp.tile <- (%canonloop) sizes(%ts, %ts : i32, i32)
+
+ return
+}
+
+// -----
+
+func.func @insufficient_generatees(%tc : i32, %ts : i32) {
+ %canonloop = omp.new_cli
+ %grid = omp.new_cli
+ omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ omp.terminator
+ }
+
+ // expected-error@+1 {{'omp.tile' op expecting two times the number of generatees than applyees}}
+ omp.tile (%grid) <- (%canonloop) sizes(%ts : i32)
+
+ return
+}
+
+// -----
+
+func.func @not_perfectly_nested(%tc : i32, %ts : i32) {
+ %canonloop1 = omp.new_cli
+ %canonloop2 = omp.new_cli
+ omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%tc) {
+ %v = arith.constant 42 : i32
+ omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%tc) {
+ omp.terminator
+ }
+ omp.terminator
+ }
+
+ // expected-error@+1 {{'omp.tile' op tiled loop nest must be perfectly nested}}
+ omp.tile <-(%canonloop1, %canonloop2) sizes(%ts, %ts : i32, i32)
+
+ llvm.return
+}
+
+// -----
+
+func.func @non_nectangular(%tc : i32, %ts : i32) {
+ %canonloop1 = omp.new_cli
+ %canonloop2 = omp.new_cli
+ omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%tc) {
+ omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%iv1) {
+ omp.terminator
+ }
+ omp.terminator
+ }
+
+ // expected-error@+1 {{'omp.tile' op tiled loop nest must be rectangular}}
+ omp.tile <-(%canonloop1, %canonloop2) sizes(%ts, %ts : i32, i32)
+
+ llvm.return
+}
diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
new file mode 100644
index 0000000..bc9a05a
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @promote_in0
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
+// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
+// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}}
+func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %mm = transform.structured.match ops{["linalg.matmul"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %mm[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_out
+// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+ // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
+ // CHECK-NOT: materialize_in_destination
+ // CHECK: linalg.add {{.*}} outs(%[[ALLOC]]
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_in0_out_bufferize
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
+ // CHECK: %{{.+}} = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
+ // CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
+ // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
+ %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %la[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ %func = transform.structured.match ops{["func.func"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+
+ %bufferized = transform.bufferization.one_shot_bufferize %func
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.yield
+ }
+}
+
+
+
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
index 2e5f433..efc3890 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error@below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}}
+ transform.tune.alternatives<"bifurcation"> selected_region = 2 {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param
+ // expected-error@below {{param should hold exactly one integer attribute, got: [0]}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %c0 = transform.param.constant 0 -> !transform.any_param
+ %c1 = transform.param.constant 1 -> !transform.any_param
+ %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param
+ // expected-error@below {{param should hold exactly one integer attribute}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %c2 = transform.param.constant 2 -> !transform.any_param
+ // expected-error@below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error@below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}}
+ transform.tune.alternatives<"bifurcation"> {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 0a253c6..5da48a2 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -59,3 +59,129 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+
+// -----
+
+// CHECK-LABEL: schedule_with_two_independent_choices_already_made
+func.func @schedule_with_two_independent_choices_already_made(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+// CHECK-NOT: scf.forall
+// CHECK: scf.for
+// CHECK-NOT: scf.for
+// CHECK: scf.forall
+// CHECK-NOT: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: linalg.matmul
+// CHECK: scf.forall.in_parallel
+// CHECK: tensor.parallel_insert_slice
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op
+ { // First alternative/region, with index = 0
+ %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }, { // Second alternative/region, with index = 1
+ %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }
+
+ transform.tune.alternatives<"inner_par_or_seq_tiling"> selected_region = 1 -> !transform.any_op {
+ %contained_matmul, %loop = transform.structured.tile_using_for %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }, {
+ %contained_matmul, %loop = transform.structured.tile_using_forall %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: subschedule_with_choice_resolved_in_main_schedule
+func.func @subschedule_with_choice_resolved_in_main_schedule(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+// CHECK-NOT: scf.for
+// CHECK: scf.forall
+// CHECK-NOT: scf.forall
+// CHECK: scf.for
+// CHECK-NOT: scf.forall
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: linalg.matmul
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+// CHECK: scf.forall.in_parallel
+// CHECK: tensor.parallel_insert_slice
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @subschedule_with_embedded_choice(%matmul: !transform.any_op {transform.readonly},
+ %par_or_seq: !transform.param<i64> {transform.readonly},
+ %tile_size: !transform.param<i64> {transform.readonly}) -> !transform.any_op {
+ %tiled_matmul = transform.tune.alternatives<"par_or_seq_tiling"> selected_region = %par_or_seq : !transform.param<i64> -> !transform.any_op {
+ %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }, {
+ %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.yield %contained_matmul : !transform.any_op
+ }
+ transform.yield %tiled_matmul : !transform.any_op
+ }
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %outer_par = transform.param.constant 1 -> !transform.param<i64>
+ %outer_tile_size = transform.param.constant 32 -> !transform.param<i64>
+ %inner_seq = transform.tune.knob<"inner_par_or_seq"> = 0 from options = [0, 1] -> !transform.param<i64>
+ %inner_tile_size = transform.param.constant 8 -> !transform.param<i64>
+ %tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%matmul, %outer_par, %outer_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
+ %tiled_tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%tiled_matmul, %inner_seq, %inner_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: eeny_meeny_miny_moe
+func.func private @eeny_meeny_miny_moe()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ %tiled_matmul = transform.tune.alternatives<"4way"> selected_region = 3 -> !transform.any_param
+ { // First alternative/region, with index = 0
+ %out = transform.param.constant "eeny" -> !transform.any_param
+ transform.yield %out : !transform.any_param
+ }, { // Second alternative/region, with index = 1
+ %out = transform.param.constant "meeny" -> !transform.any_param
+ transform.yield %out : !transform.any_param
+ }, { // Third alternative/region, with index = 2
+ %out = transform.param.constant "miny" -> !transform.any_param
+ transform.yield %out : !transform.any_param
+ }, { // Fourth alternative/region, with index = 3
+ %out = transform.param.constant "moe" -> !transform.any_param
+ transform.yield %out : !transform.any_param
+ }
+ transform.yield
+ }
+} \ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 03c6386..38392fd 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -282,15 +282,20 @@ gpu.module @test_distribution {
// CHECK-LABEL: @store_scatter
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
gpu.func @store_scatter(%dest : memref<256xf16>) {
- // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
- // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
+ // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
+ // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
- %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
- %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
- %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
- xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+ %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+ layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+ layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+ l1_hint = #xegpu.cache_hint<cached>}
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
gpu.return
}
diff --git a/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir b/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir
new file mode 100644
index 0000000..0d559b6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --enable-var-scope
+
+
+llvm.func @tile_trivial_loop(%baseptr: !llvm.ptr, %tc: i32, %ts: i32) -> () {
+ %literal_cli = omp.new_cli
+ omp.canonical_loop(%literal_cli) %iv : i32 in range(%tc) {
+ %ptr = llvm.getelementptr inbounds %baseptr[%iv] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %val = llvm.mlir.constant(42.0 : f32) : f32
+ llvm.store %val, %ptr : f32, !llvm.ptr
+ omp.terminator
+ }
+ omp.tile <- (%literal_cli) sizes(%ts : i32)
+ llvm.return
+}
+
+
+// CHECK-LABEL: define void @tile_trivial_loop(
+// CHECK-SAME: ptr %[[TMP0:.+]], i32 %[[TMP1:.+]], i32 %[[TMP2:.+]]) {
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER]]:
+// CHECK-NEXT: %[[TMP4:.+]] = udiv i32 %[[TMP1:.+]], %[[TMP2:.+]]
+// CHECK-NEXT: %[[TMP5:.+]] = urem i32 %[[TMP1:.+]], %[[TMP2:.+]]
+// CHECK-NEXT: %[[TMP6:.+]] = icmp ne i32 %[[TMP5:.+]], 0
+// CHECK-NEXT: %[[TMP7:.+]] = zext i1 %[[TMP6:.+]] to i32
+// CHECK-NEXT: %[[OMP_FLOOR0_TRIPCOUNT:.+]] = add nuw i32 %[[TMP4:.+]], %[[TMP7:.+]]
+// CHECK-NEXT: br label %[[OMP_FLOOR0_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_HEADER]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR0_PREHEADER:.+]] ], [ %[[OMP_FLOOR0_NEXT:.+]], %[[OMP_FLOOR0_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_FLOOR0_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_COND]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR0_IV:.+]], %[[OMP_FLOOR0_TRIPCOUNT:.+]]
+// CHECK-NEXT: br i1 %[[OMP_FLOOR0_CMP:.+]], label %[[OMP_FLOOR0_BODY:.+]], label %[[OMP_FLOOR0_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_BODY]]:
+// CHECK-NEXT: %[[TMP8:.+]] = icmp eq i32 %[[OMP_FLOOR0_IV:.+]], %[[TMP4:.+]]
+// CHECK-NEXT: %[[TMP9:.+]] = select i1 %[[TMP8:.+]], i32 %[[TMP5:.+]], i32 %[[TMP2:.+]]
+// CHECK-NEXT: br label %[[OMP_TILE0_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_HEADER]]:
+// CHECK-NEXT: %[[OMP_TILE0_IV:.+]] = phi i32 [ 0, %[[OMP_TILE0_PREHEADER:.+]] ], [ %[[OMP_TILE0_NEXT:.+]], %[[OMP_TILE0_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_TILE0_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_COND]]:
+// CHECK-NEXT: %[[OMP_TILE0_CMP:.+]] = icmp ult i32 %[[OMP_TILE0_IV:.+]], %[[TMP9:.+]]
+// CHECK-NEXT: br i1 %[[OMP_TILE0_CMP:.+]], label %[[OMP_TILE0_BODY:.+]], label %[[OMP_TILE0_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_BODY]]:
+// CHECK-NEXT: %[[TMP10:.+]] = mul nuw i32 %[[TMP2:.+]], %[[OMP_FLOOR0_IV:.+]]
+// CHECK-NEXT: %[[TMP11:.+]] = add nuw i32 %[[TMP10:.+]], %[[OMP_TILE0_IV:.+]]
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_BODY]]:
+// CHECK-NEXT: br label %[[OMP_LOOP_REGION:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_LOOP_REGION]]:
+// CHECK-NEXT: %[[TMP12:.+]] = getelementptr inbounds float, ptr %[[TMP0:.+]], i32 %[[TMP11:.+]]
+// CHECK-NEXT: store float 4.200000e+01, ptr %[[TMP12:.+]], align 4
+// CHECK-NEXT: br label %[[OMP_REGION_CONT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_REGION_CONT]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_INC]]:
+// CHECK-NEXT: %[[OMP_TILE0_NEXT:.+]] = add nuw i32 %[[OMP_TILE0_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_INC]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR0_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_AFTER]]:
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
diff --git a/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir b/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir
new file mode 100644
index 0000000..22c2973
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir
@@ -0,0 +1,184 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --enable-var-scope
+
+
+llvm.func @tile_2d_loop(%baseptr: !llvm.ptr, %tc1: i32, %tc2: i32, %ts1: i32, %ts2: i32) -> () {
+ %literal_outer = omp.new_cli
+ %literal_inner = omp.new_cli
+ omp.canonical_loop(%literal_outer) %iv1 : i32 in range(%tc1) {
+ omp.canonical_loop(%literal_inner) %iv2 : i32 in range(%tc2) {
+ %idx = llvm.add %iv1, %iv2 : i32
+ %ptr = llvm.getelementptr inbounds %baseptr[%idx] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %val = llvm.mlir.constant(42.0 : f32) : f32
+ llvm.store %val, %ptr : f32, !llvm.ptr
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.tile <- (%literal_outer, %literal_inner) sizes(%ts1, %ts2 : i32,i32)
+ llvm.return
+}
+
+
+// CHECK-LABEL: define void @tile_2d_loop(
+// CHECK-SAME: ptr %[[TMP0:.+]], i32 %[[TMP1:.+]], i32 %[[TMP2:.+]], i32 %[[TMP3:.+]], i32 %[[TMP4:.+]]) {
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER]]:
+// CHECK-NEXT: %[[TMP6:.+]] = udiv i32 %[[TMP1:.+]], %[[TMP3:.+]]
+// CHECK-NEXT: %[[TMP7:.+]] = urem i32 %[[TMP1:.+]], %[[TMP3:.+]]
+// CHECK-NEXT: %[[TMP8:.+]] = icmp ne i32 %[[TMP7:.+]], 0
+// CHECK-NEXT: %[[TMP9:.+]] = zext i1 %[[TMP8:.+]] to i32
+// CHECK-NEXT: %[[OMP_FLOOR0_TRIPCOUNT:.+]] = add nuw i32 %[[TMP6:.+]], %[[TMP9:.+]]
+// CHECK-NEXT: %[[TMP10:.+]] = udiv i32 %[[TMP2:.+]], %[[TMP4:.+]]
+// CHECK-NEXT: %[[TMP11:.+]] = urem i32 %[[TMP2:.+]], %[[TMP4:.+]]
+// CHECK-NEXT: %[[TMP12:.+]] = icmp ne i32 %[[TMP11:.+]], 0
+// CHECK-NEXT: %[[TMP13:.+]] = zext i1 %[[TMP12:.+]] to i32
+// CHECK-NEXT: %[[OMP_FLOOR1_TRIPCOUNT:.+]] = add nuw i32 %[[TMP10:.+]], %[[TMP13:.+]]
+// CHECK-NEXT: br label %[[OMP_FLOOR0_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_HEADER:.+]]:
+// CHECK-NEXT: %[[OMP_OMP_LOOP_IV:.+]] = phi i32 [ %[[OMP_OMP_LOOP_NEXT:.+]], %[[OMP_OMP_LOOP_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_COND]]:
+// CHECK-NEXT: %[[OMP_OMP_LOOP_CMP:.+]] = icmp ult i32 %[[TMP19:.+]], %[[TMP1:.+]]
+// CHECK-NEXT: br i1 %[[OMP_OMP_LOOP_CMP:.+]], label %[[OMP_OMP_LOOP_BODY:.+]], label %[[OMP_OMP_LOOP_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_BODY]]:
+// CHECK-NEXT: br label %[[OMP_LOOP_REGION:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_LOOP_REGION]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER1:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER1]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY4:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_HEADER]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR0_PREHEADER:.+]] ], [ %[[OMP_FLOOR0_NEXT:.+]], %[[OMP_FLOOR0_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_FLOOR0_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_COND]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR0_IV:.+]], %[[OMP_FLOOR0_TRIPCOUNT:.+]]
+// CHECK-NEXT: br i1 %[[OMP_FLOOR0_CMP:.+]], label %[[OMP_FLOOR0_BODY:.+]], label %[[OMP_FLOOR0_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_BODY]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR1_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR1_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_HEADER]]:
+// CHECK-NEXT: %[[OMP_FLOOR1_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR1_PREHEADER:.+]] ], [ %[[OMP_FLOOR1_NEXT:.+]], %[[OMP_FLOOR1_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_FLOOR1_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_COND]]:
+// CHECK-NEXT: %[[OMP_FLOOR1_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR1_IV:.+]], %[[OMP_FLOOR1_TRIPCOUNT:.+]]
+// CHECK-NEXT: br i1 %[[OMP_FLOOR1_CMP:.+]], label %[[OMP_FLOOR1_BODY:.+]], label %[[OMP_FLOOR1_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_BODY]]:
+// CHECK-NEXT: %[[TMP14:.+]] = icmp eq i32 %[[OMP_FLOOR0_IV:.+]], %[[TMP6:.+]]
+// CHECK-NEXT: %[[TMP15:.+]] = select i1 %[[TMP14:.+]], i32 %[[TMP7:.+]], i32 %[[TMP3:.+]]
+// CHECK-NEXT: %[[TMP16:.+]] = icmp eq i32 %[[OMP_FLOOR1_IV:.+]], %[[TMP10:.+]]
+// CHECK-NEXT: %[[TMP17:.+]] = select i1 %[[TMP16:.+]], i32 %[[TMP11:.+]], i32 %[[TMP4:.+]]
+// CHECK-NEXT: br label %[[OMP_TILE0_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_HEADER]]:
+// CHECK-NEXT: %[[OMP_TILE0_IV:.+]] = phi i32 [ 0, %[[OMP_TILE0_PREHEADER:.+]] ], [ %[[OMP_TILE0_NEXT:.+]], %[[OMP_TILE0_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_TILE0_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_COND]]:
+// CHECK-NEXT: %[[OMP_TILE0_CMP:.+]] = icmp ult i32 %[[OMP_TILE0_IV:.+]], %[[TMP15:.+]]
+// CHECK-NEXT: br i1 %[[OMP_TILE0_CMP:.+]], label %[[OMP_TILE0_BODY:.+]], label %[[OMP_TILE0_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_BODY]]:
+// CHECK-NEXT: br label %[[OMP_TILE1_PREHEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_PREHEADER]]:
+// CHECK-NEXT: br label %[[OMP_TILE1_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_HEADER]]:
+// CHECK-NEXT: %[[OMP_TILE1_IV:.+]] = phi i32 [ 0, %[[OMP_TILE1_PREHEADER:.+]] ], [ %[[OMP_TILE1_NEXT:.+]], %[[OMP_TILE1_INC:.+]] ]
+// CHECK-NEXT: br label %[[OMP_TILE1_COND:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_COND]]:
+// CHECK-NEXT: %[[OMP_TILE1_CMP:.+]] = icmp ult i32 %[[OMP_TILE1_IV:.+]], %[[TMP17:.+]]
+// CHECK-NEXT: br i1 %[[OMP_TILE1_CMP:.+]], label %[[OMP_TILE1_BODY:.+]], label %[[OMP_TILE1_EXIT:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_BODY]]:
+// CHECK-NEXT: %[[TMP18:.+]] = mul nuw i32 %[[TMP3:.+]], %[[OMP_FLOOR0_IV:.+]]
+// CHECK-NEXT: %[[TMP19:.+]] = add nuw i32 %[[TMP18:.+]], %[[OMP_TILE0_IV:.+]]
+// CHECK-NEXT: %[[TMP20:.+]] = mul nuw i32 %[[TMP4:.+]], %[[OMP_FLOOR1_IV:.+]]
+// CHECK-NEXT: %[[TMP21:.+]] = add nuw i32 %[[TMP20:.+]], %[[OMP_TILE1_IV:.+]]
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_BODY4]]:
+// CHECK-NEXT: br label %[[OMP_LOOP_REGION12:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_LOOP_REGION12]]:
+// CHECK-NEXT: %[[TMP22:.+]] = add i32 %[[TMP19:.+]], %[[TMP21:.+]]
+// CHECK-NEXT: %[[TMP23:.+]] = getelementptr inbounds float, ptr %[[TMP0:.+]], i32 %[[TMP22:.+]]
+// CHECK-NEXT: store float 4.200000e+01, ptr %[[TMP23:.+]], align 4
+// CHECK-NEXT: br label %[[OMP_REGION_CONT11:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_REGION_CONT11]]:
+// CHECK-NEXT: br label %[[OMP_TILE1_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_INC]]:
+// CHECK-NEXT: %[[OMP_TILE1_NEXT:.+]] = add nuw i32 %[[OMP_TILE1_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_TILE1_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_TILE1_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE1_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_INC]]:
+// CHECK-NEXT: %[[OMP_TILE0_NEXT:.+]] = add nuw i32 %[[OMP_TILE0_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_TILE0_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_TILE0_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR1_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_INC]]:
+// CHECK-NEXT: %[[OMP_FLOOR1_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR1_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_FLOOR1_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR1_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR1_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_INC]]:
+// CHECK-NEXT: %[[OMP_FLOOR0_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR0_IV:.+]], 1
+// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_FLOOR0_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_FLOOR0_AFTER]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_REGION_CONT:.+]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_INC:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_INC]]:
+// CHECK-NEXT: %[[OMP_OMP_LOOP_NEXT:.+]] = add nuw i32 %[[TMP19:.+]], 1
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_HEADER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_EXIT]]:
+// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]]
+// CHECK-EMPTY:
+// CHECK-NEXT: [[OMP_OMP_LOOP_AFTER]]:
+// CHECK-NEXT: ret void
+// CHECK-NEXT: }
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index e043a8c..00ee6b7 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) {
llvm.return
}
+// CHECK-LABEL: rocdl.cvt.scalef32.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) {
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]])
+ %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]])
+ %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]])
+ %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32
+
+ llvm.return
+}
+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
index 103bc94..7d32577 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
@@ -12,5 +12,7 @@ add_mlir_library(MLIRTestIRDLToCppDialect
mlir_target_link_libraries(MLIRTestIRDLToCppDialect PUBLIC
MLIRIR
MLIRPass
+ MLIRSCFDialect
MLIRTransforms
+ MLIRTestDialect
)
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
index 9550e4c..421db7e 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
@@ -13,6 +13,7 @@
// #include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -54,16 +55,34 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> {
}
};
+struct TestRegionConversion
+ : public OpConversionPattern<test_irdl_to_cpp::ConditionalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(mlir::test_irdl_to_cpp::ConditionalOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Just exercising the C++ API even though these are not enforced in the
+ // dialect definition
+ assert(op.getThen().getBlocks().size() == 1);
+ assert(adaptor.getElse().getBlocks().size() == 1);
+ auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), op.getInput());
+ rewriter.replaceOp(op, ifOp);
+ return success();
+ }
+};
+
struct ConvertTestDialectToSomethingPass
: PassWrapper<ConvertTestDialectToSomethingPass, OperationPass<ModuleOp>> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
- patterns.add<TestOpConversion>(ctx);
+ patterns.add<TestOpConversion, TestRegionConversion>(ctx);
ConversionTarget target(getContext());
- target.addIllegalOp<test_irdl_to_cpp::BeefOp>();
- target.addLegalOp<test_irdl_to_cpp::BarOp>();
- target.addLegalOp<test_irdl_to_cpp::HashOp>();
+ target.addIllegalOp<test_irdl_to_cpp::BeefOp,
+ test_irdl_to_cpp::ConditionalOp>();
+ target.addLegalOp<test_irdl_to_cpp::BarOp, test_irdl_to_cpp::HashOp,
+ scf::IfOp, scf::YieldOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -73,6 +92,10 @@ struct ConvertTestDialectToSomethingPass
StringRef getDescription() const final {
return "Checks the convertability of an irdl dialect";
}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<scf::SCFDialect>();
+ }
};
void registerIrdlTestDialect(mlir::DialectRegistry &registry) {
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
index f6233ee..1915324 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
@@ -1,15 +1,29 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-irdl-conversion-check)" | FileCheck %s
// CHECK-LABEL: module {
module {
- // CHECK: func.func @test() {
+ // CHECK: func.func @test(%[[test_arg:[^ ]*]]: i1) {
// CHECK: %[[v0:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32
// CHECK: %[[v1:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32
// CHECK: %[[v2:[^ ]*]] = "test_irdl_to_cpp.hash"(%[[v0]], %[[v0]]) : (i32, i32) -> i32
+ // CHECK: scf.if %[[test_arg]]
// CHECK: return
// CHECK: }
- func.func @test() {
+ func.func @test(%test_arg: i1) {
%0 = "test_irdl_to_cpp.bar"() : () -> i32
%1 = "test_irdl_to_cpp.beef"(%0, %0) : (i32, i32) -> i32
+ "test_irdl_to_cpp.conditional"(%test_arg) ({
+ ^cond(%test: i1):
+ %3 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }, {
+ ^then(%what: i1, %ever: i32):
+ %4 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }, {
+ ^else():
+ %5 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }) : (i1) -> ()
return
}
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
index 42e713e..85fb8cb 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
@@ -2,7 +2,7 @@
// CHECK: class TestIrdlToCpp
irdl.dialect @test_irdl_to_cpp {
-
+
// CHECK: class FooType
irdl.type @foo
@@ -32,4 +32,53 @@ irdl.dialect @test_irdl_to_cpp {
irdl.operands(lhs: %0, rhs: %0)
irdl.results(res: %0)
}
+
+ // CHECK: ConditionalOp declarations
+ // CHECK: ConditionalOpGenericAdaptorBase
+ // CHECK: ::mlir::Region &getCond() { return *getRegions()[0]; }
+ // CHECK: ::mlir::Region &getThen() { return *getRegions()[1]; }
+ // CHECK: ::mlir::Region &getElse() { return *getRegions()[2]; }
+ //
+ // CHECK: class ConditionalOp : public ::mlir::Op<ConditionalOp, ::mlir::OpTrait::NRegions<3>::Impl, ::mlir::OpTrait::OpInvariants>
+ // CHECK: ::mlir::Region &getCond() { return (*this)->getRegion(0); }
+ // CHECK: ::mlir::Region &getThen() { return (*this)->getRegion(1); }
+ // CHECK: ::mlir::Region &getElse() { return (*this)->getRegion(2); }
+
+ // CHECK: ConditionalOp definitions
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond
+ // CHECK: if (!(region.getNumArguments() == 1)) {
+ // CHECK: failed to verify constraint: region with 1 entry block argument(s)
+
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then
+ // CHECK: if (!(true)) {
+
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else
+ // CHECK: if (!(region.getNumArguments() == 0)) {
+ // CHECK: failed to verify constraint: region with 0 entry block argument(s)
+
+ // CHECK: ConditionalOp::build
+ // CHECK: for (unsigned i = 0; i != 3; ++i)
+ // CHECK-NEXT: (void)odsState.addRegion();
+
+ // CHECK: ConditionalOp::verifyInvariantsImpl
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond
+ // CHECK: failure
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then
+ // CHECK: failure
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else
+ // CHECK: failure
+ // CHECK: success
+ irdl.operation @conditional {
+ %r0 = irdl.region // Unconstrained region
+ %r1 = irdl.region() // Region with no entry block arguments
+
+ // TODO(#161018): support irdl.is in irdl-to-cpp
+ // %v0 = irdl.is i1 // Type constraint: i1 (boolean)
+ %v0 = irdl.any
+ %r2 = irdl.region(%v0) // Region with one i1 entry block argument
+ irdl.regions(cond: %r2, then: %r0, else: %r1)
+
+ %0 = irdl.any
+ irdl.operands(input: %0)
+ }
}
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
index 403b492..cc27456 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
@@ -7,7 +7,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.results(res: %1)
}
}
-// -----
+// -----
irdl.dialect @test_irdl_to_cpp {
irdl.operation @operands_no_any_of {
@@ -42,7 +42,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.dialect @test_irdl_to_cpp {
irdl.type @ty {
- %0 = irdl.any
+ %0 = irdl.any
// expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}}
irdl.parameters(ty: %0)
}
@@ -51,29 +51,8 @@ irdl.dialect @test_irdl_to_cpp {
// -----
irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.region operation}}
- %0 = irdl.region()
- irdl.regions(reg: %0)
- }
-
-}
-
-// -----
-
-irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}}
- irdl.regions()
- }
-
-}
-
-// -----
-
-irdl.dialect @test_irdl_to_cpp {
irdl.type @test_derived {
// expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.base operation}}
%0 = irdl.base "!builtin.integer"
- }
+ }
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 094ef0a..e51cac4 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -173,8 +173,6 @@ struct TestXeGPUUnrollingPatterns
#undef DEBUG_TYPE
#define DEBUG_TYPE "test-xegpu-layout-interface"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
// Test pattern for distributing vector::StepOp from workgroup to subgroup.
// Validates DistributeLayoutAttr interfaces for offset computation
diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index 2f29543..0a022ad 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -307,7 +307,7 @@ def DirectiveTypeZOperandInvalidI : TestFormat_Op<[{
def LiteralInvalidA : TestFormat_Op<[{
`a:`
}]>;
-// CHECK: error: expected valid literal but got '1': single character literal must be a letter or one of '_:,=<>()[]{}?+*'
+// CHECK: error: expected valid literal but got '1': single character literal must be a letter or one of '_:,=<>()[]{}?+-*'
def LiteralInvalidB : TestFormat_Op<[{
`1`
}]>;
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 1541cd0..1ac2311 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -123,7 +123,7 @@ def DirectiveTypeValid : TestFormat_Op<[{
// CHECK-NOT: error
def LiteralValid : TestFormat_Op<[{
- `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._`
+ `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `-` `*` ` ` `` `->` `\n` `abc$._`
attr-dict
}]>;
diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index dfb9359..eb2a083 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -1,21 +1,21 @@
# RUN: %PYTHON %s | FileCheck %s
-from mlir.ir import *
+from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import tune, debug
def run(f):
- print("\nTEST:", f.__name__)
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
+ print("\n// TEST:", f.__name__)
+ with ir.Context(), ir.Location.unknown():
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
- with InsertionPoint(sequence.body):
+ with ir.InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print(module)
@@ -29,10 +29,10 @@ def testKnobOp(target):
# CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
heads_or_tails = tune.KnobOp(
- result=any_param, name=StringAttr.get("coin"), options=[True, False]
+ result=any_param, name=ir.StringAttr.get("coin"), options=[True, False]
)
# CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
- tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()])
+ tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()])
# CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32])
# CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param
@@ -45,7 +45,10 @@ def testKnobOp(target):
heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True)
# CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
tune.KnobOp(
- any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog"
+ any_param,
+ name="animal",
+ options=["cat", "dog", ir.UnitAttr.get()],
+ selected="dog",
)
# CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8)
@@ -57,16 +60,90 @@ def testKnobOp(target):
# CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param
# NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified.
- i64 = IntegerType.get_signless(64)
+ i64 = ir.IntegerType.get_signless(64)
tune.knob(
any_param,
"range_as_a_dict",
- DictAttr.get(
+ ir.DictAttr.get(
{
- "start": IntegerAttr.get(i64, 2),
- "stop": IntegerAttr.get(i64, 16),
- "step": IntegerAttr.get(i64, 2),
+ "start": ir.IntegerAttr.get(i64, 2),
+ "stop": ir.IntegerAttr.get(i64, 16),
+ "step": ir.IntegerAttr.get(i64, 2),
}
),
selected=4,
)
+
+
+# CHECK-LABEL: TEST: testAlternativesOp
+@run
+def testAlternativesOp(target):
+ any_param = transform.AnyParamType.get()
+
+ # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param {
+ left_or_right = tune.AlternativesOp(
+ [transform.AnyParamType.get()], "left_or_right", 2
+ )
+ idx_for_left, idx_for_right = 0, 1
+ with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]):
+ # CHECK: %[[C0:.*]] = transform.param.constant 0
+ i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
+ c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
+ # CHECK: transform.yield %[[C0]]
+ transform.yield_(c0)
+ # CHECK-NEXT: }, {
+ with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]):
+ # CHECK: %[[C1:.*]] = transform.param.constant 1
+ i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
+ c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
+ # CHECK: transform.yield %[[C1]]
+ transform.yield_(c1)
+ # CHECK-NEXT: }
+ outcome_of_left_or_right_decision = left_or_right.results[0]
+
+ # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param {
+ fork_in_the_road = tune.AlternativesOp(
+ [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0
+ )
+ with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]):
+ # CHECK: %[[C0:.*]] = transform.param.constant 0
+ i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
+ c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
+ # CHECK: transform.yield %[[C0]]
+ transform.yield_(c0)
+ # CHECK-NEXT: }, {
+ with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]):
+ # CHECK: %[[C1:.*]] = transform.param.constant 1
+ i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
+ c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
+ # CHECK: transform.yield %[[C1]]
+ transform.yield_(c1)
+ # CHECK-NEXT: }
+
+ # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param {
+ left_or_right_as_before = tune.AlternativesOp(
+ [],
+ "left_or_right_as_before",
+ 2,
+ selected_region=outcome_of_left_or_right_decision,
+ )
+ with ir.InsertionPoint(
+ left_or_right_as_before.alternatives[idx_for_left].blocks[0]
+ ):
+ # CHECK: transform.param.constant 1337
+ i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)
+ c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337)
+ # CHECK: transform.debug.emit_param_as_remark
+ debug.emit_param_as_remark(c1337)
+ transform.yield_([])
+ # CHECK-NEXT: }, {
+ with ir.InsertionPoint(
+ left_or_right_as_before.alternatives[idx_for_right].blocks[0]
+ ):
+ # CHECK: transform.param.constant 42
+ i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+ c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42)
+ # CHECK: transform.debug.emit_param_as_remark
+ debug.emit_param_as_remark(c42)
+ transform.yield_([])
+ # CHECK-NEXT: }
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 4a3625c..cb4cfc8c 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -696,6 +696,7 @@ def testOperationPrint():
# CHECK: resource1: "0x08
module.operation.print(large_elements_limit=2)
+
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
@@ -969,6 +970,13 @@ def testOperationLoc():
assert op.location == loc
assert op.operation.location == loc
+ another_loc = Location.name("another_loc")
+ op.location = another_loc
+ assert op.location == another_loc
+ assert op.operation.location == another_loc
+ # CHECK: loc("another_loc")
+ print(op.location)
+
# CHECK-LABEL: TEST: testModuleMerge
@run
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
index fd8ae7e..795766f 100644
--- a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -35,7 +35,7 @@ namespace mlir {
using OperationDefinition = AsmParserState::OperationDefinition;
/// Return the source code associated with the OperationDefinition.
-SMRange getOpRange(const OperationDefinition &op) {
+static SMRange getOpRange(const OperationDefinition &op) {
const char *startOp = op.scopeLoc.Start.getPointer();
const char *endOp = op.scopeLoc.End.getPointer();
@@ -187,15 +187,15 @@ std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename,
}
/// Return the source code associated with the operation name.
-SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+static SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
/// Return whether the operation was printed using generic syntax in original
/// buffer.
-bool isGeneric(const OperationDefinition &op) {
+static bool isGeneric(const OperationDefinition &op) {
return op.loc.Start.getPointer()[0] == '"';
}
-inline int asMainReturnCode(LogicalResult r) {
+static inline int asMainReturnCode(LogicalResult r) {
return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
}
@@ -293,7 +293,7 @@ static llvm::cl::opt<std::string> simpleRenameReplace{
llvm::cl::cat(clSimpleRenameCategory)};
// Rewriter that does simple renames.
-LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
+static LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
StringRef opName = simpleRenameOpName;
StringRef match = simpleRenameMatch;
StringRef replace = simpleRenameReplace;
@@ -317,7 +317,7 @@ static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
simpleRename);
// Rewriter that insert range markers.
-LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
+static LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
for (const auto &it : rewriteState.getOpDefs()) {
auto [startOp, endOp] = getOpRange(it);
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index a1899a8..8dd9713 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -403,6 +403,7 @@ void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
.Case("]", "RSquare")
.Case("?", "Question")
.Case("+", "Plus")
+ .Case("-", "Minus")
.Case("*", "Star")
.Case("...", "Ellipsis")
<< "()";
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 4dfdde2..04d3ed1 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -518,7 +518,7 @@ bool mlir::tblgen::isValidLiteral(StringRef value,
// If there is only one character, this must either be punctuation or a
// single character bare identifier.
if (value.size() == 1) {
- StringRef bare = "_:,=<>()[]{}?+*";
+ StringRef bare = "_:,=<>()[]{}?+-*";
if (isalpha(front) || bare.contains(front))
return true;
if (emitError)
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0d113b3..ccf21d1 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -852,6 +852,7 @@ static void genLiteralParser(StringRef value, MethodBody &body) {
.Case("]", "RSquare()")
.Case("?", "Question()")
.Case("+", "Plus()")
+ .Case("-", "Minus()")
.Case("*", "Star()")
.Case("...", "Ellipsis()");
}
diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp
index 27f2fa0..ac01d49 100644
--- a/mlir/unittests/TableGen/PassGenTest.cpp
+++ b/mlir/unittests/TableGen/PassGenTest.cpp
@@ -11,7 +11,8 @@
#include "gmock/gmock.h"
-std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
+static std::unique_ptr<mlir::Pass>
+createTestPassWithCustomConstructor(int v = 0);
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION