aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/DefiningDialects/AttributesAndTypes.md2
-rw-r--r--mlir/include/mlir/Conversion/Passes.td22
-rw-r--r--mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td1
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td1
-rw-r--r--mlir/include/mlir/Dialect/Async/IR/AsyncOps.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td63
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td109
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td18
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td1
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h7
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td70
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCFOps.td15
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td4
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td26
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h40
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h4
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td6
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td154
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td1
-rw-r--r--mlir/include/mlir/Interfaces/CallInterfaces.td32
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/All.h3
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h31
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleImport.h31
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h23
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp9
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp7
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp11
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp52
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp1
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp61
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp8
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp1
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp51
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp17
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp8
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp59
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp118
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp54
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp15
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp7
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp31
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp125
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp103
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp9
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/Parser/Parser.cpp30
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp66
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp61
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp32
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h1
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp47
-rw-r--r--mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir17
-rw-r--r--mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir108
-rw-r--r--mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir54
-rw-r--r--mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir27
-rw-r--r--mlir/test/Dialect/Async/canonicalize.mlir10
-rw-r--r--mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir55
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir149
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir24
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir51
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir34
-rw-r--r--mlir/test/Dialect/SPIRV/IR/logical-ops.mlir18
-rw-r--r--mlir/test/Dialect/SPIRV/IR/types.mlir6
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir16
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir6
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir34
-rw-r--r--mlir/test/Dialect/Vector/int-range-interface.mlir17
-rw-r--r--mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir16
-rw-r--r--mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir24
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/invalid.mlir68
-rw-r--r--mlir/test/Dialect/XeGPU/ops.mlir29
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir242
-rw-r--r--mlir/test/IR/top-level.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll20
-rw-r--r--mlir/test/Target/LLVMIR/Import/module-asm.ll5
-rw-r--r--mlir/test/Target/LLVMIR/invalid-module.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir24
-rw-r--r--mlir/test/Target/LLVMIR/module-asm.mlir6
-rw-r--r--mlir/test/Target/LLVMIR/xevm.mlir21
-rw-r--r--mlir/test/Target/SPIRV/constant.mlir28
-rw-r--r--mlir/test/Target/SPIRV/logical-ops.mlir2
-rw-r--r--mlir/test/Target/SPIRV/memory-ops.mlir20
-rw-r--r--mlir/test/Target/SPIRV/struct.mlir38
-rw-r--r--mlir/test/Target/SPIRV/undef.mlir6
-rw-r--r--mlir/test/mlir-tblgen/op-properties-predicates.td6
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp3
113 files changed, 2695 insertions, 836 deletions
diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md
index 022bdad..b991863 100644
--- a/mlir/docs/DefiningDialects/AttributesAndTypes.md
+++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md
@@ -136,7 +136,7 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
/// Here we've defined two parameters, one is a "self" type parameter, and the
/// other is the integer value of the attribute. The self type parameter is
/// specially handled by the assembly format.
- let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
/// Here we've defined a custom builder for the type, that removes the need to pass
/// in an MLIRContext instance; as it can be infered from the `type`.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf7596c..6e1baaf 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
@@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
- " the target">
+ " the target">,
+ Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+ "bool", /*default=*/"true",
+ "Emulate unsupported float types by representing them with integer "
+ "types of same bit width">
];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32..06fb851 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
+ /*bit requiresArgAndResultAttrs=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 8988df6..d055bb4 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -92,6 +92,7 @@ class ArmSVE_IntrOp<string mnemonic,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
+ /*bit requiresArgAndResultAttrs=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index a8455c2..b52f136 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -38,7 +38,8 @@ def Async_ExecuteOp :
["getEntrySuccessorOperands",
"areTypesCompatible"]>,
AttrSizedOperandSegments,
- AutomaticAllocationScope]> {
+ AutomaticAllocationScope,
+ RecursiveMemoryEffects]> {
let summary = "Asynchronous execute operation";
let description = [{
The `body` region attached to the `async.execute` operation semantically
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index b5ea8fc..107bf3e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -83,6 +83,9 @@ def LLVM_Dialect : Dialect {
return "llvm.emit_c_interface";
}
+ /// Name of the module level assembly attribute.
+ static StringRef getModuleLevelAsmAttrName() { return "llvm.module_asm"; }
+
/// Name of the dependent libraries attribute.
static StringRef getDependentLibrariesAttrName() {
return "llvm.dependent_libraries";
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 8c6f1ee..d38298f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -140,8 +140,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
/*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3],
- /*immArgAttrNames=*/["rw", "hint", "cache"]
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
> {
let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
}
@@ -200,13 +200,13 @@ class LLVM_MemcpyIntrOpBase<string name> :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
- /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
AnySignlessInteger:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$len,
"bool":$isVolatile), [{
@@ -217,7 +217,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, src, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -231,13 +232,13 @@ def LLVM_MemcpyInlineOp :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
- /*immArgAttrNames=*/["len", "isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
APIntAttr:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$src, "IntegerAttr":$len,
"bool":$isVolatile), [{
@@ -248,7 +249,8 @@ def LLVM_MemcpyInlineOp :
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, src, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -258,12 +260,12 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
- /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$val, "Value":$len,
"bool":$isVolatile), [{
@@ -274,7 +276,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, val, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -284,12 +287,12 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
- /*immArgAttrNames=*/["len", "isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
"bool":$isVolatile), [{
@@ -300,7 +303,8 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, val, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -349,8 +353,8 @@ def LLVM_PtrMaskOp
class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
- /*immArgAttrNames=*/["size"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
}
@@ -370,8 +374,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[1],
- /*immArgAttrNames=*/["size"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> {
let arguments = (ins LLVM_DefaultPointer:$start,
I64Attr:$size,
LLVM_AnyPointer:$ptr);
@@ -542,9 +546,10 @@ def LLVM_AssumeOp
: LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[],
/*requiresAccessGroup=*/0,
/*requiresAliasAnalysis=*/0,
+ /*requiresArgAndResultAttrs=*/0,
/*requiresOpBundles=*/1> {
dag args = (ins I1:$cond);
- let arguments = !con(args, opBundleArgs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$cond
@@ -1126,8 +1131,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
/*overloadedOperands=*/[], /*traits=*/[],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
- /*immArgAttrNames=*/["failureKind"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
let arguments = (ins I8Attr:$failureKind);
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index e845ea9f..a8d7cf2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
//===----------------------------------------------------------------------===//
// LLVM dialect type constraints.
@@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
// intrinsic and "enumName" contains the name of the intrinsic as appears in
// `llvm::Intrinsic` enum; one usually wants these to be related. Additionally,
// the base class also defines the "mlirBuilder" field to support the inverse
-// translation starting from an LLVM IR intrinsic. The "requiresAccessGroup",
-// "requiresAliasAnalysis", and "requiresFastmath" flags specify which
-// interfaces the intrinsic implements. If the corresponding flags are set, the
-// "aliasAttrs" list contains the arguments required by the access group and
-// alias analysis interfaces. Derived intrinsics should append the "aliasAttrs"
-// to their argument list if they set one of the flags. LLVM `immargs` can be
-// represented as MLIR attributes by providing both the `immArgPositions` and
-// `immArgAttrNames` lists. These two lists should have equal length, with
-// `immArgPositions` containing the argument positions on the LLVM IR attribute
-// that are `immargs`, and `immArgAttrNames` mapping these to corresponding
-// MLIR attributes.
+// translation starting from an LLVM IR intrinsic.
+//
+// The flags "requiresAccessGroup", "requiresAliasAnalysis",
+// "requiresFastmath", and "requiresArgAndResultAttrs" indicate which
+// interfaces the intrinsic implements. When a flag is set, the "baseArgs"
+// list includes the arguments required by the corresponding interface.
+// Derived intrinsics must append "baseArgs" to their argument list if they
+// enable any of these flags.
+//
+// LLVM `immargs` can be represented as MLIR attributes by providing both
+// the `immArgPositions` and `immArgAttrNames` lists. These two lists should
+// have equal length, with `immArgPositions` containing the argument
+// positions on the LLVM IR attribute that are `immargs`, and
+// `immArgAttrNames` mapping these to corresponding MLIR attributes.
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0, bit requiresOpBundles = 0,
+ bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0,
+ bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_OpBase<dialect, opName, !listconcat(
@@ -311,10 +316,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
[DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []),
!if(!gt(requiresFastmath, 0),
[DeclareOpInterfaceMethods<FastmathFlagsInterface>], []),
+ !if(!gt(requiresArgAndResultAttrs, 0),
+ [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []),
traits)>,
LLVM_MemOpPatterns,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
- dag aliasAttrs = !con(
+ dag baseArgs = !con(
!if(!gt(requiresAccessGroup, 0),
(ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups),
(ins )),
@@ -322,13 +329,17 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
(ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes,
OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
+ (ins )),
+ !if(!gt(requiresArgAndResultAttrs, 0),
+ (ins OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs),
+ (ins )),
+ !if(!gt(requiresOpBundles, 0),
+ (ins VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ OptionalAttr<ArrayAttr>:$op_bundle_tags),
(ins )));
- dag opBundleArgs = !if(!gt(requiresOpBundles, 0),
- (ins VariadicOfVariadic<LLVM_Type,
- "op_bundle_sizes">:$op_bundle_operands,
- DenseI32ArrayAttr:$op_bundle_sizes,
- OptionalAttr<ArrayAttr>:$op_bundle_tags),
- (ins ));
string llvmEnumName = enumName;
string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}";
string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}";
@@ -342,23 +353,35 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
(void) inst;
}];
+ string baseLlvmBuilderArgAndResultAttrs = [{
+ if (failed(moduleTranslation.convertArgAndResultAttrs(
+ op,
+ inst,
+ }] # immArgPositionsCpp # [{))) {
+ return failure();
+ }
+ }];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
- let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
- # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
- # baseLlvmBuilderCoda;
+ let llvmBuilder = baseLlvmBuilder
+ # !if(!gt(requiresAccessGroup, 0),
+ setAccessGroupsMetadataCode, "")
+ # !if(!gt(requiresAliasAnalysis, 0),
+ setAliasAnalysisMetadataCode, "")
+ # !if(!gt(requiresArgAndResultAttrs, 0),
+ baseLlvmBuilderArgAndResultAttrs, "")
+ # baseLlvmBuilderCoda;
string baseMlirBuilder = [{
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands,
- llvmOpBundles,
- }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
- }] # immArgPositionsCpp # [{,
- }] # immArgAttrNamesCpp # [{,
- mlirOperands,
- mlirAttrs))
- ) {
+ llvmOperands,
+ llvmOpBundles,
+ }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
+ }] # immArgPositionsCpp # [{,
+ }] # immArgAttrNamesCpp # [{,
+ mlirOperands,
+ mlirAttrs))) {
return failure();
}
SmallVector<Type> resultTypes =
@@ -366,9 +389,16 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
auto op = $_qualCppClassName::create($_builder,
$_location, resultTypes, mlirOperands, mlirAttrs);
}];
+ string baseMlirBuilderArgAndResultAttrs = [{
+ moduleImport.convertArgAndResultAttrs(
+ inst, op, }] # immArgPositionsCpp # [{);
+ }];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
- let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+ let mlirBuilder = baseMlirBuilder
+ # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
+ # !if(!gt(requiresArgAndResultAttrs, 0),
+ baseMlirBuilderArgAndResultAttrs, "")
# baseMlirBuilderCoda;
// Code for handling a `range` attribute that holds the constant range of the
@@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
int numResults, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
- bit requiresOpBundles = 0,
+ bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ requiresFastmath, requiresArgAndResultAttrs,
+ requiresOpBundles, immArgPositions, immArgAttrNames>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
list<Trait> traits = [],
bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0,
+ bit requiresArgAndResultAttrs = 0,
bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
requiresAccessGroup, requiresAliasAnalysis,
- /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ /*requiresFastMath=*/0, requiresArgAndResultAttrs,
+ requiresOpBundles, immArgPositions, immArgAttrNames>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+ requiresFastmath, /*requiresArgAndResultAttrs=*/0,
+ /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
// Base class for LLVM intrinsic operations returning two results. Places the
@@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+ requiresFastmath, /*requiresArgAndResultAttrs=*/0,
+ /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
def LLVM_OneResultOpBuilder :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 51004f5..3f27f6d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -2405,7 +2405,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
def LLVM_CallIntrinsicOp
: LLVM_Op<"call_intrinsic",
- [AttrSizedOperandSegments,
+ [ArgAndResultAttrsOpInterface,
+ AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 04a0b58..a2354e2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
- requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
+ requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>;
// Subclass to save typing and ease readibility when there aren't overloaded
// operands or memory accesses.
@@ -482,7 +482,7 @@ def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -507,7 +507,7 @@ def ROCDL_LoadToLDSOp :
I32Attr:$size,
I32Attr:$offset,
I32Attr:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux
attr-dict `:` type($globalPtr)
@@ -526,7 +526,7 @@ def ROCDL_GlobalLoadLDSOp :
I32Attr:$size,
I32Attr:$offset,
I32Attr:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux
attr-dict
@@ -561,7 +561,7 @@ def ROCDL_RawPtrBufferLoadOp :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -579,7 +579,7 @@ def ROCDL_RawPtrBufferLoadLdsOp :
I32:$soffset,
I32:$offset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -595,7 +595,7 @@ def ROCDL_RawPtrBufferStoreOp :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($vdata)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -614,7 +614,7 @@ def ROCDL_RawPtrBufferAtomicCmpSwap :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -630,7 +630,7 @@ class ROCDL_RawPtrBufferAtomicNoRet<string op> :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($vdata)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8d45c40..61ce23f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
iteration domain induces a padding of the operands that is consistent
across the op semantics and, unlike for simple elementwise ops, may not be
trivially deducible or specifiable on operands only (e.g. convolutions).
+ Currently, only a limited set of projected permutation maps are supported.
The specification of `padding_sizes` follows that of `tile_sizes` during
tiling: the value "0" on a particular iterator encode "no padding". Like in
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e625eef..d4ffe0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// affine.apply operations.
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permuation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 96b9adc..e1e99c3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+// These are parallelism determination modes for `acc loop`.
+// In the enum names, we use the "loop_" prefix because "auto" is
+// a language keyword - and thus for consistency all other cases
+// do the same.
+def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>;
+def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>;
+def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>;
+
+def OpenACC_LoopParMode : I32EnumAttr<
+ "LoopParMode",
+ "Encodes the options for loop parallelism determination mode",
+ [
+ OpenACC_LoopAuto, OpenACC_LoopIndependent,
+ OpenACC_LoopSeq]> {
+ let cppNamespace = "::mlir::acc";
+ let genSpecializedAttr = 0;
+}
+
// Type used in operation below.
def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>;
@@ -2373,6 +2391,11 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
// Return whether this LoopOp has a gang, worker, or vector applying to the
// 'default'/None device-type.
bool hasDefaultGangWorkerVector();
+
+ // Used to obtain the parallelism mode for the requested device type.
+ // This first checks if the mode is set for the device_type requested.
+ // And if not, it returns the non-device_type mode.
+ LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType);
}];
let hasCustomAssemblyFormat = 1;
@@ -2404,6 +2427,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
}];
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "::mlir::ValueRange":$lowerbounds,
+ "::mlir::ValueRange":$upperbounds,
+ "::mlir::ValueRange":$steps,
+ "LoopParMode":$parMode), [{
+ auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get(
+ $_builder.getContext(), mlir::acc::DeviceType::None);
+ auto arrOfDeviceNone = mlir::ArrayAttr::get(
+ $_builder.getContext(), deviceNoneAttr);
+ build($_builder, $_state,
+ /*results=*/{},
+ /*lowerbound=*/lowerbounds,
+ /*upperbound=*/upperbounds,
+ /*step=*/steps,
+ /*inclusiveUpperbound=*/nullptr,
+ /*collapse=*/nullptr,
+ /*collapseDeviceType=*/nullptr,
+ /*gangOperands=*/{},
+ /*gangOperandsArgType=*/nullptr,
+ /*gangOperandsSegments=*/nullptr,
+ /*gangOperandsDeviceType=*/nullptr,
+ /*workerNumOperands=*/{},
+ /*workerNumOperandsDeviceType=*/nullptr,
+ /*vectorOperands=*/{},
+ /*vectorOperandsDeviceType=*/nullptr,
+ /*seq=*/parMode == LoopParMode::loop_seq ?
+ arrOfDeviceNone : nullptr,
+ /*independent=*/parMode == LoopParMode::loop_independent ?
+ arrOfDeviceNone : nullptr,
+ /*auto_=*/parMode == LoopParMode::loop_auto ?
+ arrOfDeviceNone : nullptr,
+ /*gang=*/nullptr,
+ /*worker=*/nullptr,
+ /*vector=*/nullptr,
+ /*tileOperands=*/{},
+ /*tileOperandsSegments=*/nullptr,
+ /*tileOperandsDeviceType=*/nullptr,
+ /*cacheOperands=*/{},
+ /*privateOperands=*/{},
+ /*privatizationRecipes=*/nullptr,
+ /*reductionOperands=*/{},
+ /*reductionRecipes=*/nullptr,
+ /*combined=*/nullptr);
+ }]
+ >
+ ];
}
// Yield operation for the acc.loop and acc.parallel operations.
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 2d15544..0c1c15b 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -87,6 +87,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
be accessed inside the op. The op's region can have multiple blocks and the
blocks can have multiple distinct terminators. Values returned from this op's
region define the op's results.
+ The optional 'no_inline' flag can be set to request the ExecuteRegionOp to be
+ preserved as much as possible and not being inlined in the parent block until
+ an explicit lowering step.
Example:
@@ -98,6 +101,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
}
}
+ // the same as above but with no_inline attribute
+ scf.for %i = 0 to 128 step %c1 {
+ %y = scf.execute_region -> i32 no_inline {
+ %x = load %A[%i] : memref<128xi32>
+ scf.yield %x : i32
+ }
+ }
+
affine.for %i = 0 to 100 {
"foo"() : () -> ()
%v = scf.execute_region -> i64 {
@@ -119,6 +130,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
```
}];
+ let arguments = (ins
+ UnitAttr:$no_inline
+ );
+
let results = (outs Variadic<AnyType>);
let regions = (region AnyRegion:$region);
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9038326..9c74cff0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4448,6 +4448,7 @@ def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended"
def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>;
def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>;
def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>;
+def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>;
def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>;
def SPIRV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>;
def SPIRV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>;
@@ -4630,7 +4631,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
- SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
+ SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite,
+ SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr,
SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect,
SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index ab535d7..9331fc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -403,6 +403,28 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
// -----
+def SPIRV_IsFiniteOp : SPIRV_LogicalUnaryOp<"IsFinite", SPIRV_Float, []> {
+ let summary = "Result is true if x is an IEEE Finite, otherwise result is false";
+
+ let description = [{
+ Result Type must be a scalar or vector of Boolean type.
+
+ x must be a scalar or vector of floating-point type. It must have the
+ same number of components as Result Type.
+
+ Results are computed per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.IsFinite %0: f32
+ %3 = spirv.IsFinite %1: vector<4xf32>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> {
let summary = "Result is true if x is an IEEE Inf, otherwise result is false";
@@ -418,7 +440,7 @@ def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> {
```mlir
%2 = spirv.IsInf %0: f32
- %3 = spirv.IsInf %1: vector<4xi32>
+ %3 = spirv.IsInf %1: vector<4xf32>
```
}];
}
@@ -442,7 +464,7 @@ def SPIRV_IsNanOp : SPIRV_LogicalUnaryOp<"IsNan", SPIRV_Float, []> {
```mlir
%2 = spirv.IsNan %0: f32
- %3 = spirv.IsNan %1: vector<4xi32>
+ %3 = spirv.IsNan %1: vector<4xf32>
```
}];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index c691d59..531fecc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -330,10 +330,34 @@ public:
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
};
+ // Type for specifying the decoration(s) on the struct itself.
+ struct StructDecorationInfo {
+ Decoration decoration;
+ Attribute decorationValue;
+
+ StructDecorationInfo(Decoration decoration, Attribute decorationValue)
+ : decoration(decoration), decorationValue(decorationValue) {}
+
+ friend bool operator==(const StructDecorationInfo &lhs,
+ const StructDecorationInfo &rhs) {
+ return lhs.decoration == rhs.decoration &&
+ lhs.decorationValue == rhs.decorationValue;
+ }
+
+ friend bool operator<(const StructDecorationInfo &lhs,
+ const StructDecorationInfo &rhs) {
+ return llvm::to_underlying(lhs.decoration) <
+ llvm::to_underlying(rhs.decoration);
+ }
+
+ bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
+ };
+
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
@@ -367,6 +391,9 @@ public:
bool hasOffset() const;
+ /// Returns true if the struct has a specified decoration.
+ bool hasDecoration(spirv::Decoration decoration) const;
+
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -380,12 +407,18 @@ public:
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
+ // Returns in `structDecorations` the Decorations associated with the
+ // StructType.
+ void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+ &structDecorations) const;
+
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
@@ -396,6 +429,9 @@ public:
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec9..03ae54a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
+ /// Whether to emulate unsupported floats with integer types of same bit
+ /// width.
+ bool emulateUnsupportedFloatTypes{true};
+
/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3885439..5d45508 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2595,6 +2595,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
def Vector_TransposeOp :
Vector_Op<"transpose", [Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]> {
@@ -2876,7 +2877,10 @@ def Vector_ScanOp :
// VectorStepOp
//===----------------------------------------------------------------------===//
-def Vector_StepOp : Vector_Op<"step", [Pure]> {
+def Vector_StepOp : Vector_Op<"step", [
+ Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
+ ]> {
let summary = "A linear sequence of values from 0 to N";
let description = [{
A `step` operation produces an index vector, i.e. a 1-D vector of values of
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 91d6b2a..75b16a87 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.
- Example:
+ Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```
+
+ Example 2:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
+ The source operand could be a raw pointer (uint64_t).
+ Please refer to create_tdesc for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
+ xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<4xindex>
+ ```
}];
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
+ Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+ }
+ return TypedValue<xegpu::TensorDescType>();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];
- let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+ let assemblyFormat = [{
+ $source
+ (`[` $offsets^ `]`)?
+ prop-dict
+ attr-dict `:` type(operands)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $source,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
-def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
- ]> {
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";
let description = [{ It (aka. load) load data per each work-item. The output
@@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```
+
Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```
+
+ Example 4:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
+ The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
+ for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %offsets = vector.step : vector<16xindex>
+ %mask = vector.constant_mask [16]: vector<16xi1>
+ %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+ ```
}];
- let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+ let arguments = (ins XeGPU_GatherScatterSourceType: $source,
+ Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
+ OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let results = (outs XeGPU_ValueType: $value);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+
+ Type getSourceType() {
+ return getSource().getType();
+ }
+
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+ }
+ return TypedValue<xegpu::TensorDescType>();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
mlir::Type getElementType() {
@@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
}];
- let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
- `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+ let assemblyFormat = [{
+ $source
+ (`[` $offsets^ `]`)? `,`
+ $mask prop-dict
+ attr-dict `:` type(operands) `->` type($value)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
- AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
- ]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
```
+
+ Example 4:
+ A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+ It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
+ The dest operand could be a raw pointer (uint64_t).
+ Please refer to create_tdesc for the restriction of memref.
+ ```mlir
+ %a = memref.alloc() : memref<1024xf32>
+ %val = arith.constant dense<0.0> : vector<16xf32>
+ %offsets = vector.step : vector<16xindex>
+ %mask = vector.constant_mask [16]: vector<16xi1>
+ xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<cached>,
+ l3_hint = #xegpu.cache_hint<cached>}
+ : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+ ```
+
}];
let arguments = (ins
XeGPU_ValueType: $value,
- XeGPU_TensorDesc: $TensorDesc,
+ XeGPU_GatherScatterSourceType: $dest,
+ Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
+ OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let extraClassDeclaration = extraBaseClassDeclaration # [{
+ Type getDestType() {
+ return getDest().getType();
+ }
+
+ TypedValue<xegpu::TensorDescType> getTensorDesc() {
+ if (auto tdescType = getTensorDescType()) {
+ return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
+ }
+ return TypedValue<xegpu::TensorDescType>();
+ }
+
xegpu::TensorDescType getTensorDescType() {
- return getTensorDesc().getType();
+ return dyn_cast<xegpu::TensorDescType>(getDestType());
}
VectorType getValueType() {
@@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];
- let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
- `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
+ let assemblyFormat = [{
+ $value `,`
+ $dest
+ (`[` $offsets^ `]`)? `,`
+ $mask
+ prop-dict
+ attr-dict `:` type(operands)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>
+ ];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 20916ae..b268cab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}
+def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index e3c2aec..19d3afe 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -18,9 +18,15 @@
include "mlir/IR/OpBase.td"
-/// Interface for operations with arguments attributes (both call-like
-/// and callable operations).
-def ArgumentAttributesMethods {
+/// Interface for operations with result and argument attributes.
+def ArgAndResultAttrsOpInterface : OpInterface<"ArgAndResultAttrsOpInterface"> {
+ let description = [{
+ An operation that has argument and result attributes. This interface
+ provides functions to access and modify the argument and result
+ attributes of the operation.
+ }];
+ let cppNamespace = "::mlir";
+
list<InterfaceMethod> methods = [
InterfaceMethod<[{
Get the array of argument attribute dictionaries. The method should
@@ -64,7 +70,8 @@ def ArgumentAttributesMethods {
// a call-like operation. This represents the destination of the call.
/// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface"> {
+def CallOpInterface : OpInterface<"CallOpInterface",
+ [ArgAndResultAttrsOpInterface]> {
let description = [{
A call-like operation is one that transfers control from one sub-routine to
another. These operations may be traditional direct calls `call @foo`, or
@@ -123,11 +130,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
return ::mlir::call_interface_impl::resolveCallable($_op);
}]
>
- ] # ArgumentAttributesMethods.methods;
+ ];
}
/// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface"> {
+def CallableOpInterface : OpInterface<"CallableOpInterface",
+ [ArgAndResultAttrsOpInterface]> {
let description = [{
A callable operation is one who represents a potential sub-routine, and may
be a target for a call-like operation (those providing the CallOpInterface
@@ -140,11 +148,11 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
let methods = [
InterfaceMethod<[{
- Returns the region on the current operation that is callable. This may
- return null in the case of an external callable object, e.g. an external
- function.
- }],
- "::mlir::Region *", "getCallableRegion">,
+ Returns the region on the current operation that is callable. This may
+ return null in the case of an external callable object, e.g. an external
+ function.
+ }],
+ "::mlir::Region *", "getCallableRegion">,
InterfaceMethod<[{
Returns the callable's argument types based exclusively on the type (to
allow for this method may be called on function declarations).
@@ -155,7 +163,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
allow for this method may be called on function declarations).
}],
"::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
- ] # ArgumentAttributesMethods.methods;
+ ];
}
#endif // MLIR_INTERFACES_CALLINTERFACES
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 60615cf6..e4670cb 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -28,6 +28,7 @@
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
namespace mlir {
class DialectRegistry;
@@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
registerVCIXDialectTranslation(registry);
+ registerXeVMDialectTranslation(registry);
// Extension required for translating GPU offloading Ops.
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
@@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
registerNVVMDialectTranslation(registry);
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
+ registerXeVMDialectTranslation(registry);
// Extension required for translating GPU offloading Ops.
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h
new file mode 100644
index 0000000..b4f6750
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for XeVM dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the XeVM dialect and the translation from it to the LLVM IR in the
+/// given registry;
+void registerXeVMDialectTranslation(mlir::DialectRegistry &registry);
+
+/// Register the XeVM dialect and the translation from it in the registry
+/// associated with the given context.
+void registerXeVMDialectTranslation(mlir::MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 17ef8e4..09d819a 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -83,6 +83,10 @@ public:
/// specification.
void convertTargetTriple();
+ /// Converts the module level asm of the LLVM module to an MLIR module
+ /// level asm specification.
+ void convertModuleLevelAsm();
+
/// Stores the mapping between an LLVM value and its MLIR counterpart.
void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; }
@@ -291,10 +295,12 @@ public:
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
- /// Converts the parameter and result attributes in `argsAttr` and `resAttr`
- /// and add them to the `callOp`.
- void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr,
- ArrayAttr &resAttr, OpBuilder &builder);
+ /// Converts the argument and result attributes attached to `call` and adds
+ /// them to `attrsOp`. For intrinsic calls, filters out attributes
+ /// corresponding to immediate arguments specified by `immArgPositions`.
+ void convertArgAndResultAttrs(llvm::CallBase *call,
+ ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions = {});
/// Whether the importer should try to convert all intrinsics to
/// llvm.call_intrinsic instead of dialect supported operations.
@@ -378,19 +384,12 @@ private:
bool &isIncompatibleCall);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
- /// Converts the parameter and result attributes attached to `func` and adds
+ /// Converts the argument and result attributes attached to `func` and adds
/// them to the `funcOp`.
- void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
- OpBuilder &builder);
- /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
- /// DictionaryAttr for the LLVM dialect.
- DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder);
- /// Converts the parameter and result attributes attached to `call` and adds
- /// them to the `callOp`. Implemented in terms of the the public definition of
- /// convertParameterAttributes.
- void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
- OpBuilder &builder);
+ void convertArgAndResultAttrs(llvm::Function *func, LLVMFuncOp funcOp);
+ /// Converts the argument or result attributes in `llvmAttrSet` to a
+ /// corresponding MLIR LLVM dialect attribute dictionary.
+ DictionaryAttr convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
/// Converts the attributes attached to `inst` and adds them to the `op`.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index f3f73f4..eb7dfa7 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -25,11 +25,13 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class BasicBlock;
+class CallBase;
+class CanonicalLoopInfo;
class Function;
class IRBuilderBase;
class OpenMPIRBuilder;
@@ -306,10 +308,16 @@ public:
/*recordInsertions=*/false);
}
- /// Translates parameter attributes of a call and adds them to the returned
- /// AttrBuilder. Returns failure if any of the translations failed.
- FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
- DictionaryAttr paramAttrs);
+ /// Converts argument and result attributes from `attrsOp` to LLVM IR
+ /// attributes on the `call` instruction. Returns failure if conversion fails.
+ /// The `immArgPositions` parameter is only relevant for intrinsics. It
+ /// specifies the positions of immediate arguments, which do not have
+ /// associated argument attributes in MLIR and should be skipped during
+ /// attribute mapping.
+ LogicalResult
+ convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp,
+ llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions = {});
/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
@@ -389,6 +397,11 @@ private:
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);
+ /// Translates parameter attributes of a call and adds them to the returned
+ /// AttrBuilder. Returns failure if any of the translations failed.
+ FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
+ DictionaryAttr paramAttrs);
+
/// Translates parameter attributes of a function and adds them to the
/// returned AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e681..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4..56b6181 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f65..c0439a4 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 855c582..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -22,7 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -32,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index a877ad2..1787e0a 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -488,7 +488,12 @@ namespace mlir {
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
- patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
+ patterns
+ .add<CopySignPattern,
+ CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
+ CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
+ CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
+ typeConverter, patterns.getContext());
// GLSL patterns
patterns
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 6ba5bfe4..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5d13353..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -26,13 +26,12 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index fd40e7c..fa9e544 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -36,7 +36,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "shard-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386e..8cd650e 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a425eff..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8d7053c..22608a1 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -26,7 +26,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
@@ -40,7 +40,6 @@ using llvm::divideFloorSigned;
using llvm::mod;
#define DEBUG_TYPE "affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
@@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
AffineMap *map,
ValueRange dims,
ValueRange syms) {
+ LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
AffineMap affineMinMap = minOp.getAffineMap();
- LLVM_DEBUG({
- DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
- });
-
// Check the value is positive.
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
// Compare each expression in the minimum against 0.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index cffe310..52cd0ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 935aa3c..b951df8 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -22,6 +22,8 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
#define DEBUG_TYPE "llvm-inliner"
using namespace mlir;
@@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
bool wouldBeCloned) const final {
auto callOp = dyn_cast<LLVM::CallOp>(call);
if (!callOp) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
- << LLVM::CallOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: call is not an '"
+ << LLVM::CallOp::getOperationName() << "' op";
return false;
}
if (callOp.getNoInline()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n");
+ LDBG() << "Cannot inline: call is marked no_inline";
return false;
}
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!funcOp) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: callable is not an '"
- << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: callable is not an '"
+ << LLVM::LLVMFuncOp::getOperationName() << "' op";
return false;
}
if (funcOp.isNoInline()) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: function is marked no_inline\n");
+ LDBG() << "Cannot inline: function is marked no_inline";
return false;
}
if (funcOp.isVarArg()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
+ LDBG() << "Cannot inline: callable is variadic";
return false;
}
// TODO: Generate aliasing metadata from noalias result attributes.
if (auto attrs = funcOp.getArgAttrs()) {
for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": inalloca arguments not supported\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": inalloca arguments not supported";
return false;
}
}
}
// TODO: Handle exceptions.
if (funcOp.getPersonality()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": unhandled function personality\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": unhandled function personality";
return false;
}
if (funcOp.getPassthrough()) {
@@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
if (!stringAttr)
return false;
if (disallowedFunctionAttrs.contains(stringAttr)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline " << funcOp.getSymName()
- << ": found disallowed function attribute "
- << stringAttr << "\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": found disallowed function attribute " << stringAttr;
return true;
}
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 7f9ba1b..bf66ed0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
}
ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+ ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
int64_t padRank = sourceShape.size();
auto isStaticZero = [](OpFoldResult f) {
@@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
SmallVector<int64_t> newShape;
+ SmallVector<int64_t> newResultShape;
SmallVector<OpFoldResult> newLowPad;
SmallVector<OpFoldResult> newHighPad;
- for (const auto [dim, size, low, high] :
- zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
- padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+ for (const auto [dim, size, outSize, low, high] : zip_equal(
+ llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+ resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
isStaticZero(high)) {
unitDims.insert(dim);
} else {
newShape.push_back(size);
+ newResultShape.push_back(outSize);
newLowPad.push_back(low);
newHighPad.push_back(high);
}
@@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
reassociationMap, options.rankReductionStrategy);
- auto newPadOp = tensor::PadOp::create(
- rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+ auto newResultType = RankedTensorType::get(
+ newResultShape, padOp.getResultType().getElementType());
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2c62cb6..2e62523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ // Add 1 to the maximum accessed index and get the final padded size.
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 793eec7..ea68b1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1946,12 +1946,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
- // otherwise the validator complains that the mask size is invalid.
- SmallVector<int64_t> writeVectorSizes(
- unpackOp.getDestType().hasStaticShape()
- ? vectorSizes
- : shapeCastOp.getResultVectorType().getShape());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e73bdd3..9d5dfc1 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() {
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
}
+acc::LoopParMode
+acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
+ if (hasSeq(deviceType))
+ return LoopParMode::loop_seq;
+ if (hasAuto(deviceType))
+ return LoopParMode::loop_auto;
+ if (hasIndependent(deviceType))
+ return LoopParMode::loop_independent;
+ if (hasSeq())
+ return LoopParMode::loop_seq;
+ if (hasAuto())
+ return LoopParMode::loop_auto;
+ assert(hasIndependent() &&
+ "loop must have default auto, seq, or independent");
+ return LoopParMode::loop_independent;
+}
+
void acc::LoopOp::addGangOperands(
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 759e58b..0262a1b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
+ if (succeeded(parser.parseOptionalKeyword("no_inline")))
+ result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
+
// Introduce the body region and parse it.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
@@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printOptionalArrowTypeList(getResultTypes());
-
p << ' ';
+ if (getNoInline())
+ p << "no_inline ";
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
@@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
- if (!op.getRegion().hasOneBlock())
+ if (!op.getRegion().hasOneBlock() || op.getNoInline())
return failure();
replaceOpWithRegion(rewriter, op, op.getRegion());
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 9bee200..fcf1526 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(decoration.value(), decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 46739bc..ddb3426 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -835,12 +835,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 35ec019..8f4c4cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
+ return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return type;
+ Type srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6a9b951..a53d0a7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() {
if (walkResult.wasInterrupted())
return signalPassFailure();
+ // Update min version requirement for capabilities after deducing them.
+ for (spirv::Capability cap : deducedCapabilities) {
+ if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
+ deducedVersion = std::max(deducedVersion, *minVersion);
+ if (deducedVersion > allowedVersion) {
+ module.emitError("Capability '")
+ << spirv::stringifyCapability(cap) << "' requires min version "
+ << spirv::stringifyVersion(deducedVersion)
+ << " but target environment allows up to "
+ << spirv::stringifyVersion(allowedVersion);
+ return signalPassFailure();
+ }
+ }
+ }
+
// TODO: verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index e5a3b5d..08fccfa 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -38,7 +38,6 @@
#include <utility>
#define DEBUG_TYPE "shard-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::shard;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f36..9543fa1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (failed(maybeProfDef) && failed(maybeExtDef))
+ return success();
- if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
- !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+ const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8789f55..a21b5ba 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5916,14 +5916,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
}
// shape_cast(constant) -> constant
- if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ if (auto denseAttr =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+ return denseAttr.reshape(getType());
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
@@ -6316,6 +6315,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}
+void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
@@ -7198,6 +7202,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
}
//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable()) {
+ return;
+ }
+ unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
+ APInt zero(bitwidth, 0);
+ APInt high(bitwidth, resultType.getDimSize(0) - 1);
+ ConstantIntRanges result = {zero, high, zero, high};
+ setResultRanges(getResult(), result);
+}
+
+//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566..dedc3b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast with a vector operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,23 @@ public:
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return rewriter.notifyMatchFailure(
+ op, "broadcast from scalar already in lowered form");
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ // Here we are broadcasting to a rank-1 vector. Ensure that the source is a
+ // scalar.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ fullRankPosition);
+ assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4baeb11..2cf8f0b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = vector::SplatOp::create(
+ Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d7..cbb9d4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ public:
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, dstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 48d680c..c707f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -25,12 +25,10 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-transfer-opt"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
using namespace mlir;
/// Return the ancestor op in the region or nullptr if the region is not
@@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
/// transfer_write is dead if all reads that can be reached from the potentially
/// dead transfer_write are dominated by the overwriting transfer_write.
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
- LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
- << "\n");
+ LDBG() << "Candidate for dead store: " << *write.getOperation();
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
@@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
- LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
- << *accessAncestor << "\n");
+ LDBG() << "Store may not be dead due to op: " << *accessAncestor;
return;
}
}
- LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
- << " overwritten by: " << *firstOverwriteCandidate << "\n");
+ LDBG() << "Found dead store: " << *write.getOperation()
+ << " overwritten by: " << *firstOverwriteCandidate;
opToErase.push_back(write.getOperation());
}
@@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
- LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
- << "\n");
+ LDBG() << "Candidate for Forwarding: " << *read.getOperation();
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
@@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
- LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
- << *write << "\n");
+ LDBG() << "Fail to do write to read forwarding due to op: " << *write;
return;
}
}
- LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
- << " to: " << *read.getOperation() << "\n");
+ LDBG() << "Forward value from " << *lastwrite.getOperation()
+ << " to: " << *read.getOperation();
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fe..2269a40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -939,7 +939,7 @@ public:
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -965,6 +965,45 @@ private:
std::function<bool(BitCastOp)> controlFn;
};
+static bool haveSameShapeAndScaling(Type t, Type u) {
+ auto tVec = dyn_cast<VectorType>(t);
+ auto uVec = dyn_cast<VectorType>(u);
+ if (!tVec) {
+ return !uVec;
+ }
+ if (!uVec) {
+ return false;
+ }
+ return tVec.getShape() == uVec.getShape() &&
+ tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return shapedType.clone(newElementType);
+ }
+ return newElementType;
+}
+
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return {};
+
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcast.getSource();
+
+ if (auto splat = dyn_cast<vector::SplatOp>(op))
+ return splat.getInput();
+
+ return {};
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -988,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
- if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
- if (op->getResults()[0].getType() != op->getOperand(0).getType())
- return rewriter.notifyMatchFailure(op,
- "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
@@ -1005,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ Type resultElemType = resultType.getElementType();
+
+ // Get the type of the first non-constant operand
+ Value splatSource;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ splatSource = getBroadcastLikeSource(operand);
+ break;
+ }
+ if (!splatSource)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type unbroadcastResultType =
+ cloneOrReplace(splatSource.getType(), resultElemType);
- // Make sure that all operands are broadcast from identical types:
+ // Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
+ if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return haveSameShapeAndScaling(source.getType(),
+ splatSource.getType());
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op,
+ "not all operands are constants or broadcasts from the same type");
}
// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ Type elementType = getElementTypeOrSelf(operand.getType());
+ Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+ if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
+ newConst = splatConst.resizeSplat(newTypeShaped);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, newType, operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ unbroadcastResultType, op->getAttrs());
// Replace the original Op with the elementwise Op
- auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- op, vectorType, elementwiseOp->getResults());
+ op, resultType, elementwiseOp->getResults());
return success();
}
@@ -1239,15 +1302,17 @@ public:
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");
- Operation *splat = op.getValueToStore().getDefiningOp();
- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ Value toStore = op.getValueToStore();
+ Value source = getBroadcastLikeSource(toStore);
+ if (!source)
+ return rewriter.notifyMatchFailure(
+ op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
+ Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();
@@ -1297,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 704deea..33450f3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+static LogicalResult
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
+ int64_t chunkSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!valueTy)
+ return emitError() << "Expecting a vector type result.";
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+
+ // a valid shape for SIMT case
+ if (valueTy.getRank() == 1) {
+ if (valueTy.getNumElements() != chunkSize)
+ return emitError() << "value elements must match chunk size " << chunkSize
+ << " for SIMT code.";
+ return success();
+ }
+
+ llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
+ return emitError() << "Mask should match value except the chunk size dim.";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (!tdescTy.isScattered())
+
+ if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}
+void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
@@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
+
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+ auto srcTy = getSourceType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(srcTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+ l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
@@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getDest()) > 1)
+ return emitOpError(
+ "Expecting the dest is a 1D memref or pointer (uint64_t).");
+
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+
+ auto destTy = getDestType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(destTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+ l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index ec8fad4..c793b71 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad29..de52fbd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -40,7 +40,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
@@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
return failure();
});
if (failed(verify(op))) {
- LLVM_DEBUG(llvm::dbgs()
- << DEBUG_TYPE << ": '" << op->getName()
- << "' failed to verify and will be printed in generic form\n");
+ LDBG() << op->getName()
+ << "' failed to verify and will be printed in generic form";
printerFlags.printGenericOpForm();
}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e9b5e92..310680b 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -17,14 +17,32 @@
using namespace mlir;
+static std::pair<int64_t, int64_t>
+getLineAndColStart(const llvm::SourceMgr &sourceMgr) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ if (lastFileID == 1)
+ return {0, 0};
+
+ auto bufferID = sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ // Exclude same start.
+ if (main->getBufferStart() < last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ return sourceMgr.getLineAndColumn(
+ llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID);
+ }
+ return {0, 0};
+}
+
LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
Block *block, const ParserConfig &config,
LocationAttr *sourceFileLoc) {
const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(*sourceBuf, block, config);
@@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
const auto *sourceBuf =
sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(*sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(sourceMgr, block, config);
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ff34a08..0f675a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
-static LogicalResult
-convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
- ArrayAttr resAttrsArray, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (argAttrsArray) {
- for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
- if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
- !argAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, argAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addParamAttrs(argIdx, *attrBuilder);
- }
- }
- }
-
- if (resAttrsArray && resAttrsArray.size() > 0) {
- if (resAttrsArray.size() != 1)
- return mlir::emitError(loc, "llvm.func cannot have multiple results");
- if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
- !resAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, resAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addRetAttrs(*attrBuilder);
- }
- }
- return success();
-}
-
-static LogicalResult
-convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- return convertParameterAndResultAttrs(
- callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
- moduleTranslation);
-}
-
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));
- if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
- op.getResAttrsAttr(), inst,
- moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
return failure();
if (op.getNumResults() == 1)
@@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getInlineHintAttr())
call->addFnAttr(llvm::Attribute::InlineHint);
- if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
return failure();
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
@@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
- if (failed(
- convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index 1c9e226..55e73e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "llvm/IR/ConstantRange.h"
using namespace mlir;
using namespace mlir::NVVM;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
index 580afdd..cb1f234 100644
--- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
@@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
+ llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false,
+ /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands,
+ mlirAttrs)))
return failure();
Type resultType = moduleImport.convertType(inst->getType());
@@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
ValueRange{mlirOperands}, FastmathFlagsAttr{});
moduleImport.setFastmathFlagsAttr(inst, op);
-
- ArrayAttr argsAttr, resAttr;
- moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
- op.setArgAttrsAttr(argsAttr);
- op.setResAttrsAttr(resAttr);
+ moduleImport.convertArgAndResultAttrs(inst, op);
// Update importer tracking of results.
unsigned numRes = op.getNumResults();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 58e3c44..6325480 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
#include "llvm/IR/Constants.h"
@@ -1063,6 +1064,18 @@ void ModuleImport::convertTargetTriple() {
builder.getStringAttr(llvmModule->getTargetTriple().str()));
}
+void ModuleImport::convertModuleLevelAsm() {
+ llvm::StringRef asmStr = llvmModule->getModuleInlineAsm();
+ llvm::SmallVector<mlir::Attribute> asmArrayAttr;
+
+ for (llvm::StringRef line : llvm::split(asmStr, '\n'))
+ if (!line.empty())
+ asmArrayAttr.push_back(builder.getStringAttr(line));
+
+ mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(),
+ builder.getArrayAttr(asmArrayAttr));
+}
+
LogicalResult ModuleImport::convertFunctions() {
for (llvm::Function &func : llvmModule->functions())
if (failed(processFunction(&func)))
@@ -2267,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
- convertParameterAttributes(callInst, callOp, builder);
+ convertArgAndResultAttrs(callInst, callOp);
return callOp.getOperation();
}();
@@ -2364,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ convertArgAndResultAttrs(invokeInst, invokeOp);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
@@ -2730,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
}
DictionaryAttr
-ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder) {
+ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) {
SmallVector<NamedAttribute> paramAttrs;
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
+ auto llvmAttr = llvmAttrSet.getAttribute(llvmKind);
// Skip attributes that are not attached.
if (!llvmAttr.isValid())
continue;
@@ -2769,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
return builder.getDictionaryAttr(paramAttrs);
}
-void ModuleImport::convertParameterAttributes(llvm::Function *func,
- LLVMFuncOp funcOp,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(llvm::Function *func,
+ LLVMFuncOp funcOp) {
auto llvmAttrs = func->getAttributes();
for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
- funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
+ funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs));
}
// Convert the result attributes and attach them wrapped in an ArrayAttribute
// to the funcOp.
@@ -2783,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
if (!llvmResAttr.hasAttributes())
return;
funcOp.setResAttrsAttr(
- builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
+ builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)}));
}
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- ArrayAttr &argsAttr,
- ArrayAttr &resAttr,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(
+ llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions) {
+ // Compute the set of immediate argument positions.
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ // Convert the argument attributes and filter out immediate arguments.
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ // Skip immediate arguments.
+ if (immArgPositionsSet.contains(i))
+ continue;
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
@@ -2807,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
- argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argsAttr = getArrayAttr(argAttrs);
+ argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs));
+ attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
+ // Convert the result attributes.
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
- DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
- resAttr = getArrayAttr({resAttrs});
-}
-
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- CallOpInterface callOp,
- OpBuilder &builder) {
- ArrayAttr argsAttr, resAttr;
- convertParameterAttributes(call, argsAttr, resAttr, builder);
- callOp.setArgAttrsAttr(argsAttr);
- callOp.setResAttrsAttr(resAttr);
+ DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr);
+ attrsOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}
template <typename Op>
@@ -2892,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
builder, loc, func->getName(), functionType,
convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
- convertParameterAttributes(func, funcOp, builder);
+ convertArgAndResultAttrs(func, funcOp);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
funcOp.setPersonalityAttr(personality);
@@ -3199,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule(
if (failed(moduleImport.convertIFuncs()))
return {};
moduleImport.convertTargetTriple();
+ moduleImport.convertModuleLevelAsm();
return module;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b997e55..b3a06e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+LogicalResult ModuleTranslation::convertArgAndResultAttrs(
+ ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions) {
+ // Convert the argument attributes.
+ if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) {
+ unsigned argAttrIdx = 0;
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) {
+ if (argAttrIdx >= argAttrsArray.size())
+ break;
+ // Skip immediate arguments (they have no entries in argAttrsArray).
+ if (immArgPositionsSet.contains(argIdx))
+ continue;
+ // Skip empty argument attributes.
+ auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]);
+ if (argAttrs.empty())
+ continue;
+ // Convert and add attributes to the call instruction.
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ // Convert the result attributes.
+ if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) {
+ if (!resAttrsArray.empty()) {
+ auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+ }
+
+ return success();
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
@@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
llvmModule->setTargetTriple(
llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
+ if (auto asmAttr = m->getDiscardableAttr(
+ LLVM::LLVMDialect::getModuleLevelAsmAttrName())) {
+ auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr);
+ if (!asmArrayAttr) {
+ m->emitError("expected an array attribute for a module level asm");
+ return nullptr;
+ }
+
+ for (Attribute elt : asmArrayAttr) {
+ auto asmStrAttr = dyn_cast<StringAttr>(elt);
+ if (!asmStrAttr) {
+ m->emitError(
+ "expected a string attribute for each entry of a module level asm");
+ return nullptr;
+ }
+ llvmModule->appendModuleInlineAsm(asmStrAttr.getValue());
+ }
+ }
+
return llvmModule;
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e5934bb..88931b5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ structDecorationsInfo.emplace_back(decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd..db1cc3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a8a2b2e..737f296 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -318,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -630,11 +631,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -704,6 +710,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
@@ -938,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
+ } else if (isa<spirv::TensorArmType>(constType)) {
+ numberOfConstituents = shapedType.getNumElements();
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ uint32_t elementID = 0;
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
+ elementID =
+ elementType.isInteger(1)
+ ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
+ : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
+ }
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
+ elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
+ }
+ if (!elementID) {
+ return 0;
+ }
+ operands.push_back(elementID);
+ }
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd..6e2352e 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,23 @@ func.func @constant() {
return
}
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+ // CHECK: spirv.Constant 56 : i8
+ %cst = arith.constant 1.0 : f8E4M3
+ // CHECK: spirv.Constant 56 : i8
+ %cst_i8 = arith.bitcast %cst : f8E4M3 to i8
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+ // CHECK: spirv.Constant dense<56> : vector<4xi8>
+ %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+ // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+ %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
+ return
+}
+
// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c59..ae59f28 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
//CHECK-LABEL: @abs_caller
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
return %rf, %rd : f32, f64
}
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+ %af = complex.angle %f : complex<f32>
+ // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+ %ad = complex.angle %d : complex<f64>
+ // CHECK: return %[[AF]], %[[AD]]
+ return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+ %cf = complex.cos %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+ %cd = complex.cos %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf, %cd : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @exp_caller
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
// CHECK: return %[[EF]], %[[ED]]
return %ef, %ed : complex<f32>, complex<f64>
}
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+ %lf = complex.log %f : complex<f32>
+ // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+ %ld = complex.log %d : complex<f64>
+ // CHECK: return %[[LF]], %[[LD]]
+ return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+ %cf2 = complex.conj %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+ %cd2 = complex.conj %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+ %pf = complex.pow %f, %f : complex<f32>
+ // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+ %pd = complex.pow %d, %d : complex<f64>
+ // CHECK: return %[[PF]], %[[PD]]
+ return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+ %sf2 = complex.sin %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+ %sd2 = complex.sin %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+ %sf = complex.sqrt %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+ %sd = complex.sqrt %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+ %tf2 = complex.tan %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+ %td2 = complex.tan %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+ %tf = complex.tanh %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+ %td = complex.tanh %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf, %td : complex<f32>, complex<f64>
+}
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a..0c77c88 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===//
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK: spirv.func @float8_to_integer8
+ // CHECK-SAME: (%arg0: i8
+ // CHECK-SAME: %arg1: i8
+ // CHECK-SAME: %arg2: i8
+ // CHECK-SAME: %arg3: i8
+ // CHECK-SAME: %arg4: i8
+ // CHECK-SAME: %arg5: i8
+ // CHECK-SAME: %arg6: i8
+ // CHECK-SAME: %arg7: i8
+ // CHECK-SAME: %arg8: vector<4xi8>
+ // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+ // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+ // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+ // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+ // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+ // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+ // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+ // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+ // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+ // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+ // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+ // UNSUPPORTED_FLOAT-SAME: ) {
+
+ func.func @float8_to_integer8(
+ %arg0: f8E5M2, // CHECK-NOT: f8E5M2
+ %arg1: f8E4M3, // CHECK-NOT: f8E4M3
+ %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
+ %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
+ %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
+ %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
+ %arg6: f8E3M4, // CHECK-NOT: f8E3M4
+ %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
+ %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+ %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+ %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
+ ) {
+ // CHECK: spirv.Return
+ return
+ }
+}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
new file mode 100644
index 0000000..3e5f592
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+ // CHECK-LABEL: @fpclassify
+ func.func @fpclassify(%x: f32, %v: vector<4xf32>) {
+ // CHECK: spirv.IsFinite %{{.*}} : f32
+ %0 = math.isfinite %x : f32
+ // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32>
+ %1 = math.isfinite %v : vector<4xf32>
+
+ // CHECK: spirv.IsNan %{{.*}} : f32
+ %2 = math.isnan %x : f32
+ // CHECK: spirv.IsNan %{{.*}} : vector<4xf32>
+ %3 = math.isnan %v : vector<4xf32>
+
+ // CHECK: spirv.IsInf %{{.*}} : f32
+ %4 = math.isinf %x : f32
+ // CHECK: spirv.IsInf %{{.*}} : vector<4xf32>
+ %5 = math.isinf %v : vector<4xf32>
+
+ return
+ }
+
+}
diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir
new file mode 100644
index 0000000..1a74eaa
--- /dev/null
+++ b/mlir/test/Dialect/Async/canonicalize.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+// CHECK-NOT: async.execute
+
+func.func @empty_execute() {
+ %token = async.execute {
+ async.yield
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a00c798..5f42938 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// -----
+func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?xf32> to tensor<1x16xf32>
+ return %padded : tensor<1x16xf32>
+}
+// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] {
+// CHECK: ^bb0(%[[IDX:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<?xf32> to tensor<16xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32>
+// CHECK: return %[[EXPANDED]] : tensor<1x16xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+module {
+ func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
+ %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x1x61x1xf32>
+ return %1 : tensor<?x1x61x1xf32>
+ }
+}
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
@@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
// CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32>
// CHECK: }
-#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-module {
- func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
- %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
- %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
- %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x1x61x1xf32>
- return %1 : tensor<?x1x61x1xf32>
- }
-}
-
// -----
func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b6..981f5dc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
- // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+// CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+ // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+ // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+ // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+ // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+ return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+ // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed..f741876 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
- // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
- // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
//
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
- // CHECK: } -> tensor<8x14x13xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+ // CHECK: } -> tensor<8x14x12xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
//
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
^bb0(%in: f32, %out: f32):
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee892..d7722ea 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 12d30e17..308cf150 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() {
// -----
-// CHECK-LABEL: func @execute_region_elim
-func.func @execute_region_elim() {
+// CHECK-LABEL: func @execute_region_inline
+func.func @execute_region_inline() {
affine.for %i = 0 to 100 {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
@@ -1461,8 +1461,30 @@ func.func @execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim
-func.func @func_execute_region_elim() {
+// CHECK-LABEL: func @execute_region_no_inline
+func.func @execute_region_no_inline() {
+ affine.for %i = 0 to 100 {
+ "test.foo"() : () -> ()
+ %v = scf.execute_region -> i64 no_inline {
+ %x = "test.val"() : () -> i64
+ scf.yield %x : i64
+ }
+ "test.bar"(%v) : (i64) -> ()
+ }
+ return
+}
+
+// CHECK-NEXT: affine.for %arg0 = 0 to 100 {
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: scf.execute_region
+// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64
+// CHECK-NEXT: scf.yield %[[VAL]] : i64
+// CHECK-NEXT: }
+
+// -----
+
+// CHECK-LABEL: func @func_execute_region_inline
+func.func @func_execute_region_inline() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
@@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() {
// -----
-// CHECK-LABEL: func @func_execute_region_elim_multi_yield
-func.func @func_execute_region_elim_multi_yield() {
+// CHECK-LABEL: func @func_execute_region_inline_multi_yield
+func.func @func_execute_region_inline_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d6c3464..58b8288 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto
// -----
//===----------------------------------------------------------------------===//
+// spirv.IsFinite
+//===----------------------------------------------------------------------===//
+
+func.func @isfinite_scalar(%arg0: f32) -> i1 {
+ // CHECK: spirv.IsFinite {{.*}} : f32
+ %0 = spirv.IsFinite %arg0 : f32
+ return %0 : i1
+}
+
+func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+ // CHECK: spirv.IsFinite {{.*}} : vector<2xf32>
+ %0 = spirv.IsFinite %arg0 : vector<2xf32>
+ return %0 : vector<2xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spirv.IsInf
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 5d05a654..6d321af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
// CHECK: func private @struct_empty(!spirv.struct<()>)
func.func private @struct_empty(!spirv.struct<()>)
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
// -----
// expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b23766..8d7f3da 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes {
// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
// implicitly by v1.5.
-// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
spirv.module Logical Vulkan attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b90d6f5..3bccb32 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+ // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+ return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+ %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+ return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index cbe0056..bf9ed8a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
// -----
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
- return %0 : tensor<1x1x1x1x13x21x3xf32>
+ %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+ return %0 : tensor<1x1x1x1x13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9cfebd5..56996b5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> {
// -----
-// CHECK-LABEL: shape_cast_constant
+// CHECK-LABEL: shape_cast_splat_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32>
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
-func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
%cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32>
%cst_1 = arith.constant dense<1> : vector<12x2xi32>
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
@@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// Test of shape_cast's fold method:
+// shape_cast(constant) -> constant.
+//
+// CHECK-LABEL: @shape_cast_dense_int_constant
+// CHECK: %[[CST:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]>
+// CHECK: return %[[CST]] : vector<2x3xi8>
+func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> {
+ %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8>
+ %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8>
+ return %0 : vector<2x3xi8>
+}
+
+// -----
+
+// Test of shape_cast fold's method:
+// (shape_cast(const_x), const_x) -> (const_x_folded, const_x)
+//
+// CHECK-LABEL: @shape_cast_dense_float_constant
+// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32>
+// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32>
+// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32>
+func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){
+ %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32>
+ %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32>
+ return %0, %cst : vector<2xf32>, vector<1x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: shape_cast_poison
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 2563b48..b2f16bb 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> {
func.return %2 : vector<4x4xindex>
}
+// CHECK-LABEL: func @vector_transpose
+// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index}
+func.func @vector_transpose() -> vector<2x4xindex> {
+ %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex>
+ %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex>
+ %2 = test.reflect_bounds %1 : vector<2x4xindex>
+ func.return %2 : vector<2x4xindex>
+}
+
// CHECK-LABEL: func @vector_extract
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
func.func @vector_extract() -> index {
@@ -99,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> {
%2 = test.reflect_bounds %1 : vector<2xi32>
func.return %2 : vector<2xi32>
}
+
+// CHECK-LABEL: func @vector_step
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+func.func @vector_step() -> vector<8xindex> {
+ %0 = vector.step : vector<8xindex>
+ %1 = test.reflect_bounds %0 : vector<8xindex>
+ func.return %1 : vector<8xindex>
+}
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 8e167a5..d5e3443 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func @broadcast_vec1d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32>
// CHECK: return %[[T0]] : vector<2xf32>
func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
@@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> {
// CHECK-LABEL: func @broadcast_vec2d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
// CHECK: return %[[T0]] : vector<2x3xf32>
func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
// CHECK-LABEL: func @broadcast_vec3d_from_scalar
// CHECK-SAME: %[[A:.*0]]: f32
-// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32>
// CHECK: return %[[T0]] : vector<2x3x4xf32>
func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: func @broadcast_stretch
// CHECK-SAME: %[[A:.*0]]: vector<1xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32>
// CHECK: return %[[T1]] : vector<4xf32>
func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
@@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32>
// CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32>
+// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32>
-// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32>
+// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32>
// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32>
// CHECK: return %[[T15]] : vector<4x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
index 059d955..5a8125e 100644
--- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
@@ -5,11 +5,11 @@
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>
@@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
+// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
@@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
+// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
@@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32>
-// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
+// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32>
-// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
+// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
@@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
@@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
@@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>
// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
@@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
-// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
+// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdc..ef881ba 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
// -----
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+// CHECK-SAME: %[[ARG0:.+]]: f32)
+// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+// CHECK: return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
@@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
return %1 : vector<1xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.subi %cst, %0 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+ %2 = arith.mulf %0, %cst : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+ %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+ %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+ %cst = arith.constant dense<3> : vector<1x4xi32>
+ %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+ return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<3> : vector<3x4xi32>
+ %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+ return %2 : vector<3x4xf32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0160bfe..dff3ffa 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) {
}
// -----
+func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
+ return
+}
+
+// -----
+func.func @load_gather_offset_sg(%src: memref<?xf16>) {
+ %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<8xi1>
+ // expected-error@+1 {{Mask should match value except the chunk size dim}}
+ %2 = xegpu.load %src[%offsets], %mask
+ : memref<?xf16>, vector<4xindex>, vector<8xi1>
+ -> vector<4x2xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{value elements must match chunk size}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
+ %val = arith.constant dense<2.9>: vector<4xf16>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ %mask = arith.constant dense<1>: vector<1xi1>
+ // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
+ xegpu.store %val, %src[%offsets], %mask
+ : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_2(%src: ui64) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{value elements must match chunk size}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16>
+ return
+}
+
+// -----
+func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offsets = arith.constant dense<[0]> : vector<1xindex>
+ // expected-error@+1 {{Expecting the source is a 1D memref or pointer}}
+ %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32>
+ return
+}
+
+// -----
func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
%0 = arith.constant dense<1>: vector<4xi1>
%1 = arith.constant dense<2.9>: vector<4x2xf32>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3ebb1b969a..6be2371 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+ gpu.return
+}
+
// CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
gpu.func @subgroup_store(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
+ %val = arith.constant dense<2.9>: vector<4x2xf16>
+ %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %mask = arith.constant dense<1>: vector<4xi1>
+ //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+ : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
gpu.func @prefetch(%src: ui64) {
//CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) {
gpu.return
}
+// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
+gpu.func @prefetch_offset(%src: ui64) {
+ //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+ xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+ gpu.return
+}
// CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
gpu.func @create_update_tdesc(%src: ui64) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb4..628a485 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -2,122 +2,117 @@
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
+ // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT : xegpu.store_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @update_nd(%src: memref<24x32xf32>){
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
- // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
- gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+ gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+ -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+ -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x256xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: broadcast
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
+ gpu.func @broadcast(%src: memref<128x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+ -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK-COUNT-3: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<128x1xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<128x1xf32> to vector<128x64xf32>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d511224..d4b0037 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,201 +4,181 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
// CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C12:.*]] = arith.constant 12 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
// CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
- // CHECK: %[[C24:.*]] = arith.constant 24 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
- // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+ // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
+ // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+ // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
// CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
- // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-gpu.func @update_nd(%src: memref<24x32xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: dpas_no_sg_data
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.prefetch_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas_with_no_create_nd_desc
- gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) {
- // CHECK-NOT: vector<12x12xf32>
+ gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+ // CHECK-NOT: vector<32x32xf32>
%dpas = xegpu.dpas %a, %b
{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim1
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
- %broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<256x1xf32> to vector<256x32xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim0
- // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
- gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
- -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
- -> vector<1x32xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
- : vector<1x32xf32> to vector<12x32xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<1x128xf32> to vector<32x128xf32>
gpu.return
}
diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir
index e0adb4d82..5389691 100644
--- a/mlir/test/IR/top-level.mlir
+++ b/mlir/test/IR/top-level.mlir
@@ -6,10 +6,10 @@ func.func private @foo()
// -----
-// expected-error@-9 {{source must contain a single top-level operation, found: 2}}
+// expected-error@-2 {{source must contain a single top-level operation, found: 2}}
func.func private @bar()
func.func private @baz()
// -----
-// expected-error@-15 {{source must contain a single top-level operation, found: 0}}
+// expected-error@-2 {{source must contain a single top-level operation, found: 0}}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 24380b5..a419d75 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -570,10 +570,10 @@ define void @trap_intrinsics() {
; CHECK-LABEL: llvm.func @memcpy_test
define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
- call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+ call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false)
ret void
@@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
; CHECK-LABEL: llvm.func @memmove_test
define void @memmove_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false)
ret void
}
; CHECK-LABEL: llvm.func @memset_test
define void @memset_test(i32 %0, ptr %1, i8 %2) {
- ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
- call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
+ call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false)
ret void
diff --git a/mlir/test/Target/LLVMIR/Import/module-asm.ll b/mlir/test/Target/LLVMIR/Import/module-asm.ll
new file mode 100644
index 0000000..38f6ea4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/module-asm.ll
@@ -0,0 +1,5 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+; CHECK: llvm.module_asm = ["foo", "bar"]
+
+module asm "foo"
+module asm "bar"
diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir
index 7fd5f26..5ed6244 100644
--- a/mlir/test/Target/LLVMIR/invalid-module.mlir
+++ b/mlir/test/Target/LLVMIR/invalid-module.mlir
@@ -1,6 +1,16 @@
-// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s
+// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module -split-input-file %s
// expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}}
llvm.func @foo() {
llvm.return
}
+
+// -----
+
+// expected-error@below {{expected an array attribute for a module level asm}}
+module attributes {llvm.module_asm = "foo"} {}
+
+// -----
+
+// expected-error@below {{expected a string attribute for each entry of a module level asm}}
+module attributes {llvm.module_asm = [42]} {}
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 44074ce..eb3510c 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() {
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true
- "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true
+ "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true
"llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+
+ // Verify that trailing empty argument attribute dictionaries can be omitted.
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memmove_test
llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memset_test
llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) {
%i1 = llvm.mlir.constant(false) : i1
- // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true
- "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true
+ "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
// CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true
"llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
llvm.return
diff --git a/mlir/test/Target/LLVMIR/module-asm.mlir b/mlir/test/Target/LLVMIR/module-asm.mlir
new file mode 100644
index 0000000..2afb37c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/module-asm.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {llvm.module_asm = ["foo", "bar"]} {}
+
+// CHECK: module asm "foo"
+// CHECK: module asm "bar"
diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir
new file mode 100644
index 0000000..a3dd0b6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/xevm.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s
+
+module {
+ llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64)
+ llvm.func @prefetch(%arg0: !llvm.ptr<1>) {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm
+ // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+ llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0)
+ {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
+ no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64,
+ xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]}
+ : (!llvm.ptr<1>, i64) -> ()
+ llvm.return
+ }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6aca11e..1695d2a 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -307,6 +307,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
}
+ // CHECK-LABEL: @arm_tensor_of_i32
+ spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_i32
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @arm_tensor_of_f32
+ spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_f32
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index b200871..05cbddc 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%15 = spirv.IsNan %arg0 : f32
// CHECK: spirv.IsInf
%16 = spirv.IsInf %arg1 : f32
+ // CHECK: spirv.IsFinite
+ %17 = spirv.IsFinite %arg0 : f32
spirv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c39..786d07a2 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : f32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : f32
spirv.Return
}
- spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : i32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : i32
spirv.Return
}
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0b..4984ee7 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
- spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
- spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
- spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
- spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
// CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
+ // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+ spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
- spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+ spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
// CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe..8889b80 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
%5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
%6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
- // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
- %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+ // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+ %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
%8 = spirv.Constant 0 : i32
- %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Return
}
}
diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td
index 7cd24aa..af09ee7 100644
--- a/mlir/test/mlir-tblgen/op-properties-predicates.td
+++ b/mlir/test/mlir-tblgen/op-properties-predicates.td
@@ -70,6 +70,12 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> {
// CHECK-NEXT: if (!(((!prop.has_value())) || ((::llvm::all_of((*(prop)), [](const int64_t& baseStore) -> bool { return [](int64_t baseIface) -> bool { return ((baseIface >= 0)); }(baseStore); })) && (!(((*(prop)).empty()))))))
// CHECK: failed to satisfy constraint: optional non-empty array of non-negative int64_
+// CHECK-LABEL: ::llvm::LogicalResult OpWithPredicatesAdaptor::verify
+// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below
+// CHECK: int64_t tblgen_scalar = this->getScalar();
+// CHECK: if (!((tblgen_scalar >= 0)))
+// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t");
+
// CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl()
// Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused
// CHECK: [[maybe_unused:\[\[maybe_unused\]\]]] int64_t tblgen_scalar = this->getScalar();
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f35cfa6..8ea4eb7 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1127,7 +1127,7 @@ static void genPropertyVerifier(
body << formatv(fetchProperty, varName, getterName,
prop.prop.getInterfaceType());
auto uniquedFn = staticVerifierEmitter.getPropConstraintFn(prop.prop);
- if (uniquedFn.has_value())
+ if (uniquedFn.has_value() && emitHelper.isEmittingForOp())
body << formatv(verifyPropertyUniqued, *uniquedFn, varName, prop.name);
else
body << formatv(
@@ -4764,6 +4764,7 @@ void OpOperandAdaptorEmitter::addVerification() {
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
+ genPropertyVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter,
useProperties);