diff options
Diffstat (limited to 'mlir')
113 files changed, 2695 insertions, 836 deletions
diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md index 022bdad..b991863 100644 --- a/mlir/docs/DefiningDialects/AttributesAndTypes.md +++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -136,7 +136,7 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> { /// Here we've defined two parameters, one is a "self" type parameter, and the /// other is the integer value of the attribute. The self type parameter is /// specially handled by the assembly format. - let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); + let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value); /// Here we've defined a custom builder for the type, that removes the need to pass /// in an MLIRContext instance; as it can be infered from the `type`. diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index cf7596c..6e1baaf 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index e81db32..06fb851 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 8988df6..d055bb4 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -92,6 +92,7 @@ class ArmSVE_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index a8455c2..b52f136 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -38,7 +38,8 @@ def Async_ExecuteOp : ["getEntrySuccessorOperands", "areTypesCompatible"]>, AttrSizedOperandSegments, - AutomaticAllocationScope]> { + AutomaticAllocationScope, + RecursiveMemoryEffects]> { let summary = "Asynchronous execute operation"; let description = [{ The `body` region attached to the `async.execute` operation semantically diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td index b5ea8fc..107bf3e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -83,6 +83,9 @@ def LLVM_Dialect : Dialect { return "llvm.emit_c_interface"; } + /// Name of the module level assembly attribute. + static StringRef getModuleLevelAsmAttrName() { return "llvm.module_asm"; } + /// Name of the dependent libraries attribute. static StringRef getDependentLibrariesAttrName() { return "llvm.dependent_libraries"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 8c6f1ee..d38298f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -140,8 +140,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">; def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3], - /*immArgAttrNames=*/["rw", "hint", "cache"] + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"] > { let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache); } @@ -200,13 +200,13 @@ class LLVM_MemcpyIntrOpBase<string name> : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$len, "bool":$isVolatile), [{ @@ -217,7 +217,8 @@ class LLVM_MemcpyIntrOpBase<string name> : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -231,13 +232,13 @@ def LLVM_MemcpyInlineOp : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -248,7 +249,8 @@ def LLVM_MemcpyInlineOp : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -258,12 +260,12 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "Value":$len, "bool":$isVolatile), [{ @@ -274,7 +276,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -284,12 +287,12 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -300,7 +303,8 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -349,8 +353,8 @@ def LLVM_PtrMaskOp class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; } @@ -370,8 +374,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> { let arguments = (ins LLVM_DefaultPointer:$start, I64Attr:$size, LLVM_AnyPointer:$ptr); @@ -542,9 +546,10 @@ def LLVM_AssumeOp : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/1> { dag args = (ins I1:$cond); - let arguments = !con(args, opBundleArgs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $cond @@ -1126,8 +1131,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">; def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["failureKind"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> { let arguments = (ins I8Attr:$failureKind); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index e845ea9f..a8d7cf2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" //===----------------------------------------------------------------------===// // LLVM dialect type constraints. @@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> : // intrinsic and "enumName" contains the name of the intrinsic as appears in // `llvm::Intrinsic` enum; one usually wants these to be related. Additionally, // the base class also defines the "mlirBuilder" field to support the inverse -// translation starting from an LLVM IR intrinsic. The "requiresAccessGroup", -// "requiresAliasAnalysis", and "requiresFastmath" flags specify which -// interfaces the intrinsic implements. If the corresponding flags are set, the -// "aliasAttrs" list contains the arguments required by the access group and -// alias analysis interfaces. Derived intrinsics should append the "aliasAttrs" -// to their argument list if they set one of the flags. LLVM `immargs` can be -// represented as MLIR attributes by providing both the `immArgPositions` and -// `immArgAttrNames` lists. These two lists should have equal length, with -// `immArgPositions` containing the argument positions on the LLVM IR attribute -// that are `immargs`, and `immArgAttrNames` mapping these to corresponding -// MLIR attributes. +// translation starting from an LLVM IR intrinsic. +// +// The flags "requiresAccessGroup", "requiresAliasAnalysis", +// "requiresFastmath", and "requiresArgAndResultAttrs" indicate which +// interfaces the intrinsic implements. When a flag is set, the "baseArgs" +// list includes the arguments required by the corresponding interface. +// Derived intrinsics must append "baseArgs" to their argument list if they +// enable any of these flags. +// +// LLVM `immargs` can be represented as MLIR attributes by providing both +// the `immArgPositions` and `immArgAttrNames` lists. These two lists should +// have equal length, with `immArgPositions` containing the argument +// positions on the LLVM IR attribute that are `immargs`, and +// `immArgAttrNames` mapping these to corresponding MLIR attributes. class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, - bit requiresFastmath = 0, bit requiresOpBundles = 0, + bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0, + bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_OpBase<dialect, opName, !listconcat( @@ -311,10 +316,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, [DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []), !if(!gt(requiresFastmath, 0), [DeclareOpInterfaceMethods<FastmathFlagsInterface>], []), + !if(!gt(requiresArgAndResultAttrs, 0), + [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []), traits)>, LLVM_MemOpPatterns, Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> { - dag aliasAttrs = !con( + dag baseArgs = !con( !if(!gt(requiresAccessGroup, 0), (ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups), (ins )), @@ -322,13 +329,17 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, (ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes, OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes, OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa), + (ins )), + !if(!gt(requiresArgAndResultAttrs, 0), + (ins OptionalAttr<DictArrayAttr>:$arg_attrs, + OptionalAttr<DictArrayAttr>:$res_attrs), + (ins )), + !if(!gt(requiresOpBundles, 0), + (ins VariadicOfVariadic<LLVM_Type, + "op_bundle_sizes">:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + OptionalAttr<ArrayAttr>:$op_bundle_tags), (ins ))); - dag opBundleArgs = !if(!gt(requiresOpBundles, 0), - (ins VariadicOfVariadic<LLVM_Type, - "op_bundle_sizes">:$op_bundle_operands, - DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr<ArrayAttr>:$op_bundle_tags), - (ins )); string llvmEnumName = enumName; string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}"; string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}"; @@ -342,23 +353,35 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{); (void) inst; }]; + string baseLlvmBuilderArgAndResultAttrs = [{ + if (failed(moduleTranslation.convertArgAndResultAttrs( + op, + inst, + }] # immArgPositionsCpp # [{))) { + return failure(); + } + }]; string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", ""); - let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "") - # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "") - # baseLlvmBuilderCoda; + let llvmBuilder = baseLlvmBuilder + # !if(!gt(requiresAccessGroup, 0), + setAccessGroupsMetadataCode, "") + # !if(!gt(requiresAliasAnalysis, 0), + setAliasAnalysisMetadataCode, "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseLlvmBuilderArgAndResultAttrs, "") + # baseLlvmBuilderCoda; string baseMlirBuilder = [{ SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, - llvmOpBundles, - }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, - }] # immArgPositionsCpp # [{, - }] # immArgAttrNamesCpp # [{, - mlirOperands, - mlirAttrs)) - ) { + llvmOperands, + llvmOpBundles, + }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, + }] # immArgPositionsCpp # [{, + }] # immArgAttrNamesCpp # [{, + mlirOperands, + mlirAttrs))) { return failure(); } SmallVector<Type> resultTypes = @@ -366,9 +389,16 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, auto op = $_qualCppClassName::create($_builder, $_location, resultTypes, mlirOperands, mlirAttrs); }]; + string baseMlirBuilderArgAndResultAttrs = [{ + moduleImport.convertArgAndResultAttrs( + inst, op, }] # immArgPositionsCpp # [{); + }]; string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); - let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0), + let mlirBuilder = baseMlirBuilder + # !if(!gt(requiresFastmath, 0), "moduleImport.setFastmathFlagsAttr(inst, op);", "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseMlirBuilderArgAndResultAttrs, "") # baseMlirBuilderCoda; // Code for handling a `range` attribute that holds the constant range of the @@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, bit requiresFastmath = 0, - bit requiresOpBundles = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, requiresAliasAnalysis, - requiresFastmath, requiresOpBundles, immArgPositions, - immArgAttrNames>; + requiresFastmath, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [], list<Trait> traits = [], bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0, requiresAccessGroup, requiresAliasAnalysis, - /*requiresFastMath=*/0, requiresOpBundles, immArgPositions, - immArgAttrNames>; + /*requiresFastMath=*/0, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning one result. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is @@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning two results. Places the @@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; def LLVM_OneResultOpBuilder : diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 51004f5..3f27f6d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2405,7 +2405,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", - [AttrSizedOperandSegments, + [ArgAndResultAttrsOpInterface, + AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> { let summary = "Call to an LLVM intrinsic function."; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 04a0b58..a2354e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults, LLVM_IntrOpBase<ROCDL_Dialect, mnemonic, "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, - requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>; + requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>; // Subclass to save typing and ease readibility when there aren't overloaded // operands or memory accesses. @@ -482,7 +482,7 @@ def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> : ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> { dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -507,7 +507,7 @@ def ROCDL_LoadToLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict `:` type($globalPtr) @@ -526,7 +526,7 @@ def ROCDL_GlobalLoadLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict @@ -561,7 +561,7 @@ def ROCDL_RawPtrBufferLoadOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -579,7 +579,7 @@ def ROCDL_RawPtrBufferLoadLdsOp : I32:$soffset, I32:$offset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -595,7 +595,7 @@ def ROCDL_RawPtrBufferStoreOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -614,7 +614,7 @@ def ROCDL_RawPtrBufferAtomicCmpSwap : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -630,7 +630,7 @@ class ROCDL_RawPtrBufferAtomicNoRet<string op> : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8d45c40..61ce23f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac iteration domain induces a padding of the operands that is consistent across the op semantics and, unlike for simple elementwise ops, may not be trivially deducible or specifiable on operands only (e.g. convolutions). + Currently, only a limited set of projected permutation maps are supported. The specification of `padding_sizes` follows that of `tile_sizes` during tiling: the value "0" on a particular iterator encode "no padding". Like in diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e625eef..d4ffe0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// affine.apply operations. /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and /// provides a gentle portability path for Linalg-like ops with affine maps. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permuation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 96b9adc..e1e99c3 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr< let printBitEnumPrimaryGroups = 1; } +// These are parallelism determination modes for `acc loop`. +// In the enum names, we use the "loop_" prefix because "auto" is +// a language keyword - and thus for consistency all other cases +// do the same. +def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>; +def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>; +def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>; + +def OpenACC_LoopParMode : I32EnumAttr< + "LoopParMode", + "Encodes the options for loop parallelism determination mode", + [ + OpenACC_LoopAuto, OpenACC_LoopIndependent, + OpenACC_LoopSeq]> { + let cppNamespace = "::mlir::acc"; + let genSpecializedAttr = 0; +} + // Type used in operation below. def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; @@ -2373,6 +2391,11 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", // Return whether this LoopOp has a gang, worker, or vector applying to the // 'default'/None device-type. bool hasDefaultGangWorkerVector(); + + // Used to obtain the parallelism mode for the requested device type. + // This first checks if the mode is set for the device_type requested. + // And if not, it returns the non-device_type mode. + LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType); }]; let hasCustomAssemblyFormat = 1; @@ -2404,6 +2427,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", }]; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "::mlir::ValueRange":$lowerbounds, + "::mlir::ValueRange":$upperbounds, + "::mlir::ValueRange":$steps, + "LoopParMode":$parMode), [{ + auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get( + $_builder.getContext(), mlir::acc::DeviceType::None); + auto arrOfDeviceNone = mlir::ArrayAttr::get( + $_builder.getContext(), deviceNoneAttr); + build($_builder, $_state, + /*results=*/{}, + /*lowerbound=*/lowerbounds, + /*upperbound=*/upperbounds, + /*step=*/steps, + /*inclusiveUpperbound=*/nullptr, + /*collapse=*/nullptr, + /*collapseDeviceType=*/nullptr, + /*gangOperands=*/{}, + /*gangOperandsArgType=*/nullptr, + /*gangOperandsSegments=*/nullptr, + /*gangOperandsDeviceType=*/nullptr, + /*workerNumOperands=*/{}, + /*workerNumOperandsDeviceType=*/nullptr, + /*vectorOperands=*/{}, + /*vectorOperandsDeviceType=*/nullptr, + /*seq=*/parMode == LoopParMode::loop_seq ? + arrOfDeviceNone : nullptr, + /*independent=*/parMode == LoopParMode::loop_independent ? + arrOfDeviceNone : nullptr, + /*auto_=*/parMode == LoopParMode::loop_auto ? + arrOfDeviceNone : nullptr, + /*gang=*/nullptr, + /*worker=*/nullptr, + /*vector=*/nullptr, + /*tileOperands=*/{}, + /*tileOperandsSegments=*/nullptr, + /*tileOperandsDeviceType=*/nullptr, + /*cacheOperands=*/{}, + /*privateOperands=*/{}, + /*privatizationRecipes=*/nullptr, + /*reductionOperands=*/{}, + /*reductionRecipes=*/nullptr, + /*combined=*/nullptr); + }] + > + ]; } // Yield operation for the acc.loop and acc.parallel operations. diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 2d15544..0c1c15b 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -87,6 +87,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ be accessed inside the op. The op's region can have multiple blocks and the blocks can have multiple distinct terminators. Values returned from this op's region define the op's results. + The optional 'no_inline' flag can be set to request the ExecuteRegionOp to be + preserved as much as possible and not being inlined in the parent block until + an explicit lowering step. Example: @@ -98,6 +101,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ } } + // the same as above but with no_inline attribute + scf.for %i = 0 to 128 step %c1 { + %y = scf.execute_region -> i32 no_inline { + %x = load %A[%i] : memref<128xi32> + scf.yield %x : i32 + } + } + affine.for %i = 0 to 100 { "foo"() : () -> () %v = scf.execute_region -> i64 { @@ -119,6 +130,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ ``` }]; + let arguments = (ins + UnitAttr:$no_inline + ); + let results = (outs Variadic<AnyType>); let regions = (region AnyRegion:$region); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 9038326..9c74cff0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4448,6 +4448,7 @@ def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended" def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>; def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; +def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>; def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; def SPIRV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPIRV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; @@ -4630,7 +4631,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, - SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, + SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite, + SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index ab535d7..9331fc5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -403,6 +403,28 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual", // ----- +def SPIRV_IsFiniteOp : SPIRV_LogicalUnaryOp<"IsFinite", SPIRV_Float, []> { + let summary = "Result is true if x is an IEEE Finite, otherwise result is false"; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + #### Example: + + ```mlir + %2 = spirv.IsFinite %0: f32 + %3 = spirv.IsFinite %1: vector<4xf32> + ``` + }]; +} + +// ----- + def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { let summary = "Result is true if x is an IEEE Inf, otherwise result is false"; @@ -418,7 +440,7 @@ def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { ```mlir %2 = spirv.IsInf %0: f32 - %3 = spirv.IsInf %1: vector<4xi32> + %3 = spirv.IsInf %1: vector<4xf32> ``` }]; } @@ -442,7 +464,7 @@ def SPIRV_IsNanOp : SPIRV_LogicalUnaryOp<"IsNan", SPIRV_Float, []> { ```mlir %2 = spirv.IsNan %0: f32 - %3 = spirv.IsNan %1: vector<4xi32> + %3 = spirv.IsNan %1: vector<4xf32> ``` }]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index c691d59..531fecc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -330,10 +330,34 @@ public: bool hasValue() const { return !isa<UnitAttr>(decorationValue); } }; + // Type for specifying the decoration(s) on the struct itself. + struct StructDecorationInfo { + Decoration decoration; + Attribute decorationValue; + + StructDecorationInfo(Decoration decoration, Attribute decorationValue) + : decoration(decoration), decorationValue(decorationValue) {} + + friend bool operator==(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return lhs.decoration == rhs.decoration && + lhs.decorationValue == rhs.decorationValue; + } + + friend bool operator<(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return llvm::to_underlying(lhs.decoration) < + llvm::to_underlying(rhs.decoration); + } + + bool hasValue() const { return !isa<UnitAttr>(decorationValue); } + }; + /// Construct a literal StructType with at least one member. static StructType get(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); /// Construct an identified StructType. This creates a StructType whose body /// (member types, offset info, and decorations) is not set yet. A call to @@ -367,6 +391,9 @@ public: bool hasOffset() const; + /// Returns true if the struct has a specified decoration. + bool hasDecoration(spirv::Decoration decoration) const; + uint64_t getMemberOffset(unsigned) const; // Returns in `memberDecorations` the Decorations (apart from Offset) @@ -380,12 +407,18 @@ public: unsigned i, SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const; + // Returns in `structDecorations` the Decorations associated with the + // StructType. + void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo> + &structDecorations) const; + /// Sets the contents of an incomplete identified StructType. This method must /// be called only for identified StructTypes and it must be called only once /// per instance. Otherwise, failure() is returned. LogicalResult trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional<StorageClass> storage = std::nullopt); @@ -396,6 +429,9 @@ public: llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); +llvm::hash_code +hash_value(const StructType::StructDecorationInfo &structDecorationInfo); + // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase<CooperativeMatrixType, CompositeType, diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 3d22ec9..03ae54a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -39,6 +39,10 @@ struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// Whether to emulate unsupported floats with integer types of same bit + /// width. + bool emulateUnsupportedFloatTypes{true}; + /// How sub-byte values are storaged in memory. SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3885439..5d45508 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2595,6 +2595,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ def Vector_TransposeOp : Vector_Op<"transpose", [Pure, + DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]> { @@ -2876,7 +2877,10 @@ def Vector_ScanOp : // VectorStepOp //===----------------------------------------------------------------------===// -def Vector_StepOp : Vector_Op<"step", [Pure]> { +def Vector_StepOp : Vector_Op<"step", [ + Pure, + DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> + ]> { let summary = "A linear sequence of values from 0 to N"; let description = [{ A `step` operation produces an index vector, i.e. a 1-D vector of values of diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 91d6b2a..75b16a87 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. - Example: + Example 1: ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16xf16> ``` + + Example 2: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> + xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<4xindex> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $source, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> - ]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>, vector<16xi1> -> vector<16x8xf32> ``` + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, @@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>> vector<16xi1> -> vector<8xf32> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc + for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } mlir::Type getElementType() { @@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? `,` + $mask prop-dict + attr-dict `:` type(operands) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]> - ]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be @@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ l3_hint = #xegpu.cache_hint<write_through>}> : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". + The dest operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %val = arith.constant dense<0.0> : vector<16xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` + }]; let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc: $TensorDesc, + XeGPU_GatherScatterSourceType: $dest, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getDestType() { + return getDest().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getDestType()); } VectorType getValueType() { @@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}]; + let assemblyFormat = [{ + $value `,` + $dest + (`[` $offsets^ `]`)? `,` + $mask + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 20916ae..b268cab 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index e3c2aec..19d3afe 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -18,9 +18,15 @@ include "mlir/IR/OpBase.td" -/// Interface for operations with arguments attributes (both call-like -/// and callable operations). -def ArgumentAttributesMethods { +/// Interface for operations with result and argument attributes. +def ArgAndResultAttrsOpInterface : OpInterface<"ArgAndResultAttrsOpInterface"> { + let description = [{ + An operation that has argument and result attributes. This interface + provides functions to access and modify the argument and result + attributes of the operation. + }]; + let cppNamespace = "::mlir"; + list<InterfaceMethod> methods = [ InterfaceMethod<[{ Get the array of argument attribute dictionaries. The method should @@ -64,7 +70,8 @@ def ArgumentAttributesMethods { // a call-like operation. This represents the destination of the call. /// Interface for call-like operations. -def CallOpInterface : OpInterface<"CallOpInterface"> { +def CallOpInterface : OpInterface<"CallOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A call-like operation is one that transfers control from one sub-routine to another. These operations may be traditional direct calls `call @foo`, or @@ -123,11 +130,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { return ::mlir::call_interface_impl::resolveCallable($_op); }] > - ] # ArgumentAttributesMethods.methods; + ]; } /// Interface for callable operations. -def CallableOpInterface : OpInterface<"CallableOpInterface"> { +def CallableOpInterface : OpInterface<"CallableOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A callable operation is one who represents a potential sub-routine, and may be a target for a call-like operation (those providing the CallOpInterface @@ -140,11 +148,11 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { let methods = [ InterfaceMethod<[{ - Returns the region on the current operation that is callable. This may - return null in the case of an external callable object, e.g. an external - function. - }], - "::mlir::Region *", "getCallableRegion">, + Returns the region on the current operation that is callable. This may + return null in the case of an external callable object, e.g. an external + function. + }], + "::mlir::Region *", "getCallableRegion">, InterfaceMethod<[{ Returns the callable's argument types based exclusively on the type (to allow for this method may be called on function declarations). @@ -155,7 +163,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { allow for this method may be called on function declarations). }], "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">, - ] # ArgumentAttributesMethods.methods; + ]; } #endif // MLIR_INTERFACES_CALLINTERFACES diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index 60615cf6..e4670cb 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -28,6 +28,7 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); @@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry ®istry) { registerNVVMDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h new file mode 100644 index 0000000..b4f6750 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for XeVM dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the XeVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerXeVMDialectTranslation(mlir::DialectRegistry ®istry); + +/// Register the XeVM dialect and the translation from it in the registry +/// associated with the given context. +void registerXeVMDialectTranslation(mlir::MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 17ef8e4..09d819a 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -83,6 +83,10 @@ public: /// specification. void convertTargetTriple(); + /// Converts the module level asm of the LLVM module to an MLIR module + /// level asm specification. + void convertModuleLevelAsm(); + /// Stores the mapping between an LLVM value and its MLIR counterpart. void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; } @@ -291,10 +295,12 @@ public: SmallVectorImpl<Value> &valuesOut, SmallVectorImpl<NamedAttribute> &attrsOut); - /// Converts the parameter and result attributes in `argsAttr` and `resAttr` - /// and add them to the `callOp`. - void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr, - ArrayAttr &resAttr, OpBuilder &builder); + /// Converts the argument and result attributes attached to `call` and adds + /// them to `attrsOp`. For intrinsic calls, filters out attributes + /// corresponding to immediate arguments specified by `immArgPositions`. + void convertArgAndResultAttrs(llvm::CallBase *call, + ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions = {}); /// Whether the importer should try to convert all intrinsics to /// llvm.call_intrinsic instead of dialect supported operations. @@ -378,19 +384,12 @@ private: bool &isIncompatibleCall); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); - /// Converts the parameter and result attributes attached to `func` and adds + /// Converts the argument and result attributes attached to `func` and adds /// them to the `funcOp`. - void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, - OpBuilder &builder); - /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding - /// DictionaryAttr for the LLVM dialect. - DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder); - /// Converts the parameter and result attributes attached to `call` and adds - /// them to the `callOp`. Implemented in terms of the the public definition of - /// convertParameterAttributes. - void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, - OpBuilder &builder); + void convertArgAndResultAttrs(llvm::Function *func, LLVMFuncOp funcOp); + /// Converts the argument or result attributes in `llvmAttrSet` to a + /// corresponding MLIR LLVM dialect attribute dictionary. + DictionaryAttr convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet); /// Converts the attributes attached to `inst` and adds them to the `op`. LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op); /// Converts the attributes attached to `inst` and adds them to the `op`. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index f3f73f4..eb7dfa7 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -25,11 +25,13 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/FPEnv.h" +#include "llvm/IR/Module.h" namespace llvm { class BasicBlock; +class CallBase; +class CanonicalLoopInfo; class Function; class IRBuilderBase; class OpenMPIRBuilder; @@ -306,10 +308,16 @@ public: /*recordInsertions=*/false); } - /// Translates parameter attributes of a call and adds them to the returned - /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, - DictionaryAttr paramAttrs); + /// Converts argument and result attributes from `attrsOp` to LLVM IR + /// attributes on the `call` instruction. Returns failure if conversion fails. + /// The `immArgPositions` parameter is only relevant for intrinsics. It + /// specifies the positions of immediate arguments, which do not have + /// associated argument attributes in MLIR and should be skipped during + /// attribute mapping. + LogicalResult + convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp, + llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions = {}); /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. @@ -389,6 +397,11 @@ private: convertDialectAttributes(Operation *op, ArrayRef<llvm::Instruction *> instructions); + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, + DictionaryAttr paramAttrs); + /// Translates parameter attributes of a function and adds them to the /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr<llvm::AttrBuilder> diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e681..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 855c582..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -22,7 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -32,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index a877ad2..1787e0a 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -488,7 +488,12 @@ namespace mlir { void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns - patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); + patterns + .add<CopySignPattern, + CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>, + CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>, + CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>( + typeConverter, patterns.getContext()); // GLSL patterns patterns diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 6ba5bfe4..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 5d13353..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -26,13 +26,12 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index fd40e7c..fa9e544 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -36,7 +36,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "shard-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { #define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index a425eff..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } @@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } @@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); @@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, } scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8d7053c..22608a1 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -26,7 +26,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include <numeric> @@ -40,7 +40,6 @@ using llvm::divideFloorSigned; using llvm::mod; #define DEBUG_TYPE "affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" @@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineMap *map, ValueRange dims, ValueRange syms) { + LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`"; AffineMap affineMinMap = minOp.getAffineMap(); - LLVM_DEBUG({ - DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n"; - }); - // Check the value is positive. for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) { // Compare each expression in the minimum against 0. diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index cffe310..52cd0ce 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 935aa3c..b951df8 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -22,6 +22,8 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + #define DEBUG_TYPE "llvm-inliner" using namespace mlir; @@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { bool wouldBeCloned) const final { auto callOp = dyn_cast<LLVM::CallOp>(call); if (!callOp) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '" - << LLVM::CallOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: call is not an '" + << LLVM::CallOp::getOperationName() << "' op"; return false; } if (callOp.getNoInline()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n"); + LDBG() << "Cannot inline: call is marked no_inline"; return false; } auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable); if (!funcOp) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: callable is not an '" - << LLVM::LLVMFuncOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: callable is not an '" + << LLVM::LLVMFuncOp::getOperationName() << "' op"; return false; } if (funcOp.isNoInline()) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: function is marked no_inline\n"); + LDBG() << "Cannot inline: function is marked no_inline"; return false; } if (funcOp.isVarArg()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n"); + LDBG() << "Cannot inline: callable is variadic"; return false; } // TODO: Generate aliasing metadata from noalias result attributes. if (auto attrs = funcOp.getArgAttrs()) { for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) { if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": inalloca arguments not supported\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": inalloca arguments not supported"; return false; } } } // TODO: Handle exceptions. if (funcOp.getPersonality()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled function personality\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": unhandled function personality"; return false; } if (funcOp.getPassthrough()) { @@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { if (!stringAttr) return false; if (disallowedFunctionAttrs.contains(stringAttr)) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline " << funcOp.getSymName() - << ": found disallowed function attribute " - << stringAttr << "\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": found disallowed function attribute " << stringAttr; return true; } return false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 7f9ba1b..bf66ed0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { } ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); + ArrayRef<int64_t> resultShape = padOp.getResultType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { @@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { allowedUnitDims.end()); llvm::SmallDenseSet<unsigned> unitDims; SmallVector<int64_t> newShape; + SmallVector<int64_t> newResultShape; SmallVector<OpFoldResult> newLowPad; SmallVector<OpFoldResult> newHighPad; - for (const auto [dim, size, low, high] : - zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, - padOp.getMixedLowPad(), padOp.getMixedHighPad())) { + for (const auto [dim, size, outSize, low, high] : zip_equal( + llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, + resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); + newResultShape.push_back(outSize); newLowPad.push_back(low); newHighPad.push_back(high); } @@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, + auto newResultType = RankedTensorType::get( + newResultShape, padOp.getResultType().getElementType()); + auto newPadOp = rewriter.create<tensor::PadOp>( + padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2c62cb6..2e62523 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, return paddingSizes; } +/// Extracts the constant multiplier from an affine expression of the form +/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an +/// AffineConstantExpr. Returns 1 if the expression is not a simple +/// multiplication of a dimension and a constant. +static int64_t extractConstantMultiplier(AffineExpr expr) { + if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) { + if (binOp.getKind() == AffineExprKind::Mul) { + auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS()); + auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS()); + if (lhsD && rhsC) { + return rhsC.getValue(); + } + auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS()); + auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS()); + if (lhsC && rhsD) { + return lhsC.getValue(); + } + } + } + return 1; +} + /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases @@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permutation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> linalg::computePaddedShape( @@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( /*compressDims=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. + OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); - terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); - terms.push_back(paddingDimOfr); } + // Adjust for the maximum accessed index, which is (paddingSize - 1) * + // multiplier. + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); + AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); + OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( + rewriter, loc, subtractMap, {paddingDimOfr}); + terms.push_back(maxAccessIdx); + LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } @@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; - OpFoldResult paddedDimOfr = - affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms); + // Add 1 to the maximum accessed index and get the final padded size. + OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 793eec7..ea68b1a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1946,12 +1946,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); - // writeVectorSizes had to match the shapecast shape for dynamic sizes, - // otherwise the validator complains that the mask size is invalid. - SmallVector<int64_t> writeVectorSizes( - unpackOp.getDestType().hasStaticShape() - ? vectorSizes - : shapeCastOp.getResultVectorType().getShape()); Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e73bdd3..9d5dfc1 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() { getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static); } +acc::LoopParMode +acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) { + if (hasSeq(deviceType)) + return LoopParMode::loop_seq; + if (hasAuto(deviceType)) + return LoopParMode::loop_auto; + if (hasIndependent(deviceType)) + return LoopParMode::loop_independent; + if (hasSeq()) + return LoopParMode::loop_seq; + if (hasAuto()) + return LoopParMode::loop_auto; + assert(hasIndependent() && + "loop must have default auto, seq, or independent"); + return LoopParMode::loop_independent; +} + void acc::LoopOp::addGangOperands( MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes, llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 759e58b..0262a1b 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, if (parser.parseOptionalArrowTypeList(result.types)) return failure(); + if (succeeded(parser.parseOptionalKeyword("no_inline"))) + result.addAttribute("no_inline", parser.getBuilder().getUnitAttr()); + // Introduce the body region and parse it. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printOptionalArrowTypeList(getResultTypes()); - p << ' '; + if (getNoInline()) + p << "no_inline "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock()) + if (!op.getRegion().hasOneBlock() || op.getNoInline()) return failure(); replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 9bee200..fcf1526 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations( // `!spirv.struct<` (id `,`)? // `(` // (spirv-type (`[` struct-member-decoration `]`)?)* -// `)>` +// `)` +// (`,` struct-decoration)? +// `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { // TODO: This function is quite lengthy. Break it down into smaller chunks. @@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect, return Type(); } - if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + if (failed(parser.parseRParen())) + return Type(); + + SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo; + + auto parseStructDecoration = [&]() { + std::optional<spirv::Decoration> decoration = + parseAndVerify<spirv::Decoration>(dialect, parser); + if (!decoration) + return failure(); + + // Parse decoration value if it exists. + if (succeeded(parser.parseOptionalEqual())) { + Attribute decorationValue; + if (failed(parser.parseAttribute(decorationValue))) + return failure(); + + structDecorationInfo.emplace_back(decoration.value(), decorationValue); + } else { + structDecorationInfo.emplace_back(decoration.value(), + UnitAttr::get(dialect.getContext())); + } + return success(); + }; + + while (succeeded(parser.parseOptionalComma())) + if (failed(parseStructDecoration())) + return Type(); + + if (failed(parser.parseGreater())) return Type(); if (!identifier.empty()) { if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, - memberDecorationInfo))) + memberDecorationInfo, + structDecorationInfo))) return Type(); return idStructTy; } - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo, + structDecorationInfo); } // spirv-type ::= array-type @@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) { }; llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, printMember); - os << ")>"; + os << ")"; + + SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations; + type.getStructDecorations(decorations); + if (!decorations.empty()) { + os << ", "; + auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); + } + }; + llvm::interleaveComma(decorations, os, eachFn); + } + + os << ">"; } static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 46739bc..ddb3426 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -835,12 +835,14 @@ void SampledImageType::getCapabilities( /// - for literal structs: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. /// /// Identified structures only have a mutable component consisting of: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { /// Construct a storage object for an identified struct type. A struct type /// associated with such storage must call StructType::trySetBody(...) later @@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage(StringRef identifier) : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), + numStructDecorations(0), structDecorationsInfo(nullptr), identifier(identifier) {} /// Construct a storage object for a literal struct type. A struct type @@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, - StructType::MemberDecorationInfo const *memberDecorationsInfo) + StructType::MemberDecorationInfo const *memberDecorationsInfo, + unsigned numStructDecorations, + StructType::StructDecorationInfo const *structDecorationsInfo) : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), + numStructDecorations(numStructDecorations), + structDecorationsInfo(structDecorationsInfo) {} /// A storage key is divided into 2 parts: /// - for identified structs: @@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - an ArrayRef<Type> for member types; /// - an ArrayRef<StructType::OffsetInfo> for member offset info; /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration + /// info; + /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration /// info. /// /// An identified struct type is uniqued only by the first part (field 0) /// of the key. /// - /// A literal struct type is uniqued only by the second part (fields 1, 2, and - /// 3) of the key. The identifier field (field 0) must be empty. + /// A literal struct type is uniqued only by the second part (fields 1, 2, 3 + /// and 4) of the key. The identifier field (field 0) must be empty. using KeyTy = std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, - ArrayRef<StructType::MemberDecorationInfo>>; + ArrayRef<StructType::MemberDecorationInfo>, + ArrayRef<StructType::StructDecorationInfo>>; /// For identified structs, return true if the given key contains the same /// identifier. @@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), - getMemberDecorationsInfo()); + getMemberDecorationsInfo(), getStructDecorationsInfo()); } /// If the given key contains a non-empty identifier, this method constructs @@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } - return new (allocator.allocate<StructTypeStorage>()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, - numMemberDecorations, memberDecorationList); + const StructType::StructDecorationInfo *structDecorationList = nullptr; + unsigned numStructDecorations = 0; + if (!std::get<4>(key).empty()) { + auto keyStructDecorations = std::get<4>(key); + numStructDecorations = keyStructDecorations.size(); + structDecorationList = allocator.copyInto(keyStructDecorations).data(); + } + + return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage( + keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, + memberDecorationList, numStructDecorations, structDecorationList); } ArrayRef<Type> getMemberTypes() const { @@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return {}; } + ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const { + if (structDecorationsInfo) + return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo, + numStructDecorations); + return {}; + } + StringRef getIdentifier() const { return identifier; } bool isIdentified() const { return !identifier.empty(); } @@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - If called for an identified struct whose body was set before (through a /// call to this method) but with different contents from the passed /// arguments. - LogicalResult mutate( - TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, - ArrayRef<StructType::OffsetInfo> structOffsetInfo, - ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, + ArrayRef<StructType::OffsetInfo> structOffsetInfo, + ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo, + ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) { if (!isIdentified()) return failure(); if (memberTypesAndIsBodySet.getInt() && (getMemberTypes() != structMemberTypes || getOffsetInfo() != structOffsetInfo || - getMemberDecorationsInfo() != structMemberDecorationInfo)) + getMemberDecorationsInfo() != structMemberDecorationInfo || + getStructDecorationsInfo() != structDecorationInfo)) return failure(); memberTypesAndIsBodySet.setInt(true); @@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { allocator.copyInto(structMemberDecorationInfo).data(); } + if (!structDecorationInfo.empty()) { + numStructDecorations = structDecorationInfo.size(); + structDecorationsInfo = allocator.copyInto(structDecorationInfo).data(); + } + return success(); } @@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + unsigned numStructDecorations; + StructType::StructDecorationInfo const *structDecorationsInfo; StringRef identifier; }; StructType StructType::get(ArrayRef<Type> memberTypes, ArrayRef<StructType::OffsetInfo> offsetInfo, - ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { + ArrayRef<StructType::MemberDecorationInfo> memberDecorations, + ArrayRef<StructType::StructDecorationInfo> structDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. - SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( + SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations( memberDecorations); - llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); + llvm::array_pod_sort(sortedMemberDecorations.begin(), + sortedMemberDecorations.end()); + SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations( + structDecorations); + llvm::array_pod_sort(sortedStructDecorations.begin(), + sortedStructDecorations.end()); + return Base::get(memberTypes.vec().front().getContext(), /*identifier=*/StringRef(), memberTypes, offsetInfo, - sortedDecorations); + sortedMemberDecorations, sortedStructDecorations); } StructType StructType::getIdentified(MLIRContext *context, @@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context, return Base::get(context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); } StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { StructType newStructType = Base::get( context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); // Set an empty body in case this is a identified struct. if (newStructType.isIdentified() && failed(newStructType.trySetBody( ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()))) + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()))) return StructType(); return newStructType; @@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const { bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasDecoration(spirv::Decoration decoration) const { + for (StructType::StructDecorationInfo info : + getImpl()->getStructDecorationsInfo()) + if (info.decoration == decoration) + return true; + + return false; +} + uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); return getImpl()->offsetInfo[index]; @@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations( } } +void StructType::getStructDecorations( + SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations) + const { + structDecorations.clear(); + auto implDecorations = getImpl()->getStructDecorationsInfo(); + structDecorations.append(implDecorations.begin(), implDecorations.end()); +} + LogicalResult StructType::trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo, - ArrayRef<MemberDecorationInfo> memberDecorations) { - return Base::mutate(memberTypes, offsetInfo, memberDecorations); + ArrayRef<MemberDecorationInfo> memberDecorations, + ArrayRef<StructDecorationInfo> structDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations, + structDecorations); } void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value( memberDecorationInfo.decoration); } +llvm::hash_code spirv::hash_value( + const StructType::StructDecorationInfo &structDecorationInfo) { + return llvm::hash_value(structDecorationInfo.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec019..8f4c4cc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + return std::nullopt; + } + if (auto complexType = dyn_cast<ComplexType>(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); + return nullptr; +} + +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type or emulation flag is set to false. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return type; + Type srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional<spirv::StorageClass> storageClass = {}) { type = cast<VectorType>(convertIndexElementType(type, options)); + type = cast<VectorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast<TensorType>(convertIndexElementType(type, options)); + type = cast<TensorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast<IndexType>(elementType)) { type = cast<MemRefType>(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + // Hnadle 8 bit float types. + type = cast<MemRefType>(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional<Type> { if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951..a53d0a7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + // Update min version requirement for capabilities after deducing them. + for (spirv::Capability cap : deducedCapabilities) { + if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + module.emitError("Capability '") + << spirv::stringifyCapability(cap) << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + return signalPassFailure(); + } + } + } + // TODO: verify that the deduced version is consistent with // SPIR-V ops' maximal version requirements. diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index e5a3b5d..08fccfa 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -38,7 +38,6 @@ #include <utility> #define DEBUG_TYPE "shard-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::shard; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 88b0f36..9543fa1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { CheckCondition condition = CheckCondition::invalid; const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + if (failed(maybeProfDef) && failed(maybeExtDef)) + return success(); - if (!failed(maybeProfDef) && !failed(maybeExtDef) && - !maybeProfDef.value().size() && !maybeExtDef.value().size()) { + const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->empty()); + if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); os << "illegal: operation operand/result data types did not align with any " diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8789f55..a21b5ba 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { } // shape_cast(constant) -> constant - if (auto splatAttr = - llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) - return splatAttr.reshape(getType()); + if (auto denseAttr = + dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource())) + return denseAttr.reshape(getType()); // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) { + if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) return ub::PoisonAttr::get(getContext()); - } return {}; } @@ -6316,6 +6315,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { return llvm::to_vector<4>(getResultVectorType().getShape()); } +void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. @@ -7198,6 +7202,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast<VectorType>(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + +//===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index cb8e566..dedc3b3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -28,7 +28,10 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Progressive lowering of BroadcastOp. + +/// Convert a vector.broadcast with a vector operand to a lower rank +/// vector.broadcast. vector.broadcast with a scalar operand is expected to be +/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: using OpRewritePattern::OpRewritePattern; @@ -40,20 +43,23 @@ public: VectorType srcType = dyn_cast<VectorType>(op.getSourceType()); Type eltType = dstType.getElementType(); - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); - return success(); - } + // A broadcast from a scalar is considered to be in the lowered form. + if (!srcType) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar already in lowered form"); // Determine rank of source and destination. int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + // Here we are broadcasting to a rank-1 vector. Ensure that the source is a + // scalar. if (srcRank <= 1 && dstRank == 1) { - Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); + SmallVector<int64_t> fullRankPosition(srcRank, 0); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), + fullRankPosition); + assert(!isa<VectorType>(ext.getType()) && "expected scalar"); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4baeb11..2cf8f0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = vector::SplatOp::create( + Value fill = vector::BroadcastOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); res = vector::MaskedLoadOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 72352d7..cbb9d4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -303,7 +303,7 @@ public: // Extract/insert on a lower ranked extract strided slice op. Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, dstType, zero); + Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 48d680c..c707f38 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -25,12 +25,10 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-transfer-opt" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") - using namespace mlir; /// Return the ancestor op in the region or nullptr if the region is not @@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) { /// transfer_write is dead if all reads that can be reached from the potentially /// dead transfer_write are dominated by the overwriting transfer_write. void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { - LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() - << "\n"); + LDBG() << "Candidate for dead store: " << *write.getOperation(); llvm::SmallVector<Operation *, 8> blockingAccesses; Operation *firstOverwriteCandidate = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase())); @@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { !isReachable(writeAncestor, accessAncestor)) continue; if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { - LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " - << *accessAncestor << "\n"); + LDBG() << "Store may not be dead due to op: " << *accessAncestor; return; } } - LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() - << " overwritten by: " << *firstOverwriteCandidate << "\n"); + LDBG() << "Found dead store: " << *write.getOperation() + << " overwritten by: " << *firstOverwriteCandidate; opToErase.push_back(write.getOperation()); } @@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (read.hasOutOfBoundsDim()) return; - LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() - << "\n"); + LDBG() << "Candidate for Forwarding: " << *read.getOperation(); SmallVector<Operation *, 8> blockingWrites; vector::TransferWriteOp lastwrite = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase())); @@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) continue; if (!postDominators.postDominates(lastwrite, write)) { - LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " - << *write << "\n"); + LDBG() << "Fail to do write to read forwarding due to op: " << *write; return; } } - LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() - << " to: " << *read.getOperation() << "\n"); + LDBG() << "Forward value from " << *lastwrite.getOperation() + << " to: " << *read.getOperation(); read.replaceAllUsesWith(lastwrite.getVector()); opToErase.push_back(read.getOperation()); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 8de87fe..2269a40 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -939,7 +939,7 @@ public: Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, castDstType, zero); + Value res = BroadcastOp::create(rewriter, loc, castDstType, zero); SmallVector<int64_t> sliceShape = {castDstLastDim}; SmallVector<int64_t> strides = {1}; @@ -965,6 +965,45 @@ private: std::function<bool(BitCastOp)> controlFn; }; +static bool haveSameShapeAndScaling(Type t, Type u) { + auto tVec = dyn_cast<VectorType>(t); + auto uVec = dyn_cast<VectorType>(u); + if (!tVec) { + return !uVec; + } + if (!uVec) { + return false; + } + return tVec.getShape() == uVec.getShape() && + tVec.getScalableDims() == uVec.getScalableDims(); +} + +/// If `type` is shaped, clone it with `newElementType`. Otherwise, +/// return `newElementType`. +static Type cloneOrReplace(Type type, Type newElementType) { + if (auto shapedType = dyn_cast<ShapedType>(type)) { + return shapedType.clone(newElementType); + } + return newElementType; +} + +/// If `value` is the result of a splat or broadcast operation, return the input +/// of the splat/broadcast operation. +static Value getBroadcastLikeSource(Value value) { + + Operation *op = value.getDefiningOp(); + if (!op) + return {}; + + if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) + return broadcast.getSource(); + + if (auto splat = dyn_cast<vector::SplatOp>(op)) + return splat.getInput(); + + return {}; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -988,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); - if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) + auto resultType = dyn_cast<VectorType>(op->getResult(0).getType()); + if (!resultType) return failure(); if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure( op, "Op doesn't have ElementwiseMappableTraits"); if (op->getNumOperands() == 0) return failure(); - if (op->getResults()[0].getType() != op->getOperand(0).getType()) - return rewriter.notifyMatchFailure(op, - "result and operand type mismatch"); if (isa<vector::FMAOp>(op)) { return rewriter.notifyMatchFailure( op, @@ -1005,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } - // Get the type of the lhs operand - auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); - if (!lhsBcastOrSplat || - !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) + Type resultElemType = resultType.getElementType(); + + // Get the type of the first non-constant operand + Value splatSource; + for (Value operand : op->getOperands()) { + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + if (definingOp->hasTrait<OpTrait::ConstantLike>()) + continue; + splatSource = getBroadcastLikeSource(operand); + break; + } + if (!splatSource) return failure(); - auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); + Type unbroadcastResultType = + cloneOrReplace(splatSource.getType(), resultElemType); - // Make sure that all operands are broadcast from identical types: + // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { - auto bcast = val.getDefiningOp<vector::BroadcastOp>(); - if (bcast) - return (bcast.getOperand().getType() == lhsBcastOrSplatType); - auto splat = val.getDefiningOp<vector::SplatOp>(); - if (splat) - return (splat.getOperand().getType() == lhsBcastOrSplatType); - return false; + if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (auto source = getBroadcastLikeSource(val)) + return haveSameShapeAndScaling(source.getType(), + splatSource.getType()); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); })) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "not all operands are constants or broadcasts from the same type"); } // Collect the source values before broadcasting SmallVector<Value> srcValues; srcValues.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + SplatElementsAttr splatConst; + if (matchPattern(operand, m_Constant(&splatConst))) { + Attribute newConst; + Type elementType = getElementTypeOrSelf(operand.getType()); + Type newType = cloneOrReplace(unbroadcastResultType, elementType); + if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) { + newConst = splatConst.resizeSplat(newTypeShaped); + } else { + newConst = splatConst.getSplatValue<Attribute>(); + } + Operation *newConstOp = + operand.getDefiningOp()->getDialect()->materializeConstant( + rewriter, newConst, newType, operand.getLoc()); + srcValues.push_back(newConstOp->getResult(0)); + } else { + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + } } // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - lhsBcastOrSplatType, op->getAttrs()); + unbroadcastResultType, op->getAttrs()); // Replace the original Op with the elementwise Op - auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - op, vectorType, elementwiseOp->getResults()); + op, resultType, elementwiseOp->getResults()); return success(); } @@ -1239,15 +1302,17 @@ public: return rewriter.notifyMatchFailure( op, "only 1-element vectors are supported"); - Operation *splat = op.getValueToStore().getDefiningOp(); - if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) - return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + Value toStore = op.getValueToStore(); + Value source = getBroadcastLikeSource(toStore); + if (!source) + return rewriter.notifyMatchFailure( + op, "value to store is not from a broadcast"); // Checking for single use so we can remove splat. + Operation *splat = toStore.getDefiningOp(); if (!splat->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); - Value source = splat->getOperand(0); Value base = op.getBase(); ValueRange indices = op.getIndices(); @@ -1297,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o); indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deea..33450f3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref<InFlightDiagnostic()> emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(srcTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(destTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index ec8fad4..c793b71 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad29..de52fbd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" @@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, return failure(); }); if (failed(verify(op))) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << ": '" << op->getName() - << "' failed to verify and will be printed in generic form\n"); + LDBG() << op->getName() + << "' failed to verify and will be printed in generic form"; printerFlags.printGenericOpForm(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e9b5e92..310680b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -17,14 +17,32 @@ using namespace mlir; +static std::pair<int64_t, int64_t> +getLineAndColStart(const llvm::SourceMgr &sourceMgr) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + if (lastFileID == 1) + return {0, 0}; + + auto bufferID = sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + // Exclude same start. + if (main->getBufferStart() < last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + return sourceMgr.getLineAndColumn( + llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID); + } + return {0, 0}; +} + LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(*sourceBuf, block, config); @@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, const auto *sourceBuf = sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(*sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(sourceMgr, block, config); diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index af22a7f..9ea5c683 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRROCDLToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation + MLIRXeVMToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index f030fa7..86c731a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) +add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index ff34a08..0f675a0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands, return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); } -static LogicalResult -convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, - ArrayAttr resAttrsArray, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - if (argAttrsArray) { - for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr); - !argAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, argAttrs); - if (failed(attrBuilder)) - return failure(); - call->addParamAttrs(argIdx, *attrBuilder); - } - } - } - - if (resAttrsArray && resAttrsArray.size() > 0) { - if (resAttrsArray.size() != 1) - return mlir::emitError(loc, "llvm.func cannot have multiple results"); - if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); - !resAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, resAttrs); - if (failed(attrBuilder)) - return failure(); - call->addRetAttrs(*attrBuilder); - } - } - return success(); -} - -static LogicalResult -convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - return convertParameterAndResultAttrs( - callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call, - moduleTranslation); -} - /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, @@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), moduleTranslation)); - if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(), - op.getResAttrsAttr(), inst, - moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst))) return failure(); if (op.getNumResults() == 1) @@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getInlineHintAttr()) call->addFnAttr(llvm::Attribute::InlineHint); - if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call))) return failure(); if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { @@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); - if (failed( - convertParameterAndResultAttrs(invOp, result, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result))) return failure(); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index 1c9e226..55e73e8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" +#include "llvm/IR/ConstantRange.h" using namespace mlir; using namespace mlir::NVVM; diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt new file mode 100644 index 0000000..6308d7e --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + XeVMToLLVMIRTranslation.cpp +) + +add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation + XeVMToLLVMIRTranslation.cpp + + DEPENDS + MLIRXeVMConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRLLVMDialect + MLIRXeVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..73b166d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -0,0 +1,103 @@ +//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR XeVM dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the XeVM dialect to LLVM IR. +class XeVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(instructions.front(), + cacheControlsArray.getValue()); + } + auto func = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!func) + return failure(); + + return success(); + } + +private: + static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst, + ArrayRef<Attribute> attrs) { + SmallVector<llvm::Metadata *> decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue(); + std::array<llvm::Metadata *, 4> metadata; + llvm::transform( + valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + i32Ty, cast<IntegerAttr>(valueAttr).getValue())); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } +}; +} // namespace + +void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry ®istry) { + registry.insert<xevm::XeVMDialect>(); + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) { + DialectRegistry registry; + registerXeVMDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp index 580afdd..cb1f234 100644 --- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp @@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs))) + llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false, + /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands, + mlirAttrs))) return failure(); Type resultType = moduleImport.convertType(inst->getType()); @@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( ValueRange{mlirOperands}, FastmathFlagsAttr{}); moduleImport.setFastmathFlagsAttr(inst, op); - - ArrayAttr argsAttr, resAttr; - moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder); - op.setArgAttrsAttr(argsAttr); - op.setResAttrsAttr(resAttr); + moduleImport.convertArgAndResultAttrs(inst, op); // Update importer tracking of results. unsigned numRes = op.getNumResults(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 58e3c44..6325480 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" @@ -1063,6 +1064,18 @@ void ModuleImport::convertTargetTriple() { builder.getStringAttr(llvmModule->getTargetTriple().str())); } +void ModuleImport::convertModuleLevelAsm() { + llvm::StringRef asmStr = llvmModule->getModuleInlineAsm(); + llvm::SmallVector<mlir::Attribute> asmArrayAttr; + + for (llvm::StringRef line : llvm::split(asmStr, '\n')) + if (!line.empty()) + asmArrayAttr.push_back(builder.getStringAttr(line)); + + mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(), + builder.getArrayAttr(asmArrayAttr)); +} + LogicalResult ModuleImport::convertFunctions() { for (llvm::Function &func : llvmModule->functions()) if (failed(processFunction(&func))) @@ -2267,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // call. if (!isIncompatibleCall) - convertParameterAttributes(callInst, callOp, builder); + convertArgAndResultAttrs(callInst, callOp); return callOp.getOperation(); }(); @@ -2364,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // invoke. if (!isIncompatibleInvoke) - convertParameterAttributes(invokeInst, invokeOp, builder); + convertArgAndResultAttrs(invokeInst, invokeOp); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); @@ -2730,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, } DictionaryAttr -ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder) { +ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) { SmallVector<NamedAttribute> paramAttrs; for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { - auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind); + auto llvmAttr = llvmAttrSet.getAttribute(llvmKind); // Skip attributes that are not attached. if (!llvmAttr.isValid()) continue; @@ -2769,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, return builder.getDictionaryAttr(paramAttrs); } -void ModuleImport::convertParameterAttributes(llvm::Function *func, - LLVMFuncOp funcOp, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs(llvm::Function *func, + LLVMFuncOp funcOp) { auto llvmAttrs = func->getAttributes(); for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) { llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i); - funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder)); + funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs)); } // Convert the result attributes and attach them wrapped in an ArrayAttribute // to the funcOp. @@ -2783,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, if (!llvmResAttr.hasAttributes()) return; funcOp.setResAttrsAttr( - builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); + builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)})); } -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - ArrayAttr &argsAttr, - ArrayAttr &resAttr, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs( + llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions) { + // Compute the set of immediate argument positions. + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + // Convert the argument attributes and filter out immediate arguments. llvm::AttributeList llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + // Skip immediate arguments. + if (immArgPositionsSet.contains(i)) + continue; llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); if (llvmArgAttrsSet.back().hasAttributes()) anyArgAttrs = true; @@ -2807,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, if (anyArgAttrs) { SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) - argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - argsAttr = getArrayAttr(argAttrs); + argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs)); + attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs)); } + // Convert the result attributes. llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); if (!llvmResAttr.hasAttributes()) return; - DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder); - resAttr = getArrayAttr({resAttrs}); -} - -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - CallOpInterface callOp, - OpBuilder &builder) { - ArrayAttr argsAttr, resAttr; - convertParameterAttributes(call, argsAttr, resAttr, builder); - callOp.setArgAttrsAttr(argsAttr); - callOp.setResAttrsAttr(resAttr); + DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr); + attrsOp.setResAttrsAttr(getArrayAttr({resAttrs})); } template <typename Op> @@ -2892,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { builder, loc, func->getName(), functionType, convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); - convertParameterAttributes(func, funcOp, builder); + convertArgAndResultAttrs(func, funcOp); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func)) funcOp.setPersonalityAttr(personality); @@ -3199,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule( if (failed(moduleImport.convertIFuncs())) return {}; moduleImport.convertTargetTriple(); + moduleImport.convertModuleLevelAsm(); return module; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b997e55..b3a06e2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +LogicalResult ModuleTranslation::convertArgAndResultAttrs( + ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions) { + // Convert the argument attributes. + if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) { + unsigned argAttrIdx = 0; + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) { + if (argAttrIdx >= argAttrsArray.size()) + break; + // Skip immediate arguments (they have no entries in argAttrsArray). + if (immArgPositionsSet.contains(argIdx)) + continue; + // Skip empty argument attributes. + auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]); + if (argAttrs.empty()) + continue; + // Convert and add attributes to the call instruction. + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + // Convert the result attributes. + if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) { + if (!resAttrsArray.empty()) { + auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + } + + return success(); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(Location loc, DictionaryAttr paramAttrs) { @@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, llvmModule->setTargetTriple( llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue())); + if (auto asmAttr = m->getDiscardableAttr( + LLVM::LLVMDialect::getModuleLevelAsmAttrName())) { + auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr); + if (!asmArrayAttr) { + m->emitError("expected an array attribute for a module level asm"); + return nullptr; + } + + for (Attribute elt : asmArrayAttr) { + auto asmStrAttr = dyn_cast<StringAttr>(elt); + if (!asmStrAttr) { + m->emitError( + "expected a string attribute for each entry of a module level asm"); + return nullptr; + } + llvmModule->appendModuleInlineAsm(asmStrAttr.getValue()); + } + } + return llvmModule; } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index e5934bb..88931b5 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; } - // Block decoration does not affect spirv.struct type, but is still stored - // for verification. - // TODO: Update StructType to contain this information since - // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: @@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, - deferredStructIt->memberDecorationsInfo))) + deferredStructIt->memberDecorationsInfo, + deferredStructIt->structDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); @@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { } } + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; + if (decorations.count(operands[0])) { + NamedAttrList &allDecorations = decorations[operands[0]]; + for (NamedAttribute &decorationAttr : allDecorations) { + std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration( + llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true)); + assert(decoration.has_value()); + structDecorationsInfo.emplace_back(decoration.value(), + decorationAttr.getValue()); + } + } + uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); - typeMap[structID] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + typeMap[structID] = spirv::StructType::get( + memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) - deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, - memberTypes, offsetInfo, - memberDecorationsInfo}); + deferredStructTypesInfos.push_back( + {structTy, unresolvedMemberTypes, memberTypes, offsetInfo, + memberDecorationsInfo, structDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, - memberDecorationsInfo))) + memberDecorationsInfo, + structDecorationsInfo))) return failure(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 20482bd..db1cc3f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -95,6 +95,7 @@ struct DeferredStructTypeInfo { SmallVector<Type, 4> memberTypes; SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; }; /// A struct that collects the info needed to materialize/emit a diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index a8a2b2e..737f296 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -318,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: case spirv::Decoration::Constant: + case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -630,11 +631,16 @@ LogicalResult Serializer::prepareBasicType( operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + // TODO: Now struct decorations are supported this code may not be + // necessary. However, it is left to support backwards compatibility. + // Ideally, Block decorations should be inserted when converting to SPIR-V. if (isInterfaceStructPtrType(ptrType)) { - if (failed(emitDecoration(getTypeID(pointeeStruct), - spirv::Decoration::Block))) - return emitError(loc, "cannot decorate ") - << pointeeStruct << " with Block decoration"; + auto structType = cast<spirv::StructType>(ptrType.getPointeeType()); + if (!structType.hasDecoration(spirv::Decoration::Block)) + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; } return success(); @@ -704,6 +710,20 @@ LogicalResult Serializer::prepareBasicType( } } + SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations; + structType.getStructDecorations(structDecorations); + + for (spirv::StructType::StructDecorationInfo &structDecoration : + structDecorations) { + if (failed(processDecorationAttr(loc, resultID, + structDecoration.decoration, + structDecoration.decorationValue))) { + return emitError(loc, "cannot decorate struct ") + << structType << " with " + << stringifyDecoration(structDecoration.decoration); + } + } + typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) @@ -938,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } + } else if (isa<spirv::TensorArmType>(constType)) { + numberOfConstituents = shapedType.getNumElements(); + operands.reserve(numberOfConstituents + 2); + for (int i = 0; i < numberOfConstituents; ++i) { + uint32_t elementID = 0; + if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { + elementID = + elementType.isInteger(1) + ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) + : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); + } + if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { + elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); + } + if (!elementID) { + return 0; + } + operands.push_back(elementID); + } } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 1abe0fd..6e2352e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -559,6 +559,23 @@ func.func @constant() { return } +// CHECK-LABEL: @constant_8bit_float +func.func @constant_8bit_float() { + // CHECK: spirv.Constant 56 : i8 + %cst = arith.constant 1.0 : f8E4M3 + // CHECK: spirv.Constant 56 : i8 + %cst_i8 = arith.bitcast %cst : f8E4M3 to i8 + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8> + return +} + // CHECK-LABEL: @constant_16bit func.func @constant_16bit() { // CHECK: spirv.Constant 4 : i16 diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index bae7c59..ae59f28 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -2,8 +2,26 @@ // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32 +// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64> // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32> // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64> //CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { @@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { return %rf, %rd : f32, f64 } +//CHECK-LABEL: @angle_caller +func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { + // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}}) + %af = complex.angle %f : complex<f32> + // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}}) + %ad = complex.angle %d : complex<f64> + // CHECK: return %[[AF]], %[[AD]] + return %af, %ad : f32, f64 +} + +//CHECK-LABEL: @cos_caller +func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}}) + %cf = complex.cos %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}}) + %cd = complex.cos %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf, %cd : complex<f32>, complex<f64> +} + //CHECK-LABEL: @exp_caller func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) @@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp // CHECK: return %[[EF]], %[[ED]] return %ef, %ed : complex<f32>, complex<f64> } + +//CHECK-LABEL: @log_caller +func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}}) + %lf = complex.log %f : complex<f32> + // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}}) + %ld = complex.log %d : complex<f64> + // CHECK: return %[[LF]], %[[LD]] + return %lf, %ld : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @conj_caller +func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}}) + %cf2 = complex.conj %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}}) + %cd2 = complex.conj %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf2, %cd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @pow_caller +func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}}) + %pf = complex.pow %f, %f : complex<f32> + // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}}) + %pd = complex.pow %d, %d : complex<f64> + // CHECK: return %[[PF]], %[[PD]] + return %pf, %pd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sin_caller +func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) + %sf2 = complex.sin %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}}) + %sd2 = complex.sin %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf2, %sd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sqrt_caller +func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}}) + %sf = complex.sqrt %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}}) + %sd = complex.sqrt %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf, %sd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tan_caller +func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}}) + %tf2 = complex.tan %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}}) + %td2 = complex.tan %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf2, %td2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tanh_caller +func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}}) + %tf = complex.tanh %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}}) + %td = complex.tanh %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf, %td : complex<f32>, complex<f64> +} diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 1737f4a..0c77c88 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \ +// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT //===----------------------------------------------------------------------===// // Integer types @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } } // end module + + +// ----- + +// Check that 8-bit float types are emulated as i8. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>> +} { + + // CHECK: spirv.func @float8_to_integer8 + // CHECK-SAME: (%arg0: i8 + // CHECK-SAME: %arg1: i8 + // CHECK-SAME: %arg2: i8 + // CHECK-SAME: %arg3: i8 + // CHECK-SAME: %arg4: i8 + // CHECK-SAME: %arg5: i8 + // CHECK-SAME: %arg6: i8 + // CHECK-SAME: %arg7: i8 + // CHECK-SAME: %arg8: vector<4xi8> + // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer> + // CHECK-SAME: %arg10: !spirv.array<4 x i8> + // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8 + // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2 + // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3 + // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN + // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4 + // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU + // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ> + // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>> + // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2> + // UNSUPPORTED_FLOAT-SAME: ) { + + func.func @float8_to_integer8( + %arg0: f8E5M2, // CHECK-NOT: f8E5M2 + %arg1: f8E4M3, // CHECK-NOT: f8E4M3 + %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN + %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ + %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ + %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ + %arg6: f8E3M4, // CHECK-NOT: f8E3M4 + %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU + %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ> + %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref + %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor + ) { + // CHECK: spirv.Return + return + } +} diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir new file mode 100644 index 0000000..3e5f592 --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>> +} { + + // CHECK-LABEL: @fpclassify + func.func @fpclassify(%x: f32, %v: vector<4xf32>) { + // CHECK: spirv.IsFinite %{{.*}} : f32 + %0 = math.isfinite %x : f32 + // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32> + %1 = math.isfinite %v : vector<4xf32> + + // CHECK: spirv.IsNan %{{.*}} : f32 + %2 = math.isnan %x : f32 + // CHECK: spirv.IsNan %{{.*}} : vector<4xf32> + %3 = math.isnan %v : vector<4xf32> + + // CHECK: spirv.IsInf %{{.*}} : f32 + %4 = math.isinf %x : f32 + // CHECK: spirv.IsInf %{{.*}} : vector<4xf32> + %5 = math.isinf %v : vector<4xf32> + + return + } + +} diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir new file mode 100644 index 0000000..1a74eaa --- /dev/null +++ b/mlir/test/Dialect/Async/canonicalize.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-NOT: async.execute + +func.func @empty_execute() { + %token = async.execute { + async.yield + } + return +} diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index a00c798..5f42938 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- +func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<1x?xf32> to tensor<1x16xf32> + return %padded : tensor<1x16xf32> +} +// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<?xf32> to tensor<16xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x16xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module { + func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> + %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<?x1x61x1xf32> + return %1 : tensor<?x1x61x1xf32> + } +} // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32> // CHECK: } -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> -module { - func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> - %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor<?x1x61x1xf32> - return %1 : tensor<?x1x61x1xf32> - } -} - // ----- func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 78619b6..981f5dc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -52,22 +52,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0] // CHECK: : tensor<7x5xf32> to tensor<9x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] { - // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -83,7 +83,7 @@ module { // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: pad_conv +func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> + +// CHECK-LABEL: pad_conv_dynamic +func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> { + + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32> + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12] + // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0] + // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> + return %0 : tensor<1x14x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_strided +func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12] + // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_dilated +func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 26c03ed..f741876 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -69,22 +69,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0] // CHECK: : tensor<7x5xf32> to tensor<8x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] { - // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -102,7 +102,7 @@ module { // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -127,13 +127,13 @@ module { // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32> // CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]] // CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] { - // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32> // // CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32> // CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]] - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) { - // CHECK: } -> tensor<8x14x13xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32> + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) { + // CHECK: } -> tensor<8x14x12xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32> // %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index c3ee892..d7722ea 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[PV:.*]] = ub.poison : i32 -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> // CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> -// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex> +// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> -// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> -// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // ----- @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) // CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( // CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { -// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> -// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> -// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // CHECK: return %[[RES]] : tensor<8x1xf32> @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16 // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex> // CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex> // CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_14]] : tensor<1x4xf32> @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex> -// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_10]] : tensor<1x4xf32> // CHECK: } @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> // CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> -// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> -// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index +// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex> // CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 12d30e17..308cf150 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() { // ----- -// CHECK-LABEL: func @execute_region_elim -func.func @execute_region_elim() { +// CHECK-LABEL: func @execute_region_inline +func.func @execute_region_inline() { affine.for %i = 0 to 100 { "test.foo"() : () -> () %v = scf.execute_region -> i64 { @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim -func.func @func_execute_region_elim() { +// CHECK-LABEL: func @execute_region_no_inline +func.func @execute_region_no_inline() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 no_inline { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: scf.execute_region +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: scf.yield %[[VAL]] : i64 +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_inline +func.func @func_execute_region_inline() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim_multi_yield -func.func @func_execute_region_elim_multi_yield() { +// CHECK-LABEL: func @func_execute_region_inline_multi_yield +func.func @func_execute_region_inline_multi_yield() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d6c3464..58b8288 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto // ----- //===----------------------------------------------------------------------===// +// spirv.IsFinite +//===----------------------------------------------------------------------===// + +func.func @isfinite_scalar(%arg0: f32) -> i1 { + // CHECK: spirv.IsFinite {{.*}} : f32 + %0 = spirv.IsFinite %arg0 : f32 + return %0 : i1 +} + +func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spirv.IsFinite {{.*}} : vector<2xf32> + %0 = spirv.IsFinite %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.IsInf //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 5d05a654..6d321af 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve // CHECK: func private @struct_empty(!spirv.struct<()>) func.func private @struct_empty(!spirv.struct<()>) +// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) +func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) + +// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) +func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) + // ----- // expected-error @+1 {{offset specification must be given for all members}} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b23766..8d7f3da 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes { // Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled // implicitly by v1.5. -// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> +// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< #spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index b90d6f5..3bccb32 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- + +func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { + // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> + return %0 : tensor<1x12x11xf32> +} + +// ----- + +func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) { + // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}} + %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) + return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index cbe0056..bf9ed8a 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens // ----- -func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { +func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> { // expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> - return %0 : tensor<1x1x1x1x13x21x3xf32> + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> + return %0 : tensor<1x1x1x1x13x21x3xi32> } // ----- diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 9cfebd5..56996b5 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { // ----- -// CHECK-LABEL: shape_cast_constant +// CHECK-LABEL: shape_cast_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32> // CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> -func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { +func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32> %cst_1 = arith.constant dense<1> : vector<12x2xi32> %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> @@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// Test of shape_cast's fold method: +// shape_cast(constant) -> constant. +// +// CHECK-LABEL: @shape_cast_dense_int_constant +// CHECK: %[[CST:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]> +// CHECK: return %[[CST]] : vector<2x3xi8> +func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> { + %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8> + %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + +// Test of shape_cast fold's method: +// (shape_cast(const_x), const_x) -> (const_x_folded, const_x) +// +// CHECK-LABEL: @shape_cast_dense_float_constant +// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32> +// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32> +// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32> +func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){ + %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32> + %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32> + return %0, %cst : vector<2xf32>, vector<1x2xf32> +} + +// ----- + // CHECK-LABEL: shape_cast_poison // CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 2563b48..b2f16bb 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> { func.return %2 : vector<4x4xindex> } +// CHECK-LABEL: func @vector_transpose +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index} +func.func @vector_transpose() -> vector<2x4xindex> { + %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex> + %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %2 = test.reflect_bounds %1 : vector<2x4xindex> + func.return %2 : vector<2x4xindex> +} + // CHECK-LABEL: func @vector_extract // CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} func.func @vector_extract() -> index { @@ -99,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 8e167a5..d5e3443 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3 // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955..5a8125e 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -5,11 +5,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>, // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32> // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index b826cdc..ef881ba 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> { // ----- -// The source and the result for arith.cmp have different types - not supported - -// CHECK-LABEL: func.func @negative_source_and_result_mismatch -// CHECK: %[[BROADCAST:.+]] = vector.broadcast -// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] -// CHECK: return %[[RETURN]] -func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { +// The source and the result for arith.cmp have different types + +// CHECK-LABEL: func.func @source_and_result_mismatch( +// CHECK-SAME: %[[ARG0:.+]]: f32) +// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]] +// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1> +// CHECK: return %[[BROADCAST]] +func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> { %0 = vector.broadcast %arg0 : f32 to vector<1xf32> %1 = arith.cmpf uno, %0, %0 : vector<1xf32> return %1 : vector<1xi1> @@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { return %1 : vector<1xf32> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.subi %cst, %0 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> + %2 = arith.mulf %0, %cst : vector<3x4xf32> + return %2 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> + +func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16> + %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16> + %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32> + return %1 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32 +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32> + %cst = arith.constant dense<3> : vector<1x4xi32> + %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32> + return %2 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32> +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<3> : vector<3x4xi32> + %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32> + return %2 : vector<3x4xf32> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderCastOpsOnBroadcast] // diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfe..dff3ffa 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) { } // ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + +// ----- +func.func @load_gather_offset_sg(%src: memref<?xf16>) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : memref<?xf16>, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @load_gather_offset_wi_2(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> + return +} + +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} + +// ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> %1 = arith.constant dense<2.9>: vector<4x2xf32> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969a..6be2371 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) { gpu.return } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex> + gpu.return +} // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index d67bdb4..628a485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -2,122 +2,117 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.load_nd %{{.*}} - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-12: -> vector<2x2xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.load_nd %{{.*}} + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME-COUNT-4: -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT : xegpu.store_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @update_nd(%src: memref<24x32xf32>){ - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16] - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @update_nd(%src: memref<256x128xf32>){ + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>> // CHECK-NOT: xegpu.update_nd_offset %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas - // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>) - gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) + gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32> + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16> + -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16> + -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x256xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} - // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} + // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: broadcast - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32> + gpu.func @broadcast(%src: memref<128x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32> + -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK-COUNT-3: vector.broadcast {{.*}} - // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32> + : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<128x1xf32> + // CHECK-COUNT-2: vector.broadcast {{.*}} + // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<128x1xf32> to vector<128x64xf32> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index d511224..d4b0037 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -4,201 +4,181 @@ //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_1_1_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id - // CHECK: %[[C12:.*]] = arith.constant 12 : index - // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C32_0:.*]] = arith.constant 32 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]] // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]] - // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]] - // CHECK: %[[C24:.*]] = arith.constant 24 : index - // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]] + // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]] + // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]] - // CHECK: %[[C32:.*]] = arith.constant 32 : index - // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]] - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]] - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK: %[[C256:.*]] = arith.constant 256 : index + // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]] + // CHECK: %[[C0_2:.*]] = arith.constant 0 : index + // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]] + // CHECK: %[[C0_3:.*]] = arith.constant 0 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]] + // CHECK: %[[C0_4:.*]] = arith.constant 0 : index + // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]] + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: gpu.return - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] - // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -gpu.func @update_nd(%src: memref<24x32xf32>){ - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> +// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> +gpu.func @update_nd(%src: memref<256x128xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: dpas_no_sg_data -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: xegpu.prefetch_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas_with_no_create_nd_desc - gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) { - // CHECK-NOT: vector<12x12xf32> + gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) { + // CHECK-NOT: vector<32x32xf32> %dpas = xegpu.dpas %a, %b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: broadcast_dim1 - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast_dim1(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32> + gpu.func @broadcast_dim1(%src: memref<256x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> - %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<256x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<256x1xf32> to vector<256x32xf32> gpu.return } // CHECK-LABEL: broadcast_dim0 - // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32> - gpu.func @broadcast_dim0(%src: memref<1x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32> - -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32> + gpu.func @broadcast_dim0(%src: memref<1x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> - -> vector<1x32xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>} - // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32> + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<1x128xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32> %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>} - : vector<1x32xf32> to vector<12x32xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<1x128xf32> to vector<32x128xf32> gpu.return } diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir index e0adb4d82..5389691 100644 --- a/mlir/test/IR/top-level.mlir +++ b/mlir/test/IR/top-level.mlir @@ -6,10 +6,10 @@ func.func private @foo() // ----- -// expected-error@-9 {{source must contain a single top-level operation, found: 2}} +// expected-error@-2 {{source must contain a single top-level operation, found: 2}} func.func private @bar() func.func private @baz() // ----- -// expected-error@-15 {{source must contain a single top-level operation, found: 0}} +// expected-error@-2 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 24380b5..a419d75 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -570,10 +570,10 @@ define void @trap_intrinsics() { ; CHECK-LABEL: llvm.func @memcpy_test define void @memcpy_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () - call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false) ret void @@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) { ; CHECK-LABEL: llvm.func @memmove_test define void @memmove_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false) ret void } ; CHECK-LABEL: llvm.func @memset_test define void @memset_test(i32 %0, ptr %1, i8 %2) { - ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () - call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () + call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> () call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false) ret void diff --git a/mlir/test/Target/LLVMIR/Import/module-asm.ll b/mlir/test/Target/LLVMIR/Import/module-asm.ll new file mode 100644 index 0000000..38f6ea4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/module-asm.ll @@ -0,0 +1,5 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s +; CHECK: llvm.module_asm = ["foo", "bar"] + +module asm "foo" +module asm "bar" diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir index 7fd5f26..5ed6244 100644 --- a/mlir/test/Target/LLVMIR/invalid-module.mlir +++ b/mlir/test/Target/LLVMIR/invalid-module.mlir @@ -1,6 +1,16 @@ -// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s +// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module -split-input-file %s // expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}} llvm.func @foo() { llvm.return } + +// ----- + +// expected-error@below {{expected an array attribute for a module level asm}} +module attributes {llvm.module_asm = "foo"} {} + +// ----- + +// expected-error@below {{expected a string attribute for each entry of a module level asm}} +module attributes {llvm.module_asm = [42]} {} diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 44074ce..eb3510c 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() { // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true - "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true + "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () // CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + + // Verify that trailing empty argument attribute dictionaries can be omitted. + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memmove_test llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memset_test llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) { %i1 = llvm.mlir.constant(false) : i1 - // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true - "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () + // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true + "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () // CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> () llvm.return diff --git a/mlir/test/Target/LLVMIR/module-asm.mlir b/mlir/test/Target/LLVMIR/module-asm.mlir new file mode 100644 index 0000000..2afb37c --- /dev/null +++ b/mlir/test/Target/LLVMIR/module-asm.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.module_asm = ["foo", "bar"]} {} + +// CHECK: module asm "foo" +// CHECK: module asm "bar" diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir new file mode 100644 index 0000000..a3dd0b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) + llvm.func @prefetch(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(1 : i64) : i64 + // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0) + {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, + no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64, + xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} + : (!llvm.ptr<1>, i64) -> () + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 6aca11e..1695d2a 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -307,6 +307,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> } + // CHECK-LABEL: @arm_tensor_of_i32 + spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_i32 + spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @arm_tensor_of_f32 + spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_f32 + spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + spirv.EntryPoint "GLCompute" @bool_const } diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index b200871..05cbddc 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %15 = spirv.IsNan %arg0 : f32 // CHECK: spirv.IsInf %16 = spirv.IsInf %arg1 : f32 + // CHECK: spirv.IsFinite + %17 = spirv.IsFinite %arg0 : f32 spirv.Return } } diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir index 6b50c39..786d07a2 100644 --- a/mlir/test/Target/SPIRV/memory-ops.mlir +++ b/mlir/test/Target/SPIRV/memory-ops.mlir @@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // ----- spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { - spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : f32 spirv.Return } - spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : i32 spirv.Return } diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir index 0db0c0b..4984ee7 100644 --- a/mlir/test/Target/SPIRV/struct.mlir +++ b/mlir/test/Target/SPIRV/struct.mlir @@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> - spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> + spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> - spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> + spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> - spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> + spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> + spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> // CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer> spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer> @@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> + // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> + spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> - spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> + spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>, // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output> diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir index b9044fe..8889b80 100644 --- a/mlir/test/Target/SPIRV/undef.mlir +++ b/mlir/test/Target/SPIRV/undef.mlir @@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>> %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>> %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>> - // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> - %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> + // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> + %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> %8 = spirv.Constant 0 : i32 - %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> + %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Return } } diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td index 7cd24aa..af09ee7 100644 --- a/mlir/test/mlir-tblgen/op-properties-predicates.td +++ b/mlir/test/mlir-tblgen/op-properties-predicates.td @@ -70,6 +70,12 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> { // CHECK-NEXT: if (!(((!prop.has_value())) || ((::llvm::all_of((*(prop)), [](const int64_t& baseStore) -> bool { return [](int64_t baseIface) -> bool { return ((baseIface >= 0)); }(baseStore); })) && (!(((*(prop)).empty())))))) // CHECK: failed to satisfy constraint: optional non-empty array of non-negative int64_ +// CHECK-LABEL: ::llvm::LogicalResult OpWithPredicatesAdaptor::verify +// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below +// CHECK: int64_t tblgen_scalar = this->getScalar(); +// CHECK: if (!((tblgen_scalar >= 0))) +// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t"); + // CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl() // Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused // CHECK: [[maybe_unused:\[\[maybe_unused\]\]]] int64_t tblgen_scalar = this->getScalar(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index f35cfa6..8ea4eb7 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1127,7 +1127,7 @@ static void genPropertyVerifier( body << formatv(fetchProperty, varName, getterName, prop.prop.getInterfaceType()); auto uniquedFn = staticVerifierEmitter.getPropConstraintFn(prop.prop); - if (uniquedFn.has_value()) + if (uniquedFn.has_value() && emitHelper.isEmittingForOp()) body << formatv(verifyPropertyUniqued, *uniquedFn, varName, prop.name); else body << formatv( @@ -4764,6 +4764,7 @@ void OpOperandAdaptorEmitter::addVerification() { FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); + genPropertyVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter, useProperties); |