aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/DefiningDialects/AttributesAndTypes.md2
-rw-r--r--mlir/examples/standalone/standalone-opt/CMakeLists.txt14
-rw-r--r--mlir/examples/standalone/standalone-opt/standalone-opt.cpp2
-rw-r--r--mlir/examples/toy/Ch5/CMakeLists.txt9
-rw-r--r--mlir/examples/toy/Ch5/toyc.cpp1
-rw-r--r--mlir/examples/toy/Ch6/CMakeLists.txt11
-rw-r--r--mlir/examples/toy/Ch6/toyc.cpp1
-rw-r--r--mlir/examples/toy/Ch7/CMakeLists.txt11
-rw-r--r--mlir/examples/toy/Ch7/toyc.cpp1
-rw-r--r--mlir/examples/transform-opt/CMakeLists.txt10
-rw-r--r--mlir/examples/transform-opt/mlir-transform-opt.cpp1
-rw-r--r--mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td1
-rw-r--r--mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td1
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td63
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td109
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td18
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td1
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h7
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td70
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h40
-rw-r--r--mlir/include/mlir/IR/Diagnostics.h4
-rw-r--r--mlir/include/mlir/IR/StorageUniquerSupport.h2
-rw-r--r--mlir/include/mlir/InitAllDialects.h193
-rw-r--r--mlir/include/mlir/InitAllExtensions.h99
-rw-r--r--mlir/include/mlir/InitAllPasses.h86
-rw-r--r--mlir/include/mlir/Interfaces/CallInterfaces.td32
-rw-r--r--mlir/include/mlir/Support/ToolUtilities.h15
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/All.h3
-rw-r--r--mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h31
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleImport.h27
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h23
-rw-r--r--mlir/lib/CAPI/RegisterEverything/CMakeLists.txt11
-rw-r--r--mlir/lib/CMakeLists.txt34
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp9
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp11
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp52
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp61
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp8
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp1
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp20
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp51
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp17
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp59
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp118
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp15
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp7
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp92
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp21
-rw-r--r--mlir/lib/RegisterAllDialects.cpp207
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp115
-rw-r--r--mlir/lib/RegisterAllPasses.cpp99
-rw-r--r--mlir/lib/Support/ToolUtilities.cpp39
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp45
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp52
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp42
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp32
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h1
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp47
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp55
-rw-r--r--mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp7
-rw-r--r--mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir108
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir46
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir149
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir24
-rw-r--r--mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir51
-rw-r--r--mlir/test/Dialect/SPIRV/IR/types.mlir6
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir24
-rw-r--r--mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir16
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir6
-rw-r--r--mlir/test/Dialect/Vector/vector-sink.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir139
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir242
-rw-r--r--mlir/test/IR/diagnostic-nosplit.mlir13
-rw-r--r--mlir/test/IR/top-level.mlir4
-rw-r--r--mlir/test/Target/LLVMIR/Import/intrinsic.ll20
-rw-r--r--mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir24
-rw-r--r--mlir/test/Target/LLVMIR/xevm.mlir21
-rw-r--r--mlir/test/Target/SPIRV/constant.mlir28
-rw-r--r--mlir/test/Target/SPIRV/memory-ops.mlir20
-rw-r--r--mlir/test/Target/SPIRV/struct.mlir38
-rw-r--r--mlir/test/Target/SPIRV/undef.mlir6
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttrDefs.td1
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp10
-rw-r--r--mlir/test/mlir-tblgen/attrdefs.td5
-rw-r--r--mlir/tools/mlir-lsp-server/CMakeLists.txt21
-rw-r--r--mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp1
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt19
-rw-r--r--mlir/tools/mlir-pdll/mlir-pdll.cpp6
-rw-r--r--mlir/tools/mlir-query/CMakeLists.txt4
-rw-r--r--mlir/tools/mlir-reduce/CMakeLists.txt10
-rw-r--r--mlir/tools/mlir-rewrite/CMakeLists.txt10
-rw-r--r--mlir/tools/mlir-rewrite/mlir-rewrite.cpp1
-rw-r--r--mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp2
-rw-r--r--mlir/unittests/ExecutionEngine/CMakeLists.txt3
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp31
-rw-r--r--mlir/unittests/Target/LLVM/CMakeLists.txt4
115 files changed, 2574 insertions, 1248 deletions
diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md
index 022bdad..b991863 100644
--- a/mlir/docs/DefiningDialects/AttributesAndTypes.md
+++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md
@@ -136,7 +136,7 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
/// Here we've defined two parameters, one is a "self" type parameter, and the
/// other is the integer value of the attribute. The self type parameter is
/// specially handled by the assembly format.
- let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
/// Here we've defined a custom builder for the type, that removes the need to pass
/// in an MLIRContext instance; as it can be infered from the `type`.
diff --git a/mlir/examples/standalone/standalone-opt/CMakeLists.txt b/mlir/examples/standalone/standalone-opt/CMakeLists.txt
index 27f8128..4b38de7 100644
--- a/mlir/examples/standalone/standalone-opt/CMakeLists.txt
+++ b/mlir/examples/standalone/standalone-opt/CMakeLists.txt
@@ -1,12 +1,10 @@
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LIBS
- ${dialect_libs}
- ${conversion_libs}
- MLIRArithDialect
- MLIROptLib
- MLIRStandalone
- )
+ MLIRArithDialect
+ MLIROptLib
+ MLIRRegisterAllDialects
+ MLIRRegisterAllPasses
+ MLIRStandalone
+ )
add_llvm_executable(standalone-opt standalone-opt.cpp)
llvm_update_compile_flags(standalone-opt)
diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
index e39fa96..eebfcb7 100644
--- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
+++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt
index f4f0fec..454ca56 100644
--- a/mlir/examples/toy/Ch5/CMakeLists.txt
+++ b/mlir/examples/toy/Ch5/CMakeLists.txt
@@ -27,12 +27,8 @@ add_toy_chapter(toyc-ch5
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch5
PRIVATE
- ${dialect_libs}
- ${extension_libs}
MLIRAnalysis
MLIRCallInterfaces
MLIRCastInterfaces
@@ -40,6 +36,9 @@ target_link_libraries(toyc-ch5
MLIRIR
MLIRParser
MLIRPass
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
MLIRSideEffectInterfaces
MLIRSupport
- MLIRTransforms)
+ MLIRTransforms
+ )
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index 6a0c631..afdf782 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "toy/AST.h"
#include "toy/Dialect.h"
diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt
index 283b895..73df602 100644
--- a/mlir/examples/toy/Ch6/CMakeLists.txt
+++ b/mlir/examples/toy/Ch6/CMakeLists.txt
@@ -37,14 +37,8 @@ add_toy_chapter(toyc-ch6
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch6
PRIVATE
- ${dialect_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRAnalysis
MLIRBuiltinToLLVMIRTranslation
MLIRCallInterfaces
@@ -58,8 +52,11 @@ target_link_libraries(toyc-ch6
MLIRMemRefDialect
MLIRParser
MLIRPass
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
MLIRSideEffectInterfaces
MLIRSupport
MLIRTargetLLVMIRExport
MLIRTransforms
- )
+ )
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index dccab91..4a5e109 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "toy/AST.h"
diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt
index 362ab51..a489ae5 100644
--- a/mlir/examples/toy/Ch7/CMakeLists.txt
+++ b/mlir/examples/toy/Ch7/CMakeLists.txt
@@ -36,14 +36,8 @@ add_toy_chapter(toyc-ch7
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
target_link_libraries(toyc-ch7
PRIVATE
- ${dialect_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRAnalysis
MLIRBuiltinToLLVMIRTranslation
MLIRCallInterfaces
@@ -56,7 +50,10 @@ target_link_libraries(toyc-ch7
MLIRMemRefDialect
MLIRParser
MLIRPass
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
MLIRSideEffectInterfaces
MLIRTargetLLVMIRExport
MLIRTransforms
- )
+ )
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index dd86265..32208ecca 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "toy/AST.h"
diff --git a/mlir/examples/transform-opt/CMakeLists.txt b/mlir/examples/transform-opt/CMakeLists.txt
index 8e23555..07d58f6 100644
--- a/mlir/examples/transform-opt/CMakeLists.txt
+++ b/mlir/examples/transform-opt/CMakeLists.txt
@@ -1,18 +1,14 @@
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
-
set(LIBS
MLIRAnalysis
MLIRIR
MLIRParser
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
MLIRSupport
MLIRTransformDialect
MLIRTransformDialectTransforms
MLIRTransforms
- ${dialect_libs}
- ${conversion_libs}
- ${extension_libs}
)
add_mlir_tool(mlir-transform-opt
diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
index 1a29913..4b12e76 100644
--- a/mlir/examples/transform-opt/mlir-transform-opt.cpp
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -22,6 +22,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <cstdlib>
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e81db32..06fb851 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
+ /*bit requiresArgAndResultAttrs=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 8988df6..d055bb4 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -92,6 +92,7 @@ class ArmSVE_IntrOp<string mnemonic,
/*bit requiresAccessGroup=*/0,
/*bit requiresAliasAnalysis=*/0,
/*bit requiresFastmath=*/0,
+ /*bit requiresArgAndResultAttrs=*/0,
/*bit requiresOpBundles=*/0,
/*list<int> immArgPositions=*/immArgPositions,
/*list<string> immArgAttrNames=*/immArgAttrNames>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 8c6f1ee..d38298f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -140,8 +140,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">;
def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;
def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
/*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3],
- /*immArgAttrNames=*/["rw", "hint", "cache"]
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"]
> {
let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
}
@@ -200,13 +200,13 @@ class LLVM_MemcpyIntrOpBase<string name> :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
- /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
AnySignlessInteger:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$len,
"bool":$isVolatile), [{
@@ -217,7 +217,8 @@ class LLVM_MemcpyIntrOpBase<string name> :
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, src, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -231,13 +232,13 @@ def LLVM_MemcpyInlineOp :
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
- /*immArgAttrNames=*/["len", "isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
Arg<LLVM_AnyPointer,"",[MemRead]>:$src,
APIntAttr:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$src, "IntegerAttr":$len,
"bool":$isVolatile), [{
@@ -248,7 +249,8 @@ def LLVM_MemcpyInlineOp :
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, src, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -258,12 +260,12 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[3],
- /*immArgAttrNames=*/["isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$val, "Value":$len,
"bool":$isVolatile), [{
@@ -274,7 +276,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2],
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, val, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -284,12 +287,12 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>],
/*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3],
- /*immArgAttrNames=*/["len", "isVolatile"]> {
+ /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> {
dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst,
I8:$val, APIntAttr:$len, I1Attr:$isVolatile);
- // Append the alias attributes defined by LLVM_IntrOpBase.
- let arguments = !con(args, aliasAttrs);
+ // Append the arguments defined by LLVM_IntrOpBase.
+ let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len,
"bool":$isVolatile), [{
@@ -300,7 +303,8 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2],
"IntegerAttr":$isVolatile), [{
build($_builder, $_state, dst, val, len, isVolatile,
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
- /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
+ /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
+ /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];
}
@@ -349,8 +353,8 @@ def LLVM_PtrMaskOp
class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
- /*immArgAttrNames=*/["size"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
}
@@ -370,8 +374,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
[DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[1],
- /*immArgAttrNames=*/["size"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> {
let arguments = (ins LLVM_DefaultPointer:$start,
I64Attr:$size,
LLVM_AnyPointer:$ptr);
@@ -542,9 +546,10 @@ def LLVM_AssumeOp
: LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[],
/*requiresAccessGroup=*/0,
/*requiresAliasAnalysis=*/0,
+ /*requiresArgAndResultAttrs=*/0,
/*requiresOpBundles=*/1> {
dag args = (ins I1:$cond);
- let arguments = !con(args, opBundleArgs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$cond
@@ -1126,8 +1131,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">;
def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap",
/*overloadedOperands=*/[], /*traits=*/[],
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- /*requiresOpBundles=*/0, /*immArgPositions=*/[0],
- /*immArgAttrNames=*/["failureKind"]> {
+ /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0,
+ /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> {
let arguments = (ins I8Attr:$failureKind);
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index e845ea9f..a8d7cf2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
//===----------------------------------------------------------------------===//
// LLVM dialect type constraints.
@@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
// intrinsic and "enumName" contains the name of the intrinsic as appears in
// `llvm::Intrinsic` enum; one usually wants these to be related. Additionally,
// the base class also defines the "mlirBuilder" field to support the inverse
-// translation starting from an LLVM IR intrinsic. The "requiresAccessGroup",
-// "requiresAliasAnalysis", and "requiresFastmath" flags specify which
-// interfaces the intrinsic implements. If the corresponding flags are set, the
-// "aliasAttrs" list contains the arguments required by the access group and
-// alias analysis interfaces. Derived intrinsics should append the "aliasAttrs"
-// to their argument list if they set one of the flags. LLVM `immargs` can be
-// represented as MLIR attributes by providing both the `immArgPositions` and
-// `immArgAttrNames` lists. These two lists should have equal length, with
-// `immArgPositions` containing the argument positions on the LLVM IR attribute
-// that are `immargs`, and `immArgAttrNames` mapping these to corresponding
-// MLIR attributes.
+// translation starting from an LLVM IR intrinsic.
+//
+// The flags "requiresAccessGroup", "requiresAliasAnalysis",
+// "requiresFastmath", and "requiresArgAndResultAttrs" indicate which
+// interfaces the intrinsic implements. When a flag is set, the "baseArgs"
+// list includes the arguments required by the corresponding interface.
+// Derived intrinsics must append "baseArgs" to their argument list if they
+// enable any of these flags.
+//
+// LLVM `immargs` can be represented as MLIR attributes by providing both
+// the `immArgPositions` and `immArgAttrNames` lists. These two lists should
+// have equal length, with `immArgPositions` containing the argument
+// positions on the LLVM IR attribute that are `immargs`, and
+// `immArgAttrNames` mapping these to corresponding MLIR attributes.
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
- bit requiresFastmath = 0, bit requiresOpBundles = 0,
+ bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0,
+ bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_OpBase<dialect, opName, !listconcat(
@@ -311,10 +316,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
[DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []),
!if(!gt(requiresFastmath, 0),
[DeclareOpInterfaceMethods<FastmathFlagsInterface>], []),
+ !if(!gt(requiresArgAndResultAttrs, 0),
+ [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []),
traits)>,
LLVM_MemOpPatterns,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
- dag aliasAttrs = !con(
+ dag baseArgs = !con(
!if(!gt(requiresAccessGroup, 0),
(ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups),
(ins )),
@@ -322,13 +329,17 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
(ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes,
OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
+ (ins )),
+ !if(!gt(requiresArgAndResultAttrs, 0),
+ (ins OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs),
+ (ins )),
+ !if(!gt(requiresOpBundles, 0),
+ (ins VariadicOfVariadic<LLVM_Type,
+ "op_bundle_sizes">:$op_bundle_operands,
+ DenseI32ArrayAttr:$op_bundle_sizes,
+ OptionalAttr<ArrayAttr>:$op_bundle_tags),
(ins )));
- dag opBundleArgs = !if(!gt(requiresOpBundles, 0),
- (ins VariadicOfVariadic<LLVM_Type,
- "op_bundle_sizes">:$op_bundle_operands,
- DenseI32ArrayAttr:$op_bundle_sizes,
- OptionalAttr<ArrayAttr>:$op_bundle_tags),
- (ins ));
string llvmEnumName = enumName;
string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}";
string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}";
@@ -342,23 +353,35 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
(void) inst;
}];
+ string baseLlvmBuilderArgAndResultAttrs = [{
+ if (failed(moduleTranslation.convertArgAndResultAttrs(
+ op,
+ inst,
+ }] # immArgPositionsCpp # [{))) {
+ return failure();
+ }
+ }];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
- let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
- # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
- # baseLlvmBuilderCoda;
+ let llvmBuilder = baseLlvmBuilder
+ # !if(!gt(requiresAccessGroup, 0),
+ setAccessGroupsMetadataCode, "")
+ # !if(!gt(requiresAliasAnalysis, 0),
+ setAliasAnalysisMetadataCode, "")
+ # !if(!gt(requiresArgAndResultAttrs, 0),
+ baseLlvmBuilderArgAndResultAttrs, "")
+ # baseLlvmBuilderCoda;
string baseMlirBuilder = [{
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands,
- llvmOpBundles,
- }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
- }] # immArgPositionsCpp # [{,
- }] # immArgAttrNamesCpp # [{,
- mlirOperands,
- mlirAttrs))
- ) {
+ llvmOperands,
+ llvmOpBundles,
+ }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
+ }] # immArgPositionsCpp # [{,
+ }] # immArgAttrNamesCpp # [{,
+ mlirOperands,
+ mlirAttrs))) {
return failure();
}
SmallVector<Type> resultTypes =
@@ -366,9 +389,16 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
auto op = $_qualCppClassName::create($_builder,
$_location, resultTypes, mlirOperands, mlirAttrs);
}];
+ string baseMlirBuilderArgAndResultAttrs = [{
+ moduleImport.convertArgAndResultAttrs(
+ inst, op, }] # immArgPositionsCpp # [{);
+ }];
string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
- let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
+ let mlirBuilder = baseMlirBuilder
+ # !if(!gt(requiresFastmath, 0),
"moduleImport.setFastmathFlagsAttr(inst, op);", "")
+ # !if(!gt(requiresArgAndResultAttrs, 0),
+ baseMlirBuilderArgAndResultAttrs, "")
# baseMlirBuilderCoda;
// Code for handling a `range` attribute that holds the constant range of the
@@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
int numResults, bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
- bit requiresOpBundles = 0,
+ bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
numResults, requiresAccessGroup, requiresAliasAnalysis,
- requiresFastmath, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ requiresFastmath, requiresArgAndResultAttrs,
+ requiresOpBundles, immArgPositions, immArgAttrNames>;
// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
list<Trait> traits = [],
bit requiresAccessGroup = 0,
bit requiresAliasAnalysis = 0,
+ bit requiresArgAndResultAttrs = 0,
bit requiresOpBundles = 0,
list<int> immArgPositions = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
requiresAccessGroup, requiresAliasAnalysis,
- /*requiresFastMath=*/0, requiresOpBundles, immArgPositions,
- immArgAttrNames>;
+ /*requiresFastMath=*/0, requiresArgAndResultAttrs,
+ requiresOpBundles, immArgPositions, immArgAttrNames>;
// Base class for LLVM intrinsic operations returning one result. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+ requiresFastmath, /*requiresArgAndResultAttrs=*/0,
+ /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
// Base class for LLVM intrinsic operations returning two results. Places the
@@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [],
list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
- requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
+ requiresFastmath, /*requiresArgAndResultAttrs=*/0,
+ /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>;
def LLVM_OneResultOpBuilder :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 51004f5..3f27f6d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -2405,7 +2405,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
def LLVM_CallIntrinsicOp
: LLVM_Op<"call_intrinsic",
- [AttrSizedOperandSegments,
+ [ArgAndResultAttrsOpInterface,
+ AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 04a0b58..a2354e2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults,
LLVM_IntrOpBase<ROCDL_Dialect, mnemonic,
"amdgcn_" # !subst(".", "_", mnemonic), overloadedResults,
overloadedOperands, traits, numResults, requiresAccessGroup,
- requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>;
+ requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>;
// Subclass to save typing and ease readibility when there aren't overloaded
// operands or memory accesses.
@@ -482,7 +482,7 @@ def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -507,7 +507,7 @@ def ROCDL_LoadToLDSOp :
I32Attr:$size,
I32Attr:$offset,
I32Attr:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux
attr-dict `:` type($globalPtr)
@@ -526,7 +526,7 @@ def ROCDL_GlobalLoadLDSOp :
I32Attr:$size,
I32Attr:$offset,
I32Attr:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = [{
$globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux
attr-dict
@@ -561,7 +561,7 @@ def ROCDL_RawPtrBufferLoadOp :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -579,7 +579,7 @@ def ROCDL_RawPtrBufferLoadLdsOp :
I32:$soffset,
I32:$offset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -595,7 +595,7 @@ def ROCDL_RawPtrBufferStoreOp :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($vdata)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -614,7 +614,7 @@ def ROCDL_RawPtrBufferAtomicCmpSwap :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($res)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
@@ -630,7 +630,7 @@ class ROCDL_RawPtrBufferAtomicNoRet<string op> :
I32:$offset,
I32:$soffset,
I32:$aux);
- let arguments = !con(args, aliasAttrs);
+ let arguments = !con(args, baseArgs);
let assemblyFormat = "operands attr-dict `:` type($vdata)";
let extraClassDefinition = [{
::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8d45c40..61ce23f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac
iteration domain induces a padding of the operands that is consistent
across the op semantics and, unlike for simple elementwise ops, may not be
trivially deducible or specifiable on operands only (e.g. convolutions).
+ Currently, only a limited set of projected permutation maps are supported.
The specification of `padding_sizes` follows that of `tile_sizes` during
tiling: the value "0" on a particular iterator encode "no padding". Like in
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index e625eef..d4ffe0a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// affine.apply operations.
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permuation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 96b9adc..e1e99c3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+// These are parallelism determination modes for `acc loop`.
+// In the enum names, we use the "loop_" prefix because "auto" is
+// a language keyword - and thus for consistency all other cases
+// do the same.
+def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>;
+def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>;
+def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>;
+
+def OpenACC_LoopParMode : I32EnumAttr<
+ "LoopParMode",
+ "Encodes the options for loop parallelism determination mode",
+ [
+ OpenACC_LoopAuto, OpenACC_LoopIndependent,
+ OpenACC_LoopSeq]> {
+ let cppNamespace = "::mlir::acc";
+ let genSpecializedAttr = 0;
+}
+
// Type used in operation below.
def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>;
@@ -2373,6 +2391,11 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
// Return whether this LoopOp has a gang, worker, or vector applying to the
// 'default'/None device-type.
bool hasDefaultGangWorkerVector();
+
+ // Used to obtain the parallelism mode for the requested device type.
+ // This first checks if the mode is set for the device_type requested.
+ // And if not, it returns the non-device_type mode.
+ LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType);
}];
let hasCustomAssemblyFormat = 1;
@@ -2404,6 +2427,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
}];
let hasVerifier = 1;
+
+ let builders = [
+ OpBuilder<(ins "::mlir::ValueRange":$lowerbounds,
+ "::mlir::ValueRange":$upperbounds,
+ "::mlir::ValueRange":$steps,
+ "LoopParMode":$parMode), [{
+ auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get(
+ $_builder.getContext(), mlir::acc::DeviceType::None);
+ auto arrOfDeviceNone = mlir::ArrayAttr::get(
+ $_builder.getContext(), deviceNoneAttr);
+ build($_builder, $_state,
+ /*results=*/{},
+ /*lowerbound=*/lowerbounds,
+ /*upperbound=*/upperbounds,
+ /*step=*/steps,
+ /*inclusiveUpperbound=*/nullptr,
+ /*collapse=*/nullptr,
+ /*collapseDeviceType=*/nullptr,
+ /*gangOperands=*/{},
+ /*gangOperandsArgType=*/nullptr,
+ /*gangOperandsSegments=*/nullptr,
+ /*gangOperandsDeviceType=*/nullptr,
+ /*workerNumOperands=*/{},
+ /*workerNumOperandsDeviceType=*/nullptr,
+ /*vectorOperands=*/{},
+ /*vectorOperandsDeviceType=*/nullptr,
+ /*seq=*/parMode == LoopParMode::loop_seq ?
+ arrOfDeviceNone : nullptr,
+ /*independent=*/parMode == LoopParMode::loop_independent ?
+ arrOfDeviceNone : nullptr,
+ /*auto_=*/parMode == LoopParMode::loop_auto ?
+ arrOfDeviceNone : nullptr,
+ /*gang=*/nullptr,
+ /*worker=*/nullptr,
+ /*vector=*/nullptr,
+ /*tileOperands=*/{},
+ /*tileOperandsSegments=*/nullptr,
+ /*tileOperandsDeviceType=*/nullptr,
+ /*cacheOperands=*/{},
+ /*privateOperands=*/{},
+ /*privatizationRecipes=*/nullptr,
+ /*reductionOperands=*/{},
+ /*reductionRecipes=*/nullptr,
+ /*combined=*/nullptr);
+ }]
+ >
+ ];
}
// Yield operation for the acc.loop and acc.parallel operations.
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index c691d59..531fecc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -330,10 +330,34 @@ public:
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
};
+ // Type for specifying the decoration(s) on the struct itself.
+ struct StructDecorationInfo {
+ Decoration decoration;
+ Attribute decorationValue;
+
+ StructDecorationInfo(Decoration decoration, Attribute decorationValue)
+ : decoration(decoration), decorationValue(decorationValue) {}
+
+ friend bool operator==(const StructDecorationInfo &lhs,
+ const StructDecorationInfo &rhs) {
+ return lhs.decoration == rhs.decoration &&
+ lhs.decorationValue == rhs.decorationValue;
+ }
+
+ friend bool operator<(const StructDecorationInfo &lhs,
+ const StructDecorationInfo &rhs) {
+ return llvm::to_underlying(lhs.decoration) <
+ llvm::to_underlying(rhs.decoration);
+ }
+
+ bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
+ };
+
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
@@ -367,6 +391,9 @@ public:
bool hasOffset() const;
+ /// Returns true if the struct has a specified decoration.
+ bool hasDecoration(spirv::Decoration decoration) const;
+
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -380,12 +407,18 @@ public:
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
+ // Returns in `structDecorations` the Decorations associated with the
+ // StructType.
+ void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
+ &structDecorations) const;
+
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
- ArrayRef<MemberDecorationInfo> memberDecorations = {});
+ ArrayRef<MemberDecorationInfo> memberDecorations = {},
+ ArrayRef<StructDecorationInfo> structDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
@@ -396,6 +429,9 @@ public:
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
+llvm::hash_code
+hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
+
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 4ed0423..7ff718a 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -639,6 +639,10 @@ public:
/// verified correctly, failure otherwise.
LogicalResult verify();
+ /// Register this handler with the given context. This is intended for use
+ /// with the splitAndProcessBuffer function.
+ void registerInContext(MLIRContext *ctx);
+
private:
/// Process a single diagnostic.
void process(Diagnostic &diag);
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index 2162a74..8959dab 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -200,7 +200,7 @@ public:
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
- return UniquerT::template get<ConcreteT>(ctx, args...);
+ return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}
/// Get an instance of the concrete type from a void pointer.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 856170e..7628171 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -14,200 +14,15 @@
#ifndef MLIR_INITALLDIALECTS_H_
#define MLIR_INITALLDIALECTS_H_
-#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
-#include "mlir/Dialect/AMX/AMXDialect.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
-#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
-#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
-#include "mlir/Dialect/Async/IR/Async.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Complex/IR/Complex.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
-#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
-#include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
-#include "mlir/Dialect/IRDL/IR/IRDL.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
-#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
-#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
-#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
-#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/MPI/IR/MPI.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
-#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
-#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
-#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
-#include "mlir/Dialect/PDL/IR/PDL.h"
-#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
-#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
-#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
-#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
-#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Shard/IR/ShardDialect.h"
-#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
-#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
-#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
-#include "mlir/Dialect/UB/IR/UBOps.h"
-#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
-#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/Interfaces/CastInterfaces.h"
-#include "mlir/Target/LLVM/NVVM/Target.h"
-#include "mlir/Target/LLVM/ROCDL/Target.h"
-#include "mlir/Target/SPIRV/Target.h"
-
namespace mlir {
+class DialectRegistry;
+class MLIRContext;
/// Add all the MLIR dialects to the provided registry.
-inline void registerAllDialects(DialectRegistry &registry) {
- // clang-format off
- registry.insert<acc::OpenACCDialect,
- affine::AffineDialect,
- amdgpu::AMDGPUDialect,
- amx::AMXDialect,
- arith::ArithDialect,
- arm_neon::ArmNeonDialect,
- arm_sme::ArmSMEDialect,
- arm_sve::ArmSVEDialect,
- async::AsyncDialect,
- bufferization::BufferizationDialect,
- cf::ControlFlowDialect,
- complex::ComplexDialect,
- DLTIDialect,
- emitc::EmitCDialect,
- func::FuncDialect,
- gpu::GPUDialect,
- index::IndexDialect,
- irdl::IRDLDialect,
- linalg::LinalgDialect,
- LLVM::LLVMDialect,
- math::MathDialect,
- memref::MemRefDialect,
- shard::ShardDialect,
- ml_program::MLProgramDialect,
- mpi::MPIDialect,
- nvgpu::NVGPUDialect,
- NVVM::NVVMDialect,
- omp::OpenMPDialect,
- pdl::PDLDialect,
- pdl_interp::PDLInterpDialect,
- ptr::PtrDialect,
- quant::QuantDialect,
- ROCDL::ROCDLDialect,
- scf::SCFDialect,
- shape::ShapeDialect,
- smt::SMTDialect,
- sparse_tensor::SparseTensorDialect,
- spirv::SPIRVDialect,
- tensor::TensorDialect,
- tosa::TosaDialect,
- transform::TransformDialect,
- ub::UBDialect,
- vector::VectorDialect,
- x86vector::X86VectorDialect,
- xegpu::XeGPUDialect,
- xevm::XeVMDialect>();
- // clang-format on
-
- // Register all external models.
- affine::registerValueBoundsOpInterfaceExternalModels(registry);
- arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
- arith::registerBufferizableOpInterfaceExternalModels(registry);
- arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
- arith::registerShardingInterfaceExternalModels(registry);
- arith::registerValueBoundsOpInterfaceExternalModels(registry);
- bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
- registry);
- builtin::registerCastOpInterfaceExternalModels(registry);
- cf::registerBufferizableOpInterfaceExternalModels(registry);
- cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
- gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
- gpu::registerValueBoundsOpInterfaceExternalModels(registry);
- LLVM::registerInlinerInterface(registry);
- NVVM::registerInlinerInterface(registry);
- linalg::registerAllDialectInterfaceImplementations(registry);
- linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
- memref::registerAllocationOpInterfaceExternalModels(registry);
- memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
- memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
- memref::registerValueBoundsOpInterfaceExternalModels(registry);
- memref::registerMemorySlotExternalModels(registry);
- ml_program::registerBufferizableOpInterfaceExternalModels(registry);
- scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
- scf::registerBufferizableOpInterfaceExternalModels(registry);
- scf::registerValueBoundsOpInterfaceExternalModels(registry);
- shape::registerBufferizableOpInterfaceExternalModels(registry);
- sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
- tensor::registerBufferizableOpInterfaceExternalModels(registry);
- tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
- tensor::registerInferTypeOpInterfaceExternalModels(registry);
- tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
- tensor::registerSubsetOpInterfaceExternalModels(registry);
- tensor::registerTilingInterfaceExternalModels(registry);
- tensor::registerValueBoundsOpInterfaceExternalModels(registry);
- tosa::registerShardingInterfaceExternalModels(registry);
- vector::registerBufferizableOpInterfaceExternalModels(registry);
- vector::registerSubsetOpInterfaceExternalModels(registry);
- vector::registerValueBoundsOpInterfaceExternalModels(registry);
- NVVM::registerNVVMTargetInterfaceExternalModels(registry);
- ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
- spirv::registerSPIRVTargetInterfaceExternalModels(registry);
-}
+void registerAllDialects(DialectRegistry &registry);
/// Append all the MLIR dialects to the registry contained in the given context.
-inline void registerAllDialects(MLIRContext &context) {
- DialectRegistry registry;
- registerAllDialects(registry);
- context.appendDialectRegistry(registry);
-}
+void registerAllDialects(MLIRContext &context);
} // namespace mlir
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index d5a9a2c..a7f64d9 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -14,110 +14,15 @@
#ifndef MLIR_INITALLEXTENSIONS_H_
#define MLIR_INITALLEXTENSIONS_H_
-#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
-#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
-#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
-#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
-#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
-#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
-#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
-#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
-#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
-#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
-#include "mlir/Dialect/AMX/Transforms.h"
-#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
-#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
-#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
-#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
-#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
-#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
-#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
-#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
-#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
-#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
-#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
-#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
-#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
-#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
-#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
-#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
-#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
-#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
-#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
-
-#include <cstdlib>
-
namespace mlir {
+class DialectRegistry;
/// This function may be called to register all MLIR dialect extensions with the
/// provided registry.
/// If you're building a compiler, you generally shouldn't use this: you would
/// individually register the specific extensions that are useful for the
/// pipelines and transformations you are using.
-inline void registerAllExtensions(DialectRegistry &registry) {
- // Register all conversions to LLVM extensions.
- registerConvertArithToEmitCInterface(registry);
- arith::registerConvertArithToLLVMInterface(registry);
- registerConvertComplexToLLVMInterface(registry);
- cf::registerConvertControlFlowToLLVMInterface(registry);
- func::registerAllExtensions(registry);
- tensor::registerAllExtensions(registry);
- registerConvertFuncToEmitCInterface(registry);
- registerConvertFuncToLLVMInterface(registry);
- index::registerConvertIndexToLLVMInterface(registry);
- registerConvertMathToLLVMInterface(registry);
- mpi::registerConvertMPIToLLVMInterface(registry);
- registerConvertMemRefToEmitCInterface(registry);
- registerConvertMemRefToLLVMInterface(registry);
- registerConvertNVVMToLLVMInterface(registry);
- registerConvertOpenMPToLLVMInterface(registry);
- registerConvertSCFToEmitCInterface(registry);
- ub::registerConvertUBToLLVMInterface(registry);
- registerConvertAMXToLLVMInterface(registry);
- gpu::registerConvertGpuToLLVMInterface(registry);
- NVVM::registerConvertGpuToNVVMInterface(registry);
- vector::registerConvertVectorToLLVMInterface(registry);
- registerConvertXeVMToLLVMInterface(registry);
-
- // Register all transform dialect extensions.
- affine::registerTransformDialectExtension(registry);
- bufferization::registerTransformDialectExtension(registry);
- dlti::registerTransformDialectExtension(registry);
- func::registerTransformDialectExtension(registry);
- gpu::registerTransformDialectExtension(registry);
- linalg::registerTransformDialectExtension(registry);
- memref::registerTransformDialectExtension(registry);
- nvgpu::registerTransformDialectExtension(registry);
- scf::registerTransformDialectExtension(registry);
- sparse_tensor::registerTransformDialectExtension(registry);
- tensor::registerTransformDialectExtension(registry);
- transform::registerDebugExtension(registry);
- transform::registerIRDLExtension(registry);
- transform::registerLoopExtension(registry);
- transform::registerPDLExtension(registry);
- transform::registerTuneExtension(registry);
- vector::registerTransformDialectExtension(registry);
- arm_neon::registerTransformDialectExtension(registry);
- arm_sve::registerTransformDialectExtension(registry);
-
- // Translation extensions need to be registered by calling
- // `registerAllToLLVMIRTranslations` (see All.h).
-}
+void registerAllExtensions(DialectRegistry &registry);
} // namespace mlir
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 002ff61..4554290 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -1,4 +1,4 @@
-//===- LinkAllPassesAndDialects.h - MLIR Registration -----------*- C++ -*-===//
+//===- InitAllPasses.h - MLIR Registration ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,50 +6,14 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines a helper to trigger the registration of all dialects and
-// passes to the system.
+// This file defines a helper to trigger the registration of all passes to the
+// system.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INITALLPASSES_H_
#define MLIR_INITALLPASSES_H_
-#include "mlir/Conversion/Passes.h"
-#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
-#include "mlir/Dialect/Affine/Passes.h"
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
-#include "mlir/Dialect/Async/Passes.h"
-#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
-#include "mlir/Dialect/EmitC/Transforms/Passes.h"
-#include "mlir/Dialect/Func/Transforms/Passes.h"
-#include "mlir/Dialect/GPU/Pipelines/Passes.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
-#include "mlir/Dialect/Math/Transforms/Passes.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
-#include "mlir/Dialect/Quant/Transforms/Passes.h"
-#include "mlir/Dialect/SCF/Transforms/Passes.h"
-#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
-#include "mlir/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/Dialect/Shard/Transforms/Passes.h"
-#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
-#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
-#include "mlir/Dialect/Tensor/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Transform/Transforms/Passes.h"
-#include "mlir/Dialect/Vector/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Transforms/Passes.h"
-
-#include <cstdlib>
-
namespace mlir {
// This function may be called to register the MLIR passes with the
@@ -59,49 +23,7 @@ namespace mlir {
// registry, since it would already be calling the creation routine of the
// individual passes.
// The global registry is interesting to interact with the command-line tools.
-inline void registerAllPasses() {
- // General passes
- registerTransformsPasses();
-
- // Conversion passes
- registerConversionPasses();
-
- // Dialect passes
- acc::registerOpenACCPasses();
- affine::registerAffinePasses();
- amdgpu::registerAMDGPUPasses();
- registerAsyncPasses();
- arith::registerArithPasses();
- bufferization::registerBufferizationPasses();
- func::registerFuncPasses();
- registerGPUPasses();
- registerLinalgPasses();
- registerNVGPUPasses();
- registerSparseTensorPasses();
- LLVM::registerLLVMPasses();
- math::registerMathPasses();
- memref::registerMemRefPasses();
- shard::registerShardPasses();
- ml_program::registerMLProgramPasses();
- quant::registerQuantPasses();
- registerSCFPasses();
- registerShapePasses();
- spirv::registerSPIRVPasses();
- tensor::registerTensorPasses();
- tosa::registerTosaOptPasses();
- transform::registerTransformPasses();
- vector::registerVectorPasses();
- arm_sme::registerArmSMEPasses();
- arm_sve::registerArmSVEPasses();
- emitc::registerEmitCPasses();
- xegpu::registerXeGPUPasses();
-
- // Dialect pipelines
- bufferization::registerBufferizationPipelines();
- sparse_tensor::registerSparseTensorPipelines();
- tosa::registerTosaToLinalgPipelines();
- gpu::registerGPUToNVVMPipeline();
-}
+void registerAllPasses();
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index e3c2aec..19d3afe 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -18,9 +18,15 @@
include "mlir/IR/OpBase.td"
-/// Interface for operations with arguments attributes (both call-like
-/// and callable operations).
-def ArgumentAttributesMethods {
+/// Interface for operations with result and argument attributes.
+def ArgAndResultAttrsOpInterface : OpInterface<"ArgAndResultAttrsOpInterface"> {
+ let description = [{
+ An operation that has argument and result attributes. This interface
+ provides functions to access and modify the argument and result
+ attributes of the operation.
+ }];
+ let cppNamespace = "::mlir";
+
list<InterfaceMethod> methods = [
InterfaceMethod<[{
Get the array of argument attribute dictionaries. The method should
@@ -64,7 +70,8 @@ def ArgumentAttributesMethods {
// a call-like operation. This represents the destination of the call.
/// Interface for call-like operations.
-def CallOpInterface : OpInterface<"CallOpInterface"> {
+def CallOpInterface : OpInterface<"CallOpInterface",
+ [ArgAndResultAttrsOpInterface]> {
let description = [{
A call-like operation is one that transfers control from one sub-routine to
another. These operations may be traditional direct calls `call @foo`, or
@@ -123,11 +130,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
return ::mlir::call_interface_impl::resolveCallable($_op);
}]
>
- ] # ArgumentAttributesMethods.methods;
+ ];
}
/// Interface for callable operations.
-def CallableOpInterface : OpInterface<"CallableOpInterface"> {
+def CallableOpInterface : OpInterface<"CallableOpInterface",
+ [ArgAndResultAttrsOpInterface]> {
let description = [{
A callable operation is one who represents a potential sub-routine, and may
be a target for a call-like operation (those providing the CallOpInterface
@@ -140,11 +148,11 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
let methods = [
InterfaceMethod<[{
- Returns the region on the current operation that is callable. This may
- return null in the case of an external callable object, e.g. an external
- function.
- }],
- "::mlir::Region *", "getCallableRegion">,
+ Returns the region on the current operation that is callable. This may
+ return null in the case of an external callable object, e.g. an external
+ function.
+ }],
+ "::mlir::Region *", "getCallableRegion">,
InterfaceMethod<[{
Returns the callable's argument types based exclusively on the type (to
allow for this method may be called on function declarations).
@@ -155,7 +163,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
allow for this method may be called on function declarations).
}],
"::llvm::ArrayRef<::mlir::Type>", "getResultTypes">,
- ] # ArgumentAttributesMethods.methods;
+ ];
}
#endif // MLIR_INTERFACES_CALLINTERFACES
diff --git a/mlir/include/mlir/Support/ToolUtilities.h b/mlir/include/mlir/Support/ToolUtilities.h
index cb6ba29..657f117 100644
--- a/mlir/include/mlir/Support/ToolUtilities.h
+++ b/mlir/include/mlir/Support/ToolUtilities.h
@@ -21,10 +21,16 @@
namespace llvm {
class MemoryBuffer;
+class MemoryBufferRef;
} // namespace llvm
namespace mlir {
+// A function that processes a chunk of a buffer and writes the result to an
+// output stream.
using ChunkBufferHandler = function_ref<LogicalResult(
+ std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
+ const llvm::MemoryBufferRef &sourceBuffer, raw_ostream &os)>;
+using NoSourceChunkBufferHandler = function_ref<LogicalResult(
std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os)>;
extern inline const char *const kDefaultSplitMarker = "// -----";
@@ -45,6 +51,15 @@ splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
ChunkBufferHandler processChunkBuffer, raw_ostream &os,
llvm::StringRef inputSplitMarker = kDefaultSplitMarker,
llvm::StringRef outputSplitMarker = "");
+
+/// Same as above, but for case where the original buffer is not used while
+/// processing the chunk.
+LogicalResult
+splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
+ NoSourceChunkBufferHandler processChunkBuffer,
+ raw_ostream &os,
+ llvm::StringRef inputSplitMarker = kDefaultSplitMarker,
+ llvm::StringRef outputSplitMarker = "");
} // namespace mlir
#endif // MLIR_SUPPORT_TOOLUTILITIES_H
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
index 60615cf6..e4670cb 100644
--- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h
@@ -28,6 +28,7 @@
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
namespace mlir {
class DialectRegistry;
@@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
registerVCIXDialectTranslation(registry);
+ registerXeVMDialectTranslation(registry);
// Extension required for translating GPU offloading Ops.
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
@@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
registerNVVMDialectTranslation(registry);
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
+ registerXeVMDialectTranslation(registry);
// Extension required for translating GPU offloading Ops.
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h
new file mode 100644
index 0000000..b4f6750
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h
@@ -0,0 +1,31 @@
+//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for XeVM dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the XeVM dialect and the translation from it to the LLVM IR in the
+/// given registry;
+void registerXeVMDialectTranslation(mlir::DialectRegistry &registry);
+
+/// Register the XeVM dialect and the translation from it in the registry
+/// associated with the given context.
+void registerXeVMDialectTranslation(mlir::MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 17ef8e4..b22ed60 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -291,10 +291,12 @@ public:
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
- /// Converts the parameter and result attributes in `argsAttr` and `resAttr`
- /// and add them to the `callOp`.
- void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr,
- ArrayAttr &resAttr, OpBuilder &builder);
+ /// Converts the argument and result attributes attached to `call` and adds
+ /// them to `attrsOp`. For intrinsic calls, filters out attributes
+ /// corresponding to immediate arguments specified by `immArgPositions`.
+ void convertArgAndResultAttrs(llvm::CallBase *call,
+ ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions = {});
/// Whether the importer should try to convert all intrinsics to
/// llvm.call_intrinsic instead of dialect supported operations.
@@ -378,19 +380,12 @@ private:
bool &isIncompatibleCall);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
- /// Converts the parameter and result attributes attached to `func` and adds
+ /// Converts the argument and result attributes attached to `func` and adds
/// them to the `funcOp`.
- void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
- OpBuilder &builder);
- /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
- /// DictionaryAttr for the LLVM dialect.
- DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder);
- /// Converts the parameter and result attributes attached to `call` and adds
- /// them to the `callOp`. Implemented in terms of the the public definition of
- /// convertParameterAttributes.
- void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp,
- OpBuilder &builder);
+ void convertArgAndResultAttrs(llvm::Function *func, LLVMFuncOp funcOp);
+ /// Converts the argument or result attributes in `llvmAttrSet` to a
+ /// corresponding MLIR LLVM dialect attribute dictionary.
+ DictionaryAttr convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
/// Converts the attributes attached to `inst` and adds them to the `op`.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index f3f73f4..eb7dfa7 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -25,11 +25,13 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/FPEnv.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class BasicBlock;
+class CallBase;
+class CanonicalLoopInfo;
class Function;
class IRBuilderBase;
class OpenMPIRBuilder;
@@ -306,10 +308,16 @@ public:
/*recordInsertions=*/false);
}
- /// Translates parameter attributes of a call and adds them to the returned
- /// AttrBuilder. Returns failure if any of the translations failed.
- FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
- DictionaryAttr paramAttrs);
+ /// Converts argument and result attributes from `attrsOp` to LLVM IR
+ /// attributes on the `call` instruction. Returns failure if conversion fails.
+ /// The `immArgPositions` parameter is only relevant for intrinsics. It
+ /// specifies the positions of immediate arguments, which do not have
+ /// associated argument attributes in MLIR and should be skipped during
+ /// attribute mapping.
+ LogicalResult
+ convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp,
+ llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions = {});
/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
@@ -389,6 +397,11 @@ private:
convertDialectAttributes(Operation *op,
ArrayRef<llvm::Instruction *> instructions);
+ /// Translates parameter attributes of a call and adds them to the returned
+ /// AttrBuilder. Returns failure if any of the translations failed.
+ FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc,
+ DictionaryAttr paramAttrs);
+
/// Translates parameter attributes of a function and adds them to the
/// returned AttrBuilder. Returns failure if any of the translations failed.
FailureOr<llvm::AttrBuilder>
diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
index 8b9a395..ccda668 100644
--- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
+++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
@@ -1,19 +1,16 @@
# Dialect registration.
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything
RegisterEverything.cpp
LINK_LIBS PUBLIC
- ${dialect_libs}
${translation_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRBuiltinToLLVMIRTranslation
MLIRCAPIIR
- MLIRLLVMToLLVMIRTranslation
MLIRCAPITransforms
+ MLIRLLVMToLLVMIRTranslation
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
)
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index d25c84a..191b5ab6 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -20,3 +20,37 @@ add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Transforms)
add_subdirectory(ExecutionEngine)
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
+
+add_mlir_library(MLIRRegisterAllDialects
+ RegisterAllDialects.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllPasses
+ RegisterAllPasses.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs} # Some passes are part of the dialect libs
+ ${conversion_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllExtensions
+ RegisterAllExtensions.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ ${conversion_libs}
+ ${extension_libs}
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 855c582..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -22,7 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -32,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 6ba5bfe4..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5d13353..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -26,13 +26,12 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index fd40e7c..fa9e544 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -36,7 +36,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "shard-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a425eff..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8d7053c..22608a1 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -26,7 +26,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
@@ -40,7 +40,6 @@ using llvm::divideFloorSigned;
using llvm::mod;
#define DEBUG_TYPE "affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
@@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
AffineMap *map,
ValueRange dims,
ValueRange syms) {
+ LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
AffineMap affineMinMap = minOp.getAffineMap();
- LLVM_DEBUG({
- DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
- });
-
// Check the value is positive.
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
// Compare each expression in the minimum against 0.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index cffe310..52cd0ce 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 935aa3c..b951df8 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -22,6 +22,8 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
#define DEBUG_TYPE "llvm-inliner"
using namespace mlir;
@@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
bool wouldBeCloned) const final {
auto callOp = dyn_cast<LLVM::CallOp>(call);
if (!callOp) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
- << LLVM::CallOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: call is not an '"
+ << LLVM::CallOp::getOperationName() << "' op";
return false;
}
if (callOp.getNoInline()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n");
+ LDBG() << "Cannot inline: call is marked no_inline";
return false;
}
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!funcOp) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: callable is not an '"
- << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: callable is not an '"
+ << LLVM::LLVMFuncOp::getOperationName() << "' op";
return false;
}
if (funcOp.isNoInline()) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: function is marked no_inline\n");
+ LDBG() << "Cannot inline: function is marked no_inline";
return false;
}
if (funcOp.isVarArg()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
+ LDBG() << "Cannot inline: callable is variadic";
return false;
}
// TODO: Generate aliasing metadata from noalias result attributes.
if (auto attrs = funcOp.getArgAttrs()) {
for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": inalloca arguments not supported\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": inalloca arguments not supported";
return false;
}
}
}
// TODO: Handle exceptions.
if (funcOp.getPersonality()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": unhandled function personality\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": unhandled function personality";
return false;
}
if (funcOp.getPassthrough()) {
@@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
if (!stringAttr)
return false;
if (disallowedFunctionAttrs.contains(stringAttr)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline " << funcOp.getSymName()
- << ": found disallowed function attribute "
- << stringAttr << "\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": found disallowed function attribute " << stringAttr;
return true;
}
return false;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1..73ae029 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.m.begin(), dimensions.m.end());
- llvm::sort(dimensions.n.begin(), dimensions.n.end());
- llvm::sort(dimensions.k.begin(), dimensions.k.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.m);
+ llvm::sort(dimensions.n);
+ llvm::sort(dimensions.k);
return dimensions;
}
@@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
- llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
- llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
- llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
- llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.outputImage);
+ llvm::sort(dimensions.outputChannel);
+ llvm::sort(dimensions.filterLoop);
+ llvm::sort(dimensions.inputChannel);
+ llvm::sort(dimensions.depth);
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b56a212..34c63d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+ if (!defBroadcastOp)
+ return failure();
+ ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ SmallVector<int64_t> foldedDims(dimensions);
+ Value init = broadcastOp.getInit();
+ int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensions, dim))
+ dimMap.push_back(dim);
+ }
+ for (auto dim : defDimensions)
+ foldedDims.push_back(dimMap[dim]);
+
+ llvm::sort(foldedDims);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2c62cb6..2e62523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ // Add 1 to the maximum accessed index and get the final padded size.
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e73bdd3..9d5dfc1 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() {
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
}
+acc::LoopParMode
+acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
+ if (hasSeq(deviceType))
+ return LoopParMode::loop_seq;
+ if (hasAuto(deviceType))
+ return LoopParMode::loop_auto;
+ if (hasIndependent(deviceType))
+ return LoopParMode::loop_independent;
+ if (hasSeq())
+ return LoopParMode::loop_seq;
+ if (hasAuto())
+ return LoopParMode::loop_auto;
+ assert(hasIndependent() &&
+ "loop must have default auto, seq, or independent");
+ return LoopParMode::loop_independent;
+}
+
void acc::LoopOp::addGangOperands(
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 9bee200..fcf1526 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(decoration.value(), decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 46739bc..ddb3426 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -835,12 +835,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 81365b4..3911ec0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = cast<spirv::PointerType>(varType);
- auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
+ Type pointeeType = varPtrType.getPointeeType();
+
+ // Images are an opaque type and so we can just return a pointer to an image.
+ // Note that currently only sampled images are supported in the SPIR-V
+ // lowering.
+ if (isa<spirv::SampledImageType>(pointeeType))
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+
+ auto varPointeeType = cast<spirv::StructType>(pointeeType);
// Set the offset information.
varPointeeType =
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6a9b951..a53d0a7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() {
if (walkResult.wasInterrupted())
return signalPassFailure();
+ // Update min version requirement for capabilities after deducing them.
+ for (spirv::Capability cap : deducedCapabilities) {
+ if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
+ deducedVersion = std::max(deducedVersion, *minVersion);
+ if (deducedVersion > allowedVersion) {
+ module.emitError("Capability '")
+ << spirv::stringifyCapability(cap) << "' requires min version "
+ << spirv::stringifyVersion(deducedVersion)
+ << " but target environment allows up to "
+ << spirv::stringifyVersion(allowedVersion);
+ return signalPassFailure();
+ }
+ }
+ }
+
// TODO: verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index e5a3b5d..08fccfa 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -38,7 +38,6 @@
#include <utility>
#define DEBUG_TYPE "shard-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::shard;
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f36..9543fa1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (failed(maybeProfDef) && failed(maybeExtDef))
+ return success();
- if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
- !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+ const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 48d680c..c707f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -25,12 +25,10 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-transfer-opt"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
using namespace mlir;
/// Return the ancestor op in the region or nullptr if the region is not
@@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
/// transfer_write is dead if all reads that can be reached from the potentially
/// dead transfer_write are dominated by the overwriting transfer_write.
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
- LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
- << "\n");
+ LDBG() << "Candidate for dead store: " << *write.getOperation();
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
@@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
- LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
- << *accessAncestor << "\n");
+ LDBG() << "Store may not be dead due to op: " << *accessAncestor;
return;
}
}
- LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
- << " overwritten by: " << *firstOverwriteCandidate << "\n");
+ LDBG() << "Found dead store: " << *write.getOperation()
+ << " overwritten by: " << *firstOverwriteCandidate;
opToErase.push_back(write.getOperation());
}
@@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
- LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
- << "\n");
+ LDBG() << "Candidate for Forwarding: " << *read.getOperation();
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
@@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
- LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
- << *write << "\n");
+ LDBG() << "Fail to do write to read forwarding due to op: " << *write;
return;
}
}
- LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
- << " to: " << *read.getOperation() << "\n");
+ LDBG() << "Forward value from " << *lastwrite.getOperation()
+ << " to: " << *read.getOperation();
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fe..7500bf7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -965,6 +965,28 @@ private:
std::function<bool(BitCastOp)> controlFn;
};
+static bool haveSameShapeAndScaling(Type t, Type u) {
+ auto tVec = dyn_cast<VectorType>(t);
+ auto uVec = dyn_cast<VectorType>(u);
+ if (!tVec) {
+ return !uVec;
+ }
+ if (!uVec) {
+ return false;
+ }
+ return tVec.getShape() == uVec.getShape() &&
+ tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return shapedType.clone(newElementType);
+ }
+ return newElementType;
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -988,16 +1010,14 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
- if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
- if (op->getResults()[0].getType() != op->getOperand(0).getType())
- return rewriter.notifyMatchFailure(op,
- "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
@@ -1005,25 +1025,38 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ Type resultElemType = resultType.getElementType();
+ // Get the type of the first non-constant operand
+ Operation *firstBroadcastOrSplat = nullptr;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
+ return failure();
+ firstBroadcastOrSplat = definingOp;
+ break;
+ }
+ if (!firstBroadcastOrSplat)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type unbroadcastResultType = cloneOrReplace(
+ firstBroadcastOrSplat->getOperand(0).getType(), resultElemType);
- // Make sure that all operands are broadcast from identical types:
+ // Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
+ if (!llvm::all_of(op->getOperands(), [&unbroadcastResultType](Value val) {
+ if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+ return haveSameShapeAndScaling(bcastOp.getOperand().getType(),
+ unbroadcastResultType);
+ if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+ return haveSameShapeAndScaling(splatOp.getOperand().getType(),
+ unbroadcastResultType);
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
})) {
return failure();
}
@@ -1032,18 +1065,33 @@ struct ReorderElementwiseOpsOnBroadcast final
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ Type elementType = getElementTypeOrSelf(operand.getType());
+ Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+ if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
+ newConst = splatConst.resizeSplat(newTypeShaped);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, newType, operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ unbroadcastResultType, op->getAttrs());
// Replace the original Op with the elementwise Op
- auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- op, vectorType, elementwiseOp->getResults());
+ op, resultType, elementwiseOp->getResults());
return success();
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad29..de52fbd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -40,7 +40,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
@@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
return failure();
});
if (failed(verify(op))) {
- LLVM_DEBUG(llvm::dbgs()
- << DEBUG_TYPE << ": '" << op->getName()
- << "' failed to verify and will be printed in generic form\n");
+ LDBG() << op->getName()
+ << "' failed to verify and will be printed in generic form";
printerFlags.printGenericOpForm();
}
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 3e33795..776b5c6 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
(void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
- // Register a handler to verify the diagnostics.
- setHandler([&](Diagnostic &diag) {
- // Process the main diagnostics.
- process(diag);
-
- // Process each of the notes.
- for (auto &note : diag.getNotes())
- process(note);
- });
+ registerInContext(ctx);
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
@@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
return impl->status;
}
+void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) {
+ ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) {
+ // Process the main diagnostics.
+ process(diag);
+
+ // Process each of the notes.
+ for (auto &note : diag.getNotes())
+ process(note);
+ });
+}
+
/// Process a single diagnostic.
void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
return process(diag.getLocation(), diag.str(), diag.getSeverity());
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
new file mode 100644
index 0000000..7a345ed
--- /dev/null
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -0,0 +1,207 @@
+//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialects and
+// passes to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllDialects.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
+
+/// Add all the MLIR dialects to the provided registry.
+void mlir::registerAllDialects(DialectRegistry &registry) {
+ // clang-format off
+ registry.insert<acc::OpenACCDialect,
+ affine::AffineDialect,
+ amdgpu::AMDGPUDialect,
+ amx::AMXDialect,
+ arith::ArithDialect,
+ arm_neon::ArmNeonDialect,
+ arm_sme::ArmSMEDialect,
+ arm_sve::ArmSVEDialect,
+ async::AsyncDialect,
+ bufferization::BufferizationDialect,
+ cf::ControlFlowDialect,
+ complex::ComplexDialect,
+ DLTIDialect,
+ emitc::EmitCDialect,
+ func::FuncDialect,
+ gpu::GPUDialect,
+ index::IndexDialect,
+ irdl::IRDLDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ math::MathDialect,
+ memref::MemRefDialect,
+ shard::ShardDialect,
+ ml_program::MLProgramDialect,
+ mpi::MPIDialect,
+ nvgpu::NVGPUDialect,
+ NVVM::NVVMDialect,
+ omp::OpenMPDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ ptr::PtrDialect,
+ quant::QuantDialect,
+ ROCDL::ROCDLDialect,
+ scf::SCFDialect,
+ shape::ShapeDialect,
+ smt::SMTDialect,
+ sparse_tensor::SparseTensorDialect,
+ spirv::SPIRVDialect,
+ tensor::TensorDialect,
+ tosa::TosaDialect,
+ transform::TransformDialect,
+ ub::UBDialect,
+ vector::VectorDialect,
+ x86vector::X86VectorDialect,
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
+ // clang-format on
+
+ // Register all external models.
+ affine::registerValueBoundsOpInterfaceExternalModels(registry);
+ arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
+ arith::registerValueBoundsOpInterfaceExternalModels(registry);
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ builtin::registerCastOpInterfaceExternalModels(registry);
+ cf::registerBufferizableOpInterfaceExternalModels(registry);
+ cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerValueBoundsOpInterfaceExternalModels(registry);
+ LLVM::registerInlinerInterface(registry);
+ NVVM::registerInlinerInterface(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
+ linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerValueBoundsOpInterfaceExternalModels(registry);
+ memref::registerMemorySlotExternalModels(registry);
+ ml_program::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerValueBoundsOpInterfaceExternalModels(registry);
+ shape::registerBufferizableOpInterfaceExternalModels(registry);
+ sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
+ tensor::registerInferTypeOpInterfaceExternalModels(registry);
+ tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
+ tensor::registerTilingInterfaceExternalModels(registry);
+ tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerSubsetOpInterfaceExternalModels(registry);
+ vector::registerValueBoundsOpInterfaceExternalModels(registry);
+ NVVM::registerNVVMTargetInterfaceExternalModels(registry);
+ ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+ spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+}
+
+/// Append all the MLIR dialects to the registry contained in the given context.
+void mlir::registerAllDialects(MLIRContext &context) {
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
new file mode 100644
index 0000000..8f7c67c
--- /dev/null
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -0,0 +1,115 @@
+//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialect
+// extensions to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllExtensions.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
+#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
+#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+
+/// This function may be called to register all MLIR dialect extensions with the
+/// provided registry.
+/// If you're building a compiler, you generally shouldn't use this: you would
+/// individually register the specific extensions that are useful for the
+/// pipelines and transformations you are using.
+void mlir::registerAllExtensions(DialectRegistry &registry) {
+ // Register all conversions to LLVM extensions.
+ registerConvertArithToEmitCInterface(registry);
+ arith::registerConvertArithToLLVMInterface(registry);
+ registerConvertComplexToLLVMInterface(registry);
+ cf::registerConvertControlFlowToLLVMInterface(registry);
+ func::registerAllExtensions(registry);
+ tensor::registerAllExtensions(registry);
+ registerConvertFuncToEmitCInterface(registry);
+ registerConvertFuncToLLVMInterface(registry);
+ index::registerConvertIndexToLLVMInterface(registry);
+ registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
+ registerConvertMemRefToEmitCInterface(registry);
+ registerConvertMemRefToLLVMInterface(registry);
+ registerConvertNVVMToLLVMInterface(registry);
+ registerConvertOpenMPToLLVMInterface(registry);
+ registerConvertSCFToEmitCInterface(registry);
+ ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
+ gpu::registerConvertGpuToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMInterface(registry);
+ vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
+
+ // Register all transform dialect extensions.
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ dlti::registerTransformDialectExtension(registry);
+ func::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ sparse_tensor::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ transform::registerDebugExtension(registry);
+ transform::registerIRDLExtension(registry);
+ transform::registerLoopExtension(registry);
+ transform::registerPDLExtension(registry);
+ transform::registerTuneExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
+
+ // Translation extensions need to be registered by calling
+ // `registerAllToLLVMIRTranslations` (see All.h).
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
new file mode 100644
index 0000000..1ed3a37
--- /dev/null
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -0,0 +1,99 @@
+//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all passes to the
+// system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllPasses.h"
+
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Transforms/Passes.h"
+
+// This function may be called to register the MLIR passes with the
+// global registry.
+// If you're building a compiler, you likely don't need this: you would build a
+// pipeline programmatically without the need to register with the global
+// registry, since it would already be calling the creation routine of the
+// individual passes.
+// The global registry is interesting to interact with the command-line tools.
+void mlir::registerAllPasses() {
+ // General passes
+ registerTransformsPasses();
+
+ // Conversion passes
+ registerConversionPasses();
+
+ // Dialect passes
+ acc::registerOpenACCPasses();
+ affine::registerAffinePasses();
+ amdgpu::registerAMDGPUPasses();
+ registerAsyncPasses();
+ arith::registerArithPasses();
+ bufferization::registerBufferizationPasses();
+ func::registerFuncPasses();
+ registerGPUPasses();
+ registerLinalgPasses();
+ registerNVGPUPasses();
+ registerSparseTensorPasses();
+ LLVM::registerLLVMPasses();
+ math::registerMathPasses();
+ memref::registerMemRefPasses();
+ shard::registerShardPasses();
+ ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
+ registerSCFPasses();
+ registerShapePasses();
+ spirv::registerSPIRVPasses();
+ tensor::registerTensorPasses();
+ tosa::registerTosaOptPasses();
+ transform::registerTransformPasses();
+ vector::registerVectorPasses();
+ arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
+ emitc::registerEmitCPasses();
+ xegpu::registerXeGPUPasses();
+
+ // Dialect pipelines
+ bufferization::registerBufferizationPipelines();
+ sparse_tensor::registerSparseTensorPipelines();
+ tosa::registerTosaToLinalgPipelines();
+ gpu::registerGPUToNVVMPipeline();
+}
diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp
index 748f928..2cf33eb 100644
--- a/mlir/lib/Support/ToolUtilities.cpp
+++ b/mlir/lib/Support/ToolUtilities.cpp
@@ -14,6 +14,8 @@
#include "mlir/Support/LLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <utility>
using namespace mlir;
@@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
ChunkBufferHandler processChunkBuffer,
raw_ostream &os, llvm::StringRef inputSplitMarker,
llvm::StringRef outputSplitMarker) {
+ llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef();
// If splitting is disabled, we process the full input buffer.
if (inputSplitMarker.empty())
- return processChunkBuffer(std::move(originalBuffer), os);
+ return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os);
const int inputSplitMarkerLen = inputSplitMarker.size();
- auto *origMemBuffer = originalBuffer.get();
SmallVector<StringRef, 8> rawSourceBuffers;
const int checkLen = 2;
// Split dropping the last checkLen chars to enable flagging near misses.
- origMemBuffer->getBuffer().split(rawSourceBuffers,
- inputSplitMarker.drop_back(checkLen));
+ originalBufferRef.getBuffer().split(rawSourceBuffers,
+ inputSplitMarker.drop_back(checkLen));
if (rawSourceBuffers.empty())
return success();
@@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
auto interleaveFn = [&](StringRef subBuffer) {
auto splitLoc = SMLoc::getFromPointer(subBuffer.data());
unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first;
- auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy(
- subBuffer, Twine("within split at ") +
- origMemBuffer->getBufferIdentifier() + ":" +
- Twine(splitLine) + " offset ");
- if (failed(processChunkBuffer(std::move(subMemBuffer), os)))
+ std::string name((Twine("within split at ") +
+ originalBufferRef.getBufferIdentifier() + ":" +
+ Twine(splitLine) + " offset ")
+ .str());
+ // Use MemoryBufferRef to avoid copying the buffer & keep at same location
+ // relative to the original buffer.
+ auto subMemBuffer =
+ llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name),
+ /*RequiresNullTerminator=*/false);
+ if (failed(
+ processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os)))
hadFailure = true;
};
llvm::interleave(sourceBuffers, os, interleaveFn,
@@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
// If any fails, then return a failure of the tool.
return failure(hadFailure);
}
+
+LogicalResult
+mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
+ NoSourceChunkBufferHandler processChunkBuffer,
+ raw_ostream &os, llvm::StringRef inputSplitMarker,
+ llvm::StringRef outputSplitMarker) {
+ auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
+ const llvm::MemoryBufferRef &, raw_ostream &os) {
+ return processChunkBuffer(std::move(chunkBuffer), os);
+ };
+ return splitAndProcessBuffer(std::move(originalBuffer), process, os,
+ inputSplitMarker, outputSplitMarker);
+}
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ff34a08..0f675a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
-static LogicalResult
-convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
- ArrayAttr resAttrsArray, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (argAttrsArray) {
- for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
- if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
- !argAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, argAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addParamAttrs(argIdx, *attrBuilder);
- }
- }
- }
-
- if (resAttrsArray && resAttrsArray.size() > 0) {
- if (resAttrsArray.size() != 1)
- return mlir::emitError(loc, "llvm.func cannot have multiple results");
- if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
- !resAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, resAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addRetAttrs(*attrBuilder);
- }
- }
- return success();
-}
-
-static LogicalResult
-convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- return convertParameterAndResultAttrs(
- callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
- moduleTranslation);
-}
-
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));
- if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
- op.getResAttrsAttr(), inst,
- moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
return failure();
if (op.getNumResults() == 1)
@@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getInlineHintAttr())
call->addFnAttr(llvm::Attribute::InlineHint);
- if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
return failure();
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
@@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
- if (failed(
- convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index 1c9e226..55e73e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "llvm/IR/ConstantRange.h"
using namespace mlir;
using namespace mlir::NVVM;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f18199..49e1e55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
- llvm::sort(indices.begin(), indices.end(),
- [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
-
- if (aIndex == bIndex)
- continue;
-
- if (aIndex < bIndex)
- return first;
-
- if (aIndex > bIndex)
- return !first;
- }
-
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
+ llvm::sort(indices, [&](const size_t a, const size_t b) {
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+ for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+
+ if (aIndex == bIndex)
+ continue;
+
+ if (aIndex < bIndex)
+ return first;
+
+ if (aIndex > bIndex)
+ return !first;
+ }
+
+ // Iterated the up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ return memberIndicesA.size() < memberIndicesB.size();
+ });
return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
index 580afdd..cb1f234 100644
--- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
@@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
+ llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false,
+ /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands,
+ mlirAttrs)))
return failure();
Type resultType = moduleImport.convertType(inst->getType());
@@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
ValueRange{mlirOperands}, FastmathFlagsAttr{});
moduleImport.setFastmathFlagsAttr(inst, op);
-
- ArrayAttr argsAttr, resAttr;
- moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
- op.setArgAttrsAttr(argsAttr);
- op.setResAttrsAttr(resAttr);
+ moduleImport.convertArgAndResultAttrs(inst, op);
// Update importer tracking of results.
unsigned numRes = op.getNumResults();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 58e3c44..a207cce 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2267,7 +2267,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
- convertParameterAttributes(callInst, callOp, builder);
+ convertArgAndResultAttrs(callInst, callOp);
return callOp.getOperation();
}();
@@ -2364,7 +2364,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ convertArgAndResultAttrs(invokeInst, invokeOp);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
@@ -2730,11 +2730,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
}
DictionaryAttr
-ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder) {
+ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) {
SmallVector<NamedAttribute> paramAttrs;
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
+ auto llvmAttr = llvmAttrSet.getAttribute(llvmKind);
// Skip attributes that are not attached.
if (!llvmAttr.isValid())
continue;
@@ -2769,13 +2768,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
return builder.getDictionaryAttr(paramAttrs);
}
-void ModuleImport::convertParameterAttributes(llvm::Function *func,
- LLVMFuncOp funcOp,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(llvm::Function *func,
+ LLVMFuncOp funcOp) {
auto llvmAttrs = func->getAttributes();
for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
- funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
+ funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs));
}
// Convert the result attributes and attach them wrapped in an ArrayAttribute
// to the funcOp.
@@ -2783,17 +2781,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
if (!llvmResAttr.hasAttributes())
return;
funcOp.setResAttrsAttr(
- builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
+ builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)}));
}
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- ArrayAttr &argsAttr,
- ArrayAttr &resAttr,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(
+ llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions) {
+ // Compute the set of immediate argument positions.
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ // Convert the argument attributes and filter out immediate arguments.
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ // Skip immediate arguments.
+ if (immArgPositionsSet.contains(i))
+ continue;
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
@@ -2807,24 +2811,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
- argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argsAttr = getArrayAttr(argAttrs);
+ argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs));
+ attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
+ // Convert the result attributes.
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
- DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
- resAttr = getArrayAttr({resAttrs});
-}
-
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- CallOpInterface callOp,
- OpBuilder &builder) {
- ArrayAttr argsAttr, resAttr;
- convertParameterAttributes(call, argsAttr, resAttr, builder);
- callOp.setArgAttrsAttr(argsAttr);
- callOp.setResAttrsAttr(resAttr);
+ DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr);
+ attrsOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}
template <typename Op>
@@ -2892,7 +2888,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
builder, loc, func->getName(), functionType,
convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
- convertParameterAttributes(func, funcOp, builder);
+ convertArgAndResultAttrs(func, funcOp);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
funcOp.setPersonalityAttr(personality);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b997e55..2685b5c9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+LogicalResult ModuleTranslation::convertArgAndResultAttrs(
+ ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions) {
+ // Convert the argument attributes.
+ if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) {
+ unsigned argAttrIdx = 0;
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) {
+ if (argAttrIdx >= argAttrsArray.size())
+ break;
+ // Skip immediate arguments (they have no entries in argAttrsArray).
+ if (immArgPositionsSet.contains(argIdx))
+ continue;
+ // Skip empty argument attributes.
+ auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]);
+ if (argAttrs.empty())
+ continue;
+ // Convert and add attributes to the call instruction.
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ // Convert the result attributes.
+ if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) {
+ if (!resAttrsArray.empty()) {
+ auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+ }
+
+ return success();
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e5934bb..88931b5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ structDecorationsInfo.emplace_back(decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd..db1cc3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index a8a2b2e..737f296 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -318,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -630,11 +631,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -704,6 +710,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
@@ -938,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
+ } else if (isa<spirv::TensorArmType>(constType)) {
+ numberOfConstituents = shapedType.getNumElements();
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ uint32_t elementID = 0;
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
+ elementID =
+ elementType.isInteger(1)
+ ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
+ : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
+ }
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
+ elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
+ }
+ if (!elementID) {
+ return 0;
+ }
+ operands.push_back(elementID);
+ }
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f78590..bdcdaa4 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -508,13 +508,20 @@ performActions(raw_ostream &os,
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult processBuffer(raw_ostream &os,
- std::unique_ptr<MemoryBuffer> ownedBuffer,
- const MlirOptMainConfig &config,
- DialectRegistry &registry,
- llvm::ThreadPoolInterface *threadPool) {
+static LogicalResult
+processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
+ llvm::MemoryBufferRef sourceBuffer,
+ const MlirOptMainConfig &config, DialectRegistry &registry,
+ SourceMgrDiagnosticVerifierHandler *verifyHandler,
+ llvm::ThreadPoolInterface *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
+ // Add the original buffer to the source manager to use for determining
+ // locations.
+ sourceMgr->AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(sourceBuffer,
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
@@ -522,6 +529,8 @@ static LogicalResult processBuffer(raw_ostream &os,
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
+ if (verifyHandler)
+ verifyHandler->registerInContext(&context);
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
@@ -545,17 +554,12 @@ static LogicalResult processBuffer(raw_ostream &os,
return performActions(os, sourceMgr, &context, config);
}
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
- *sourceMgr, &context, config.verifyDiagnosticsLevel());
-
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, sourceMgr, &context, config);
- // Verify the diagnostic handler to make sure that each of the diagnostics
- // matched.
- return sourceMgrHandler.verify();
+ return success();
}
std::pair<std::string, std::string>
@@ -624,14 +628,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
+ // Note: this creates a verifier handler independent of the the flag set, as
+ // internally if the flag is not set, a new scoped diagnostic handler is
+ // created which would intercept the diagnostics and verify them.
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+ sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel());
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
- raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), config, registry,
- threadPool);
+ llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) {
+ return processBuffer(
+ os, std::move(chunkBuffer), sourceBuffer, config, registry,
+ config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr,
+ threadPool);
};
- return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
- config.inputSplitMarker(),
- config.outputSplitMarker());
+ LogicalResult status = splitAndProcessBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ chunkFn, outputStream, config.inputSplitMarker(),
+ config.outputSplitMarker());
+ if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify()))
+ status = failure();
+ return status;
}
LogicalResult mlir::MlirOptMain(int argc, char **argv,
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index c11cb8d..e1c8afb 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -135,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
+ // Many of the translations expect a null-terminated buffer while splitting
+ // the buffer does not guarantee null-termination. Make a copy of the buffer
+ // to ensure null-termination.
+ if (!ownedBuffer->getBuffer().ends_with('\0')) {
+ ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy(
+ ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier());
+ }
// Temporary buffers for chained translation processing.
std::string dataIn;
std::string dataOut;
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c59..ae59f28 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
//CHECK-LABEL: @abs_caller
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
return %rf, %rd : f32, f64
}
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+ %af = complex.angle %f : complex<f32>
+ // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+ %ad = complex.angle %d : complex<f64>
+ // CHECK: return %[[AF]], %[[AD]]
+ return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+ %cf = complex.cos %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+ %cd = complex.cos %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf, %cd : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @exp_caller
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
// CHECK: return %[[EF]], %[[ED]]
return %ef, %ed : complex<f32>, complex<f64>
}
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+ %lf = complex.log %f : complex<f32>
+ // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+ %ld = complex.log %d : complex<f64>
+ // CHECK: return %[[LF]], %[[LD]]
+ return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+ %cf2 = complex.conj %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+ %cd2 = complex.conj %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+ %pf = complex.pow %f, %f : complex<f32>
+ // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+ %pd = complex.pow %d, %d : complex<f64>
+ // CHECK: return %[[PF]], %[[PD]]
+ return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+ %sf2 = complex.sin %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+ %sd2 = complex.sin %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+ %sf = complex.sqrt %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+ %sd = complex.sqrt %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+ %tf2 = complex.tan %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+ %td2 = complex.tan %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+ %tf = complex.tanh %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+ %td = complex.tanh %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf, %td : complex<f32>, complex<f64>
+}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 39a7b1b..5c5f7e8 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
// -----
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x3xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x3xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x3xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [2]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_broadcast_fold
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
+// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
+// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
+// CHECK-NOT: linalg.broadcast
+// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
+func.func @broadcast_broadcast_fold(%input: tensor<2xf32>,
+ %init1: tensor<2x4xf32>,
+ %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ %broadcast1 = linalg.broadcast
+ ins(%input: tensor<2xf32>)
+ outs(%init1: tensor<2x4xf32>)
+ dimensions = [1]
+ %broadcast2 = linalg.broadcast
+ ins(%broadcast1: tensor<2x4xf32>)
+ outs(%init2: tensor<2x3x4xf32>)
+ dimensions = [1]
+ func.return %broadcast2 : tensor<2x3x4xf32>
+}
+
+// -----
+
func.func @transpose_1d(%input: tensor<16xf32>,
%init: tensor<16xf32>) -> tensor<16xf32> {
%transpose = linalg.transpose
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
index 78619b6..981f5dc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir
@@ -52,22 +52,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0]
// CHECK: : tensor<7x5xf32> to tensor<9x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] {
- // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -83,7 +83,7 @@ module {
// -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: pad_conv
+func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)>
+
+// CHECK-LABEL: pad_conv_dynamic
+func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> {
+
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32>
+ // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12]
+ // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]]
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0]
+ // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+ // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32>
+ return %0 : tensor<1x14x?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_strided
+func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12]
+ // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: pad_conv_dilated
+func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
+
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12]
+ // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12]
+ // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32>
+ // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0]
+ // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32>
+ // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32>
+
+ %0 = linalg.conv_2d_nhwc_fhwc
+ {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>)
+ outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
+ return %0 : tensor<1x14x14x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of {
+ padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index 26c03ed..f741876 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -69,22 +69,22 @@ module {
// CHECK-LABEL: @generic
// CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>,
-// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>)
- func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>)
+ func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> {
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
// CHECK: : tensor<7x5xf32> to tensor<8x5xf32>
// CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
- // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32>
// CHECK-NEXT: linalg.generic
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
- %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32>
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
- } -> tensor<7x11x12xf32>
- return %0 : tensor<7x11x12xf32>
+ } -> tensor<7x11x11xf32>
+ return %0 : tensor<7x11x11xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -102,7 +102,7 @@ module {
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)>
#map = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -127,13 +127,13 @@ module {
// CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32>
// CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]]
// CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] {
- // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32>
+ // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32>
//
// CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32>
// CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]]
- // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) {
- // CHECK: } -> tensor<8x14x13xf32>
- // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32>
+ // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) {
+ // CHECK: } -> tensor<8x14x12xf32>
+ // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32>
//
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) {
^bb0(%in: f32, %out: f32):
diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee892..d7722ea 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
// CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_10]] : tensor<1x4xf32>
// CHECK: }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 5d05a654..6d321af 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve
// CHECK: func private @struct_empty(!spirv.struct<()>)
func.func private @struct_empty(!spirv.struct<()>)
+// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>)
+
+// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>)
+
// -----
// expected-error @+1 {{offset specification must be given for all members}}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index bd51a07..f3a3218 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s
// CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
// CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
} // end spirv.module
+
+// -----
+
+module {
+ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} {
+ // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ // CHECK: spirv.func @read_image
+ spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+ // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+ %cst0_i32 = spirv.Constant 0 : i32
+ // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]]
+ %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+ %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+ %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32>
+ %cst0_i32_0 = spirv.Constant 0 : i32
+ %cst0_i32_1 = spirv.Constant 0 : i32
+ %cst1_i32 = spirv.Constant 1 : i32
+ %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ spirv.Store "StorageBuffer" %4, %3 : f32
+ spirv.Return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b23766..8d7f3da 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes {
// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled
// implicitly by v1.5.
-// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]>
spirv.module Logical Vulkan attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b90d6f5..3bccb32 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+
+func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
+ // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
+ return %0 : tensor<1x12x11xf32>
+}
+
+// -----
+
+func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}}
+ %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>)
+ return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index cbe0056..bf9ed8a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens
// -----
-func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> {
+func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> {
// expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32>
- return %0 : tensor<1x1x1x1x13x21x3xf32>
+ %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32>
+ return %0 : tensor<1x1x1x1x13x21x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdc..ef881ba 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> {
// -----
-// The source and the result for arith.cmp have different types - not supported
-
-// CHECK-LABEL: func.func @negative_source_and_result_mismatch
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
+// The source and the result for arith.cmp have different types
+
+// CHECK-LABEL: func.func @source_and_result_mismatch(
+// CHECK-SAME: %[[ARG0:.+]]: f32)
+// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]]
+// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1>
+// CHECK: return %[[BROADCAST]]
+func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> {
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
return %1 : vector<1xi1>
@@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
return %1 : vector<1xf32>
}
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<2> : vector<1x4xindex>
+ %2 = arith.subi %cst, %0 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+ %2 = arith.mulf %0, %cst : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+ %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+ %2 = arith.addi %0, %cst : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16>
+ %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> {
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16>
+ %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32>
+ return %1 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32>
+// CHECK: return %[[BCAST]] : vector<1x4xf32>
+
+func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32>
+ %cst = arith.constant dense<3> : vector<1x4xi32>
+ %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32>
+ return %2 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32>
+// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+ %cst = arith.constant dense<3> : vector<3x4xi32>
+ %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32>
+ return %2 : vector<3x4xf32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d67bdb4..628a485 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -2,122 +2,117 @@
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
// CHECK-NOT: xegpu.load_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
+ // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT : xegpu.store_nd
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @update_nd(%src: memref<24x32xf32>){
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
- // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @update_nd(%src: memref<256x128xf32>){
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
+ // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
// CHECK-NOT: xegpu.update_nd_offset
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
- // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
- gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+ gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
- // CHECK-NOT: xegpu.create_nd_tdesc
- // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.create_nd_tdesc
// CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
// CHECK-NOT: xegpu.dpas
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+ -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+ -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
- -> vector<8x8xf32>
- %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
- -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x256xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
- // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
+ // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK-NOT: xegpu.prefetch_nd
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: broadcast
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
+ gpu.func @broadcast(%src: memref<128x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+ -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK-COUNT-3: vector.broadcast {{.*}}
- // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<128x1xf32>
+ // CHECK-COUNT-2: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<128x1xf32> to vector<128x64xf32>
gpu.return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d511224..d4b0037 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,201 +4,181 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
// CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C12:.*]] = arith.constant 12 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+ // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
// CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
- // CHECK: %[[C24:.*]] = arith.constant 24 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
+ // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+ // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
- // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK: %[[C256:.*]] = arith.constant 256 : index
+ // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+ // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
+ // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+ // CHECK: %[[C128:.*]] = arith.constant 128 : index
+ // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+ // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+ // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: load_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
gpu.return
}
// CHECK-LABEL: store_nd
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @store_nd(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @store_nd(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ // CHECK-SAME: -> vector<32x32xf32>
// CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
- // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
xegpu.store_nd %load, %tdesc
- : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-gpu.func @update_nd(%src: memref<24x32xf32>){
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+gpu.func @update_nd(%src: memref<256x128xf32>){
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16]
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: dpas_no_sg_data
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
-gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
- // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<12x8xf32>
- // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
- // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
- // CHECK-SAME: -> vector<8x12xf32>
- // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
- %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
%load_a = xegpu.load_nd %tdesc_a
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
- -> vector<24x32xf32>
- %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
- -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+ -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
%load_b = xegpu.load_nd %tdesc_b
- : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
- -> vector<32x24xf32>
+ : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+ order = [1, 0]>>
+ -> vector<128x128xf16>
%dpas = xegpu.dpas %load_a, %load_b
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+ : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
gpu.return
}
// CHECK-LABEL: prefetch_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
- gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: xegpu.prefetch_nd %[[TDESC]]
- // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
- -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.prefetch_nd %tdesc
- : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
// CHECK-LABEL: dpas_with_no_create_nd_desc
- gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) {
- // CHECK-NOT: vector<12x12xf32>
+ gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+ // CHECK-NOT: vector<32x32xf32>
%dpas = xegpu.dpas %a, %b
{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
- : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+ : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim1
- // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
- -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
- -> vector<24x1xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
- %broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
- : vector<24x1xf32> to vector<24x8xf32>
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+ : vector<256x1xf32> to vector<256x32xf32>
gpu.return
}
// CHECK-LABEL: broadcast_dim0
- // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
- gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
- -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
- : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
- -> vector<1x32xf32>
- // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
- // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
%broadcast = vector.broadcast %load
- {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
- : vector<1x32xf32> to vector<12x32xf32>
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<1x128xf32> to vector<32x128xf32>
gpu.return
}
diff --git a/mlir/test/IR/diagnostic-nosplit.mlir b/mlir/test/IR/diagnostic-nosplit.mlir
new file mode 100644
index 0000000..ecfb9c6
--- /dev/null
+++ b/mlir/test/IR/diagnostic-nosplit.mlir
@@ -0,0 +1,13 @@
+// RUN: not mlir-opt %s -o - --split-input-file 2>&1 | FileCheck %s
+// This test verifies that diagnostic handler doesn't emit splits.
+
+
+// -----
+
+
+
+func.func @constant_out_of_range() {
+ // CHECK: mlir:11:8: error: 'arith.constant'
+ %x = "arith.constant"() {value = 100} : () -> i1
+ return
+}
diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir
index b571d94..e0adb4d82 100644
--- a/mlir/test/IR/top-level.mlir
+++ b/mlir/test/IR/top-level.mlir
@@ -6,10 +6,10 @@ func.func private @foo()
// -----
-// expected-error@-3 {{source must contain a single top-level operation, found: 2}}
+// expected-error@-9 {{source must contain a single top-level operation, found: 2}}
func.func private @bar()
func.func private @baz()
// -----
-// expected-error@-3 {{source must contain a single top-level operation, found: 0}}
+// expected-error@-15 {{source must contain a single top-level operation, found: 0}}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 24380b5..a419d75 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -570,10 +570,10 @@ define void @trap_intrinsics() {
; CHECK-LABEL: llvm.func @memcpy_test
define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
- call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+ call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false)
ret void
@@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) {
; CHECK-LABEL: llvm.func @memmove_test
define void @memmove_test(i32 %0, ptr %1, ptr %2) {
- ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false)
ret void
}
; CHECK-LABEL: llvm.func @memset_test
define void @memset_test(i32 %0, ptr %1, i8 %2) {
- ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false)
- ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
- call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false)
+ ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false)
+ ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
+ call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false)
; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false)
ret void
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 44074ce..eb3510c 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() {
// CHECK-LABEL: @memcpy_test
llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
- // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true
- "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true
+ "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true
"llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> ()
+
+ // Verify that trailing empty argument attribute dictionaries can be omitted.
+ // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memmove_test
llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
- // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
llvm.return
}
// CHECK-LABEL: @memset_test
llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) {
%i1 = llvm.mlir.constant(false) : i1
- // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
- "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
- // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true
- "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false
+ "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true
+ "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> ()
// CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true
"llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> ()
llvm.return
diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir
new file mode 100644
index 0000000..a3dd0b6
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/xevm.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s
+
+module {
+ llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64)
+ llvm.func @prefetch(%arg0: !llvm.ptr<1>) {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm
+ // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+ llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0)
+ {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
+ no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64,
+ xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]}
+ : (!llvm.ptr<1>, i64) -> ()
+ llvm.return
+ }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 6aca11e..1695d2a 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -307,6 +307,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
}
+ // CHECK-LABEL: @arm_tensor_of_i32
+ spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_i32
+ spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
+ }
+
+ // CHECK-LABEL: @arm_tensor_of_f32
+ spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
+ // CHECK-LABEL: @splat_arm_tensor_of_f32
+ spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32>
+ %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
+ spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
+ }
+
spirv.EntryPoint "GLCompute" @bool_const
}
diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir
index 6b50c39..786d07a2 100644
--- a/mlir/test/Target/SPIRV/memory-ops.mlir
+++ b/mlir/test/Target/SPIRV/memory-ops.mlir
@@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : f32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : f32
spirv.Return
}
- spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" {
- // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" {
+ // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32
%0 = spirv.Constant 0 : i32
- %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
%2 = spirv.Load "StorageBuffer" %1 : i32
- // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>
+ // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>
// CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
%3 = spirv.Constant 0 : i32
- %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
+ %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer>
spirv.Store "StorageBuffer" %4, %2 : i32
spirv.Return
}
diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir
index 0db0c0b..4984ee7 100644
--- a/mlir/test/Target/SPIRV/struct.mlir
+++ b/mlir/test/Target/SPIRV/struct.mlir
@@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
- spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
- spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
- spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
- spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
- spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
+ spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer>
// CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer>
spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer>
@@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input>
+ // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
+ spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer>
- // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
- spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer>
+ // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
+ spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform>
- // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
- spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform>
+ // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
+ spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output>
// CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>,
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output>
diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir
index b9044fe..8889b80 100644
--- a/mlir/test/Target/SPIRV/undef.mlir
+++ b/mlir/test/Target/SPIRV/undef.mlir
@@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>>
%5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>>
%6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>>
- // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
- %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>
+ // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
+ %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>
%8 = spirv.Constant 0 : i32
- %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
+ %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer>
spirv.Return
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 382da59..5685004 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
let mnemonic = "copy_count";
let parameters = (ins TestParamCopyCount:$copy_count);
let assemblyFormat = "`<` $copy_count `>`";
+ let genVerifyDecl = 1;
}
def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index b31e90f..5890913 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -214,6 +214,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
}
//===----------------------------------------------------------------------===//
+// TestCopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCopyCountAttr::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/,
+ CopyCount /*copy_count*/) {
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// CopyCountAttr Implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index d47411d..a809611 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
+// DEF: CompoundAAttr CompoundAAttr::getChecked(
+// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner
+// DEF-SAME: )
+// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
+
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
// DEF-NEXT: return getImpl()->inner;
}
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index 6932e0f..0518620 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -2,8 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
null.cpp
)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(LLVM_LINK_COMPONENTS
Core
Support
@@ -35,22 +33,11 @@ if(MLIR_INCLUDE_TESTS)
endif()
set(LIBS
- ${conversion_libs}
- ${dialect_libs}
- ${extension_libs}
-
- MLIRAffineAnalysis
- MLIRAnalysis
- MLIRDialect
- MLIRFuncAllExtensions
MLIRLspServerLib
- MLIRParser
- MLIRPass
- MLIRTensorAllExtensions
- MLIRTransforms
- MLIRTransformUtils
- MLIRSupport
- MLIRIR
+
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
)
add_mlir_tool(mlir-lsp-server
diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
index 6a759d9..10d602f 100644
--- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
+++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 6958fe3..7cc6e78 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -2,9 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
null.cpp
)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
set(LLVM_LINK_COMPONENTS
Core
Support
@@ -65,21 +62,11 @@ if(MLIR_INCLUDE_TESTS)
endif()
set(LIBS
- ${dialect_libs}
- ${conversion_libs}
- ${extension_libs}
- MLIRAffineAnalysis
- MLIRAnalysis
- MLIRCastInterfaces
- MLIRDialect
MLIROptLib
- MLIRParser
- MLIRPass
- MLIRTransforms
- MLIRTransformUtils
- MLIRSupport
- MLIRIR
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
# TODO: Remove when registerAllGPUToLLVMIRTranslations is no longer
# registered directly in mlir-opt.cpp.
diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index 88a5f36..f99dcdb 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -201,6 +201,12 @@ int main(int argc, char **argv) {
llvm::raw_string_ostream outputStrOS(outputStr);
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
raw_ostream &os) {
+ // Split does not guarantee null-termination. Make a copy of the buffer to
+ // ensure null-termination.
+ if (!chunkBuffer->getBuffer().ends_with('\0')) {
+ chunkBuffer = llvm::MemoryBuffer::getMemBufferCopy(
+ chunkBuffer->getBuffer(), chunkBuffer->getBufferIdentifier());
+ }
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs,
dumpODS, includedFiles);
};
diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt
index 1826397..1668bba 100644
--- a/mlir/tools/mlir-query/CMakeLists.txt
+++ b/mlir/tools/mlir-query/CMakeLists.txt
@@ -1,5 +1,3 @@
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-
if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRTestDialect
@@ -12,8 +10,8 @@ add_mlir_tool(mlir-query
llvm_update_compile_flags(mlir-query)
mlir_target_link_libraries(mlir-query
PRIVATE
- ${dialect_libs}
MLIRQueryLib
+ MLIRRegisterAllDialects
)
target_link_libraries(mlir-query PRIVATE ${test_libs})
diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt
index d71ac86..349d75b 100644
--- a/mlir/tools/mlir-reduce/CMakeLists.txt
+++ b/mlir/tools/mlir-reduce/CMakeLists.txt
@@ -1,6 +1,3 @@
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-
if(MLIR_INCLUDE_TESTS)
set(test_libs
MLIRTestDialect
@@ -8,12 +5,9 @@ if(MLIR_INCLUDE_TESTS)
endif()
set(LIBS
- ${conversion_libs}
- ${dialect_libs}
- MLIRDialect
- MLIRIR
- MLIRPass
MLIRReduceLib
+ MLIRRegisterAllDialects
+ MLIRRegisterAllPasses
)
add_mlir_tool(mlir-reduce
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
index 216491e..4120b175 100644
--- a/mlir/tools/mlir-rewrite/CMakeLists.txt
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -1,21 +1,19 @@
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
set(LLVM_LINK_COMPONENTS
Support
)
set(LIBS
- ${dialect_libs}
-
MLIRAffineAnalysis
MLIRAnalysis
MLIRCastInterfaces
MLIRDialect
+ MLIRIR
MLIRParser
MLIRPass
- MLIRTransforms
- MLIRTransformUtils
+ MLIRRegisterAllDialects
MLIRSupport
- MLIRIR
+ MLIRTransformUtils
+ MLIRTransforms
)
include_directories(../../../clang/include)
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
index 87df9e1..fd8ae7e 100644
--- a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -24,6 +24,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index dbae2143..3140f12 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -495,7 +495,7 @@ void DefGen::emitCheckedBuilder() {
MethodBody &body = m->body().indent();
auto scope = body.scope("return Base::getChecked(emitError, context", ");");
for (const auto &param : params)
- body << ", " << param.getName();
+ body << ", std::move(" << param.getName() << ")";
}
static SmallVector<MethodParameter>
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index 4ef69a8..b83163e 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -10,14 +10,13 @@ add_mlir_unittest(MLIRExecutionEngineTests
StridedMemRef.cpp
Invoke.cpp
)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
mlir_target_link_libraries(MLIRExecutionEngineTests
PRIVATE
MLIRArithToLLVM
MLIRMemRefToLLVM
MLIRReconcileUnrealizedCasts
- ${dialect_libs}
+ MLIRRegisterAllDialects
)
target_link_libraries(MLIRExecutionEngineTests
PRIVATE
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index a55592d..fd40404 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -477,8 +477,9 @@ TEST(SubElementTest, Nested) {
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
}
-// Test how many times we call copy-ctor when building an attribute.
-TEST(CopyCountAttr, CopyCount) {
+// Test how many times we call copy-ctor when building an attribute with the
+// 'get' method.
+TEST(CopyCountAttr, CopyCountGet) {
MLIRContext context;
context.loadDialect<test::TestDialect>();
@@ -489,15 +490,35 @@ TEST(CopyCountAttr, CopyCount) {
test::CopyCount::counter = 0;
test::TestCopyCountAttr::get(&context, std::move(copyCount));
#ifndef NDEBUG
- // One verification enabled only in assert-mode requires a copy.
- EXPECT_EQ(counter1, 1);
- EXPECT_EQ(test::CopyCount::counter, 1);
+ // One verification enabled only in assert-mode requires two copies: one for
+ // calling 'verifyInvariants' and one for calling 'verify' inside
+ // 'verifyInvariants'.
+ EXPECT_EQ(counter1, 2);
+ EXPECT_EQ(test::CopyCount::counter, 2);
#else
EXPECT_EQ(counter1, 0);
EXPECT_EQ(test::CopyCount::counter, 0);
#endif
}
+// Test how many times we call copy-ctor when building an attribute with the
+// 'getChecked' method.
+TEST(CopyCountAttr, CopyCountGetChecked) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+ test::CopyCount::counter = 0;
+ test::CopyCount copyCount("hello");
+ auto loc = UnknownLoc::get(&context);
+ test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount));
+ int counter1 = test::CopyCount::counter;
+ test::CopyCount::counter = 0;
+ test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount));
+ // The verifiers require two copies: one for calling 'verifyInvariants' and
+ // one for calling 'verify' inside 'verifyInvariants'.
+ EXPECT_EQ(counter1, 2);
+ EXPECT_EQ(test::CopyCount::counter, 2);
+}
+
// Test stripped printing using test dialect attribute.
TEST(CopyCountAttr, PrintStripped) {
MLIRContext context;
diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt
index 0daac11..0a77deb 100644
--- a/mlir/unittests/Target/LLVM/CMakeLists.txt
+++ b/mlir/unittests/Target/LLVM/CMakeLists.txt
@@ -1,13 +1,11 @@
set(LLVM_LINK_COMPONENTS nativecodegen BitReader)
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
-
add_mlir_unittest(MLIRTargetLLVMTests
SerializeNVVMTarget.cpp
SerializeROCDLTarget.cpp
SerializeToLLVMBitcode.cpp
DEPENDS
- ${dialect_libs}
+ MLIRRegisterAllDialects
)
mlir_target_link_libraries(MLIRTargetLLVMTests