diff options
Diffstat (limited to 'mlir')
96 files changed, 3161 insertions, 516 deletions
diff --git a/mlir/.clang-format b/mlir/.clang-format index a74fda4..76cc928 100644 --- a/mlir/.clang-format +++ b/mlir/.clang-format @@ -1,2 +1,3 @@ BasedOnStyle: LLVM AlwaysBreakTemplateDeclarations: Yes +LineEnding: LF diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 061d762..c464e4d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -634,6 +634,10 @@ MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); /// Gets the location of the operation. MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op); +/// Sets the location of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetLocation(MlirOperation op, + MlirLocation loc); + /// Gets the type id of the operation. /// Returns null if the operation does not have a registered operation /// description. diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index b5f985f..847951a 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -116,7 +116,8 @@ mlirApiObjectToCapsule(nanobind::handle apiObject) { /// Casts object <-> MlirAffineMap. template <> struct type_caster<MlirAffineMap> { - NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) + NB_TYPE_CASTER(MlirAffineMap, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAffineMap(capsule->ptr()); @@ -138,7 +139,8 @@ struct type_caster<MlirAffineMap> { /// Casts object <-> MlirAttribute. template <> struct type_caster<MlirAttribute> { - NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) + NB_TYPE_CASTER(MlirAttribute, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAttribute(capsule->ptr()); @@ -161,7 +163,7 @@ struct type_caster<MlirAttribute> { /// Casts object -> MlirBlock. template <> struct type_caster<MlirBlock> { - NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) + NB_TYPE_CASTER(MlirBlock, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Block"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToBlock(capsule->ptr()); @@ -174,7 +176,8 @@ struct type_caster<MlirBlock> { /// Casts object -> MlirContext. template <> struct type_caster<MlirContext> { - NB_TYPE_CASTER(MlirContext, const_name("MlirContext")) + NB_TYPE_CASTER(MlirContext, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Context"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. @@ -192,7 +195,8 @@ struct type_caster<MlirContext> { /// Casts object <-> MlirDialectRegistry. template <> struct type_caster<MlirDialectRegistry> { - NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) + NB_TYPE_CASTER(MlirDialectRegistry, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToDialectRegistry(capsule->ptr()); @@ -214,7 +218,8 @@ struct type_caster<MlirDialectRegistry> { /// Casts object <-> MlirLocation. template <> struct type_caster<MlirLocation> { - NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")) + NB_TYPE_CASTER(MlirLocation, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Location"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. @@ -240,7 +245,7 @@ struct type_caster<MlirLocation> { /// Casts object <-> MlirModule. template <> struct type_caster<MlirModule> { - NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) + NB_TYPE_CASTER(MlirModule, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Module"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToModule(capsule->ptr()); @@ -262,8 +267,9 @@ struct type_caster<MlirModule> { /// Casts object <-> MlirFrozenRewritePatternSet. template <> struct type_caster<MlirFrozenRewritePatternSet> { - NB_TYPE_CASTER(MlirFrozenRewritePatternSet, - const_name("MlirFrozenRewritePatternSet")) + NB_TYPE_CASTER( + MlirFrozenRewritePatternSet, + const_name(MAKE_MLIR_PYTHON_QUALNAME("rewrite.FrozenRewritePatternSet"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr()); @@ -285,7 +291,8 @@ struct type_caster<MlirFrozenRewritePatternSet> { /// Casts object <-> MlirOperation. template <> struct type_caster<MlirOperation> { - NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) + NB_TYPE_CASTER(MlirOperation, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Operation"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToOperation(capsule->ptr()); @@ -309,7 +316,7 @@ struct type_caster<MlirOperation> { /// Casts object <-> MlirValue. template <> struct type_caster<MlirValue> { - NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) + NB_TYPE_CASTER(MlirValue, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Value"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToValue(capsule->ptr()); @@ -334,7 +341,8 @@ struct type_caster<MlirValue> { /// Casts object -> MlirPassManager. template <> struct type_caster<MlirPassManager> { - NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) + NB_TYPE_CASTER(MlirPassManager, const_name(MAKE_MLIR_PYTHON_QUALNAME( + "passmanager.PassManager"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToPassManager(capsule->ptr()); @@ -347,7 +355,7 @@ struct type_caster<MlirPassManager> { /// Casts object <-> MlirTypeID. template <> struct type_caster<MlirTypeID> { - NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) + NB_TYPE_CASTER(MlirTypeID, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToTypeID(capsule->ptr()); @@ -371,7 +379,7 @@ struct type_caster<MlirTypeID> { /// Casts object <-> MlirType. template <> struct type_caster<MlirType> { - NB_TYPE_CASTER(MlirType, const_name("MlirType")) + NB_TYPE_CASTER(MlirType, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Type"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToType(capsule->ptr()); @@ -394,7 +402,7 @@ struct type_caster<MlirType> { /// Casts MlirStringRef -> object. template <> struct type_caster<MlirStringRef> { - NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef")) + NB_TYPE_CASTER(MlirStringRef, const_name("str")) static handle from_cpp(MlirStringRef s, rv_policy, cleanup_list *cleanup) noexcept { return nanobind::str(s.data, s.length).release(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 8b687a7..29001e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -985,7 +985,6 @@ class ScaleArgInfo<TypeConstraint argTyVal, string typeName> { //===---------------------------------------------------------------------===// // Scaled {fp4,bf8,fp8} to {bf16,f16,f32} conversion intrinsics //===---------------------------------------------------------------------===// - foreach smallT = [ ScaleArgInfo<I32, "Fp4">, ScaleArgInfo<ROCDL_V2I32Type, "Fp8">, @@ -996,6 +995,8 @@ foreach smallT = [ ScaleArgInfo<ROCDL_V8BF16Type, "Bf16">, ScaleArgInfo<ROCDL_V8F32Type, "F32">, ] in { + + // Up-scaling def ROCDL_CvtPkScalePk8 # largeT.nameForOp # smallT.nameForOp # Op : ROCDL_ConcreteNonMemIntrOp<"cvt.scale.pk8." # largeT.name # "." # smallT.name, [Pure], 1, [2], ["scaleSel"]>, @@ -1010,13 +1011,30 @@ foreach smallT = [ attr-dict $src `,` $scale `[` $scaleSel `]` `:` type($res) }]; } + + // Down-scaling + def ROCDL_CvtScaleF32Pk8 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk8." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name ; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } } // foreach largeT } // foreach smallTOp //===---------------------------------------------------------------------===// // Scaled {bf6,fp6} to {bf16,f16,f32} conversion intrinsics //===---------------------------------------------------------------------===// - foreach smallT = [ ScaleArgInfo<ROCDL_V3I32Type, "Fp6">, ScaleArgInfo<ROCDL_V3I32Type, "Bf6"> diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8f3232f..0d6ebc0 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" @@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect, Transform_AnyOpType:$new_ops); let assemblyFormat = "$target attr-dict `:` type($target)"; let hasVerifier = 1; +} - let builders = [ - OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>, - OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)> - ]; +//===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor", + [DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + SameOperandsAndResultType]> { + let summary = "Request a tensor value to live in a specific memory space " + "after bufferization"; + let description = [{ + Requests that a tensor value lives in a specific memory space for its + lifetime. This is achieved by allocating a new tensor in the desired + memory space with `bufferization.alloc_tensor` and optionally materializing + the source value into that allocation with + `bufferization.materialize_in_destination`. All uses of the original value + are then redirected to the promoted value. + + The generated code for promoting tensor value %0 resembles the following: + + %1 = bufferization.alloc_tensor(<dynamic dims of %0>) + { memory_space = memory_space } + // Note: the materialization is omitted if %0 is never read and is only + // written into (i.e., it behaves as a result tensor). + %2 = bufferization.materialize_in_destination %0 in %1 + // ... + <all users of %0 now use %2 instead> + + Deallocation is not handled by this transform. + + Return modes: + - Produces a silenceable failure if the given handle does not point to + tensor-typed values. + - Succeeds otherwise and returns a handle to the promoted value(s), i.e., + the result of materialization if present and the allocation otherwise. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$tensor, + OptionalAttr<AnyAttr>:$memory_space); + let results = (outs TransformValueHandleTypeInterface:$promoted); + + let assemblyFormat = + "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index 4d415ae..48346abd 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> { ]; } +def MathSincosFusionPass : Pass<"math-sincos-fusion"> { + let summary = "Fuse sin and cos operations."; + let description = [{ + Fuse sin and cos operations into a sincos operation. + }]; + let dependentDialects = ["math::MathDialect"]; +} + #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 2bf953e..d4d67bf 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -155,7 +155,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ The `assume_alignment` operation takes a memref and an integer alignment value. It returns a new SSA value of the same memref type, but associated with the assumption that the underlying buffer is aligned to the given - alignment. + alignment. If the buffer isn't aligned to the given alignment, its result is poison. This operation doesn't affect the semantics of a program where the @@ -170,7 +170,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; let extraClassDeclaration = [{ MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); } - + Value getViewSource() { return getMemref(); } }]; @@ -179,6 +179,41 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ } //===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +def DistinctObjectsOp : MemRef_Op<"distinct_objects", [ + Pure, + DeclareOpInterfaceMethods<InferTypeOpInterface> + // ViewLikeOpInterface TODO: ViewLikeOpInterface only supports a single argument + ]> { + let summary = "assumption that acesses to specific memrefs will never alias"; + let description = [{ + The `distinct_objects` operation takes a list of memrefs and returns the same + memrefs, with the additional assumption that accesses to them will never + alias with each other. This means that loads and stores to different + memrefs in the list can be safely reordered. + + If the memrefs do alias, the load/store behavior is undefined. This + operation doesn't affect the semantics of a valid program. It is + intended for optimization purposes, allowing the compiler to generate more + efficient code based on the non-aliasing assumption. The optimization is + best-effort. + + Example: + + ```mlir + %1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32> + ``` + }]; + let arguments = (ins Variadic<AnyMemRef>:$operands); + let results = (outs Variadic<AnyMemRef>:$results); + + let assemblyFormat = "$operands attr-dict `:` type($operands)"; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 1eda5e4..8e43c42 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -996,6 +996,35 @@ class OpenMP_NumTeamsClauseSkip< def OpenMP_NumTeamsClause : OpenMP_NumTeamsClauseSkip<>; //===----------------------------------------------------------------------===// +// V5.1: [10.1.2] `sizes` clause +//===----------------------------------------------------------------------===// + +class OpenMP_SizesClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause<traits, arguments, assemblyFormat, description, + extraClassDeclaration> { + let arguments = (ins + Variadic<IntLikeType>:$sizes + ); + + let optAssemblyFormat = [{ + `sizes` `(` $sizes `:` type($sizes) `)` + }]; + + let description = [{ + The `sizes` clauses defines the size of a grid over a multi-dimensional + logical iteration space. This grid is used for loop transformations such as + `tile` and `strip`. The size per dimension can be a variable, but only + values that are not at least 2 make sense. It is not specified what happens + when smaller values are used, but should still result in a loop nest that + executes each logical iteration once. + }]; +} + +def OpenMP_SizesClause : OpenMP_SizesClauseSkip<>; + +//===----------------------------------------------------------------------===// // V5.2: [10.1.2] `num_threads` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td index bbcfb87f..5ad4e4b 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpBase.td @@ -38,6 +38,44 @@ def OpenMP_MapBoundsType : OpenMP_Type<"MapBounds", "map_bounds_ty"> { let summary = "Type for representing omp map clause bounds information"; } +//===---------------------------------------------------------------------===// +// OpenMP Canonical Loop Info Type +//===---------------------------------------------------------------------===// + +def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> { + let summary = "Type for representing a reference to a canonical loop"; + let description = [{ + A variable of type CanonicalLoopInfo refers to an OpenMP-compatible + canonical loop in the same function. Values of this type are not + available at runtime and therefore cannot be used by the program itself, + i.e. an opaque type. It is similar to the transform dialect's + `!transform.interface` type, but instead of implementing an interface + for each transformation, the OpenMP dialect itself defines possible + operations on this type. + + A value of type CanonicalLoopInfoType (in the following: CLI) value can be + + 1. created by omp.new_cli. + 2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI + can only be associated once. + 3. passed to an omp loop transformation operation that modifies the loop + associated with the CLI. The CLI is the "applyee" and the operation is + the consumer. A CLI can only be consumed once. + 4. passed to an omp loop transformation operation to associate the cli with + a result of that transformation. The CLI is the "generatee" and the + operation is the generator. + + A CLI cannot + + 1. be returned from a function. + 2. be passed to operations that are not specifically designed to take a + CanonicalLoopInfoType, including AnyType. + + A CLI directly corresponds to an object of + OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR. + }]; +} + //===----------------------------------------------------------------------===// // Base classes for OpenMP dialect operations. //===----------------------------------------------------------------------===// @@ -211,8 +249,35 @@ class OpenMP_Op<string mnemonic, list<Trait> traits = [], // Doesn't actually create a C++ base class (only defines default values for // tablegen classes that derive from this). Use LoopTransformationInterface // instead for common operations. -class OpenMPTransform_Op<string mnemonic, list<Trait> traits = []> : - OpenMP_Op<mnemonic, !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits) > { +class OpenMPTransform_Op<string mnemonic, + list<Trait> traits = [], + list<OpenMP_Clause> clauses = []> : + OpenMP_Op<mnemonic, + traits = !listconcat([DeclareOpInterfaceMethods<LoopTransformationInterface>], traits), + clauses = clauses> { +} + +// Base clause for loop transformations using the standard syntax. +// +// omp.opname ($generatees) <- ($applyees) clause(...) clause(...) ... <attr-dicr> +// omp.opname ($applyees) clause(...) clause(...) ... <attr-dict> +// +// $generatees is optional and is assumed to be empty if omitted +class OpenMPTransformBase_Op<string mnemonic, + list<Trait> traits = [], + list<OpenMP_Clause> clauses = []> : + OpenMPTransform_Op<mnemonic, + traits = !listconcat(traits, [AttrSizedOperandSegments]), + clauses = clauses> { + + let arguments = !con( + (ins Variadic<CanonicalLoopInfoType>:$generatees, + Variadic<CanonicalLoopInfoType>:$applyees + ), clausesArgs); + + let assemblyFormat = [{ custom<LoopTransformClis>($generatees, $applyees) }] + # clausesAssemblyFormat + # [{ attr-dict }]; } #endif // OPENMP_OP_BASE diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 5c77e21..b73091e 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -358,44 +358,6 @@ def SingleOp : OpenMP_Op<"single", traits = [ } //===---------------------------------------------------------------------===// -// OpenMP Canonical Loop Info Type -//===---------------------------------------------------------------------===// - -def CanonicalLoopInfoType : OpenMP_Type<"CanonicalLoopInfo", "cli"> { - let summary = "Type for representing a reference to a canonical loop"; - let description = [{ - A variable of type CanonicalLoopInfo refers to an OpenMP-compatible - canonical loop in the same function. Values of this type are not - available at runtime and therefore cannot be used by the program itself, - i.e. an opaque type. It is similar to the transform dialect's - `!transform.interface` type, but instead of implementing an interface - for each transformation, the OpenMP dialect itself defines possible - operations on this type. - - A value of type CanonicalLoopInfoType (in the following: CLI) value can be - - 1. created by omp.new_cli. - 2. passed to omp.canonical_loop to associate the loop to that CLI. A CLI - can only be associated once. - 3. passed to an omp loop transformation operation that modifies the loop - associated with the CLI. The CLI is the "applyee" and the operation is - the consumer. A CLI can only be consumed once. - 4. passed to an omp loop transformation operation to associate the cli with - a result of that transformation. The CLI is the "generatee" and the - operation is the generator. - - A CLI cannot - - 1. be returned from a function. - 2. be passed to operations that are not specifically designed to take a - CanonicalLoopInfoType, including AnyType. - - A CLI directly corresponds to an object of - OpenMPIRBuilder's CanonicalLoopInfo struct when lowering to LLVM-IR. - }]; -} - -//===---------------------------------------------------------------------===// // OpenMP Canonical Loop Info Creation //===---------------------------------------------------------------------===// @@ -564,6 +526,31 @@ def UnrollHeuristicOp : OpenMPTransform_Op<"unroll_heuristic", []> { } //===----------------------------------------------------------------------===// +// OpenMP tile operation +//===----------------------------------------------------------------------===// + +def TileOp : OpenMPTransformBase_Op<"tile", + clauses = [OpenMP_SizesClause]> { + let summary = "OpenMP tile operation"; + let description = [{ + Represents the OpenMP tile directive introduced in OpenMP 5.1. + + The construct partitions the logical iteration space of the affected loops + into equally-sized tiles, then creates two sets of nested loops. The outer + loops, called the grid loops, iterate over all tiles. The inner loops, + called the intratile loops, iterate over the logical iterations of a tile. + The sizes clause determines the size of a tile. + + Currently, the affected loops must be rectangular (the tripcount of the + inner loop must not depend on any iv of an surrounding affected loop) and + perfectly nested (except for the innermost affected loop, no operations + other than the nested loop and the terminator in the loop body). + }] # clausesDescription; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // 2.8.3 Workshare Construct //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h index 74e1d28..ba11259 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H #define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td index d68d451..d095659 100644 --- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td @@ -11,10 +11,15 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonAttrConstraints.td" +//===----------------------------------------------------------------------===// +// KnobOp +//===----------------------------------------------------------------------===// + def KnobOp : Op<Transform_Dialect, "tune.knob", [ DeclareOpInterfaceMethods<TransformOpInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, @@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [ "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)"; } +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [ + DeclareOpInterfaceMethods<RegionBranchOpInterface, + ["getEntrySuccessorOperands", "getSuccessorRegions", + "getRegionInvocationBounds"]>, + DeclareOpInterfaceMethods<TransformOpInterface>, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">, + NoRegionArguments +]> { + let summary = "Represents a choice among its regions, i.e. sub-schedules"; + + let description = [{ + This op represents a choice over which of its regions is to be used. + + When `selected_region` is provided, the semantics are that this op is to be + substituted for by the selected region, meaning the region's results become + the results of this op. Without a provided `selected_region`, the semantics + are that this non-deterministic choice is yet to be resolved -- which in + terms of the op's interpreted semantics is a failure. + + The `selected_region` argument is either an `IntegerAttr` or a param holding + an `IntegerAttr`, which should provide a valid zero-based index with respect + to the number of alternatives, i.e. regions. + }]; + let cppNamespace = [{ mlir::transform::tune }]; + + let arguments = (ins Builtin_StringAttr:$name, + OptionalAttr<APIntAttr>:$selected_region_attr, + Optional<TransformParamTypeInterface>:$selected_region_param); + let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results); + let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives); + + let assemblyFormat = [{ + `<` $name `>` + (`selected_region` `=` custom<AlternativesOpSelectedRegion>( + $selected_region_attr, $selected_region_param)^)? + attr-dict-with-keyword + (`:` type($selected_region_param)^)? + (`->` type($results)^)? + regions + }]; + + let hasVerifier = 1; +} + #endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 83a8757..32b2b0c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3219,13 +3219,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("end_line"), nb::arg("end_col"), nb::arg("context") = nb::none(), kContextGetFileRangeDocstring) .def("is_a_file", mlirLocationIsAFileLineColRange) - .def_prop_ro( - "filename", - [](MlirLocation loc) { - return mlirIdentifierStr( - mlirLocationFileLineColRangeGetFilename(loc)); - }, - nb::sig("def filename(self) -> str")) + .def_prop_ro("filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }) .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) @@ -3274,12 +3272,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("name"), nb::arg("childLoc") = nb::none(), nb::arg("context") = nb::none(), kContextGetNameLocationDocString) .def("is_a_name", mlirLocationIsAName) - .def_prop_ro( - "name_str", - [](MlirLocation loc) { - return mlirIdentifierStr(mlirLocationNameGetName(loc)); - }, - nb::sig("def name_str(self) -> str")) + .def_prop_ro("name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }) .def_prop_ro("child_loc", [](PyLocation &self) { return PyLocation(self.getContext(), @@ -3453,15 +3449,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_prop_ro( - "name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = concreteOperation.get(); - return mlirIdentifierStr(mlirOperationGetName(operation)); - }, - nb::sig("def name(self) -> str")) + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + return mlirIdentifierStr(mlirOperationGetName(operation)); + }) .def_prop_ro("operands", [](PyOperationBase &self) { return PyOpOperandList(self.getOperation().getRef()); @@ -3485,15 +3479,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_prop_ro( + .def_prop_rw( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); return PyLocation(operation.getContext(), mlirOperationGetLocation(operation.get())); }, - "Returns the source location the operation was defined or derived " - "from.") + [](PyOperationBase &self, const PyLocation &location) { + PyOperation &operation = self.getOperation(); + mlirOperationSetLocation(operation.get(), location.get()); + }, + nb::for_getter("Returns the source location the operation was " + "defined or derived from."), + nb::for_setter("Sets the source location the operation was defined " + "or derived from.")) .def_prop_ro("parent", [](PyOperationBase &self) -> std::optional<nb::typed<nb::object, PyOperation>> { @@ -3597,12 +3597,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def( - "walk", &PyOperationBase::walk, nb::arg("callback"), - nb::arg("walk_order") = MlirWalkPostOrder, - // clang-format off - nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = " MAKE_MLIR_PYTHON_QUALNAME("ir.WalkOrder.POST_ORDER") ") -> None") - // clang-format on + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder, + // clang-format off + nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None") + // clang-format on ); nb::class_<PyOperation, PyOperationBase>(m, "Operation") @@ -4118,7 +4117,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyNamedAttribute &self) { return mlirIdentifierStr(self.namedAttr.name); }, - nb::sig("def name(self) -> str"), "The name of the NamedAttribute binding") .def_prop_ro( "attr", @@ -4336,17 +4334,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { kValueReplaceAllUsesWithDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, PyOperation &exception) { + [](PyValue &self, PyValue &with, PyOperation &exception) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, nb::arg("with_"), nb::arg("exceptions"), - nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " - "Operation) -> None"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, nb::list exceptions) { + [](PyValue &self, PyValue &with, const nb::list &exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector<MlirOperation> exceptionOps; for (nb::handle exception : exceptions) { @@ -4358,8 +4354,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { exceptionOps.data()); }, nb::arg("with_"), nb::arg("exceptions"), - nb::sig("def replace_all_uses_except(self, with_: Value, exceptions: " - "Sequence[Operation]) -> None"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 598ae01..edbd73e 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -273,8 +273,7 @@ class DefaultingPyMlirContext : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - MAKE_MLIR_PYTHON_QUALNAME("ir.Context"); + static constexpr const char kTypeDescription[] = "Context"; static PyMlirContext &resolve(); }; @@ -500,8 +499,7 @@ class DefaultingPyLocation : public Defaulting<DefaultingPyLocation, PyLocation> { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - MAKE_MLIR_PYTHON_QUALNAME("ir.Location"); + static constexpr const char kTypeDescription[] = "Location"; static PyLocation &resolve(); operator MlirLocation() const { return *get(); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 3488d92..34c5b8d 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -1010,7 +1010,7 @@ public: }, nb::arg("elements"), nb::arg("context") = nb::none(), // clang-format off - nb::sig("def get_tuple(elements: Sequence[Type], context: mlir.ir.Context | None = None) -> TupleType"), + nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), // clang-format on "Create a tuple type"); c.def( @@ -1070,7 +1070,7 @@ public: }, nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), // clang-format off - nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"), + nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), // clang-format on "Gets a FunctionType from a list of input and result types"); c.def_prop_ro( diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 52656138..a14f09f 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -115,9 +115,6 @@ NB_MODULE(_mlir, m) { }); }, "typeid"_a, nb::kw_only(), "replace"_a = false, - // clang-format off - nb::sig("def register_type_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"), - // clang-format on "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, @@ -130,9 +127,6 @@ NB_MODULE(_mlir, m) { }); }, "typeid"_a, nb::kw_only(), "replace"_a = false, - // clang-format off - nb::sig("def register_value_caster(typeid: " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID") ", *, replace: bool = False) -> object"), - // clang-format on "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index f18298e..836f44fd 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -127,7 +127,7 @@ public: mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e9844a7..1881865 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -656,6 +656,10 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) { return wrap(unwrap(op)->getLoc()); } +void mlirOperationSetLocation(MlirOperation op, MlirLocation loc) { + unwrap(op)->setLoc(unwrap(loc)); +} + MlirTypeID mlirOperationGetTypeID(MlirOperation op) { if (auto info = unwrap(op)->getRegisteredInfo()) return wrap(info->getTypeID()); diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 8ee6308..0d56259 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { +static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); } -inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { +static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { return {module}; } -inline mlir::FrozenRewritePatternSet * +static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); } -inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { +static inline MlirFrozenRewritePatternSet +wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } @@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { +static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::PDLPatternModule *>(module.ptr); } -inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { +static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { return {module}; } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index cc6314c..a6f816a 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering } }; +struct DistinctObjectsOpLowering + : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> { + using ConvertOpToLLVMPattern< + memref::DistinctObjectsOp>::ConvertOpToLLVMPattern; + explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {} + + LogicalResult + matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + if (operands.size() <= 1) { + // Fast path. + rewriter.replaceOp(op, operands); + return success(); + } + + Location loc = op.getLoc(); + SmallVector<Value> ptrs; + for (auto [origOperand, newOperand] : + llvm::zip_equal(op.getOperands(), operands)) { + auto memrefType = cast<MemRefType>(origOperand.getType()); + MemRefDescriptor memRefDescriptor(newOperand); + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType); + ptrs.push_back(ptr); + } + + auto cond = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1); + // Generate separate_storage assumptions for each pair of pointers. + for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) { + for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) { + Value ptr1 = ptrs[i]; + Value ptr2 = ptrs[j]; + LLVM::AssumeOp::create(rewriter, loc, cond, + LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2); + } + } + + rewriter.replaceOp(op, operands); + return success(); + } +}; + // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. @@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( patterns.add< AllocaOpLowering, AllocaScopeOpLowering, - AtomicRMWOpLowering, AssumeAlignmentOpLowering, + AtomicRMWOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + DistinctObjectsOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, + MemorySpaceCastOpLowering, PrefetchOpLowering, RankOpLowering, - ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, + ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index 035f197..399ccf3 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -267,9 +267,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { copyInfo.push_back(info); } // Create a call to the kernel and copy the data back. - Operation *callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( - op, kernelFunc, ArrayRef<Value>()); - rewriter.setInsertionPointAfter(callOp); + rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, + ArrayRef<Value>()); for (CopyInfo info : copyInfo) copy(loc, info.src, info.dst, info.size, rewriter); return success(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 6f28849..0cb0bad 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -802,7 +802,6 @@ public: ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); - rewriter.setInsertionPointAfter(op); NanPropagationMode nanMode = op.getNanMode(); rewriter.replaceOp(op, resultOp); diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index f3e065a..9821a75 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -246,6 +246,6 @@ void SimplifyAffineMinMaxPass::runOnOperation() { patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>( func.getContext()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsGreedily(func, std::move(frozenPatterns)))) + if (failed(applyPatternsGreedily(func, frozenPatterns))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3..898d76c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan | + arith::FastMathFlags::nsz)) { + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); + } + return constFoldBinaryOp<FloatAttr>( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 7626d35..c64e10f5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::SplatOp, + vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3f0b0ba..dd9b4c2 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -42,6 +42,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" @@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns( // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - Attribute memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/memorySpace); -} - -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - int64_t memorySpace) { - SmallVector<Type> resultTypes; - resultTypes.push_back(b.getType<transform::AnyValueType>()); - resultTypes.push_back(b.getType<transform::AnyOpType>()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/b.getI64IntegerAttr(memorySpace)); -} - namespace { class NewOpsListener : public RewriterBase::ForwardingListener { public: @@ -409,6 +384,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() { } //===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +/// Return true if the operand may be read from by its owner. This is currently +/// very conservative and only looks inside linalg operations to prevent +/// unintentional data loss. +static bool mayBeRead(OpOperand &operand) { + auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner()); + + // Be conservative about ops we cannot analyze deeper. + if (!linalgOp) + return true; + + // Look inside linalg ops. + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); + return !blockArgument.use_empty(); +} + +/// Return true if the value may be read through any of its uses. +static bool mayBeRead(Value value) { + // If the value has a reference semantics, it + // may be read through any alias... + if (!isa<TensorType, FloatType, IntegerType>(value.getType())) + return true; + return llvm::any_of(value.getUses(), + static_cast<bool (&)(OpOperand &)>(mayBeRead)); +} + +DiagnosedSilenceableFailure +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector<Value> promoted; + for (Value tensor : state.getPayloadValues(getTensor())) { + auto type = dyn_cast<RankedTensorType>(tensor.getType()); + if (!type) { + return emitSilenceableError() << "non-tensor type: " << tensor; + } + + Operation *definingOp = tensor.getDefiningOp(); + if (definingOp) + rewriter.setInsertionPointAfter(definingOp); + else + rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner()); + + // Check this before we emit operations using this value. + bool needsMaterialization = mayBeRead(tensor); + + SmallVector<Value> dynamicDims; + llvm::SmallPtrSet<Operation *, 4> preservedOps; + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { + if (!ShapedType::isDynamic(dim)) + continue; + Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); + auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + preservedOps.insert(dimOp); + dynamicDims.push_back(dimOp); + } + auto allocation = rewriter.create<bufferization::AllocTensorOp>( + tensor.getLoc(), type, dynamicDims); + // Set memory space if provided. + if (getMemorySpaceAttr()) + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); + Value allocated = allocation; + + // Only insert a materialization (typically bufferizes to a copy) when the + // value may be read from. + if (needsMaterialization) { + auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( + tensor.getLoc(), tensor, allocated); + preservedOps.insert(copy); + promoted.push_back(copy.getResult()); + } else { + promoted.push_back(allocated); + } + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); + } + results.setValues(cast<OpResult>(getPromoted()), promoted); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PromoteTensorOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + transform::onlyReadsHandle(getTensorMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3bd763e..05fc7cb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion( } } -void collapseOperandsAndResults(LinalgOp op, - const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter, - SmallVectorImpl<Value> &inputOperands, - SmallVectorImpl<Value> &outputOperands, - SmallVectorImpl<Type> &resultTypes) { +static void collapseOperandsAndResults(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter, + SmallVectorImpl<Value> &inputOperands, + SmallVectorImpl<Value> &outputOperands, + SmallVectorImpl<Type> &resultTypes) { Location loc = op->getLoc(); inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { @@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op, /// Clone a `LinalgOp` to a collapsed version of same name template <typename OpTy> -OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, - const CollapsingInfo &collapsingInfo) { +static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, + const CollapsingInfo &collapsingInfo) { return nullptr; } @@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, return collapsedOp; } -LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { +static LinalgOp createCollapsedOp(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) { return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); } else { diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index ff62b51..8899c3a 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp + SincosFusion.cpp UpliftToFMA.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp new file mode 100644 index 0000000..69407df --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp @@ -0,0 +1,80 @@ +//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::math; + +namespace { + +/// Fuse a math.sin and math.cos in the same block that use the same operand and +/// have identical fastmath flags into a single math.sincos. +struct SincosFusionPattern : OpRewritePattern<math::SinOp> { + using Base::Base; + + LogicalResult matchAndRewrite(math::SinOp sinOp, + PatternRewriter &rewriter) const override { + Value operand = sinOp.getOperand(); + mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath(); + + math::CosOp cosOp = nullptr; + sinOp->getBlock()->walk([&](math::CosOp op) { + if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) { + cosOp = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!cosOp) + return failure(); + + Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation() + : cosOp.getOperation(); + rewriter.setInsertionPoint(firstOp); + + Type elemType = sinOp.getType(); + auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(), + TypeRange{elemType, elemType}, operand, + sinOp.getFastmathAttr()); + + rewriter.replaceOp(sinOp, sincos.getSin()); + rewriter.replaceOp(cosOp, sincos.getCos()); + return success(); + } +}; + +} // namespace + +namespace mlir::math { +#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +namespace { + +struct MathSincosFusionPass final + : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> { + using MathSincosFusionPassBase::MathSincosFusionPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add<SincosFusionPattern>(&getContext()); + + GreedyRewriteConfig config; + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 349b4de..e9bdcda 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -607,6 +607,29 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) { } //===----------------------------------------------------------------------===// +// DistinctObjectsOp +//===----------------------------------------------------------------------===// + +LogicalResult DistinctObjectsOp::verify() { + if (getOperandTypes() != getResultTypes()) + return emitOpError("operand types and result types must match"); + + if (getOperandTypes().empty()) + return emitOpError("expected at least one operand"); + + return success(); +} + +LogicalResult DistinctObjectsOp::inferReturnTypes( + MLIRContext * /*context*/, std::optional<Location> /*location*/, + ValueRange operands, DictionaryAttr /*attributes*/, + OpaqueProperties /*properties*/, RegionRange /*regions*/, + SmallVectorImpl<Type> &inferredReturnTypes) { + llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes)); + return success(); +} + +//===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index f01ad05..5672942 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> #include <optional> @@ -77,6 +78,232 @@ struct LLVMPointerPointerLikeModel }; } // namespace +/// Generate a name of a canonical loop nest of the format +/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region +/// argument index of an operation that has multiple regions, if the operation +/// has multiple regions. +/// `_s<idx>` identifies the position of an operation within a region, where +/// only operations that may potentially contain loops ("container operations" +/// i.e. have region arguments) are counted. Again, it is omitted if there is +/// only one such operation in a region. If there are canonical loops nested +/// inside each other, also may also use the format `_d<num>` where <num> is the +/// nesting depth of the loop. +/// +/// The generated name is a best-effort to make canonical loop unique within an +/// SSA namespace. This also means that regions with IsolatedFromAbove property +/// do not consider any parents or siblings. +static std::string generateLoopNestingName(StringRef prefix, + CanonicalLoopOp op) { + struct Component { + /// If true, this component describes a region operand of an operation (the + /// operand's owner) If false, this component describes an operation located + /// in a parent region + bool isRegionArgOfOp; + bool skip = false; + bool isUnique = false; + + size_t idx; + Operation *op; + Region *parentRegion; + size_t loopDepth; + + Operation *&getOwnerOp() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return op; + } + size_t &getArgIdx() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return idx; + } + + Operation *&getContainerOp() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return op; + } + size_t &getOpPos() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return idx; + } + bool isLoopOp() const { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return isa<CanonicalLoopOp>(op); + } + Region *&getParentRegion() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return parentRegion; + } + size_t &getLoopDepth() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return loopDepth; + } + + void skipIf(bool v = true) { skip = skip || v; } + }; + + // List of ancestors, from inner to outer. + // Alternates between + // * region argument of an operation + // * operation within a region + SmallVector<Component> components; + + // Gather a list of parent regions and operations, and the position within + // their parent + Operation *o = op.getOperation(); + while (o) { + // Operation within a region + Region *r = o->getParentRegion(); + if (!r) + break; + + llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front()); + size_t idx = 0; + bool found = false; + size_t sequentialIdx = -1; + bool isOnlyContainerOp = true; + for (Block *b : traversal) { + for (Operation &op : *b) { + if (&op == o && !found) { + sequentialIdx = idx; + found = true; + } + if (op.getNumRegions()) { + idx += 1; + if (idx > 1) + isOnlyContainerOp = false; + } + if (found && !isOnlyContainerOp) + break; + } + } + + Component &containerOpInRegion = components.emplace_back(); + containerOpInRegion.isRegionArgOfOp = false; + containerOpInRegion.isUnique = isOnlyContainerOp; + containerOpInRegion.getContainerOp() = o; + containerOpInRegion.getOpPos() = sequentialIdx; + containerOpInRegion.getParentRegion() = r; + + Operation *parent = r->getParentOp(); + + // Region argument of an operation + Component ®ionArgOfOperation = components.emplace_back(); + regionArgOfOperation.isRegionArgOfOp = true; + regionArgOfOperation.isUnique = true; + regionArgOfOperation.getArgIdx() = 0; + regionArgOfOperation.getOwnerOp() = parent; + + // The IsolatedFromAbove trait of the parent operation implies that each + // individual region argument has its own separate namespace, so no + // ambiguity. + if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) + break; + + // Component only needed if operation has multiple region operands. Region + // arguments may be optional, but we currently do not consider this. + if (parent->getRegions().size() > 1) { + auto getRegionIndex = [](Operation *o, Region *r) { + for (auto [idx, region] : llvm::enumerate(o->getRegions())) { + if (®ion == r) + return idx; + } + llvm_unreachable("Region not child of its parent operation"); + }; + regionArgOfOperation.isUnique = false; + regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r); + } + + // next parent + o = parent; + } + + // Determine whether a region-argument component is not needed + for (Component &c : components) + c.skipIf(c.isRegionArgOfOp && c.isUnique); + + // Find runs of nested loops and determine each loop's depth in the loop nest + size_t numSurroundingLoops = 0; + for (Component &c : llvm::reverse(components)) { + if (c.skip) + continue; + + // non-skipped multi-argument operands interrupt the loop nest + if (c.isRegionArgOfOp) { + numSurroundingLoops = 0; + continue; + } + + // Multiple loops in a region means each of them is the outermost loop of a + // new loop nest + if (!c.isUnique) + numSurroundingLoops = 0; + + c.getLoopDepth() = numSurroundingLoops; + + // Next loop is surrounded by one more loop + if (isa<CanonicalLoopOp>(c.getContainerOp())) + numSurroundingLoops += 1; + } + + // In loop nests, skip all but the innermost loop that contains the depth + // number + bool isLoopNest = false; + for (Component &c : components) { + if (c.skip || c.isRegionArgOfOp) + continue; + + if (!isLoopNest && c.getLoopDepth() >= 1) { + // Innermost loop of a loop nest of at least two loops + isLoopNest = true; + } else if (isLoopNest) { + // Non-innermost loop of a loop nest + c.skipIf(c.isUnique); + + // If there is no surrounding loop left, this must have been the outermost + // loop; leave loop-nest mode for the next iteration + if (c.getLoopDepth() == 0) + isLoopNest = false; + } + } + + // Skip non-loop unambiguous regions (but they should interrupt loop nests, so + // we mark them as skipped only after computing loop nests) + for (Component &c : components) + c.skipIf(!c.isRegionArgOfOp && c.isUnique && + !isa<CanonicalLoopOp>(c.getContainerOp())); + + // Components can be skipped if they are already disambiguated by their parent + // (or does not have a parent) + bool newRegion = true; + for (Component &c : llvm::reverse(components)) { + c.skipIf(newRegion && c.isUnique); + + // non-skipped components disambiguate unique children + if (!c.skip) + newRegion = true; + + // ...except canonical loops that need a suffix for each nest + if (!c.isRegionArgOfOp && c.getContainerOp()) + newRegion = false; + } + + // Compile the nesting name string + SmallString<64> Name{prefix}; + llvm::raw_svector_ostream NameOS(Name); + for (auto &c : llvm::reverse(components)) { + if (c.skip) + continue; + + if (c.isRegionArgOfOp) + NameOS << "_r" << c.getArgIdx(); + else if (c.getLoopDepth() >= 1) + NameOS << "_d" << c.getLoopDepth(); + else + NameOS << "_s" << c.getOpPos(); + } + + return NameOS.str().str(); +} + void OpenMPDialect::initialize() { addOperations< #define GET_OP_LIST @@ -182,7 +409,7 @@ static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { } template <typename ClauseAttr> -void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { +static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { p << stringifyEnum(attr.getValue()); } @@ -1511,8 +1738,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { //===----------------------------------------------------------------------===// // Helper function to get bitwise AND of `value` and 'flag' -uint64_t mapTypeToBitFlag(uint64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) { +static uint64_t mapTypeToBitFlag(uint64_t value, + llvm::omp::OpenMPOffloadMappingFlags flag) { return value & llvm::to_underlying(flag); } @@ -3159,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { Value result = getResult(); auto [newCli, gen, cons] = decodeCli(result); + // Structured binding `gen` cannot be captured in lambdas before C++20 + OpOperand *generator = gen; + // Derive the CLI variable name from its generator: // * "canonloop" for omp.canonical_loop // * custom name for loop transformation generatees @@ -3172,71 +3402,29 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { cliName = TypeSwitch<Operation *, std::string>(gen->getOwner()) .Case([&](CanonicalLoopOp op) { - // Find the canonical loop nesting: For each ancestor add a - // "+_r<idx>" suffix (in reverse order) - SmallVector<std::string> components; - Operation *o = op.getOperation(); - while (o) { - if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) - break; - - Region *r = o->getParentRegion(); - if (!r) - break; - - auto getSequentialIndex = [](Region *r, Operation *o) { - llvm::ReversePostOrderTraversal<Block *> traversal( - &r->getBlocks().front()); - size_t idx = 0; - for (Block *b : traversal) { - for (Operation &op : *b) { - if (&op == o) - return idx; - // Only consider operations that are containers as - // possible children - if (!op.getRegions().empty()) - idx += 1; - } - } - llvm_unreachable("Operation not part of the region"); - }; - size_t sequentialIdx = getSequentialIndex(r, o); - components.push_back(("s" + Twine(sequentialIdx)).str()); - - Operation *parent = r->getParentOp(); - if (!parent) - break; - - // If the operation has more than one region, also count in - // which of the regions - if (parent->getRegions().size() > 1) { - auto getRegionIndex = [](Operation *o, Region *r) { - for (auto [idx, region] : - llvm::enumerate(o->getRegions())) { - if (®ion == r) - return idx; - } - llvm_unreachable("Region not child its parent operation"); - }; - size_t regionIdx = getRegionIndex(parent, r); - components.push_back(("r" + Twine(regionIdx)).str()); - } - - // next parent - o = parent; - } - - SmallString<64> Name("canonloop"); - for (const std::string &s : reverse(components)) { - Name += '_'; - Name += s; - } - - return Name; + return generateLoopNestingName("canonloop", op); }) .Case([&](UnrollHeuristicOp op) -> std::string { llvm_unreachable("heuristic unrolling does not generate a loop"); }) + .Case([&](TileOp op) -> std::string { + auto [generateesFirst, generateesCount] = + op.getGenerateesODSOperandIndexAndLength(); + unsigned firstGrid = generateesFirst; + unsigned firstIntratile = generateesFirst + generateesCount / 2; + unsigned end = generateesFirst + generateesCount; + unsigned opnum = generator->getOperandNumber(); + // In the OpenMP apply and looprange clauses, indices are 1-based + if (firstGrid <= opnum && opnum < firstIntratile) { + unsigned gridnum = opnum - firstGrid + 1; + return ("grid" + Twine(gridnum)).str(); + } + if (firstIntratile <= opnum && opnum < end) { + unsigned intratilenum = opnum - firstIntratile + 1; + return ("intratile" + Twine(intratilenum)).str(); + } + llvm_unreachable("Unexpected generatee argument"); + }) .Default([&](Operation *op) { assert(false && "TODO: Custom name for this operation"); return "transformed"; @@ -3323,7 +3511,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) { void CanonicalLoopOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { - setNameFn(region.getArgument(0), "iv"); + std::string ivName = generateLoopNestingName("iv", *this); + setNameFn(region.getArgument(0), ivName); } void CanonicalLoopOp::print(OpAsmPrinter &p) { @@ -3465,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() { } //===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, + OperandRange generatees, + OperandRange applyees) { + if (!generatees.empty()) + p << '(' << llvm::interleaved(generatees) << ')'; + + if (!applyees.empty()) + p << " <- (" << llvm::interleaved(applyees) << ')'; +} + +static ParseResult parseLoopTransformClis( + OpAsmParser &parser, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) { + if (parser.parseOptionalLess()) { + // Syntax 1: generatees present + + if (parser.parseOperandList(generateesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + if (parser.parseLess()) + return failure(); + } else { + // Syntax 2: generatees omitted + } + + // Parse `<-` (`<` has already been parsed) + if (parser.parseMinus()) + return failure(); + + if (parser.parseOperandList(applyeesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + return success(); +} + +LogicalResult TileOp::verify() { + if (getApplyees().empty()) + return emitOpError() << "must apply to at least one loop"; + + if (getSizes().size() != getApplyees().size()) + return emitOpError() << "there must be one tile size for each applyee"; + + if (!getGeneratees().empty() && + 2 * getSizes().size() != getGeneratees().size()) + return emitOpError() + << "expecting two times the number of generatees than applyees"; + + DenseSet<Value> parentIVs; + + Value parent = getApplyees().front(); + for (auto &&applyee : llvm::drop_begin(getApplyees())) { + auto [parentCreate, parentGen, parentCons] = decodeCli(parent); + auto [create, gen, cons] = decodeCli(applyee); + + if (!parentGen) + return emitOpError() << "applyee CLI has no generator"; + + auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner()); + if (!parentGen) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + parentIVs.insert(parentLoop.getInductionVar()); + + if (!gen) + return emitOpError() << "applyee CLI has no generator"; + auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner()); + if (!loop) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + // Canonical loop must be perfectly nested, i.e. the body of the parent must + // only contain the omp.canonical_loop of the nested loops, and + // omp.terminator + bool isPerfectlyNested = [&]() { + auto &parentBody = parentLoop.getRegion(); + if (!parentBody.hasOneBlock()) + return false; + auto &parentBlock = parentBody.getBlocks().front(); + + auto nestedLoopIt = parentBlock.begin(); + if (nestedLoopIt == parentBlock.end() || + (&*nestedLoopIt != loop.getOperation())) + return false; + + auto termIt = std::next(nestedLoopIt); + if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt)) + return false; + + if (std::next(termIt) != parentBlock.end()) + return false; + + return true; + }(); + if (!isPerfectlyNested) + return emitOpError() << "tiled loop nest must be perfectly nested"; + + if (parentIVs.contains(loop.getTripCount())) + return emitOpError() << "tiled loop nest must be rectangular"; + + parent = applyee; + } + + // TODO: The tile sizes must be computed before the loop, but checking this + // requires dominance analysis. For instance: + // + // %canonloop = omp.new_cli + // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // // write to %x + // omp.terminator + // } + // %ts = llvm.load %x + // omp.tile <- (%canonloop) sizes(%ts : i32) + + return success(); +} + +std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_applyees); +} + +std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_generatees); +} + +//===----------------------------------------------------------------------===// // Critical construct (2.17.1) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 132ed81..3385b2a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -616,11 +616,10 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( if (diag.succeeded()) { // Tracking failure is the only failure. return trackingFailure; - } else { - diag.attachNote() << "tracking listener also failed: " - << trackingFailure.getMessage(); - (void)trackingFailure.silence(); } + diag.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); } if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index 842e880..c627158 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -6,13 +6,24 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" using namespace mlir; +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam); + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam); + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc" @@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +static ParseResult parseAlternativesOpSelectedRegion( + OpAsmParser &parser, IntegerAttr &selectedRegionAttr, + std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) { + size_t selectedRegionIdx; + OptionalParseResult attrParseRes = + parser.parseOptionalInteger(selectedRegionIdx); + if (attrParseRes.has_value()) { + if (failed(*attrParseRes)) + return failure(); + + selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx); + return success(); + } + + OpAsmParser::UnresolvedOperand param; + auto paramParseRes = parser.parseOptionalOperand(param); + if (paramParseRes.has_value()) { + if (failed(*paramParseRes)) + return failure(); + + selectedRegionParam = param; + return success(); + } + + return parser.emitError(parser.getCurrentLocation()) + << "expected either an integer attribute or a transform.param operand"; +} + +static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, + Operation *op, + IntegerAttr selectedRegionAttr, + Value selectedRegionParam) { + if (selectedRegionAttr) + printer << selectedRegionAttr.getValue(); + if (selectedRegionParam) + printer << selectedRegionParam; +} + +OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( + RegionBranchPoint point) { + // No operands will be forwarded to the region(s). + return getOperands().slice(0, 0); +} + +void transform::tune::AlternativesOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + if (point.isParent()) + if (auto selectedRegionIdx = getSelectedRegionAttr()) + regions.emplace_back( + &getAlternatives()[selectedRegionIdx->getSExtValue()], + Block::BlockArgListType()); + else + for (Region &alternative : getAlternatives()) + regions.emplace_back(&alternative, Block::BlockArgListType()); + else + regions.emplace_back(getOperation()->getResults()); +} + +void transform::tune::AlternativesOp::getRegionInvocationBounds( + ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { + (void)operands; + bounds.reserve(getNumRegions()); + + if (auto selectedRegionIdx = getSelectedRegionAttr()) { + bounds.resize(getNumRegions(), InvocationBounds(0, 0)); + bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1); + } else { + bounds.resize(getNumRegions(), InvocationBounds(0, 1)); + } +} + +void transform::tune::AlternativesOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getSelectedRegionParamMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + // TODO: should effects from regions be forwarded? +} + +DiagnosedSilenceableFailure +transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + std::optional<size_t> selectedRegionIdx; + + if (auto selectedRegionAttr = getSelectedRegionAttr()) + selectedRegionIdx = selectedRegionAttr->getSExtValue(); + + if (Value selectedRegionParam = getSelectedRegionParam()) { + ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam); + IntegerAttr selectedRegionAttr; + if (associatedAttrs.size() != 1 || + !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0]))) + return emitDefiniteFailure() + << "param should hold exactly one integer attribute, got: " + << associatedAttrs[0]; + selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue(); + } + + if (!selectedRegionIdx) + return emitDefiniteFailure() << "non-deterministic choice " << getName() + << " is only resolved through providing a " + "`selected_region` attr/param"; + + if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions()) + return emitDefiniteFailure() + << "'selected_region' attribute/param specifies region at index " + << *selectedRegionIdx << " while op has only " << getNumRegions() + << " regions"; + + Region &selectedRegion = getRegion(*selectedRegionIdx); + auto scope = state.make_region_scope(selectedRegion); + Block &block = selectedRegion.front(); + // Apply the region's ops one by one. + for (Operation &transform : block.without_terminator()) { + DiagnosedSilenceableFailure result = + state.applyTransform(cast<transform::TransformOpInterface>(transform)); + if (result.isDefiniteFailure()) + return result; + + if (result.isSilenceableFailure()) { + for (const auto &res : getResults()) + results.set(res, {}); + return result; + } + } + // Forward the operation mapping for values yielded from the region to the + // values produced by the alternatives op. + transform::detail::forwardTerminatorOperands(&block, state, results); + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::tune::AlternativesOp::verify() { + for (auto *region : getRegions()) { + auto yieldTerminator = + llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back()); + if (!yieldTerminator) + return emitOpError() << "expected '" + << transform::YieldOp::getOperationName() + << "' as terminator"; + + if (yieldTerminator->getNumOperands() != getNumResults()) + return yieldTerminator.emitOpError() + << "expected terminator to have as many operands as the parent op " + "has results"; + + for (auto [i, operandType, resultType] : llvm::zip_equal( + llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()), + yieldTerminator->getOperands().getType(), getResultTypes())) { + if (operandType == resultType) + continue; + return yieldTerminator.emitOpError() + << "the type of the terminator operand #" << i + << " must match the type of the corresponding parent op result (" + << operandType << " vs " << resultType << ")"; + } + } + + if (auto selectedRegionAttr = getSelectedRegionAttr()) { + size_t regionIdx = selectedRegionAttr->getSExtValue(); + if (regionIdx < 0 || regionIdx >= getNumRegions()) + return emitOpError() + << "'selected_region' attribute specifies region at index " + << regionIdx << " while op has only " << getNumRegions() + << " regions"; + } + + return success(); +} diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb46869..b0132e8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -580,7 +580,7 @@ namespace { // ElideSingleElementReduction for ReduceOp. struct ElideUnitDimsInMultiDimReduction : public OpRewritePattern<MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { namespace { struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -2197,7 +2197,7 @@ namespace { // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2220,7 +2220,7 @@ public: // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { @@ -2938,7 +2938,7 @@ namespace { // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -3109,7 +3109,7 @@ namespace { // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) { /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ public: /// vector.interleave. class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3326,7 +3326,7 @@ namespace { // broadcast. class InsertToBroadcast final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -3344,7 +3344,7 @@ public: /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3380,7 +3380,7 @@ public: /// %result = vector.from_elements %c1, %c2 : vector<2xi32> class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3748,7 +3748,7 @@ namespace { class FoldInsertStridedSliceSplat final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ public: class FoldInsertStridedSliceOfExtract final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3798,7 +3798,7 @@ public: class InsertStridedSliceConstantFolder final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; // Do not create constants with more than `vectorSizeFoldThreashold` elements, // unless the source vector constant has a single use. @@ -4250,7 +4250,7 @@ namespace { // %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, @@ -4310,7 +4310,7 @@ public: class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -4365,7 +4365,7 @@ public: class StridedSliceBroadcast final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4416,7 +4416,7 @@ public: /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4448,7 +4448,7 @@ public: class ContiguousExtractStridedSliceToExtract final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -5023,7 +5023,7 @@ namespace { /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern<TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -5458,7 +5458,7 @@ namespace { /// any other uses. class FoldWaw final : public OpRewritePattern<TransferWriteOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!llvm::isa<RankedTensorType>(writeOp.getShapedType())) @@ -5514,7 +5514,7 @@ public: struct SwapExtractSliceOfTransferWrite : public OpRewritePattern<tensor::InsertSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() { namespace { class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() { namespace { class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { namespace { class GatherFolder final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -5910,7 +5910,7 @@ public: /// maskedload. Only 1D fixed vectors are supported for now. class FoldContiguousGather final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { if (!isa<MemRefType>(op.getBase().getType())) @@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() { namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -5982,7 +5982,7 @@ public: /// maskedstore. Only 1D fixed vectors are supported for now. class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { if (failed(isZeroBasedContiguousSeq(op.getIndices()))) @@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() { namespace { class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() { namespace { class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { class ShapeCastCreateMaskFolderTrailingOneDim final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeOp, PatternRewriter &rewriter) const override { @@ -6330,7 +6330,7 @@ public: /// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -6614,7 +6614,7 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6646,7 +6646,7 @@ public: /// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6663,7 +6663,7 @@ public: /// Folds transpose(create_mask) into a new transposed create_mask. class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { @@ -6700,7 +6700,7 @@ public: /// Folds transpose(shape_cast) into a new shape_cast. class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6750,7 +6750,7 @@ public: /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} @@ -6971,7 +6971,7 @@ namespace { /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// %0 = arith.select %mask, %a, %passthru : vector<8xf32> /// class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const override { @@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // vector.broadcast. class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { public: - using OpRewritePattern<SplatOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(SplatOp splatOp, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index dedc3b3..61d9357 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -34,7 +34,7 @@ namespace { /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 65702ff..efe8d14 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction( /// class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 1f96a3a..6bc8347 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -50,7 +50,7 @@ namespace { /// /// Supports vector types with a fixed leading dimension. struct UnrollGather : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 9d6a865..479fc0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -163,7 +163,7 @@ private: /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 5617b06..7730c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -48,7 +48,7 @@ namespace { /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -100,7 +100,7 @@ public: /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -184,7 +184,7 @@ namespace { /// and actually match the traits of its the nested `MaskableOpInterface`. template <class SourceOp> struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { - using OpRewritePattern<MaskOp>::OpRewritePattern; + using Base::Base; private: LogicalResult matchAndRewrite(MaskOp maskOp, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 4773732d..e86e2a9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -39,7 +39,7 @@ namespace { class InnerOuterDimReductionConversion : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -136,7 +136,7 @@ private: class ReduceMultiDimReductionRank : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit ReduceMultiDimReductionRank( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -304,7 +304,7 @@ private: /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index af4851e..258f2cb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -99,7 +99,7 @@ namespace { /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 603ea41..c5f22b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { } public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -356,7 +356,7 @@ public: class ScalableShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 78102f7..8f46ad6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -44,7 +44,7 @@ namespace { /// struct MixedSizeInputShuffleOpRewrite final : OpRewritePattern<vector::ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp index ee5568a..08e7c89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; namespace { struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a86..7521e24 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern<vector::FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ce..c3f7de0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -300,7 +300,7 @@ namespace { /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) @@ -395,7 +395,7 @@ private: class Transpose2DWithUnitDimToShapeCast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; Transpose2DWithUnitDimToShapeCast(MLIRContext *context, PatternBenefit benefit = 1) @@ -433,7 +433,7 @@ public: class TransposeOp2DToShuffleLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOp2DToShuffleLowering( vector::VectorTransposeLowering vectorTransposeLowering, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index cab1289..963b2c8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -54,7 +54,7 @@ namespace { // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern<vector::InsertStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { @@ -541,7 +541,7 @@ public: // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern<vector::ConstantMaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index bdbb792..7acc120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -48,7 +48,7 @@ namespace { /// struct VectorMaskedLoadOpConverter final : OpRewritePattern<vector::MaskedLoadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, PatternRewriter &rewriter) const override { @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final /// struct VectorMaskedStoreOpConverter final : OpRewritePattern<vector::MaskedStoreOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 264cbc1..3a6684f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -548,7 +548,7 @@ namespace { // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to // `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern<vector::StoreOp>(context), @@ -827,7 +827,7 @@ private: /// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern<vector::MaskedStoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final /// those cases, loads are converted to byte-aligned, byte-sized loads and the /// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { /// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern<vector::MaskedLoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, // TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern<vector::TransferReadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, @@ -1942,7 +1942,7 @@ namespace { /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { - using OpRewritePattern<arith::TruncIOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; + using Base::Base; RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index f6d6555..9e49873 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -34,7 +34,7 @@ using namespace mlir::vector; class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -84,7 +84,7 @@ public: class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is @@ -183,7 +183,7 @@ public: class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -271,7 +271,7 @@ private: class DecomposeNDExtractStridedSlice : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 82bac8c..71fba71c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices( /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern<vector::ShuffleOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final /// struct LinearizeVectorExtract final : public OpConversionPattern<vector::ExtractOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -501,7 +501,7 @@ struct LinearizeVectorExtract final /// struct LinearizeVectorInsert final : public OpConversionPattern<vector::InsertOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -575,7 +575,7 @@ struct LinearizeVectorInsert final /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern<vector::BitCastOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> struct LinearizeVectorSplat final : public OpConversionPattern<vector::SplatOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -629,7 +629,7 @@ struct LinearizeVectorSplat final /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern<vector::CreateMaskOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -778,7 +778,7 @@ struct LinearizeVectorStore final /// struct LinearizeVectorFromElements final : public OpConversionPattern<vector::FromElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final /// struct LinearizeVectorToElements final : public OpConversionPattern<vector::ToElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c364a8b..1121d95 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1081,7 +1081,7 @@ private: /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 866f789..d6a6d7cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -78,7 +78,7 @@ namespace { /// ``` struct MultiReduceToContract : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -138,7 +138,7 @@ struct MultiReduceToContract /// ``` struct CombineContractABTranspose final : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -202,7 +202,7 @@ struct CombineContractABTranspose final /// ``` struct CombineContractResultTranspose final : public OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp resTOp, PatternRewriter &rewriter) const override { @@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern<vector::ExtractOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: BreakDownVectorBitCast(MLIRContext *context, @@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final class ExtractOpFromElementwise final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) { /// ``` class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1285,7 +1285,7 @@ public: class StoreOpFromSplatOrBroadcast final : public OpRewritePattern<vector::StoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp op, PatternRewriter &rewriter) const override { @@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { - using OpRewritePattern<arith::SelectOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::SelectOp selectOp, PatternRewriter &rewriter) const override { @@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { @@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -1845,7 +1845,7 @@ private: template <typename ExtOp> struct FoldArithExtIntoContractionOp : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp /// %b = vector.reduction <add> %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { @@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { @@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { /// %c = vector.reduction <add> %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a92..784e5d6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset return failure(); xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getValue()); + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, - mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + auto store = xegpu::StoreScatterOp::create( + rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. - if (auto newLayout = layout.dropSgLayoutAndData()) - op->setAttr("layout", newLayout); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); + } + } } rewriter.eraseOp(op); return success(); @@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); - if (!layout) - return true; + auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0)); return isLegal(layout); }); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index c84e760..8f199b6 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -489,13 +489,6 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results, SmallVector<OpFoldResult, 4> foldResults; LDBG() << "Trying to fold: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); - if (op->getName().getStringRef() == "vector.extract") { - Operation *parent = op->getParentOp(); - while (parent && parent->getName().getStringRef() != "spirv.func") - parent = parent->getParentOp(); - if (parent) - parent->dump(); - } if (failed(op->fold(foldResults))) return cleanupFailure(); diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index af4ea5a..0f28cbc 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -304,7 +304,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, umin = lhsMin.udiv(rhsMax); // X u/ Y u<= X. - APInt umax = lhsMax; + const APInt &umax = lhsMax; return ConstantIntRanges::fromUnsigned(umin, umax); } diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index d6b8a8a..e3f075f 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -54,6 +54,7 @@ struct OpStrings { std::string opCppName; SmallVector<std::string> opResultNames; SmallVector<std::string> opOperandNames; + SmallVector<std::string> opRegionNames; }; static std::string joinNameList(llvm::ArrayRef<std::string> names) { @@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) { /// Generates OpStrings from an OperatioOp static OpStrings getStrings(irdl::OperationOp op) { auto operandOp = op.getOp<irdl::OperandsOp>(); - auto resultOp = op.getOp<irdl::ResultsOp>(); + auto regionsOp = op.getOp<irdl::RegionsOp>(); OpStrings strings; strings.opName = op.getSymName(); @@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) { })); } + if (regionsOp) { + strings.opRegionNames = SmallVector<std::string>( + llvm::map_range(regionsOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast<StringAttr>(attr)); + })); + } + return strings; } @@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict, static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { const auto operandCount = strings.opOperandNames.size(); const auto resultCount = strings.opResultNames.size(); + const auto regionCount = strings.opRegionNames.size(); dict["OP_NAME"] = strings.opName; dict["OP_CPP_NAME"] = strings.opCppName; @@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}"; dict["OP_RESULT_INITIALIZER_LIST"] = resultCount ? joinNameList(strings.opResultNames) : "{\"\"}"; + dict["OP_REGION_COUNT"] = std::to_string(regionCount); } /// Fills a dictionary with values from DialectStrings @@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings) { auto opGetters = std::string{}; auto resGetters = std::string{}; + auto regionGetters = std::string{}; + auto regionAdaptorGetters = std::string{}; for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { const auto op = @@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, op, i); } + for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true); + regionAdaptorGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; } + )", + op, i); + regionGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); } + )", + op, i); + } + dict["OP_OPERAND_GETTER_DECLS"] = opGetters; dict["OP_RESULT_GETTER_DECLS"] = resGetters; + dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters; + dict["OP_REGION_GETTER_DECLS"] = regionGetters; } static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, @@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, dict["OP_BUILD_DECLS"] = buildDecls; } +// add traits to the dictionary, return true if any were added +static SmallVector<std::string> generateTraits(irdl::OperationOp op, + const OpStrings &strings) { + SmallVector<std::string> cppTraitNames; + if (!strings.opRegionNames.empty()) { + cppTraitNames.push_back( + llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl", + strings.opRegionNames.size()) + .str()); + + // Requires verifyInvariantsImpl is implemented on the op + cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants"); + } + return cppTraitNames; +} + static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict) { @@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op, const auto opStrings = getStrings(op); fillDict(dict, opStrings); + SmallVector<std::string> traitNames = generateTraits(op, opStrings); + if (traitNames.empty()) + dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName; + else + dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName, + llvm::join(traitNames, ", ")); + generateOpGetterDeclarations(dict, opStrings); generateOpBuilderDeclarations(dict, opStrings); @@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect, return success(); } +static void generateRegionConstraintVerifiers( + irdl::detail::dictionary &dict, irdl::OperationOp op, + const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers, + SmallVectorImpl<std::string> &verifierCalls) { + auto regionsOp = op.getOp<irdl::RegionsOp>(); + if (strings.opRegionNames.empty() || !regionsOp) + return; + + for (size_t i = 0; i < strings.opRegionNames.size(); ++i) { + std::string regionName = strings.opRegionNames[i]; + std::string helperFnName = + llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}", + strings.opCppName, regionName) + .str(); + + // Extract the actual region constraint from the IRDL RegionOp + std::string condition = "true"; + std::string textualConditionName = "any region"; + + if (auto regionDefOp = + dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) { + // Generate constraint condition based on RegionOp attributes + SmallVector<std::string> conditionParts; + SmallVector<std::string> descriptionParts; + + // Check number of blocks constraint + if (auto blockCount = regionDefOp.getNumberOfBlocks()) { + conditionParts.push_back( + llvm::formatv("region.getBlocks().size() == {0}", + blockCount.value()) + .str()); + descriptionParts.push_back( + llvm::formatv("exactly {0} block(s)", blockCount.value()).str()); + } + + // Check entry block arguments constraint + if (regionDefOp.getConstrainedArguments()) { + size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size(); + conditionParts.push_back( + llvm::formatv("region.getNumArguments() == {0}", expectedArgCount) + .str()); + descriptionParts.push_back( + llvm::formatv("{0} entry block argument(s)", expectedArgCount) + .str()); + } + + // Combine conditions + if (!conditionParts.empty()) { + condition = llvm::join(conditionParts, " && "); + } + + // Generate descriptive error message + if (!descriptionParts.empty()) { + textualConditionName = + llvm::formatv("region with {0}", + llvm::join(descriptionParts, " and ")) + .str(); + } + } + + verifierHelpers.push_back(llvm::formatv( + R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{ + if (!({1})) {{ + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; + } + return ::mlir::success(); +})", + helperFnName, condition, textualConditionName)); + + verifierCalls.push_back(llvm::formatv(R"( + if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1}))) + return ::mlir::failure();)", + helperFnName, i, regionName) + .str()); + } +} + +static void generateVerifiers(irdl::detail::dictionary &dict, + irdl::OperationOp op, const OpStrings &strings) { + SmallVector<std::string> verifierHelpers; + SmallVector<std::string> verifierCalls; + + generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers, + verifierCalls); + + // Add an overall verifier that sequences the helper calls + std::string verifierDef = + llvm::formatv(R"( +::llvm::LogicalResult {0}::verifyInvariantsImpl() {{ + if(::mlir::failed(verify())) + return ::mlir::failure(); + + {1} + + return ::mlir::success(); +})", + strings.opCppName, llvm::join(verifierCalls, "\n")); + + dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n"); + dict["OP_VERIFIER"] = verifierDef; +} + static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op) { static const auto perOpDefTemplate = mlir::irdl::detail::Template{ @@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, { dict["OP_BUILD_DEFS"] = buildDefinition; + generateVerifiers(dict, op, opStrings); + std::string str; llvm::raw_string_ostream stream{str}; perOpDefTemplate.render(stream, dict); @@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, dict["TYPE_PARSER"] = llvm::formatv( R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) - {0} + {0} .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ *mnemonic = keyword; return std::nullopt; @@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { "IRDL C++ translation does not yet support variadic results"); })) .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); })) + .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); })) + .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); })) .Default([](mlir::Operation *op) -> LogicalResult { return op->emitError("IRDL C++ translation does not yet support " "translation of ") diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt index e9068e9..93ce0be 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt @@ -12,15 +12,15 @@ public: struct Properties { }; public: - __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) - : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), - odsRegions(op->getRegions()) + __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) + : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), + odsRegions(op->getRegions()) {} /// Return the unstructured operand index of a structured operand along with // the amount of unstructured operands it contains. std::pair<unsigned, unsigned> - getStructuredOperandIndexAndLength (unsigned index, + getStructuredOperandIndexAndLength (unsigned index, unsigned odsOperandsSize) { return {index, 1}; } @@ -32,6 +32,12 @@ public: ::mlir::DictionaryAttr getAttributes() { return odsAttrs; } + + __OP_REGION_ADAPTER_GETTER_DECLS__ + + ::mlir::RegionRange getRegions() { + return odsRegions; + } protected: ::mlir::DictionaryAttr odsAttrs; ::std::optional<::mlir::OperationName> odsOpName; @@ -42,28 +48,28 @@ protected: } // namespace detail template <typename RangeT> -class __OP_CPP_NAME__GenericAdaptor +class __OP_CPP_NAME__GenericAdaptor : public detail::__OP_CPP_NAME__GenericAdaptorBase { using ValueT = ::llvm::detail::ValueOfRange<RangeT>; using Base = detail::__OP_CPP_NAME__GenericAdaptorBase; public: __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, - ::mlir::OpaqueProperties properties, - ::mlir::RegionRange regions = {}) - : __OP_CPP_NAME__GenericAdaptor(values, attrs, - (properties ? *properties.as<::mlir::EmptyProperties *>() + ::mlir::OpaqueProperties properties, + ::mlir::RegionRange regions = {}) + : __OP_CPP_NAME__GenericAdaptor(values, attrs, + (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} - __OP_CPP_NAME__GenericAdaptor(RangeT values, + __OP_CPP_NAME__GenericAdaptor(RangeT values, const __OP_CPP_NAME__GenericAdaptorBase &base) : Base(base), odsOperands(values) {} - // This template parameter allows using __OP_CPP_NAME__ which is declared + // This template parameter allows using __OP_CPP_NAME__ which is declared // later. template <typename LateInst = __OP_CPP_NAME__, typename = std::enable_if_t< std::is_same_v<LateInst, __OP_CPP_NAME__>>> - __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) + __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} /// Return the unstructured operand index of a structured operand along with @@ -77,7 +83,7 @@ public: RangeT getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(odsOperands.begin(), valueRange.first), - std::next(odsOperands.begin(), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; } @@ -91,7 +97,7 @@ private: RangeT odsOperands; }; -class __OP_CPP_NAME__Adaptor +class __OP_CPP_NAME__Adaptor : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> { public: using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor; @@ -100,7 +106,7 @@ public: ::llvm::LogicalResult verify(::mlir::Location loc); }; -class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> { +class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> { public: using Op::Op; using Op::print; @@ -112,6 +118,8 @@ public: return {}; } + ::llvm::LogicalResult verifyInvariantsImpl(); + static constexpr ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__"); } @@ -147,7 +155,7 @@ public: ::mlir::Operation::operand_range getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(getOperation()->operand_begin(), valueRange.first), - std::next(getOperation()->operand_begin(), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; } @@ -162,18 +170,19 @@ public: ::mlir::Operation::result_range getStructuredResults(unsigned index) { auto valueRange = getStructuredResultIndexAndLength(index); return {std::next(getOperation()->result_begin(), valueRange.first), - std::next(getOperation()->result_begin(), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; } __OP_OPERAND_GETTER_DECLS__ __OP_RESULT_GETTER_DECLS__ - + __OP_REGION_GETTER_DECLS__ + __OP_BUILD_DECLS__ - static void build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder, diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt index 30ca420..f4a1b7a 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt @@ -6,12 +6,14 @@ R"( __NAMESPACE_OPEN__ +__OP_VERIFIER_HELPERS__ + __OP_BUILD_DEFS__ -void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, +void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { assert(operands.size() == __OP_OPERAND_COUNT__); @@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, odsState.addOperands(operands); odsState.addAttributes(attributes); odsState.addTypes(resultTypes); + for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) { + (void)odsState.addRegion(); + } } __OP_CPP_NAME__ @@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder, return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes); } +__OP_VERIFIER__ __NAMESPACE_CLOSE__ diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 53209a4..9fcb02e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, return success(); } +/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the +/// OpenMPIRBuilder. +static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + + SmallVector<llvm::CanonicalLoopInfo *> translatedLoops; + SmallVector<llvm::Value *> translatedSizes; + + for (Value size : op.getSizes()) { + llvm::Value *translatedSize = moduleTranslation.lookupValue(size); + assert(translatedSize && + "sizes clause arguments must already be translated"); + translatedSizes.push_back(translatedSize); + } + + for (Value applyee : op.getApplyees()) { + llvm::CanonicalLoopInfo *consBuilderCLI = + moduleTranslation.lookupOMPLoop(applyee); + assert(applyee && "Canonical loop must already been translated"); + translatedLoops.push_back(consBuilderCLI); + } + + auto generatedLoops = + ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes); + if (!op.getGeneratees().empty()) { + for (auto [mlirLoop, genLoop] : + zip_equal(op.getGeneratees(), generatedLoops)) + moduleTranslation.mapOmpLoop(mlirLoop, genLoop); + } + + // CLIs can only be consumed once + for (Value applyee : op.getApplyees()) + moduleTranslation.invalidateOmpLoop(applyee); + + return success(); +} + /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. static llvm::AtomicOrdering convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) { @@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TileOp op) { + return applyTile(op, builder, moduleTranslation); + }) .Case([&](omp::TargetAllocMemOp) { return convertTargetAllocMemOp(*op, builder, moduleTranslation); }) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bf0136b..3a23bbf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,6 +1856,44 @@ void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { assert(newValues.size() == op->getNumResults() && "incorrect number of replacement values"); + LLVM_DEBUG({ + logger.startLine() << "** Replace : '" << op->getName() << "'(" << op + << ")\n"; + if (currentTypeConverter) { + // If the user-provided replacement types are different from the + // legalized types, as per the current type converter, print a note. + // In most cases, the replacement types are expected to match the types + // produced by the type converter, so this could indicate a bug in the + // user code. + for (auto [result, repls] : + llvm::zip_equal(op->getResults(), newValues)) { + Type resultType = result.getType(); + auto logProlog = [&, repls = repls]() { + logger.startLine() << " Note: Replacing op result of type " + << resultType << " with value(s) of type ("; + llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) { + logger.getOStream() << v.getType(); + }); + logger.getOStream() << ")"; + }; + SmallVector<Type> convertedTypes; + if (failed(currentTypeConverter->convertTypes(resultType, + convertedTypes))) { + logProlog(); + logger.getOStream() << ", but the type converter failed to legalize " + "the original type.\n"; + continue; + } + if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) { + logProlog(); + logger.getOStream() << ", but the legalized type(s) is/are ("; + llvm::interleaveComma(convertedTypes, logger.getOStream(), + [&](Type t) { logger.getOStream() << t; }); + logger.getOStream() << ")\n"; + } + } + } + }); if (!config.allowPatternRollback) { // Pattern rollback is not allowed: materialize all IR changes immediately. @@ -2072,10 +2110,6 @@ void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - LLVM_DEBUG({ - impl->logger.startLine() - << "** Replace : '" << op->getName() << "'(" << op << ")\n"; - }); // If the current insertion point is before the erased operation, we adjust // the insertion point to be after the operation. @@ -2093,10 +2127,6 @@ void ConversionPatternRewriter::replaceOpWithMultiple( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - LLVM_DEBUG({ - impl->logger.startLine() - << "** Replace : '" << op->getName() << "'(" << op << ")\n"; - }); // If the current insertion point is before the erased operation, we adjust // the insertion point to be after the operation. diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index bf40cc5..e3bacb5 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -44,18 +44,12 @@ class BufferizeToAllocationOp(BufferizeToAllocationOp): loc=None, ip=None, ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - if isinstance(memory_space, int): memory_space = str(memory_space) if isinstance(memory_space, str): memory_space = Attribute.parse(memory_space) super().__init__( - allocated_buffer_type, - new_ops_type, target, memory_space=memory_space, memcpy_op=memcpy_op, diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py index f63f88a..b3bfa80 100644 --- a/mlir/python/mlir/dialects/transform/tune.py +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -6,6 +6,9 @@ from typing import Optional, Sequence from ...ir import ( Type, + Value, + Operation, + OpView, Attribute, ArrayAttr, StringAttr, @@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import * from .._transform_tune_extension_ops_gen import _Dialect try: - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -36,7 +42,7 @@ class KnobOp(KnobOp): ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): @@ -75,8 +81,62 @@ def knob( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AlternativesOp(AlternativesOp): + def __init__( + self, + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[ + Union[int, IntegerAttr, Value, Operation, OpView] + ] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + selected_region_attr = selected_region_param = None + if isinstance(selected_region, IntegerAttr): + selected_region_attr = selected_region + elif isinstance(selected_region, int): + selected_region_attr = IntegerAttr.get( + IntegerType.get_signless(32), selected_region + ) + elif isinstance(selected_region, (Value, Operation, OpView)): + selected_region_param = _get_op_result_or_value(selected_region) + + super().__init__( + results, + name, + num_alternatives, + selected_region_attr=selected_region_attr, + selected_region_param=selected_region_param, + loc=loc, + ip=ip, + ) + for region in self.regions: + region.blocks.append() + + +def alternatives( + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None, + loc=None, + ip=None, +): + return AlternativesOp( + results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip + ) diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 45b1a1f..0cbe064 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -195,6 +195,36 @@ func.func @assume_alignment(%0 : memref<4x4xf16>) { // ----- +// ALL-LABEL: func @distinct_objects +// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>) +func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) { +// ALL-DAG: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xf16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : memref<?xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL-DAG: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : memref<?xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_0:.*]] = llvm.extractvalue %[[CAST_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_1:.*]] = llvm.extractvalue %[[CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[PTR_2:.*]] = llvm.extractvalue %[[CAST_2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// ALL: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_1]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_0]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 +// ALL: llvm.intr.assume %[[TRUE]] ["separate_storage"(%[[PTR_1]], %[[PTR_2]] : !llvm.ptr, !llvm.ptr)] : i1 + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64> + return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64> +} + +// ----- + +// ALL-LABEL: func @distinct_objects_noop +// ALL-SAME: (%[[ARG0:.*]]: memref<?xf16>) +func.func @distinct_objects_noop(%arg0: memref<?xf16>) -> memref<?xf16> { +// 1-operand version is noop +// ALL-NEXT: return %[[ARG0]] + %1 = memref.distinct_objects %arg0 : memref<?xf16> + return %1 : memref<?xf16> +} + +// ----- + // CHECK-LABEL: func @assume_alignment_w_offset // CHECK-INTERFACE-LABEL: func @assume_alignment_w_offset func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) { diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a..2fe0995 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { return %2 : f32 } +// CHECK-LABEL: @test_mulf2( +func.func @test_mulf2(%arg0 : f32) -> (f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32 + // CHECK-NEXT: return %[[C0]], %[[C0n]] + %c0 = arith.constant 0.0 : f32 + %c0n = arith.constant -0.0 : f32 + %0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32 + %1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32 + return %0, %1 : f32, f32 +} + // ----- // CHECK-LABEL: @test_divf( diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index 99790cc..fcd004a 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -85,3 +85,14 @@ func.func @no_expansion(%x: f32) -> f32 { %y = arith.addf %x, %c : f32 func.return %y : f32 } + +// ----- + +func.func @no_promote_select(%c: i1, %x: bf16, %y: bf16) -> bf16 { +// CHECK-LABEL: @no_promote_select +// CHECK-SAME: (%[[C:.+]]: i1, %[[X:.+]]: bf16, %[[Y:.+]]: bf16) +// CHECK: %[[Z:.+]] = arith.select %[[C]], %[[X]], %[[Y]] : bf16 +// CHECK: return %[[Z]] + %z = arith.select %c, %x, %y : bf16 + func.return %z : bf16 +} diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 0bad151..6134695 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1068,6 +1068,38 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { // ----- +// CHECK-LABEL: rocdl.cvt.scalef32.pk8 +llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %scale: f32) { + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.f32 + %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.f32 + %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.f32 + %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.f16 + %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.f16 + %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.f16 + %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.pk8.fp8.bf16 + %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.bf8.bf16 + %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.pk8.fp4.bf16 + %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32 + + llvm.return +} + +// ----- + // CHECK-LABEL: rocdl.cvt.scale.pk16 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir new file mode 100644 index 0000000..29fb9f1 --- /dev/null +++ b/mlir/test/Dialect/Math/sincos-fusion.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s + +// CHECK-LABEL: func.func @sincos_fusion( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32 +// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32 +// CHECK: } +func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) { + %0 = math.sin %arg0 : f32 + %1 = math.cos %arg0 : f32 + + %2 = math.cos %arg1 : f32 + %3 = math.sin %arg1 : f32 + + func.return %0, %1, %2, %3 : f32, f32, f32, f32 +} + +func.func private @sink(%arg0 : f32) + +// CHECK: func.func private @sink(f32) +// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32 +// CHECK: call @sink(%[[VAL_0]]) : (f32) -> () +// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32 +// CHECK: call @sink(%[[VAL_3]]) : (f32) -> () +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32 +// CHECK: } +func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) { + %0 = math.sin %arg0 : f32 + func.call @sink(%0) : (f32) -> () + %1 = math.cos %arg0 : f32 + %2 = math.cos %arg1 : f32 + func.call @sink(%2) : (f32) -> () + %3 = math.sin %arg1 : f32 + func.return %0, %1, %2, %3 : f32, f32, f32, f32 +} + +// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) { +// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32 +// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32 +// CHECK: } +func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) { + %0 = math.sin %arg0 fastmath<contract> : f32 + %1 = math.cos %arg0 : f32 + func.return %0, %1 : f32, f32 +} + +// CHECK-LABEL: func.func @sincos_no_fusion_different_block( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 { +// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) { +// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32 +// CHECK: scf.yield %[[VAL_1]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32 +// CHECK: scf.yield %[[VAL_2]] : f32 +// CHECK: } +// CHECK: return %[[VAL_0]] : f32 +// CHECK: } +func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 { + %0 = scf.if %flag -> f32 { + %s = math.sin %arg0 : f32 + scf.yield %s : f32 + } else { + %c = math.cos %arg0 : f32 + scf.yield %c : f32 + } + func.return %0 : f32 +} + +// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32 +// CHECK: } +func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) { + %0 = math.sin %arg0 fastmath<contract> : f32 + %1 = math.cos %arg0 fastmath<contract> : f32 + func.return %0, %1 : f32, f32 +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 3f96d90..5ff2920 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1169,3 +1169,19 @@ func.func @expand_shape_invalid_output_shape( into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> return } + +// ----- + +func.func @distinct_objects_types_mismatch(%arg0: memref<?xf32>, %arg1: memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) { + // expected-error @+1 {{operand types and result types must match}} + %0, %1 = "memref.distinct_objects"(%arg0, %arg1) : (memref<?xf32>, memref<?xi32>) -> (memref<?xi32>, memref<?xf32>) + return %0, %1 : memref<?xi32>, memref<?xf32> +} + +// ----- + +func.func @distinct_objects_0_operands() { + // expected-error @+1 {{expected at least one operand}} + "memref.distinct_objects"() : () -> () + return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 6c2298a..a90c950 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -302,6 +302,15 @@ func.func @assume_alignment(%0: memref<4x4xf16>) { return } +// CHECK-LABEL: func @distinct_objects +// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf16>, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf64>) +func.func @distinct_objects(%arg0: memref<?xf16>, %arg1: memref<?xf32>, %arg2: memref<?xf64>) -> (memref<?xf16>, memref<?xf32>, memref<?xf64>) { + // CHECK: %[[RES:.*]]:3 = memref.distinct_objects %[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<?xf16>, memref<?xf32>, memref<?xf64> + %1, %2, %3 = memref.distinct_objects %arg0, %arg1, %arg2 : memref<?xf16>, memref<?xf32>, memref<?xf64> + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : memref<?xf16>, memref<?xf32>, memref<?xf64> + return %1, %2, %3 : memref<?xf16>, memref<?xf32>, memref<?xf64> +} + // CHECK-LABEL: func @expand_collapse_shape_static func.func @expand_collapse_shape_static( %arg0: memref<3x4x5xf32>, diff --git a/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir b/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir index adadb8b..0e9385e 100644 --- a/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir +++ b/mlir/test/Dialect/OpenMP/cli-canonical_loop.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s | FileCheck %s -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s | FileCheck %s --enable-var-scope +// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope // CHECK-LABEL: @omp_canonloop_raw( @@ -24,10 +24,10 @@ func.func @omp_canonloop_raw(%tc : i32) -> () { func.func @omp_canonloop_sequential_raw(%tc : i32) -> () { // CHECK-NEXT: %canonloop_s0 = omp.new_cli %canonloop_s0 = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) { "omp.canonical_loop" (%tc, %canonloop_s0) ({ ^bb_first(%iv_first: i32): - // CHECK-NEXT: = llvm.add %iv, %iv : i32 + // CHECK-NEXT: = llvm.add %iv_s0, %iv_s0 : i32 %newval = llvm.add %iv_first, %iv_first : i32 // CHECK-NEXT: omp.terminator omp.terminator @@ -36,7 +36,7 @@ func.func @omp_canonloop_sequential_raw(%tc : i32) -> () { // CHECK-NEXT: %canonloop_s1 = omp.new_cli %canonloop_s1 = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv : i32 in range(%[[tc]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) { "omp.canonical_loop" (%tc, %canonloop_s1) ({ ^bb_second(%iv_second: i32): // CHECK: omp.terminator @@ -52,17 +52,17 @@ func.func @omp_canonloop_sequential_raw(%tc : i32) -> () { // CHECK-LABEL: @omp_nested_canonloop_raw( // CHECK-SAME: %[[tc_outer:.+]]: i32, %[[tc_inner:.+]]: i32) func.func @omp_nested_canonloop_raw(%tc_outer : i32, %tc_inner : i32) -> () { - // CHECK-NEXT: %canonloop_s0 = omp.new_cli + // CHECK-NEXT: %canonloop = omp.new_cli %outer = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli + // CHECK-NEXT: %canonloop_d1 = omp.new_cli %inner = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc_outer]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc_outer]]) { "omp.canonical_loop" (%tc_outer, %outer) ({ ^bb_outer(%iv_outer: i32): - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc_inner]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc_inner]]) { "omp.canonical_loop" (%tc_inner, %inner) ({ ^bb_inner(%iv_inner: i32): - // CHECK-NEXT: = llvm.add %iv, %iv_0 : i32 + // CHECK-NEXT: = llvm.add %iv, %iv_d1 : i32 %newval = llvm.add %iv_outer, %iv_inner: i32 // CHECK-NEXT: omp.terminator omp.terminator @@ -108,16 +108,24 @@ func.func @omp_canonloop_constant_pretty() -> () { func.func @omp_canonloop_sequential_pretty(%tc : i32) -> () { // CHECK-NEXT: %canonloop_s0 = omp.new_cli %canonloop_s0 = omp.new_cli - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { - omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%tc) { // CHECK-NEXT: omp.terminator omp.terminator } // CHECK: %canonloop_s1 = omp.new_cli %canonloop_s1 = omp.new_cli - // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv : i32 in range(%[[tc]]) { - omp.canonical_loop(%canonloop_s1) %iv_0 : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%tc) { + // CHECK-NEXT: omp.terminator + omp.terminator + } + + // CHECK: %canonloop_s2 = omp.new_cli + %canonloop_s2 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop_s2) %iv_s2 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s2) %iv_s2 : i32 in range(%tc) { // CHECK-NEXT: omp.terminator omp.terminator } @@ -126,17 +134,17 @@ func.func @omp_canonloop_sequential_pretty(%tc : i32) -> () { } -// CHECK-LABEL: @omp_canonloop_nested_pretty( +// CHECK-LABEL: @omp_canonloop_2d_nested_pretty( // CHECK-SAME: %[[tc:.+]]: i32) -func.func @omp_canonloop_nested_pretty(%tc : i32) -> () { - // CHECK-NEXT: %canonloop_s0 = omp.new_cli - %canonloop_s0 = omp.new_cli - // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli - %canonloop_s0_s0 = omp.new_cli - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { - omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%tc) { - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc]]) { - omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%tc) { +func.func @omp_canonloop_2d_nested_pretty(%tc : i32) -> () { + // CHECK-NEXT: %canonloop = omp.new_cli + %canonloop = omp.new_cli + // CHECK-NEXT: %canonloop_d1 = omp.new_cli + %canonloop_d1 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%tc) { // CHECK: omp.terminator omp.terminator } @@ -147,6 +155,77 @@ func.func @omp_canonloop_nested_pretty(%tc : i32) -> () { } +// CHECK-LABEL: @omp_canonloop_3d_nested_pretty( +// CHECK-SAME: %[[tc:.+]]: i32) +func.func @omp_canonloop_3d_nested_pretty(%tc : i32) -> () { + // CHECK: %canonloop = omp.new_cli + %canonloop = omp.new_cli + // CHECK: %canonloop_d1 = omp.new_cli + %canonloop_d1 = omp.new_cli + // CHECK: %canonloop_d2 = omp.new_cli + %canonloop_d2 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_d1) %iv_1d : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%tc) { + // CHECK-NEXT: omp.terminator + omp.terminator + // CHECK-NEXT: } + } + // CHECK-NEXT: omp.terminator + omp.terminator + // CHECK-NEXT: } + } + // CHECK-NEXT: omp.terminator + omp.terminator + } + + return +} + + +// CHECK-LABEL: @omp_canonloop_sequential_nested_pretty( +// CHECK-SAME: %[[tc:.+]]: i32) +func.func @omp_canonloop_sequential_nested_pretty(%tc : i32) -> () { + // CHECK-NEXT: %canonloop_s0 = omp.new_cli + %canonloop_s0 = omp.new_cli + // CHECK-NEXT: %canonloop_s0_d1 = omp.new_cli + %canonloop_s0_d1 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s0) %iv_s0 : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_d1) %iv_s0_d1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s0_d1) %iv_s0_d1 : i32 in range(%tc) { + // CHECK-NEXT: omp.terminator + omp.terminator + // CHECK-NEXT: } + } + // CHECK-NEXT: omp.terminator + omp.terminator + // CHECK-NEXT: } + } + + // CHECK-NEXT: %canonloop_s1 = omp.new_cli + %canonloop_s1 = omp.new_cli + // CHECK-NEXT: %canonloop_s1_d1 = omp.new_cli + %canonloop_s1_d1 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s1) %iv_s1 : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_s1_d1) %iv_s1_d1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop_s1_d1) %iv_s1d1 : i32 in range(%tc) { + // CHECK-NEXT: omp.terminator + omp.terminator + // CHECK-NEXT: } + } + // CHECK-NEXT: omp.terminator + omp.terminator + } + + return +} + + // CHECK-LABEL: @omp_newcli_unused( // CHECK-SAME: ) func.func @omp_newcli_unused() -> () { @@ -155,3 +234,74 @@ func.func @omp_newcli_unused() -> () { // CHECK-NEXT: return return } + + +// CHECK-LABEL: @omp_canonloop_multiregion_isolatedfromabove( +func.func @omp_canonloop_multiregion_isolatedfromabove() -> () { + omp.private {type = firstprivate} @x.privatizer : !llvm.ptr init { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %c42_i32 = arith.constant 42: i32 + // CHECK: omp.canonical_loop %iv : i32 in range(%c42_i32) { + omp.canonical_loop %iv1 : i32 in range(%c42_i32) { + omp.terminator + } + // CHECK: omp.yield + omp.yield(%arg0 : !llvm.ptr) + } copy { + ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + %c42_i32 = arith.constant 42: i32 + // CHECK: omp.canonical_loop %iv : i32 in range(%c42_i32) { + omp.canonical_loop %iv : i32 in range(%c42_i32) { + // CHECK: omp.canonical_loop %iv_d1 : i32 in range(%c42_i32) { + omp.canonical_loop %iv_d1 : i32 in range(%c42_i32) { + omp.terminator + } + omp.terminator + } + // CHECK: omp.yield + omp.yield(%arg0 : !llvm.ptr) + } dealloc { + ^bb0(%arg0: !llvm.ptr): + %c42_i32 = arith.constant 42: i32 + // CHECK: omp.canonical_loop %iv_s0 : i32 in range(%c42_i32) { + omp.canonical_loop %iv_s0 : i32 in range(%c42_i32) { + omp.terminator + } + // CHECK: omp.canonical_loop %iv_s1 : i32 in range(%c42_i32) { + omp.canonical_loop %iv_s1 : i32 in range(%c42_i32) { + omp.terminator + } + // CHECK: omp.yield + omp.yield + } + + // CHECK: return + return +} + + +// CHECK-LABEL: @omp_canonloop_multiregion( +func.func @omp_canonloop_multiregion(%c : i1) -> () { + %c42_i32 = arith.constant 42: i32 + %canonloop1 = omp.new_cli + %canonloop2 = omp.new_cli + %canonloop3 = omp.new_cli + scf.if %c { + // CHECK: omp.canonical_loop(%canonloop_r0) %iv_r0 : i32 in range(%c42_i32) { + omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%c42_i32) { + omp.terminator + } + } else { + // CHECK: omp.canonical_loop(%canonloop_r1_s0) %iv_r1_s0 : i32 in range(%c42_i32) { + omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%c42_i32) { + omp.terminator + } + // CHECK: omp.canonical_loop(%canonloop_r1_s1) %iv_r1_s1 : i32 in range(%c42_i32) { + omp.canonical_loop(%canonloop3) %iv3 : i32 in range(%c42_i32) { + omp.terminator + } + } + + // CHECK: return + return +} diff --git a/mlir/test/Dialect/OpenMP/cli-tile.mlir b/mlir/test/Dialect/OpenMP/cli-tile.mlir new file mode 100644 index 0000000..73d5478 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/cli-tile.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-opt %s | FileCheck %s --enable-var-scope +// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope + + +// Raw syntax check (MLIR output is always pretty-printed) +// CHECK-LABEL: @omp_tile_raw( +// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) { +func.func @omp_tile_raw(%tc : i32, %ts : i32) -> () { + // CHECK-NEXT: %canonloop = omp.new_cli + %canonloop = "omp.new_cli" () : () -> (!omp.cli) + // CHECK-NEXT: %grid1 = omp.new_cli + %grid = "omp.new_cli" () : () -> (!omp.cli) + // CHECK-NEXT: %intratile1 = omp.new_cli + %intratile = "omp.new_cli" () : () -> (!omp.cli) + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + "omp.canonical_loop" (%tc, %canonloop) ({ + ^bb0(%iv: i32): + // CHECK: omp.terminator + omp.terminator + }) : (i32, !omp.cli) -> () + // CHECK: omp.tile (%grid1, %intratile1) <- (%canonloop) sizes(%[[ts]] : i32) + "omp.tile"(%grid, %intratile, %canonloop, %ts) <{operandSegmentSizes = array<i32: 2, 1, 1>}> : (!omp.cli, !omp.cli, !omp.cli, i32) -> () + //"omp.tile" (%canonloop) : (!omp.cli) -> () + return +} + + +// Pretty syntax check +// CHECK-LABEL: @omp_tile_pretty( +// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) { +func.func @omp_tile_pretty(%tc : i32, %ts : i32) -> () { + // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli + %canonloop = omp.new_cli + // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli + %grid = omp.new_cli + // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli + %intratile = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.tile (%grid1, %intratile1) <- (%canonloop) sizes(%[[ts]] : i32) + omp.tile(%grid, %intratile) <- (%canonloop) sizes(%ts : i32) + return +} + + +// Specifying the generatees for omp.tile is optional +// CHECK-LABEL: @omp_tile_optionalgen_pretty( +// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) { +func.func @omp_tile_optionalgen_pretty(%tc : i32, %ts : i32) -> () { + // CHECK-NEXT: %canonloop = omp.new_cli + %canonloop = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.tile <- (%canonloop) sizes(%[[ts]] : i32) + omp.tile <- (%canonloop) sizes(%ts : i32) + return +} + + +// Two-dimensional tiling +// CHECK-LABEL: @omp_tile_2d_pretty( +// CHECK-SAME: %[[tc1:.+]]: i32, %[[tc2:.+]]: i32, %[[ts1:.+]]: i32, %[[ts2:.+]]: i32) { +func.func @omp_tile_2d_pretty(%tc1 : i32, %tc2 : i32, %ts1 : i32, %ts2 : i32) -> () { + // CHECK-NEXT: %canonloop = omp.new_cli + %cli_outer = omp.new_cli + // CHECK-NEXT: %canonloop_d1 = omp.new_cli + %cli_inner = omp.new_cli + // CHECK-NEXT: %grid1 = omp.new_cli + %grid1 = omp.new_cli + // CHECK-NEXT: %grid2 = omp.new_cli + %grid2 = omp.new_cli + // CHECK-NEXT: %intratile1 = omp.new_cli + %intratile1 = omp.new_cli + // CHECK-NEXT: %intratile2 = omp.new_cli + %intratile2 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc1]]) { + omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc1) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc2]]) { + omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc2) { + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.tile (%grid1, %grid2, %intratile1, %intratile2) <- (%canonloop, %canonloop_d1) sizes(%[[ts1]], %[[ts2]] : i32, i32) + omp.tile (%grid1, %grid2, %intratile1, %intratile2) <- (%cli_outer, %cli_inner) sizes(%ts1, %ts2 : i32, i32) + return +} + + +// Three-dimensional tiling +// CHECK-LABEL: @omp_tile_3d_pretty( +// CHECK-SAME: %[[tc:.+]]: i32, %[[ts:.+]]: i32) { +func.func @omp_tile_3d_pretty(%tc : i32, %ts : i32) -> () { + // CHECK-NEXT: %canonloop = omp.new_cli + %cli_outer = omp.new_cli + // CHECK-NEXT: %canonloop_d1 = omp.new_cli + %cli_middle = omp.new_cli + // CHECK-NEXT: %canonloop_d2 = omp.new_cli + %cli_inner = omp.new_cli + // CHECK-NEXT: %grid1 = omp.new_cli + %grid1 = omp.new_cli + // CHECK-NEXT: %grid2 = omp.new_cli + %grid2 = omp.new_cli + // CHECK-NEXT: %grid3 = omp.new_cli + %grid3 = omp.new_cli + // CHECK-NEXT: %intratile1 = omp.new_cli + %intratile1 = omp.new_cli + // CHECK-NEXT: %intratile2 = omp.new_cli + %intratile2 = omp.new_cli + // CHECK-NEXT: %intratile3 = omp.new_cli + %intratile3 = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { + omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) { + omp.canonical_loop(%cli_middle) %iv_middle : i32 in range(%tc) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d2) %iv_d2 : i32 in range(%[[tc]]) { + omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc) { + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.terminator + omp.terminator + } + // CHECK: omp.tile (%grid1, %grid2, %grid3, %intratile1, %intratile2, %intratile3) <- (%canonloop, %canonloop_d1, %canonloop_d2) sizes(%[[ts]], %[[ts]], %[[ts]] : i32, i32, i32) + omp.tile (%grid1, %grid2, %grid3, %intratile1, %intratile2, %intratile3) <- (%cli_outer, %cli_middle, %cli_inner) sizes(%ts, %ts, %ts: i32, i32, i32) + return +} diff --git a/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir b/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir index cda7d0b..16884f4 100644 --- a/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir +++ b/mlir/test/Dialect/OpenMP/cli-unroll-heuristic.mlir @@ -1,18 +1,18 @@ -// RUN: mlir-opt %s | FileCheck %s -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s | FileCheck %s --enable-var-scope +// RUN: mlir-opt %s | mlir-opt | FileCheck %s --enable-var-scope // CHECK-LABEL: @omp_unroll_heuristic_raw( // CHECK-SAME: %[[tc:.+]]: i32) { func.func @omp_unroll_heuristic_raw(%tc : i32) -> () { - // CHECK-NEXT: %canonloop_s0 = omp.new_cli + // CHECK-NEXT: %canonloop = omp.new_cli %canonloop = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { "omp.canonical_loop" (%tc, %canonloop) ({ ^bb0(%iv: i32): omp.terminator }) : (i32, !omp.cli) -> () - // CHECK: omp.unroll_heuristic(%canonloop_s0) + // CHECK: omp.unroll_heuristic(%canonloop) "omp.unroll_heuristic" (%canonloop) : (!omp.cli) -> () return } @@ -22,12 +22,12 @@ func.func @omp_unroll_heuristic_raw(%tc : i32) -> () { // CHECK-SAME: %[[tc:.+]]: i32) { func.func @omp_unroll_heuristic_pretty(%tc : i32) -> () { // CHECK-NEXT: %[[CANONLOOP:.+]] = omp.new_cli - %canonloop = "omp.new_cli" () : () -> (!omp.cli) - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { + %canonloop = omp.new_cli + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { omp.terminator } - // CHECK: omp.unroll_heuristic(%canonloop_s0) + // CHECK: omp.unroll_heuristic(%canonloop) omp.unroll_heuristic(%canonloop) return } @@ -36,13 +36,13 @@ func.func @omp_unroll_heuristic_pretty(%tc : i32) -> () { // CHECK-LABEL: @omp_unroll_heuristic_nested_pretty( // CHECK-SAME: %[[tc:.+]]: i32) { func.func @omp_unroll_heuristic_nested_pretty(%tc : i32) -> () { - // CHECK-NEXT: %canonloop_s0 = omp.new_cli + // CHECK-NEXT: %canonloop = omp.new_cli %cli_outer = omp.new_cli - // CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli + // CHECK-NEXT: %canonloop_d1 = omp.new_cli %cli_inner = omp.new_cli - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%[[tc]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop) %iv : i32 in range(%[[tc]]) { omp.canonical_loop(%cli_outer) %iv_outer : i32 in range(%tc) { - // CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_0 : i32 in range(%[[tc]]) { + // CHECK-NEXT: omp.canonical_loop(%canonloop_d1) %iv_d1 : i32 in range(%[[tc]]) { omp.canonical_loop(%cli_inner) %iv_inner : i32 in range(%tc) { // CHECK: omp.terminator omp.terminator @@ -51,9 +51,9 @@ func.func @omp_unroll_heuristic_nested_pretty(%tc : i32) -> () { omp.terminator } - // CHECK: omp.unroll_heuristic(%canonloop_s0) + // CHECK: omp.unroll_heuristic(%canonloop) omp.unroll_heuristic(%cli_outer) - // CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0_s0) + // CHECK-NEXT: omp.unroll_heuristic(%canonloop_d1) omp.unroll_heuristic(%cli_inner) return } diff --git a/mlir/test/Dialect/OpenMP/invalid-tile.mlir b/mlir/test/Dialect/OpenMP/invalid-tile.mlir new file mode 100644 index 0000000..e63a062 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/invalid-tile.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s + + +func.func @missing_sizes(%tc : i32, %ts : i32) { + %canonloop = omp.new_cli + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + omp.terminator + } + + // expected-error@+1 {{'omp.tile' op there must be one tile size for each applyee}} + omp.tile <-(%canonloop) + + llvm.return +} + +// ----- + +func.func @no_loop(%tc : i32, %ts : i32) { + // expected-error@+1 {{'omp.tile' op must apply to at least one loop}} + omp.tile <-() + + return +} + +// ----- + +func.func @missing_generator(%tc : i32, %ts : i32) { + // expected-error@+1 {{'omp.new_cli' op CLI has no generator}} + %canonloop = omp.new_cli + + // expected-note@+1 {{see consumer here: "omp.tile"(%0, %arg1) <{operandSegmentSizes = array<i32: 0, 1, 1>}> : (!omp.cli, i32) -> ()}} + omp.tile <-(%canonloop) sizes(%ts : i32) + + return +} + +// ----- + +func.func @insufficient_sizes(%tc : i32, %ts : i32) { + %canonloop1 = omp.new_cli + %canonloop2 = omp.new_cli + omp.canonical_loop(%canonloop1) %iv : i32 in range(%tc) { + omp.terminator + } + omp.canonical_loop(%canonloop2) %iv : i32 in range(%tc) { + omp.terminator + } + + // expected-error@+1 {{'omp.tile' op there must be one tile size for each applyee}} + omp.tile <-(%canonloop1, %canonloop2) sizes(%ts : i32) + + llvm.return +} + +// ----- + +func.func @insufficient_applyees(%tc : i32, %ts : i32) { + %canonloop = omp.new_cli + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + omp.terminator + } + + // expected-error@+1 {{omp.tile' op there must be one tile size for each applyee}} + omp.tile <- (%canonloop) sizes(%ts, %ts : i32, i32) + + return +} + +// ----- + +func.func @insufficient_generatees(%tc : i32, %ts : i32) { + %canonloop = omp.new_cli + %grid = omp.new_cli + omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + omp.terminator + } + + // expected-error@+1 {{'omp.tile' op expecting two times the number of generatees than applyees}} + omp.tile (%grid) <- (%canonloop) sizes(%ts : i32) + + return +} + +// ----- + +func.func @not_perfectly_nested(%tc : i32, %ts : i32) { + %canonloop1 = omp.new_cli + %canonloop2 = omp.new_cli + omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%tc) { + %v = arith.constant 42 : i32 + omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%tc) { + omp.terminator + } + omp.terminator + } + + // expected-error@+1 {{'omp.tile' op tiled loop nest must be perfectly nested}} + omp.tile <-(%canonloop1, %canonloop2) sizes(%ts, %ts : i32, i32) + + llvm.return +} + +// ----- + +func.func @non_nectangular(%tc : i32, %ts : i32) { + %canonloop1 = omp.new_cli + %canonloop2 = omp.new_cli + omp.canonical_loop(%canonloop1) %iv1 : i32 in range(%tc) { + omp.canonical_loop(%canonloop2) %iv2 : i32 in range(%iv1) { + omp.terminator + } + omp.terminator + } + + // expected-error@+1 {{'omp.tile' op tiled loop nest must be rectangular}} + omp.tile <-(%canonloop1, %canonloop2) sizes(%ts, %ts : i32, i32) + + llvm.return +} diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir new file mode 100644 index 0000000..bc9a05a --- /dev/null +++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir @@ -0,0 +1,104 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: @promote_in0 +// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}}) +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64} +// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]] +// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}} +func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { + %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>) + outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %mm = transform.structured.match ops{["linalg.matmul"]} in %root + : (!transform.any_op) -> !transform.any_op + %op0 = transform.get_operand %mm[0] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %op0 : !transform.any_value + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @promote_out +// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) +func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]] + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]] + // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64} + // CHECK-NOT: materialize_in_destination + // CHECK: linalg.add {{.*}} outs(%[[ALLOC]] + %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>) + outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %la = transform.structured.match ops{["linalg.add"]} in %root + : (!transform.any_op) -> !transform.any_op + %init = transform.get_operand %la[2] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %init : !transform.any_value + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @promote_in0_out_bufferize +// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) +func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { + // CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>> + // CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>> + // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[C1:.+]] = arith.constant 1 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1> + // CHECK: %{{.+}} = arith.constant 0 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1> + // CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1> + // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>) + %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>) + outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %la = transform.structured.match ops{["linalg.add"]} in %root + : (!transform.any_op) -> !transform.any_op + %op0 = transform.get_operand %la[0] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %op0 : !transform.any_value + + %init = transform.get_operand %la[2] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %init : !transform.any_value + + %func = transform.structured.match ops{["func.func"]} in %root + : (!transform.any_op) -> !transform.any_op + + %bufferized = transform.bufferization.one_shot_bufferize %func + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} + + + diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir index 2e5f433..efc3890 100644 --- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir +++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir @@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}} + transform.tune.alternatives<"bifurcation"> selected_region = 2 { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param + // expected-error@below {{param should hold exactly one integer attribute, got: [0]}} + transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %c0 = transform.param.constant 0 -> !transform.any_param + %c1 = transform.param.constant 1 -> !transform.any_param + %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param + // expected-error@below {{param should hold exactly one integer attribute}} + transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %c2 = transform.param.constant 2 -> !transform.any_param + // expected-error@below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}} + transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param { + transform.yield + }, { + transform.yield + } + transform.yield + } +} + +// ----- + +func.func private @f() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + // expected-error@below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}} + transform.tune.alternatives<"bifurcation"> { + transform.yield + }, { + transform.yield + } + transform.yield + } +} diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir index 0a253c6..5da48a2 100644 --- a/mlir/test/Dialect/Transform/test-tune-extension.mlir +++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir @@ -59,3 +59,129 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +// ----- + +// CHECK-LABEL: schedule_with_two_independent_choices_already_made +func.func @schedule_with_two_independent_choices_already_made( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK-NOT: scf.forall +// CHECK: scf.for +// CHECK-NOT: scf.for +// CHECK: scf.forall +// CHECK-NOT: scf.for +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: linalg.matmul +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice +// CHECK: tensor.insert_slice +// CHECK: scf.yield + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op + { // First alternative/region, with index = 0 + %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { // Second alternative/region, with index = 1 + %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + + transform.tune.alternatives<"inner_par_or_seq_tiling"> selected_region = 1 -> !transform.any_op { + %contained_matmul, %loop = transform.structured.tile_using_for %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { + %contained_matmul, %loop = transform.structured.tile_using_forall %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: subschedule_with_choice_resolved_in_main_schedule +func.func @subschedule_with_choice_resolved_in_main_schedule( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK-NOT: scf.for +// CHECK: scf.forall +// CHECK-NOT: scf.forall +// CHECK: scf.for +// CHECK-NOT: scf.forall +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: tensor.extract_slice +// CHECK: linalg.matmul +// CHECK: tensor.insert_slice +// CHECK: scf.yield +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @subschedule_with_embedded_choice(%matmul: !transform.any_op {transform.readonly}, + %par_or_seq: !transform.param<i64> {transform.readonly}, + %tile_size: !transform.param<i64> {transform.readonly}) -> !transform.any_op { + %tiled_matmul = transform.tune.alternatives<"par_or_seq_tiling"> selected_region = %par_or_seq : !transform.param<i64> -> !transform.any_op { + %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + }, { + %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op) + transform.yield %contained_matmul : !transform.any_op + } + transform.yield %tiled_matmul : !transform.any_op + } + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %outer_par = transform.param.constant 1 -> !transform.param<i64> + %outer_tile_size = transform.param.constant 32 -> !transform.param<i64> + %inner_seq = transform.tune.knob<"inner_par_or_seq"> = 0 from options = [0, 1] -> !transform.param<i64> + %inner_tile_size = transform.param.constant 8 -> !transform.param<i64> + %tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%matmul, %outer_par, %outer_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op + %tiled_tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%tiled_matmul, %inner_seq, %inner_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: eeny_meeny_miny_moe +func.func private @eeny_meeny_miny_moe() + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_matmul = transform.tune.alternatives<"4way"> selected_region = 3 -> !transform.any_param + { // First alternative/region, with index = 0 + %out = transform.param.constant "eeny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Second alternative/region, with index = 1 + %out = transform.param.constant "meeny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Third alternative/region, with index = 2 + %out = transform.param.constant "miny" -> !transform.any_param + transform.yield %out : !transform.any_param + }, { // Fourth alternative/region, with index = 3 + %out = transform.param.constant "moe" -> !transform.any_param + transform.yield %out : !transform.any_param + } + transform.yield + } +}
\ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 03c6386..38392fd 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -282,15 +282,20 @@ gpu.module @test_distribution { // CHECK-LABEL: @store_scatter // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16> gpu.func @store_scatter(%dest : memref<256xf16>) { - // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16> - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex> - // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1> + // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16> + // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex> + // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1> // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> + // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>, + // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>} // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1> - %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16> - %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex> - %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1> - xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>} + %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16> + %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1> + xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, + layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, + layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, + l1_hint = #xegpu.cache_hint<cached>} : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1> gpu.return } diff --git a/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir b/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir new file mode 100644 index 0000000..0d559b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-cli-tile01.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --enable-var-scope + + +llvm.func @tile_trivial_loop(%baseptr: !llvm.ptr, %tc: i32, %ts: i32) -> () { + %literal_cli = omp.new_cli + omp.canonical_loop(%literal_cli) %iv : i32 in range(%tc) { + %ptr = llvm.getelementptr inbounds %baseptr[%iv] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %val = llvm.mlir.constant(42.0 : f32) : f32 + llvm.store %val, %ptr : f32, !llvm.ptr + omp.terminator + } + omp.tile <- (%literal_cli) sizes(%ts : i32) + llvm.return +} + + +// CHECK-LABEL: define void @tile_trivial_loop( +// CHECK-SAME: ptr %[[TMP0:.+]], i32 %[[TMP1:.+]], i32 %[[TMP2:.+]]) { +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER]]: +// CHECK-NEXT: %[[TMP4:.+]] = udiv i32 %[[TMP1:.+]], %[[TMP2:.+]] +// CHECK-NEXT: %[[TMP5:.+]] = urem i32 %[[TMP1:.+]], %[[TMP2:.+]] +// CHECK-NEXT: %[[TMP6:.+]] = icmp ne i32 %[[TMP5:.+]], 0 +// CHECK-NEXT: %[[TMP7:.+]] = zext i1 %[[TMP6:.+]] to i32 +// CHECK-NEXT: %[[OMP_FLOOR0_TRIPCOUNT:.+]] = add nuw i32 %[[TMP4:.+]], %[[TMP7:.+]] +// CHECK-NEXT: br label %[[OMP_FLOOR0_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_HEADER]]: +// CHECK-NEXT: %[[OMP_FLOOR0_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR0_PREHEADER:.+]] ], [ %[[OMP_FLOOR0_NEXT:.+]], %[[OMP_FLOOR0_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_FLOOR0_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_COND]]: +// CHECK-NEXT: %[[OMP_FLOOR0_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR0_IV:.+]], %[[OMP_FLOOR0_TRIPCOUNT:.+]] +// CHECK-NEXT: br i1 %[[OMP_FLOOR0_CMP:.+]], label %[[OMP_FLOOR0_BODY:.+]], label %[[OMP_FLOOR0_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_BODY]]: +// CHECK-NEXT: %[[TMP8:.+]] = icmp eq i32 %[[OMP_FLOOR0_IV:.+]], %[[TMP4:.+]] +// CHECK-NEXT: %[[TMP9:.+]] = select i1 %[[TMP8:.+]], i32 %[[TMP5:.+]], i32 %[[TMP2:.+]] +// CHECK-NEXT: br label %[[OMP_TILE0_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_HEADER]]: +// CHECK-NEXT: %[[OMP_TILE0_IV:.+]] = phi i32 [ 0, %[[OMP_TILE0_PREHEADER:.+]] ], [ %[[OMP_TILE0_NEXT:.+]], %[[OMP_TILE0_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_TILE0_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_COND]]: +// CHECK-NEXT: %[[OMP_TILE0_CMP:.+]] = icmp ult i32 %[[OMP_TILE0_IV:.+]], %[[TMP9:.+]] +// CHECK-NEXT: br i1 %[[OMP_TILE0_CMP:.+]], label %[[OMP_TILE0_BODY:.+]], label %[[OMP_TILE0_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_BODY]]: +// CHECK-NEXT: %[[TMP10:.+]] = mul nuw i32 %[[TMP2:.+]], %[[OMP_FLOOR0_IV:.+]] +// CHECK-NEXT: %[[TMP11:.+]] = add nuw i32 %[[TMP10:.+]], %[[OMP_TILE0_IV:.+]] +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_BODY]]: +// CHECK-NEXT: br label %[[OMP_LOOP_REGION:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_LOOP_REGION]]: +// CHECK-NEXT: %[[TMP12:.+]] = getelementptr inbounds float, ptr %[[TMP0:.+]], i32 %[[TMP11:.+]] +// CHECK-NEXT: store float 4.200000e+01, ptr %[[TMP12:.+]], align 4 +// CHECK-NEXT: br label %[[OMP_REGION_CONT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_REGION_CONT]]: +// CHECK-NEXT: br label %[[OMP_TILE0_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_INC]]: +// CHECK-NEXT: %[[OMP_TILE0_NEXT:.+]] = add nuw i32 %[[OMP_TILE0_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_EXIT]]: +// CHECK-NEXT: br label %[[OMP_TILE0_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_AFTER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_INC]]: +// CHECK-NEXT: %[[OMP_FLOOR0_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR0_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_EXIT]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_AFTER]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_AFTER]]: +// CHECK-NEXT: ret void +// CHECK-NEXT: } diff --git a/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir b/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir new file mode 100644 index 0000000..22c2973 --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-cli-tile02.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s --enable-var-scope + + +llvm.func @tile_2d_loop(%baseptr: !llvm.ptr, %tc1: i32, %tc2: i32, %ts1: i32, %ts2: i32) -> () { + %literal_outer = omp.new_cli + %literal_inner = omp.new_cli + omp.canonical_loop(%literal_outer) %iv1 : i32 in range(%tc1) { + omp.canonical_loop(%literal_inner) %iv2 : i32 in range(%tc2) { + %idx = llvm.add %iv1, %iv2 : i32 + %ptr = llvm.getelementptr inbounds %baseptr[%idx] : (!llvm.ptr, i32) -> !llvm.ptr, f32 + %val = llvm.mlir.constant(42.0 : f32) : f32 + llvm.store %val, %ptr : f32, !llvm.ptr + omp.terminator + } + omp.terminator + } + omp.tile <- (%literal_outer, %literal_inner) sizes(%ts1, %ts2 : i32,i32) + llvm.return +} + + +// CHECK-LABEL: define void @tile_2d_loop( +// CHECK-SAME: ptr %[[TMP0:.+]], i32 %[[TMP1:.+]], i32 %[[TMP2:.+]], i32 %[[TMP3:.+]], i32 %[[TMP4:.+]]) { +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER]]: +// CHECK-NEXT: %[[TMP6:.+]] = udiv i32 %[[TMP1:.+]], %[[TMP3:.+]] +// CHECK-NEXT: %[[TMP7:.+]] = urem i32 %[[TMP1:.+]], %[[TMP3:.+]] +// CHECK-NEXT: %[[TMP8:.+]] = icmp ne i32 %[[TMP7:.+]], 0 +// CHECK-NEXT: %[[TMP9:.+]] = zext i1 %[[TMP8:.+]] to i32 +// CHECK-NEXT: %[[OMP_FLOOR0_TRIPCOUNT:.+]] = add nuw i32 %[[TMP6:.+]], %[[TMP9:.+]] +// CHECK-NEXT: %[[TMP10:.+]] = udiv i32 %[[TMP2:.+]], %[[TMP4:.+]] +// CHECK-NEXT: %[[TMP11:.+]] = urem i32 %[[TMP2:.+]], %[[TMP4:.+]] +// CHECK-NEXT: %[[TMP12:.+]] = icmp ne i32 %[[TMP11:.+]], 0 +// CHECK-NEXT: %[[TMP13:.+]] = zext i1 %[[TMP12:.+]] to i32 +// CHECK-NEXT: %[[OMP_FLOOR1_TRIPCOUNT:.+]] = add nuw i32 %[[TMP10:.+]], %[[TMP13:.+]] +// CHECK-NEXT: br label %[[OMP_FLOOR0_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_HEADER:.+]]: +// CHECK-NEXT: %[[OMP_OMP_LOOP_IV:.+]] = phi i32 [ %[[OMP_OMP_LOOP_NEXT:.+]], %[[OMP_OMP_LOOP_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_COND]]: +// CHECK-NEXT: %[[OMP_OMP_LOOP_CMP:.+]] = icmp ult i32 %[[TMP19:.+]], %[[TMP1:.+]] +// CHECK-NEXT: br i1 %[[OMP_OMP_LOOP_CMP:.+]], label %[[OMP_OMP_LOOP_BODY:.+]], label %[[OMP_OMP_LOOP_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_BODY]]: +// CHECK-NEXT: br label %[[OMP_LOOP_REGION:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_LOOP_REGION]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_PREHEADER1:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_PREHEADER1]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY4:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_HEADER]]: +// CHECK-NEXT: %[[OMP_FLOOR0_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR0_PREHEADER:.+]] ], [ %[[OMP_FLOOR0_NEXT:.+]], %[[OMP_FLOOR0_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_FLOOR0_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_COND]]: +// CHECK-NEXT: %[[OMP_FLOOR0_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR0_IV:.+]], %[[OMP_FLOOR0_TRIPCOUNT:.+]] +// CHECK-NEXT: br i1 %[[OMP_FLOOR0_CMP:.+]], label %[[OMP_FLOOR0_BODY:.+]], label %[[OMP_FLOOR0_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_BODY]]: +// CHECK-NEXT: br label %[[OMP_FLOOR1_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR1_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_HEADER]]: +// CHECK-NEXT: %[[OMP_FLOOR1_IV:.+]] = phi i32 [ 0, %[[OMP_FLOOR1_PREHEADER:.+]] ], [ %[[OMP_FLOOR1_NEXT:.+]], %[[OMP_FLOOR1_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_FLOOR1_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_COND]]: +// CHECK-NEXT: %[[OMP_FLOOR1_CMP:.+]] = icmp ult i32 %[[OMP_FLOOR1_IV:.+]], %[[OMP_FLOOR1_TRIPCOUNT:.+]] +// CHECK-NEXT: br i1 %[[OMP_FLOOR1_CMP:.+]], label %[[OMP_FLOOR1_BODY:.+]], label %[[OMP_FLOOR1_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_BODY]]: +// CHECK-NEXT: %[[TMP14:.+]] = icmp eq i32 %[[OMP_FLOOR0_IV:.+]], %[[TMP6:.+]] +// CHECK-NEXT: %[[TMP15:.+]] = select i1 %[[TMP14:.+]], i32 %[[TMP7:.+]], i32 %[[TMP3:.+]] +// CHECK-NEXT: %[[TMP16:.+]] = icmp eq i32 %[[OMP_FLOOR1_IV:.+]], %[[TMP10:.+]] +// CHECK-NEXT: %[[TMP17:.+]] = select i1 %[[TMP16:.+]], i32 %[[TMP11:.+]], i32 %[[TMP4:.+]] +// CHECK-NEXT: br label %[[OMP_TILE0_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_HEADER]]: +// CHECK-NEXT: %[[OMP_TILE0_IV:.+]] = phi i32 [ 0, %[[OMP_TILE0_PREHEADER:.+]] ], [ %[[OMP_TILE0_NEXT:.+]], %[[OMP_TILE0_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_TILE0_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_COND]]: +// CHECK-NEXT: %[[OMP_TILE0_CMP:.+]] = icmp ult i32 %[[OMP_TILE0_IV:.+]], %[[TMP15:.+]] +// CHECK-NEXT: br i1 %[[OMP_TILE0_CMP:.+]], label %[[OMP_TILE0_BODY:.+]], label %[[OMP_TILE0_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_BODY]]: +// CHECK-NEXT: br label %[[OMP_TILE1_PREHEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_PREHEADER]]: +// CHECK-NEXT: br label %[[OMP_TILE1_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_HEADER]]: +// CHECK-NEXT: %[[OMP_TILE1_IV:.+]] = phi i32 [ 0, %[[OMP_TILE1_PREHEADER:.+]] ], [ %[[OMP_TILE1_NEXT:.+]], %[[OMP_TILE1_INC:.+]] ] +// CHECK-NEXT: br label %[[OMP_TILE1_COND:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_COND]]: +// CHECK-NEXT: %[[OMP_TILE1_CMP:.+]] = icmp ult i32 %[[OMP_TILE1_IV:.+]], %[[TMP17:.+]] +// CHECK-NEXT: br i1 %[[OMP_TILE1_CMP:.+]], label %[[OMP_TILE1_BODY:.+]], label %[[OMP_TILE1_EXIT:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_BODY]]: +// CHECK-NEXT: %[[TMP18:.+]] = mul nuw i32 %[[TMP3:.+]], %[[OMP_FLOOR0_IV:.+]] +// CHECK-NEXT: %[[TMP19:.+]] = add nuw i32 %[[TMP18:.+]], %[[OMP_TILE0_IV:.+]] +// CHECK-NEXT: %[[TMP20:.+]] = mul nuw i32 %[[TMP4:.+]], %[[OMP_FLOOR1_IV:.+]] +// CHECK-NEXT: %[[TMP21:.+]] = add nuw i32 %[[TMP20:.+]], %[[OMP_TILE1_IV:.+]] +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_BODY:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_BODY4]]: +// CHECK-NEXT: br label %[[OMP_LOOP_REGION12:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_LOOP_REGION12]]: +// CHECK-NEXT: %[[TMP22:.+]] = add i32 %[[TMP19:.+]], %[[TMP21:.+]] +// CHECK-NEXT: %[[TMP23:.+]] = getelementptr inbounds float, ptr %[[TMP0:.+]], i32 %[[TMP22:.+]] +// CHECK-NEXT: store float 4.200000e+01, ptr %[[TMP23:.+]], align 4 +// CHECK-NEXT: br label %[[OMP_REGION_CONT11:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_REGION_CONT11]]: +// CHECK-NEXT: br label %[[OMP_TILE1_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_INC]]: +// CHECK-NEXT: %[[OMP_TILE1_NEXT:.+]] = add nuw i32 %[[OMP_TILE1_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_TILE1_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_EXIT]]: +// CHECK-NEXT: br label %[[OMP_TILE1_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE1_AFTER]]: +// CHECK-NEXT: br label %[[OMP_TILE0_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_INC]]: +// CHECK-NEXT: %[[OMP_TILE0_NEXT:.+]] = add nuw i32 %[[OMP_TILE0_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_TILE0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_EXIT]]: +// CHECK-NEXT: br label %[[OMP_TILE0_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_TILE0_AFTER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR1_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_INC]]: +// CHECK-NEXT: %[[OMP_FLOOR1_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR1_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_FLOOR1_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_EXIT]]: +// CHECK-NEXT: br label %[[OMP_FLOOR1_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR1_AFTER]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_INC]]: +// CHECK-NEXT: %[[OMP_FLOOR0_NEXT:.+]] = add nuw i32 %[[OMP_FLOOR0_IV:.+]], 1 +// CHECK-NEXT: br label %[[OMP_FLOOR0_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_EXIT]]: +// CHECK-NEXT: br label %[[OMP_FLOOR0_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_FLOOR0_AFTER]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_REGION_CONT:.+]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_INC:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_INC]]: +// CHECK-NEXT: %[[OMP_OMP_LOOP_NEXT:.+]] = add nuw i32 %[[TMP19:.+]], 1 +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_HEADER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_EXIT]]: +// CHECK-NEXT: br label %[[OMP_OMP_LOOP_AFTER:.+]] +// CHECK-EMPTY: +// CHECK-NEXT: [[OMP_OMP_LOOP_AFTER]]: +// CHECK-NEXT: ret void +// CHECK-NEXT: } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index e043a8c..00ee6b7 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1340,6 +1340,34 @@ llvm.func @rocdl.cvt.scale.pk8(%i32: i32, %v2xi32: vector<2xi32>, %scale: i32) { llvm.return } +// CHECK-LABEL: rocdl.cvt.scalef32.pk8 +// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>, %v8xbf16: vector<8xbf16>, %scale: f32) { + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %0 = rocdl.cvt.scalef32.pk8.fp8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %1 = rocdl.cvt.scalef32.pk8.bf8.f32 %v8xf32, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f32(<8 x float> %[[V8F32]], float %[[SCALE]]) + %2 = rocdl.cvt.scalef32.pk8.fp4.f32 %v8xf32, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %3 = rocdl.cvt.scalef32.pk8.fp8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %4 = rocdl.cvt.scalef32.pk8.bf8.f16 %v8xf16, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.f16(<8 x half> %[[V8F16]], float %[[SCALE]]) + %5 = rocdl.cvt.scalef32.pk8.fp4.f16 %v8xf16, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %6 = rocdl.cvt.scalef32.pk8.fp8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %7 = rocdl.cvt.scalef32.pk8.bf8.bf16 %v8xbf16, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], float %[[SCALE]]) + %8 = rocdl.cvt.scalef32.pk8.fp4.bf16 %v8xbf16, %scale : i32 + + llvm.return +} + // CHECK-LABEL: @rocdl.cvt.scale.pk16 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt index 103bc94..7d32577 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt @@ -12,5 +12,7 @@ add_mlir_library(MLIRTestIRDLToCppDialect mlir_target_link_libraries(MLIRTestIRDLToCppDialect PUBLIC MLIRIR MLIRPass + MLIRSCFDialect MLIRTransforms + MLIRTestDialect ) diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp index 9550e4c..421db7e 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp @@ -13,6 +13,7 @@ // #include "mlir/IR/Dialect.h" #include "mlir/IR/Region.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -54,16 +55,34 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> { } }; +struct TestRegionConversion + : public OpConversionPattern<test_irdl_to_cpp::ConditionalOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mlir::test_irdl_to_cpp::ConditionalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Just exercising the C++ API even though these are not enforced in the + // dialect definition + assert(op.getThen().getBlocks().size() == 1); + assert(adaptor.getElse().getBlocks().size() == 1); + auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), op.getInput()); + rewriter.replaceOp(op, ifOp); + return success(); + } +}; + struct ConvertTestDialectToSomethingPass : PassWrapper<ConvertTestDialectToSomethingPass, OperationPass<ModuleOp>> { void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add<TestOpConversion>(ctx); + patterns.add<TestOpConversion, TestRegionConversion>(ctx); ConversionTarget target(getContext()); - target.addIllegalOp<test_irdl_to_cpp::BeefOp>(); - target.addLegalOp<test_irdl_to_cpp::BarOp>(); - target.addLegalOp<test_irdl_to_cpp::HashOp>(); + target.addIllegalOp<test_irdl_to_cpp::BeefOp, + test_irdl_to_cpp::ConditionalOp>(); + target.addLegalOp<test_irdl_to_cpp::BarOp, test_irdl_to_cpp::HashOp, + scf::IfOp, scf::YieldOp>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -73,6 +92,10 @@ struct ConvertTestDialectToSomethingPass StringRef getDescription() const final { return "Checks the convertability of an irdl dialect"; } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<scf::SCFDialect>(); + } }; void registerIrdlTestDialect(mlir::DialectRegistry ®istry) { diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir index f6233ee..1915324 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir @@ -1,15 +1,29 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-irdl-conversion-check)" | FileCheck %s // CHECK-LABEL: module { module { - // CHECK: func.func @test() { + // CHECK: func.func @test(%[[test_arg:[^ ]*]]: i1) { // CHECK: %[[v0:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32 // CHECK: %[[v1:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32 // CHECK: %[[v2:[^ ]*]] = "test_irdl_to_cpp.hash"(%[[v0]], %[[v0]]) : (i32, i32) -> i32 + // CHECK: scf.if %[[test_arg]] // CHECK: return // CHECK: } - func.func @test() { + func.func @test(%test_arg: i1) { %0 = "test_irdl_to_cpp.bar"() : () -> i32 %1 = "test_irdl_to_cpp.beef"(%0, %0) : (i32, i32) -> i32 + "test_irdl_to_cpp.conditional"(%test_arg) ({ + ^cond(%test: i1): + %3 = "test_irdl_to_cpp.bar"() : () -> i32 + "test.terminator"() : ()->() + }, { + ^then(%what: i1, %ever: i32): + %4 = "test_irdl_to_cpp.bar"() : () -> i32 + "test.terminator"() : ()->() + }, { + ^else(): + %5 = "test_irdl_to_cpp.bar"() : () -> i32 + "test.terminator"() : ()->() + }) : (i1) -> () return } diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir index 42e713e..85fb8cb 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir @@ -2,7 +2,7 @@ // CHECK: class TestIrdlToCpp irdl.dialect @test_irdl_to_cpp { - + // CHECK: class FooType irdl.type @foo @@ -32,4 +32,53 @@ irdl.dialect @test_irdl_to_cpp { irdl.operands(lhs: %0, rhs: %0) irdl.results(res: %0) } + + // CHECK: ConditionalOp declarations + // CHECK: ConditionalOpGenericAdaptorBase + // CHECK: ::mlir::Region &getCond() { return *getRegions()[0]; } + // CHECK: ::mlir::Region &getThen() { return *getRegions()[1]; } + // CHECK: ::mlir::Region &getElse() { return *getRegions()[2]; } + // + // CHECK: class ConditionalOp : public ::mlir::Op<ConditionalOp, ::mlir::OpTrait::NRegions<3>::Impl, ::mlir::OpTrait::OpInvariants> + // CHECK: ::mlir::Region &getCond() { return (*this)->getRegion(0); } + // CHECK: ::mlir::Region &getThen() { return (*this)->getRegion(1); } + // CHECK: ::mlir::Region &getElse() { return (*this)->getRegion(2); } + + // CHECK: ConditionalOp definitions + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond + // CHECK: if (!(region.getNumArguments() == 1)) { + // CHECK: failed to verify constraint: region with 1 entry block argument(s) + + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then + // CHECK: if (!(true)) { + + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else + // CHECK: if (!(region.getNumArguments() == 0)) { + // CHECK: failed to verify constraint: region with 0 entry block argument(s) + + // CHECK: ConditionalOp::build + // CHECK: for (unsigned i = 0; i != 3; ++i) + // CHECK-NEXT: (void)odsState.addRegion(); + + // CHECK: ConditionalOp::verifyInvariantsImpl + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond + // CHECK: failure + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then + // CHECK: failure + // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else + // CHECK: failure + // CHECK: success + irdl.operation @conditional { + %r0 = irdl.region // Unconstrained region + %r1 = irdl.region() // Region with no entry block arguments + + // TODO(#161018): support irdl.is in irdl-to-cpp + // %v0 = irdl.is i1 // Type constraint: i1 (boolean) + %v0 = irdl.any + %r2 = irdl.region(%v0) // Region with one i1 entry block argument + irdl.regions(cond: %r2, then: %r0, else: %r1) + + %0 = irdl.any + irdl.operands(input: %0) + } } diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir index 403b492..cc27456 100644 --- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir +++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir @@ -7,7 +7,7 @@ irdl.dialect @test_irdl_to_cpp { irdl.results(res: %1) } } -// ----- +// ----- irdl.dialect @test_irdl_to_cpp { irdl.operation @operands_no_any_of { @@ -42,7 +42,7 @@ irdl.dialect @test_irdl_to_cpp { irdl.dialect @test_irdl_to_cpp { irdl.type @ty { - %0 = irdl.any + %0 = irdl.any // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}} irdl.parameters(ty: %0) } @@ -51,29 +51,8 @@ irdl.dialect @test_irdl_to_cpp { // ----- irdl.dialect @test_irdl_to_cpp { - irdl.operation @test_op { - // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.region operation}} - %0 = irdl.region() - irdl.regions(reg: %0) - } - -} - -// ----- - -irdl.dialect @test_irdl_to_cpp { - irdl.operation @test_op { - // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}} - irdl.regions() - } - -} - -// ----- - -irdl.dialect @test_irdl_to_cpp { irdl.type @test_derived { // expected-error@+1 {{IRDL C++ translation does not yet support translation of irdl.base operation}} %0 = irdl.base "!builtin.integer" - } + } } diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 094ef0a..e51cac4 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -173,8 +173,6 @@ struct TestXeGPUUnrollingPatterns #undef DEBUG_TYPE #define DEBUG_TYPE "test-xegpu-layout-interface" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") // Test pattern for distributing vector::StepOp from workgroup to subgroup. // Validates DistributeLayoutAttr interfaces for offset computation diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td index 2f29543..0a022ad 100644 --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -307,7 +307,7 @@ def DirectiveTypeZOperandInvalidI : TestFormat_Op<[{ def LiteralInvalidA : TestFormat_Op<[{ `a:` }]>; -// CHECK: error: expected valid literal but got '1': single character literal must be a letter or one of '_:,=<>()[]{}?+*' +// CHECK: error: expected valid literal but got '1': single character literal must be a letter or one of '_:,=<>()[]{}?+-*' def LiteralInvalidB : TestFormat_Op<[{ `1` }]>; diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td index 1541cd0..1ac2311 100644 --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -123,7 +123,7 @@ def DirectiveTypeValid : TestFormat_Op<[{ // CHECK-NOT: error def LiteralValid : TestFormat_Op<[{ - `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._` + `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `-` `*` ` ` `` `->` `\n` `abc$._` attr-dict }]>; diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py index dfb9359..eb2a083 100644 --- a/mlir/test/python/dialects/transform_tune_ext.py +++ b/mlir/test/python/dialects/transform_tune_ext.py @@ -1,21 +1,21 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.ir import * +from mlir import ir from mlir.dialects import transform from mlir.dialects.transform import tune, debug def run(f): - print("\nTEST:", f.__name__) - with Context(), Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): + print("\n// TEST:", f.__name__) + with ir.Context(), ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get(), ) - with InsertionPoint(sequence.body): + with ir.InsertionPoint(sequence.body): f(sequence.bodyTarget) transform.YieldOp() print(module) @@ -29,10 +29,10 @@ def testKnobOp(target): # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param heads_or_tails = tune.KnobOp( - result=any_param, name=StringAttr.get("coin"), options=[True, False] + result=any_param, name=ir.StringAttr.get("coin"), options=[True, False] ) # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param - tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()]) + tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()]) # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32]) # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param @@ -45,7 +45,10 @@ def testKnobOp(target): heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True) # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param tune.KnobOp( - any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog" + any_param, + name="animal", + options=["cat", "dog", ir.UnitAttr.get()], + selected="dog", ) # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8) @@ -57,16 +60,90 @@ def testKnobOp(target): # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified. - i64 = IntegerType.get_signless(64) + i64 = ir.IntegerType.get_signless(64) tune.knob( any_param, "range_as_a_dict", - DictAttr.get( + ir.DictAttr.get( { - "start": IntegerAttr.get(i64, 2), - "stop": IntegerAttr.get(i64, 16), - "step": IntegerAttr.get(i64, 2), + "start": ir.IntegerAttr.get(i64, 2), + "stop": ir.IntegerAttr.get(i64, 16), + "step": ir.IntegerAttr.get(i64, 2), } ), selected=4, ) + + +# CHECK-LABEL: TEST: testAlternativesOp +@run +def testAlternativesOp(target): + any_param = transform.AnyParamType.get() + + # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param { + left_or_right = tune.AlternativesOp( + [transform.AnyParamType.get()], "left_or_right", 2 + ) + idx_for_left, idx_for_right = 0, 1 + with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + outcome_of_left_or_right_decision = left_or_right.results[0] + + # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param { + fork_in_the_road = tune.AlternativesOp( + [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0 + ) + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]): + # CHECK: %[[C0:.*]] = transform.param.constant 0 + i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0) + c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0) + # CHECK: transform.yield %[[C0]] + transform.yield_(c0) + # CHECK-NEXT: }, { + with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]): + # CHECK: %[[C1:.*]] = transform.param.constant 1 + i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1) + c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1) + # CHECK: transform.yield %[[C1]] + transform.yield_(c1) + # CHECK-NEXT: } + + # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param { + left_or_right_as_before = tune.AlternativesOp( + [], + "left_or_right_as_before", + 2, + selected_region=outcome_of_left_or_right_decision, + ) + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_left].blocks[0] + ): + # CHECK: transform.param.constant 1337 + i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337) + c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c1337) + transform.yield_([]) + # CHECK-NEXT: }, { + with ir.InsertionPoint( + left_or_right_as_before.alternatives[idx_for_right].blocks[0] + ): + # CHECK: transform.param.constant 42 + i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) + c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42) + # CHECK: transform.debug.emit_param_as_remark + debug.emit_param_as_remark(c42) + transform.yield_([]) + # CHECK-NEXT: } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 4a3625c..cb4cfc8c 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -696,6 +696,7 @@ def testOperationPrint(): # CHECK: resource1: "0x08 module.operation.print(large_elements_limit=2) + # CHECK-LABEL: TEST: testKnownOpView @run def testKnownOpView(): @@ -969,6 +970,13 @@ def testOperationLoc(): assert op.location == loc assert op.operation.location == loc + another_loc = Location.name("another_loc") + op.location = another_loc + assert op.location == another_loc + assert op.operation.location == another_loc + # CHECK: loc("another_loc") + print(op.location) + # CHECK-LABEL: TEST: testModuleMerge @run diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp index fd8ae7e..795766f 100644 --- a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp +++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp @@ -35,7 +35,7 @@ namespace mlir { using OperationDefinition = AsmParserState::OperationDefinition; /// Return the source code associated with the OperationDefinition. -SMRange getOpRange(const OperationDefinition &op) { +static SMRange getOpRange(const OperationDefinition &op) { const char *startOp = op.scopeLoc.Start.getPointer(); const char *endOp = op.scopeLoc.End.getPointer(); @@ -187,15 +187,15 @@ std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename, } /// Return the source code associated with the operation name. -SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; } +static SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; } /// Return whether the operation was printed using generic syntax in original /// buffer. -bool isGeneric(const OperationDefinition &op) { +static bool isGeneric(const OperationDefinition &op) { return op.loc.Start.getPointer()[0] == '"'; } -inline int asMainReturnCode(LogicalResult r) { +static inline int asMainReturnCode(LogicalResult r) { return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE; } @@ -293,7 +293,7 @@ static llvm::cl::opt<std::string> simpleRenameReplace{ llvm::cl::cat(clSimpleRenameCategory)}; // Rewriter that does simple renames. -LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) { +static LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) { StringRef opName = simpleRenameOpName; StringRef match = simpleRenameMatch; StringRef replace = simpleRenameReplace; @@ -317,7 +317,7 @@ static mlir::RewriterRegistration rewriteSimpleRename("simple-rename", simpleRename); // Rewriter that insert range markers. -LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) { +static LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) { for (const auto &it : rewriteState.getOpDefs()) { auto [startOp, endOp] = getOpRange(it); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index a1899a8..8dd9713 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -403,6 +403,7 @@ void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, .Case("]", "RSquare") .Case("?", "Question") .Case("+", "Plus") + .Case("-", "Minus") .Case("*", "Star") .Case("...", "Ellipsis") << "()"; diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp index 4dfdde2..04d3ed1 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -518,7 +518,7 @@ bool mlir::tblgen::isValidLiteral(StringRef value, // If there is only one character, this must either be punctuation or a // single character bare identifier. if (value.size() == 1) { - StringRef bare = "_:,=<>()[]{}?+*"; + StringRef bare = "_:,=<>()[]{}?+-*"; if (isalpha(front) || bare.contains(front)) return true; if (emitError) diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 0d113b3..ccf21d1 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -852,6 +852,7 @@ static void genLiteralParser(StringRef value, MethodBody &body) { .Case("]", "RSquare()") .Case("?", "Question()") .Case("+", "Plus()") + .Case("-", "Minus()") .Case("*", "Star()") .Case("...", "Ellipsis()"); } diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp index 27f2fa0..ac01d49 100644 --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -11,7 +11,8 @@ #include "gmock/gmock.h" -std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0); +static std::unique_ptr<mlir::Pass> +createTestPassWithCustomConstructor(int v = 0); #define GEN_PASS_DECL #define GEN_PASS_REGISTRATION |