diff options
Diffstat (limited to 'mlir/include')
29 files changed, 543 insertions, 533 deletions
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 364a70c..b595b6a3 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -8,6 +8,11 @@ #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H +constexpr const char *alignedAllocFunctionName = "aligned_alloc"; +constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *cppStandardLibraryHeader = "cstdlib"; +constexpr const char *cStandardLibraryHeader = "stdlib.h"; + namespace mlir { class DialectRegistry; class RewritePatternSet; diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index eb18160..6e1baaf 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -841,9 +853,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { // MemRefToEmitC //===----------------------------------------------------------------------===// -def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { +def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> { let summary = "Convert MemRef dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "lowerToCpp", "lower-to-cpp", "bool", + /*default=*/"false", + /*description=*/"Target C++ (true) instead of C (false)">]; } //===----------------------------------------------------------------------===// @@ -1163,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index e81db32..06fb851 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 8988df6..d055bb4 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -92,6 +92,7 @@ class ArmSVE_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 2cf801d..09700f8 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -14,7 +14,7 @@ struct LogicalResult; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; namespace bufferization { struct BufferizationStatistics; @@ -23,12 +23,13 @@ struct OneShotBufferizationOptions; class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in -/// `state`. +/// `state`. This operates on any `SymbolTable` op. llvm::LogicalResult -analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, +analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics = nullptr); -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`. +/// This operates on any `SymbolTable` op. /// /// Note: This function does not run One-Shot Analysis. No buffer copies are /// inserted except two cases: @@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter` /// is not empty. The FuncOps it contains were not analyzed. Buffer copies /// will be inserted only to these FuncOps. -llvm::LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, - BufferizationStatistics *statistics = nullptr); +llvm::LogicalResult bufferizeModuleOp( + Operation *moduleOp, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); -/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. -void removeBufferizationAttributesInModule(ModuleOp moduleOp); +/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable +/// op. +void removeBufferizationAttributesInModule(Operation *moduleOp); -/// Run One-Shot Module Bufferization on the given module. Performs a simple -/// function call analysis to determine which function arguments are +/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a +/// simple function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot /// Bufferize. llvm::LogicalResult runOneShotModuleBufferize( - ModuleOp moduleOp, + Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics = nullptr); diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 1dbaf5d..2ed7d38 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1368,12 +1368,14 @@ def GPU_ShuffleOp : GPU_Op< def GPU_RotateOp : GPU_Op< "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>, - Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, + ConfinedAttr<I32Attr, [IntMinValue<0>]>:$offset, + ConfinedAttr<I32Attr, [IntPowerOf2]>:$width)>, Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> { let summary = "Rotate values within a subgroup."; let description = [{ The "rotate" op moves values across lanes in a subgroup (a.k.a., local - invocations) within the same subgroup. The `width` argument specifies the + invocations) within the same subgroup. The `width` attribute specifies the number of lanes that participate in the rotation, and must be uniform across all participating lanes. Further, the first `width` lanes of the subgroup must be active. @@ -1394,9 +1396,7 @@ def GPU_RotateOp : GPU_Op< example: ```mlir - %offset = arith.constant 1 : i32 - %width = arith.constant 16 : i32 - %1, %2 = gpu.rotate %0, %offset, %width : f32 + %1, %2 = gpu.rotate %0, 1, 16 : f32 ``` For lane `k`, returns the value from lane `(k + cst1) % width`. @@ -1406,11 +1406,6 @@ def GPU_RotateOp : GPU_Op< $value `,` $offset `,` $width attr-dict `:` type($value) }]; - let builders = [ - // Helper function that creates a rotate with constant offset/width. - OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)> - ]; - let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 8c6f1ee..d38298f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -140,8 +140,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">; def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3], - /*immArgAttrNames=*/["rw", "hint", "cache"] + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"] > { let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache); } @@ -200,13 +200,13 @@ class LLVM_MemcpyIntrOpBase<string name> : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$len, "bool":$isVolatile), [{ @@ -217,7 +217,8 @@ class LLVM_MemcpyIntrOpBase<string name> : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -231,13 +232,13 @@ def LLVM_MemcpyInlineOp : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -248,7 +249,8 @@ def LLVM_MemcpyInlineOp : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -258,12 +260,12 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "Value":$len, "bool":$isVolatile), [{ @@ -274,7 +276,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -284,12 +287,12 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -300,7 +303,8 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -349,8 +353,8 @@ def LLVM_PtrMaskOp class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; } @@ -370,8 +374,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> { let arguments = (ins LLVM_DefaultPointer:$start, I64Attr:$size, LLVM_AnyPointer:$ptr); @@ -542,9 +546,10 @@ def LLVM_AssumeOp : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/1> { dag args = (ins I1:$cond); - let arguments = !con(args, opBundleArgs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $cond @@ -1126,8 +1131,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">; def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["failureKind"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> { let arguments = (ins I8Attr:$failureKind); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index e845ea9f..a8d7cf2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" //===----------------------------------------------------------------------===// // LLVM dialect type constraints. @@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> : // intrinsic and "enumName" contains the name of the intrinsic as appears in // `llvm::Intrinsic` enum; one usually wants these to be related. Additionally, // the base class also defines the "mlirBuilder" field to support the inverse -// translation starting from an LLVM IR intrinsic. The "requiresAccessGroup", -// "requiresAliasAnalysis", and "requiresFastmath" flags specify which -// interfaces the intrinsic implements. If the corresponding flags are set, the -// "aliasAttrs" list contains the arguments required by the access group and -// alias analysis interfaces. Derived intrinsics should append the "aliasAttrs" -// to their argument list if they set one of the flags. LLVM `immargs` can be -// represented as MLIR attributes by providing both the `immArgPositions` and -// `immArgAttrNames` lists. These two lists should have equal length, with -// `immArgPositions` containing the argument positions on the LLVM IR attribute -// that are `immargs`, and `immArgAttrNames` mapping these to corresponding -// MLIR attributes. +// translation starting from an LLVM IR intrinsic. +// +// The flags "requiresAccessGroup", "requiresAliasAnalysis", +// "requiresFastmath", and "requiresArgAndResultAttrs" indicate which +// interfaces the intrinsic implements. When a flag is set, the "baseArgs" +// list includes the arguments required by the corresponding interface. +// Derived intrinsics must append "baseArgs" to their argument list if they +// enable any of these flags. +// +// LLVM `immargs` can be represented as MLIR attributes by providing both +// the `immArgPositions` and `immArgAttrNames` lists. These two lists should +// have equal length, with `immArgPositions` containing the argument +// positions on the LLVM IR attribute that are `immargs`, and +// `immArgAttrNames` mapping these to corresponding MLIR attributes. class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, - bit requiresFastmath = 0, bit requiresOpBundles = 0, + bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0, + bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_OpBase<dialect, opName, !listconcat( @@ -311,10 +316,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, [DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []), !if(!gt(requiresFastmath, 0), [DeclareOpInterfaceMethods<FastmathFlagsInterface>], []), + !if(!gt(requiresArgAndResultAttrs, 0), + [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []), traits)>, LLVM_MemOpPatterns, Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> { - dag aliasAttrs = !con( + dag baseArgs = !con( !if(!gt(requiresAccessGroup, 0), (ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups), (ins )), @@ -322,13 +329,17 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, (ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes, OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes, OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa), + (ins )), + !if(!gt(requiresArgAndResultAttrs, 0), + (ins OptionalAttr<DictArrayAttr>:$arg_attrs, + OptionalAttr<DictArrayAttr>:$res_attrs), + (ins )), + !if(!gt(requiresOpBundles, 0), + (ins VariadicOfVariadic<LLVM_Type, + "op_bundle_sizes">:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + OptionalAttr<ArrayAttr>:$op_bundle_tags), (ins ))); - dag opBundleArgs = !if(!gt(requiresOpBundles, 0), - (ins VariadicOfVariadic<LLVM_Type, - "op_bundle_sizes">:$op_bundle_operands, - DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr<ArrayAttr>:$op_bundle_tags), - (ins )); string llvmEnumName = enumName; string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}"; string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}"; @@ -342,23 +353,35 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{); (void) inst; }]; + string baseLlvmBuilderArgAndResultAttrs = [{ + if (failed(moduleTranslation.convertArgAndResultAttrs( + op, + inst, + }] # immArgPositionsCpp # [{))) { + return failure(); + } + }]; string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", ""); - let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "") - # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "") - # baseLlvmBuilderCoda; + let llvmBuilder = baseLlvmBuilder + # !if(!gt(requiresAccessGroup, 0), + setAccessGroupsMetadataCode, "") + # !if(!gt(requiresAliasAnalysis, 0), + setAliasAnalysisMetadataCode, "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseLlvmBuilderArgAndResultAttrs, "") + # baseLlvmBuilderCoda; string baseMlirBuilder = [{ SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, - llvmOpBundles, - }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, - }] # immArgPositionsCpp # [{, - }] # immArgAttrNamesCpp # [{, - mlirOperands, - mlirAttrs)) - ) { + llvmOperands, + llvmOpBundles, + }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, + }] # immArgPositionsCpp # [{, + }] # immArgAttrNamesCpp # [{, + mlirOperands, + mlirAttrs))) { return failure(); } SmallVector<Type> resultTypes = @@ -366,9 +389,16 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, auto op = $_qualCppClassName::create($_builder, $_location, resultTypes, mlirOperands, mlirAttrs); }]; + string baseMlirBuilderArgAndResultAttrs = [{ + moduleImport.convertArgAndResultAttrs( + inst, op, }] # immArgPositionsCpp # [{); + }]; string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); - let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0), + let mlirBuilder = baseMlirBuilder + # !if(!gt(requiresFastmath, 0), "moduleImport.setFastmathFlagsAttr(inst, op);", "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseMlirBuilderArgAndResultAttrs, "") # baseMlirBuilderCoda; // Code for handling a `range` attribute that holds the constant range of the @@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, bit requiresFastmath = 0, - bit requiresOpBundles = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, requiresAliasAnalysis, - requiresFastmath, requiresOpBundles, immArgPositions, - immArgAttrNames>; + requiresFastmath, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [], list<Trait> traits = [], bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0, requiresAccessGroup, requiresAliasAnalysis, - /*requiresFastMath=*/0, requiresOpBundles, immArgPositions, - immArgAttrNames>; + /*requiresFastMath=*/0, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning one result. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is @@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning two results. Places the @@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; def LLVM_OneResultOpBuilder : diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 51004f5..3f27f6d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2405,7 +2405,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", - [AttrSizedOperandSegments, + [ArgAndResultAttrsOpInterface, + AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> { let summary = "Call to an LLVM intrinsic function."; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 04a0b58..a2354e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults, LLVM_IntrOpBase<ROCDL_Dialect, mnemonic, "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, - requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>; + requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>; // Subclass to save typing and ease readibility when there aren't overloaded // operands or memory accesses. @@ -482,7 +482,7 @@ def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> : ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> { dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -507,7 +507,7 @@ def ROCDL_LoadToLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict `:` type($globalPtr) @@ -526,7 +526,7 @@ def ROCDL_GlobalLoadLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict @@ -561,7 +561,7 @@ def ROCDL_RawPtrBufferLoadOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -579,7 +579,7 @@ def ROCDL_RawPtrBufferLoadLdsOp : I32:$soffset, I32:$offset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -595,7 +595,7 @@ def ROCDL_RawPtrBufferStoreOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -614,7 +614,7 @@ def ROCDL_RawPtrBufferAtomicCmpSwap : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -630,7 +630,7 @@ class ROCDL_RawPtrBufferAtomicNoRet<string op> : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8d45c40..61ce23f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac iteration domain induces a padding of the operands that is consistent across the op semantics and, unlike for simple elementwise ops, may not be trivially deducible or specifiable on operands only (e.g. convolutions). + Currently, only a limited set of projected permutation maps are supported. The specification of `padding_sizes` follows that of `tile_sizes` during tiling: the value "0" on a particular iterator encode "no padding". Like in diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e625eef..d4ffe0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// affine.apply operations. /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and /// provides a gentle portability path for Linalg-like ops with affine maps. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permuation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 96b9adc..e1e99c3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr< let printBitEnumPrimaryGroups = 1; } +// These are parallelism determination modes for `acc loop`. +// In the enum names, we use the "loop_" prefix because "auto" is +// a language keyword - and thus for consistency all other cases +// do the same. +def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>; +def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>; +def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>; + +def OpenACC_LoopParMode : I32EnumAttr< + "LoopParMode", + "Encodes the options for loop parallelism determination mode", + [ + OpenACC_LoopAuto, OpenACC_LoopIndependent, + OpenACC_LoopSeq]> { + let cppNamespace = "::mlir::acc"; + let genSpecializedAttr = 0; +} + // Type used in operation below. def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; @@ -2373,6 +2391,11 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", // Return whether this LoopOp has a gang, worker, or vector applying to the // 'default'/None device-type. bool hasDefaultGangWorkerVector(); + + // Used to obtain the parallelism mode for the requested device type. + // This first checks if the mode is set for the device_type requested. + // And if not, it returns the non-device_type mode. + LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType); }]; let hasCustomAssemblyFormat = 1; @@ -2404,6 +2427,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", }]; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "::mlir::ValueRange":$lowerbounds, + "::mlir::ValueRange":$upperbounds, + "::mlir::ValueRange":$steps, + "LoopParMode":$parMode), [{ + auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get( + $_builder.getContext(), mlir::acc::DeviceType::None); + auto arrOfDeviceNone = mlir::ArrayAttr::get( + $_builder.getContext(), deviceNoneAttr); + build($_builder, $_state, + /*results=*/{}, + /*lowerbound=*/lowerbounds, + /*upperbound=*/upperbounds, + /*step=*/steps, + /*inclusiveUpperbound=*/nullptr, + /*collapse=*/nullptr, + /*collapseDeviceType=*/nullptr, + /*gangOperands=*/{}, + /*gangOperandsArgType=*/nullptr, + /*gangOperandsSegments=*/nullptr, + /*gangOperandsDeviceType=*/nullptr, + /*workerNumOperands=*/{}, + /*workerNumOperandsDeviceType=*/nullptr, + /*vectorOperands=*/{}, + /*vectorOperandsDeviceType=*/nullptr, + /*seq=*/parMode == LoopParMode::loop_seq ? + arrOfDeviceNone : nullptr, + /*independent=*/parMode == LoopParMode::loop_independent ? + arrOfDeviceNone : nullptr, + /*auto_=*/parMode == LoopParMode::loop_auto ? + arrOfDeviceNone : nullptr, + /*gang=*/nullptr, + /*worker=*/nullptr, + /*vector=*/nullptr, + /*tileOperands=*/{}, + /*tileOperandsSegments=*/nullptr, + /*tileOperandsDeviceType=*/nullptr, + /*cacheOperands=*/{}, + /*privateOperands=*/{}, + /*privatizationRecipes=*/nullptr, + /*reductionOperands=*/{}, + /*reductionRecipes=*/nullptr, + /*combined=*/nullptr); + }] + > + ]; } // Yield operation for the acc.loop and acc.parallel operations. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index c691d59..531fecc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -330,10 +330,34 @@ public: bool hasValue() const { return !isa<UnitAttr>(decorationValue); } }; + // Type for specifying the decoration(s) on the struct itself. + struct StructDecorationInfo { + Decoration decoration; + Attribute decorationValue; + + StructDecorationInfo(Decoration decoration, Attribute decorationValue) + : decoration(decoration), decorationValue(decorationValue) {} + + friend bool operator==(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return lhs.decoration == rhs.decoration && + lhs.decorationValue == rhs.decorationValue; + } + + friend bool operator<(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return llvm::to_underlying(lhs.decoration) < + llvm::to_underlying(rhs.decoration); + } + + bool hasValue() const { return !isa<UnitAttr>(decorationValue); } + }; + /// Construct a literal StructType with at least one member. static StructType get(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); /// Construct an identified StructType. This creates a StructType whose body /// (member types, offset info, and decorations) is not set yet. A call to @@ -367,6 +391,9 @@ public: bool hasOffset() const; + /// Returns true if the struct has a specified decoration. + bool hasDecoration(spirv::Decoration decoration) const; + uint64_t getMemberOffset(unsigned) const; // Returns in `memberDecorations` the Decorations (apart from Offset) @@ -380,12 +407,18 @@ public: unsigned i, SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const; + // Returns in `structDecorations` the Decorations associated with the + // StructType. + void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo> + &structDecorations) const; + /// Sets the contents of an incomplete identified StructType. This method must /// be called only for identified StructTypes and it must be called only once /// per instance. Otherwise, failure() is returned. LogicalResult trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional<StorageClass> storage = std::nullopt); @@ -396,6 +429,9 @@ public: llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); +llvm::hash_code +hash_value(const StructType::StructDecorationInfo &structDecorationInfo); + // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase<CooperativeMatrixType, CompositeType, diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 3d22ec9..03ae54a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -39,6 +39,10 @@ struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// Whether to emulate unsupported floats with integer types of same bit + /// width. + bool emulateUnsupportedFloatTypes{true}; + /// How sub-byte values are storaged in memory. SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 91d6b2a..75b16a87 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. - Example: + Example 1: ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16xf16> ``` + + Example 2: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> + xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<4xindex> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $source, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> - ]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>, vector<16xi1> -> vector<16x8xf32> ``` + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, @@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>> vector<16xi1> -> vector<8xf32> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc + for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } mlir::Type getElementType() { @@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? `,` + $mask prop-dict + attr-dict `:` type(operands) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]> - ]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be @@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ l3_hint = #xegpu.cache_hint<write_through>}> : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". + The dest operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %val = arith.constant dense<0.0> : vector<16xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` + }]; let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc: $TensorDesc, + XeGPU_GatherScatterSourceType: $dest, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getDestType() { + return getDest().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getDestType()); } VectorType getValueType() { @@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}]; + let assemblyFormat = [{ + $value `,` + $dest + (`[` $offsets^ `]`)? `,` + $mask + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 20916ae..b268cab 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 4ed0423..7ff718a 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -639,6 +639,10 @@ public: /// verified correctly, failure otherwise. LogicalResult verify(); + /// Register this handler with the given context. This is intended for use + /// with the splitAndProcessBuffer function. + void registerInContext(MLIRContext *ctx); + private: /// Process a single diagnostic. void process(Diagnostic &diag); diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 2162a74..8959dab 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -200,7 +200,7 @@ public: // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get<ConcreteT>(ctx, args...); + return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...); } /// Get an instance of the concrete type from a void pointer. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0..bbfa308 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Returns true if this symbol has nested visibility.", "bool", "isNested", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Nested; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Private; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private; }] >, InterfaceMethod<"Returns true if this symbol has public visibility.", "bool", "isPublic", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<"Sets the visibility of this symbol.", @@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Sets the visibility of this symbol to be nested.", "void", "setNested", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Nested); + $_op.setVisibility(mlir::SymbolTable::Visibility::Nested); }] >, InterfaceMethod<"Sets the visibility of this symbol to be private.", "void", "setPrivate", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Private); + $_op.setVisibility(mlir::SymbolTable::Visibility::Private); }] >, InterfaceMethod<"Sets the visibility of this symbol to be public.", "void", "setPublic", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Public); + $_op.setVisibility(mlir::SymbolTable::Visibility::Public); }] >, InterfaceMethod<[{ @@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { // By default, base this on the visibility alone. A symbol can be // discarded as long as it is not public. Only public symbols may be // visible from outside of the IR. - return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<[{ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 856170e..7628171 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,200 +14,15 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" -#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/IRDL/IR/IRDL.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/LLVMIR/XeVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/MPI/IR/MPI.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" -#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/SMT/IR/SMTDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Shard/IR/ShardDialect.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/Target/LLVM/ROCDL/Target.h" -#include "mlir/Target/SPIRV/Target.h" - namespace mlir { +class DialectRegistry; +class MLIRContext; /// Add all the MLIR dialects to the provided registry. -inline void registerAllDialects(DialectRegistry ®istry) { - // clang-format off - registry.insert<acc::OpenACCDialect, - affine::AffineDialect, - amdgpu::AMDGPUDialect, - amx::AMXDialect, - arith::ArithDialect, - arm_neon::ArmNeonDialect, - arm_sme::ArmSMEDialect, - arm_sve::ArmSVEDialect, - async::AsyncDialect, - bufferization::BufferizationDialect, - cf::ControlFlowDialect, - complex::ComplexDialect, - DLTIDialect, - emitc::EmitCDialect, - func::FuncDialect, - gpu::GPUDialect, - index::IndexDialect, - irdl::IRDLDialect, - linalg::LinalgDialect, - LLVM::LLVMDialect, - math::MathDialect, - memref::MemRefDialect, - shard::ShardDialect, - ml_program::MLProgramDialect, - mpi::MPIDialect, - nvgpu::NVGPUDialect, - NVVM::NVVMDialect, - omp::OpenMPDialect, - pdl::PDLDialect, - pdl_interp::PDLInterpDialect, - ptr::PtrDialect, - quant::QuantDialect, - ROCDL::ROCDLDialect, - scf::SCFDialect, - shape::ShapeDialect, - smt::SMTDialect, - sparse_tensor::SparseTensorDialect, - spirv::SPIRVDialect, - tensor::TensorDialect, - tosa::TosaDialect, - transform::TransformDialect, - ub::UBDialect, - vector::VectorDialect, - x86vector::X86VectorDialect, - xegpu::XeGPUDialect, - xevm::XeVMDialect>(); - // clang-format on - - // Register all external models. - affine::registerValueBoundsOpInterfaceExternalModels(registry); - arith::registerBufferDeallocationOpInterfaceExternalModels(registry); - arith::registerBufferizableOpInterfaceExternalModels(registry); - arith::registerBufferViewFlowOpInterfaceExternalModels(registry); - arith::registerShardingInterfaceExternalModels(registry); - arith::registerValueBoundsOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - registry); - builtin::registerCastOpInterfaceExternalModels(registry); - cf::registerBufferizableOpInterfaceExternalModels(registry); - cf::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerValueBoundsOpInterfaceExternalModels(registry); - LLVM::registerInlinerInterface(registry); - NVVM::registerInlinerInterface(registry); - linalg::registerAllDialectInterfaceImplementations(registry); - linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerAllocationOpInterfaceExternalModels(registry); - memref::registerBufferViewFlowOpInterfaceExternalModels(registry); - memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerValueBoundsOpInterfaceExternalModels(registry); - memref::registerMemorySlotExternalModels(registry); - ml_program::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerBufferDeallocationOpInterfaceExternalModels(registry); - scf::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerValueBoundsOpInterfaceExternalModels(registry); - shape::registerBufferizableOpInterfaceExternalModels(registry); - sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); - tensor::registerInferTypeOpInterfaceExternalModels(registry); - tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - tensor::registerSubsetOpInterfaceExternalModels(registry); - tensor::registerTilingInterfaceExternalModels(registry); - tensor::registerValueBoundsOpInterfaceExternalModels(registry); - tosa::registerShardingInterfaceExternalModels(registry); - vector::registerBufferizableOpInterfaceExternalModels(registry); - vector::registerSubsetOpInterfaceExternalModels(registry); - vector::registerValueBoundsOpInterfaceExternalModels(registry); - NVVM::registerNVVMTargetInterfaceExternalModels(registry); - ROCDL::registerROCDLTargetInterfaceExternalModels(registry); - spirv::registerSPIRVTargetInterfaceExternalModels(registry); -} +void registerAllDialects(DialectRegistry ®istry); /// Append all the MLIR dialects to the registry contained in the given context. -inline void registerAllDialects(MLIRContext &context) { - DialectRegistry registry; - registerAllDialects(registry); - context.appendDialectRegistry(registry); -} +void registerAllDialects(MLIRContext &context); } // namespace mlir diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c..a7f64d9 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,110 +14,15 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ -#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" -#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" -#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" -#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" -#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" -#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" -#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" -#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" -#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" -#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" -#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" -#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" -#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" -#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" -#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" - -#include <cstdlib> - namespace mlir { +class DialectRegistry; /// This function may be called to register all MLIR dialect extensions with the /// provided registry. /// If you're building a compiler, you generally shouldn't use this: you would /// individually register the specific extensions that are useful for the /// pipelines and transformations you are using. -inline void registerAllExtensions(DialectRegistry ®istry) { - // Register all conversions to LLVM extensions. - registerConvertArithToEmitCInterface(registry); - arith::registerConvertArithToLLVMInterface(registry); - registerConvertComplexToLLVMInterface(registry); - cf::registerConvertControlFlowToLLVMInterface(registry); - func::registerAllExtensions(registry); - tensor::registerAllExtensions(registry); - registerConvertFuncToEmitCInterface(registry); - registerConvertFuncToLLVMInterface(registry); - index::registerConvertIndexToLLVMInterface(registry); - registerConvertMathToLLVMInterface(registry); - mpi::registerConvertMPIToLLVMInterface(registry); - registerConvertMemRefToEmitCInterface(registry); - registerConvertMemRefToLLVMInterface(registry); - registerConvertNVVMToLLVMInterface(registry); - registerConvertOpenMPToLLVMInterface(registry); - registerConvertSCFToEmitCInterface(registry); - ub::registerConvertUBToLLVMInterface(registry); - registerConvertAMXToLLVMInterface(registry); - gpu::registerConvertGpuToLLVMInterface(registry); - NVVM::registerConvertGpuToNVVMInterface(registry); - vector::registerConvertVectorToLLVMInterface(registry); - registerConvertXeVMToLLVMInterface(registry); - - // Register all transform dialect extensions. - affine::registerTransformDialectExtension(registry); - bufferization::registerTransformDialectExtension(registry); - dlti::registerTransformDialectExtension(registry); - func::registerTransformDialectExtension(registry); - gpu::registerTransformDialectExtension(registry); - linalg::registerTransformDialectExtension(registry); - memref::registerTransformDialectExtension(registry); - nvgpu::registerTransformDialectExtension(registry); - scf::registerTransformDialectExtension(registry); - sparse_tensor::registerTransformDialectExtension(registry); - tensor::registerTransformDialectExtension(registry); - transform::registerDebugExtension(registry); - transform::registerIRDLExtension(registry); - transform::registerLoopExtension(registry); - transform::registerPDLExtension(registry); - transform::registerTuneExtension(registry); - vector::registerTransformDialectExtension(registry); - arm_neon::registerTransformDialectExtension(registry); - arm_sve::registerTransformDialectExtension(registry); - - // Translation extensions need to be registered by calling - // `registerAllToLLVMIRTranslations` (see All.h). -} +void registerAllExtensions(DialectRegistry ®istry); } // namespace mlir diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 002ff61..4554290 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -1,4 +1,4 @@ -//===- LinkAllPassesAndDialects.h - MLIR Registration -----------*- C++ -*-===// +//===- InitAllPasses.h - MLIR Registration ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,50 +6,14 @@ // //===----------------------------------------------------------------------===// // -// This file defines a helper to trigger the registration of all dialects and -// passes to the system. +// This file defines a helper to trigger the registration of all passes to the +// system. // //===----------------------------------------------------------------------===// #ifndef MLIR_INITALLPASSES_H_ #define MLIR_INITALLPASSES_H_ -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/Dialect/ArmSME/Transforms/Passes.h" -#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/EmitC/Transforms/Passes.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/GPU/Pipelines/Passes.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MLProgram/Transforms/Passes.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/NVGPU/Transforms/Passes.h" -#include "mlir/Dialect/OpenACC/Transforms/Passes.h" -#include "mlir/Dialect/Quant/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Dialect/SPIRV/Transforms/Passes.h" -#include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Dialect/Shard/Transforms/Passes.h" -#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Transform/Transforms/Passes.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" -#include "mlir/Dialect/XeGPU/Transforms/Passes.h" -#include "mlir/Transforms/Passes.h" - -#include <cstdlib> - namespace mlir { // This function may be called to register the MLIR passes with the @@ -59,49 +23,7 @@ namespace mlir { // registry, since it would already be calling the creation routine of the // individual passes. // The global registry is interesting to interact with the command-line tools. -inline void registerAllPasses() { - // General passes - registerTransformsPasses(); - - // Conversion passes - registerConversionPasses(); - - // Dialect passes - acc::registerOpenACCPasses(); - affine::registerAffinePasses(); - amdgpu::registerAMDGPUPasses(); - registerAsyncPasses(); - arith::registerArithPasses(); - bufferization::registerBufferizationPasses(); - func::registerFuncPasses(); - registerGPUPasses(); - registerLinalgPasses(); - registerNVGPUPasses(); - registerSparseTensorPasses(); - LLVM::registerLLVMPasses(); - math::registerMathPasses(); - memref::registerMemRefPasses(); - shard::registerShardPasses(); - ml_program::registerMLProgramPasses(); - quant::registerQuantPasses(); - registerSCFPasses(); - registerShapePasses(); - spirv::registerSPIRVPasses(); - tensor::registerTensorPasses(); - tosa::registerTosaOptPasses(); - transform::registerTransformPasses(); - vector::registerVectorPasses(); - arm_sme::registerArmSMEPasses(); - arm_sve::registerArmSVEPasses(); - emitc::registerEmitCPasses(); - xegpu::registerXeGPUPasses(); - - // Dialect pipelines - bufferization::registerBufferizationPipelines(); - sparse_tensor::registerSparseTensorPipelines(); - tosa::registerTosaToLinalgPipelines(); - gpu::registerGPUToNVVMPipeline(); -} +void registerAllPasses(); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index e3c2aec..19d3afe 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -18,9 +18,15 @@ include "mlir/IR/OpBase.td" -/// Interface for operations with arguments attributes (both call-like -/// and callable operations). -def ArgumentAttributesMethods { +/// Interface for operations with result and argument attributes. +def ArgAndResultAttrsOpInterface : OpInterface<"ArgAndResultAttrsOpInterface"> { + let description = [{ + An operation that has argument and result attributes. This interface + provides functions to access and modify the argument and result + attributes of the operation. + }]; + let cppNamespace = "::mlir"; + list<InterfaceMethod> methods = [ InterfaceMethod<[{ Get the array of argument attribute dictionaries. The method should @@ -64,7 +70,8 @@ def ArgumentAttributesMethods { // a call-like operation. This represents the destination of the call. /// Interface for call-like operations. -def CallOpInterface : OpInterface<"CallOpInterface"> { +def CallOpInterface : OpInterface<"CallOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A call-like operation is one that transfers control from one sub-routine to another. These operations may be traditional direct calls `call @foo`, or @@ -123,11 +130,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { return ::mlir::call_interface_impl::resolveCallable($_op); }] > - ] # ArgumentAttributesMethods.methods; + ]; } /// Interface for callable operations. -def CallableOpInterface : OpInterface<"CallableOpInterface"> { +def CallableOpInterface : OpInterface<"CallableOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A callable operation is one who represents a potential sub-routine, and may be a target for a call-like operation (those providing the CallOpInterface @@ -140,11 +148,11 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { let methods = [ InterfaceMethod<[{ - Returns the region on the current operation that is callable. This may - return null in the case of an external callable object, e.g. an external - function. - }], - "::mlir::Region *", "getCallableRegion">, + Returns the region on the current operation that is callable. This may + return null in the case of an external callable object, e.g. an external + function. + }], + "::mlir::Region *", "getCallableRegion">, InterfaceMethod<[{ Returns the callable's argument types based exclusively on the type (to allow for this method may be called on function declarations). @@ -155,7 +163,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { allow for this method may be called on function declarations). }], "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">, - ] # ArgumentAttributesMethods.methods; + ]; } #endif // MLIR_INTERFACES_CALLINTERFACES diff --git a/mlir/include/mlir/Support/ToolUtilities.h b/mlir/include/mlir/Support/ToolUtilities.h index cb6ba29..657f117 100644 --- a/mlir/include/mlir/Support/ToolUtilities.h +++ b/mlir/include/mlir/Support/ToolUtilities.h @@ -21,10 +21,16 @@ namespace llvm { class MemoryBuffer; +class MemoryBufferRef; } // namespace llvm namespace mlir { +// A function that processes a chunk of a buffer and writes the result to an +// output stream. using ChunkBufferHandler = function_ref<LogicalResult( + std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &sourceBuffer, raw_ostream &os)>; +using NoSourceChunkBufferHandler = function_ref<LogicalResult( std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os)>; extern inline const char *const kDefaultSplitMarker = "// -----"; @@ -45,6 +51,15 @@ splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker = kDefaultSplitMarker, llvm::StringRef outputSplitMarker = ""); + +/// Same as above, but for case where the original buffer is not used while +/// processing the chunk. +LogicalResult +splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, + llvm::StringRef inputSplitMarker = kDefaultSplitMarker, + llvm::StringRef outputSplitMarker = ""); } // namespace mlir #endif // MLIR_SUPPORT_TOOLUTILITIES_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index 60615cf6..e4670cb 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -28,6 +28,7 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); @@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry ®istry) { registerNVVMDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h new file mode 100644 index 0000000..b4f6750 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for XeVM dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the XeVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerXeVMDialectTranslation(mlir::DialectRegistry ®istry); + +/// Register the XeVM dialect and the translation from it in the registry +/// associated with the given context. +void registerXeVMDialectTranslation(mlir::MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 17ef8e4..b22ed60 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -291,10 +291,12 @@ public: SmallVectorImpl<Value> &valuesOut, SmallVectorImpl<NamedAttribute> &attrsOut); - /// Converts the parameter and result attributes in `argsAttr` and `resAttr` - /// and add them to the `callOp`. - void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr, - ArrayAttr &resAttr, OpBuilder &builder); + /// Converts the argument and result attributes attached to `call` and adds + /// them to `attrsOp`. For intrinsic calls, filters out attributes + /// corresponding to immediate arguments specified by `immArgPositions`. + void convertArgAndResultAttrs(llvm::CallBase *call, + ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions = {}); /// Whether the importer should try to convert all intrinsics to /// llvm.call_intrinsic instead of dialect supported operations. @@ -378,19 +380,12 @@ private: bool &isIncompatibleCall); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); - /// Converts the parameter and result attributes attached to `func` and adds + /// Converts the argument and result attributes attached to `func` and adds /// them to the `funcOp`. - void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, - OpBuilder &builder); - /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding - /// DictionaryAttr for the LLVM dialect. - DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder); - /// Converts the parameter and result attributes attached to `call` and adds - /// them to the `callOp`. Implemented in terms of the the public definition of - /// convertParameterAttributes. - void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, - OpBuilder &builder); + void convertArgAndResultAttrs(llvm::Function *func, LLVMFuncOp funcOp); + /// Converts the argument or result attributes in `llvmAttrSet` to a + /// corresponding MLIR LLVM dialect attribute dictionary. + DictionaryAttr convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet); /// Converts the attributes attached to `inst` and adds them to the `op`. LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op); /// Converts the attributes attached to `inst` and adds them to the `op`. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index f3f73f4..eb7dfa7 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -25,11 +25,13 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/FPEnv.h" +#include "llvm/IR/Module.h" namespace llvm { class BasicBlock; +class CallBase; +class CanonicalLoopInfo; class Function; class IRBuilderBase; class OpenMPIRBuilder; @@ -306,10 +308,16 @@ public: /*recordInsertions=*/false); } - /// Translates parameter attributes of a call and adds them to the returned - /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, - DictionaryAttr paramAttrs); + /// Converts argument and result attributes from `attrsOp` to LLVM IR + /// attributes on the `call` instruction. Returns failure if conversion fails. + /// The `immArgPositions` parameter is only relevant for intrinsics. It + /// specifies the positions of immediate arguments, which do not have + /// associated argument attributes in MLIR and should be skipped during + /// attribute mapping. + LogicalResult + convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp, + llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions = {}); /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. @@ -389,6 +397,11 @@ private: convertDialectAttributes(Operation *op, ArrayRef<llvm::Instruction *> instructions); + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, + DictionaryAttr paramAttrs); + /// Translates parameter attributes of a function and adds them to the /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr<llvm::AttrBuilder> |