diff options
Diffstat (limited to 'mlir')
249 files changed, 5288 insertions, 3023 deletions
diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md index 022bdad..b991863 100644 --- a/mlir/docs/DefiningDialects/AttributesAndTypes.md +++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -136,7 +136,7 @@ def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> { /// Here we've defined two parameters, one is a "self" type parameter, and the /// other is the integer value of the attribute. The self type parameter is /// specially handled by the assembly format. - let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); + let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value); /// Here we've defined a custom builder for the type, that removes the need to pass /// in an MLIRContext instance; as it can be infered from the `type`. diff --git a/mlir/docs/Dialects/Vector.md b/mlir/docs/Dialects/Vector.md index ebeb0a2..6c8949d 100644 --- a/mlir/docs/Dialects/Vector.md +++ b/mlir/docs/Dialects/Vector.md @@ -294,7 +294,7 @@ LLVM instructions are prefixed by the `llvm.` dialect prefix (e.g. `llvm.insertvalue`). Such ops operate exclusively on 1-D vectors and aggregates following the [LLVM LangRef](https://llvm.org/docs/LangRef.html). MLIR operations are prefixed by the `vector.` dialect prefix (e.g. -`vector.insertelement`). Such ops operate exclusively on MLIR `n-D` `vector` +`vector.insert`). Such ops operate exclusively on MLIR `n-D` `vector` types. ### Alternatives For Lowering an n-D Vector Type to LLVM diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md index e2288f5..6d09e93 100644 --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -18,6 +18,8 @@ The following convention is followed: GCC or Clang. * If `emitc.array` with a dimension of size zero is used, then the code requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html). +* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17 + or C11 is required. * Else the generated code is compatible with C99. These restrictions are neither inherent to the EmitC dialect itself nor to the diff --git a/mlir/docs/Tutorials/transform/Ch0.md b/mlir/docs/Tutorials/transform/Ch0.md index ac3989a..dc4b753 100644 --- a/mlir/docs/Tutorials/transform/Ch0.md +++ b/mlir/docs/Tutorials/transform/Ch0.md @@ -46,7 +46,7 @@ When no support is available, such an operation can be transformed into a loop: %c8 = arith.constant 8 : index %init = arith.constant 0.0 : f32 %result = scf.for %i = %c0 to %c8 step %c1 iter_args(%partial = %init) -> (f32) { - %element = vector.extractelement %0[%i : index] : vector<8xf32> + %element = vector.extract %0[%i] : f32 into vector<8xf32> %updated = arith.addf %partial, %element : f32 scf.yield %updated : f32 } @@ -145,7 +145,7 @@ linalg.generic { %c0 = arith.constant 0.0 : f32 %0 = arith.cmpf ogt %in_one, %c0 : f32 %1 = arith.select %0, %in_one, %c0 : f32 - linalg.yield %1 : f32 + linalg.yield %1 : f32 } ``` @@ -185,7 +185,7 @@ In the case of `linalg.generic` operations, the iteration space is implicit and For example, tiling the matrix multiplication presented above with tile sizes `(2, 8)`, we obtain a loop nest around a `linalg.generic` expressing the same operation on a `2x8` tensor. ```mlir -// A special "multi-for" loop that supports tensor-insertion semantics +// A special "multi-for" loop that supports tensor-insertion semantics // as opposed to implicit updates. The resulting 8x16 tensor will be produced // by this loop. // The trip count of iterators is computed dividing the original tensor size, @@ -202,9 +202,9 @@ For example, tiling the matrix multiplication presented above with tile sizes `( // Take slices of inputs and outputs. Only the "i" and "j" dimensions are sliced. %lhs_slice = tensor.extract_slice %lhs[%3, 0] [2, 10] [1, 1] : tensor<8x10xf32> to tensor<2x10xf32> - %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] + %rhs_slice = tensor.extract_slice %rhs[0, %4] [10, 8] [1, 1] : tensor<10x16xf32> to tensor<10x8xf32> - %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] + %result_slice = tensor.extract_slice %shared[%3, %4] [2, 8] [1, 1] : tensor<8x16xf32> to tensor<2x8xf32> // This is exactly the same operation as before, but now operating on smaller @@ -214,7 +214,7 @@ For example, tiling the matrix multiplication presented above with tile sizes `( affine_map<(i, j, k) -> (k, j)>, affine_map<(i, j, k) -> (i, j)>], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) + } ins(%lhs_slice, %rhs_slice : tensor<2x10xf32>, tensor<10x8xf32>) outs(%result_slice : tensor<2x8xf32>) -> tensor<2x8xf32> { ^bb0(%lhs_one: f32, %rhs_one: f32, %init_one: f32): %0 = arith.mulf %lhs_one, %rhs_one : f32 @@ -238,15 +238,15 @@ After materializing loops with tiling, another key code generation transformatio 1. the subset (slice) of the operand that is used by the tile, and 2. the tensor-level structured operation producing the whole tensor that is being sliced. -By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. +By inverting the `indexing_map` and applying it to the set of elements accessed through the slice, we can compute the part of the iteration space of the operation defining the full tensor necessary to compute the tile. Thus fusion boils down to replacing the `tensor.extract_slice` operation with the tile of the `linalg.generic` producing the original operand. Let us assume that the matrix multiplication operation is followed by another operation that multiplies each element of the resulting matrix with itself. This trailing elementwise operation has a 2D iteration space, unlike the 3D one in matrix multiplication. Nevertheless, it is possible to tile the trailing operation and then fuse the producer of its operand, the matmul, into the loop generated by tiling. The untiled dimension will be used in its entirety. ```mlir // Same loop as before. -%0 = scf.forall (%i, %j) in (4, 2) - shared_outs(%shared = %init) +%0 = scf.forall (%i, %j) in (4, 2) + shared_outs(%shared = %init) -> (tensor<8x16xf32>, tensor<8x16xf32>) { // Scale the loop induction variables by the tile sizes. %1 = affine.apply affine_map<(d0) -> (d0 * 2)>(%i) @@ -286,7 +286,7 @@ Let us assume that the matrix multiplication operation is followed by another op indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], iterator_types = ["parallel", "parallel"] - } ins(%partial : tensor<2x8xf32>) + } ins(%partial : tensor<2x8xf32>) outs(%shared_slice : tensor<2x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.mulf %in, %in : f32 diff --git a/mlir/examples/standalone/standalone-opt/CMakeLists.txt b/mlir/examples/standalone/standalone-opt/CMakeLists.txt index 27f8128..4b38de7 100644 --- a/mlir/examples/standalone/standalone-opt/CMakeLists.txt +++ b/mlir/examples/standalone/standalone-opt/CMakeLists.txt @@ -1,12 +1,10 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LIBS - ${dialect_libs} - ${conversion_libs} - MLIRArithDialect - MLIROptLib - MLIRStandalone - ) + MLIRArithDialect + MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRStandalone + ) add_llvm_executable(standalone-opt standalone-opt.cpp) llvm_update_compile_flags(standalone-opt) diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp index e39fa96..eebfcb7 100644 --- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp +++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt index f4f0fec..454ca56 100644 --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -27,12 +27,8 @@ add_toy_chapter(toyc-ch5 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch5 PRIVATE - ${dialect_libs} - ${extension_libs} MLIRAnalysis MLIRCallInterfaces MLIRCastInterfaces @@ -40,6 +36,9 @@ target_link_libraries(toyc-ch5 MLIRIR MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions MLIRSideEffectInterfaces MLIRSupport - MLIRTransforms) + MLIRTransforms + ) diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 6a0c631..afdf782 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "toy/AST.h" #include "toy/Dialect.h" diff --git a/mlir/examples/toy/Ch6/CMakeLists.txt b/mlir/examples/toy/Ch6/CMakeLists.txt index 283b895..73df602 100644 --- a/mlir/examples/toy/Ch6/CMakeLists.txt +++ b/mlir/examples/toy/Ch6/CMakeLists.txt @@ -37,14 +37,8 @@ add_toy_chapter(toyc-ch6 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch6 PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces @@ -58,8 +52,11 @@ target_link_libraries(toyc-ch6 MLIRMemRefDialect MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSideEffectInterfaces MLIRSupport MLIRTargetLLVMIRExport MLIRTransforms - ) + ) diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp index dccab91..4a5e109 100644 --- a/mlir/examples/toy/Ch6/toyc.cpp +++ b/mlir/examples/toy/Ch6/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt index 362ab51..a489ae5 100644 --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -36,14 +36,8 @@ add_toy_chapter(toyc-ch7 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch7 PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces @@ -56,7 +50,10 @@ target_link_libraries(toyc-ch7 MLIRMemRefDialect MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSideEffectInterfaces MLIRTargetLLVMIRExport MLIRTransforms - ) + ) diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index dd86265..32208ecca 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" diff --git a/mlir/examples/transform-opt/CMakeLists.txt b/mlir/examples/transform-opt/CMakeLists.txt index 8e23555..07d58f6 100644 --- a/mlir/examples/transform-opt/CMakeLists.txt +++ b/mlir/examples/transform-opt/CMakeLists.txt @@ -1,18 +1,14 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) - set(LIBS MLIRAnalysis MLIRIR MLIRParser + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSupport MLIRTransformDialect MLIRTransformDialectTransforms MLIRTransforms - ${dialect_libs} - ${conversion_libs} - ${extension_libs} ) add_mlir_tool(mlir-transform-opt diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp index 1a29913..4b12e76 100644 --- a/mlir/examples/transform-opt/mlir-transform-opt.cpp +++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp @@ -22,6 +22,7 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include <cstdlib> diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 364a70c..b595b6a3 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -8,6 +8,11 @@ #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H +constexpr const char *alignedAllocFunctionName = "aligned_alloc"; +constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *cppStandardLibraryHeader = "cstdlib"; +constexpr const char *cStandardLibraryHeader = "stdlib.h"; + namespace mlir { class DialectRegistry; class RewritePatternSet; diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index eb18160..6e1baaf 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -841,9 +853,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> { // MemRefToEmitC //===----------------------------------------------------------------------===// -def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { +def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> { let summary = "Convert MemRef dialect to EmitC dialect"; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "lowerToCpp", "lower-to-cpp", "bool", + /*default=*/"false", + /*description=*/"Target C++ (true) instead of C (false)">]; } //===----------------------------------------------------------------------===// @@ -1163,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index e81db32..06fb851 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -71,6 +71,7 @@ class ArmSME_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td index 8988df6..d055bb4 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td @@ -92,6 +92,7 @@ class ArmSVE_IntrOp<string mnemonic, /*bit requiresAccessGroup=*/0, /*bit requiresAliasAnalysis=*/0, /*bit requiresFastmath=*/0, + /*bit requiresArgAndResultAttrs=*/0, /*bit requiresOpBundles=*/0, /*list<int> immArgPositions=*/immArgPositions, /*list<string> immArgAttrNames=*/immArgAttrNames>; diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index a8455c2..b52f136 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -38,7 +38,8 @@ def Async_ExecuteOp : ["getEntrySuccessorOperands", "areTypesCompatible"]>, AttrSizedOperandSegments, - AutomaticAllocationScope]> { + AutomaticAllocationScope, + RecursiveMemoryEffects]> { let summary = "Asynchronous execute operation"; let description = [{ The `body` region attached to the `async.execute` operation semantically diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 2cf801d..09700f8 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -14,7 +14,7 @@ struct LogicalResult; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; namespace bufferization { struct BufferizationStatistics; @@ -23,12 +23,13 @@ struct OneShotBufferizationOptions; class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in -/// `state`. +/// `state`. This operates on any `SymbolTable` op. llvm::LogicalResult -analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, +analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics = nullptr); -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. +/// Bufferize an `op`s nested ops that implement `BufferizableOpInterface`. +/// This operates on any `SymbolTable` op. /// /// Note: This function does not run One-Shot Analysis. No buffer copies are /// inserted except two cases: @@ -37,20 +38,20 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// - `options.copyBeforeWrite` is not set and `options.noAnalysisFuncFilter` /// is not empty. The FuncOps it contains were not analyzed. Buffer copies /// will be inserted only to these FuncOps. -llvm::LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, - BufferizationStatistics *statistics = nullptr); +llvm::LogicalResult bufferizeModuleOp( + Operation *moduleOp, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); -/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. -void removeBufferizationAttributesInModule(ModuleOp moduleOp); +/// Remove bufferization attributes on every FuncOp arguments in the SymbolTable +/// op. +void removeBufferizationAttributesInModule(Operation *moduleOp); -/// Run One-Shot Module Bufferization on the given module. Performs a simple -/// function call analysis to determine which function arguments are +/// Run One-Shot Module Bufferization on the given SymbolTable. Performs a +/// simple function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot /// Bufferize. llvm::LogicalResult runOneShotModuleBufferize( - ModuleOp moduleOp, + Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics = nullptr); diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 7fe2da8..937b34a6 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1659,13 +1659,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} // Example with no attribute: emitc.field @fieldName0 : !emitc.array<1xf32> + // Example with an initial value: + emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> + // Example with an initial value and attributes: + emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> { + emitc.opaque = "input_tensor"} ``` }]; let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, - OptionalAttr<AnyAttr>:$attrs); + OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value); - let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; + let assemblyFormat = [{ + $sym_name + `:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value) + attr-dict + }]; let hasVerifier = 1; } @@ -1686,7 +1695,7 @@ def EmitC_GetFieldOp }]; let arguments = (ins FlatSymbolRefAttr:$field_name); - let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); + let results = (outs EmitCType:$result); let assemblyFormat = "$field_name `:` type($result) attr-dict"; } diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 1dbaf5d..2ed7d38 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1368,12 +1368,14 @@ def GPU_ShuffleOp : GPU_Op< def GPU_RotateOp : GPU_Op< "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>, - Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, + ConfinedAttr<I32Attr, [IntMinValue<0>]>:$offset, + ConfinedAttr<I32Attr, [IntPowerOf2]>:$width)>, Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> { let summary = "Rotate values within a subgroup."; let description = [{ The "rotate" op moves values across lanes in a subgroup (a.k.a., local - invocations) within the same subgroup. The `width` argument specifies the + invocations) within the same subgroup. The `width` attribute specifies the number of lanes that participate in the rotation, and must be uniform across all participating lanes. Further, the first `width` lanes of the subgroup must be active. @@ -1394,9 +1396,7 @@ def GPU_RotateOp : GPU_Op< example: ```mlir - %offset = arith.constant 1 : i32 - %width = arith.constant 16 : i32 - %1, %2 = gpu.rotate %0, %offset, %width : f32 + %1, %2 = gpu.rotate %0, 1, 16 : f32 ``` For lane `k`, returns the value from lane `(k + cst1) % width`. @@ -1406,11 +1406,6 @@ def GPU_RotateOp : GPU_Op< $value `,` $offset `,` $width attr-dict `:` type($value) }]; - let builders = [ - // Helper function that creates a rotate with constant offset/width. - OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)> - ]; - let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td index b5ea8fc..107bf3e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -83,6 +83,9 @@ def LLVM_Dialect : Dialect { return "llvm.emit_c_interface"; } + /// Name of the module level assembly attribute. + static StringRef getModuleLevelAsmAttrName() { return "llvm.module_asm"; } + /// Name of the dependent libraries attribute. static StringRef getDependentLibrariesAttrName() { return "llvm.dependent_libraries"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 8c6f1ee..d38298f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -140,8 +140,8 @@ def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">; def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1, 2, 3], - /*immArgAttrNames=*/["rw", "hint", "cache"] + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1, 2, 3], /*immArgAttrNames=*/["rw", "hint", "cache"] > { let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache); } @@ -200,13 +200,13 @@ class LLVM_MemcpyIntrOpBase<string name> : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "Value":$len, "bool":$isVolatile), [{ @@ -217,7 +217,8 @@ class LLVM_MemcpyIntrOpBase<string name> : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -231,13 +232,13 @@ def LLVM_MemcpyInlineOp : DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, Arg<LLVM_AnyPointer,"",[MemRead]>:$src, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -248,7 +249,8 @@ def LLVM_MemcpyInlineOp : "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, src, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -258,12 +260,12 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[3], - /*immArgAttrNames=*/["isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[3], /*immArgAttrNames=*/["isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "Value":$len, "bool":$isVolatile), [{ @@ -274,7 +276,8 @@ def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -284,12 +287,12 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>, DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>], /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1, - /*requiresOpBundles=*/0, /*immArgPositions=*/[2, 3], - /*immArgAttrNames=*/["len", "isVolatile"]> { + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[2, 3], /*immArgAttrNames=*/["len", "isVolatile"]> { dag args = (ins Arg<LLVM_AnyPointer,"",[MemWrite]>:$dst, I8:$val, APIntAttr:$len, I1Attr:$isVolatile); - // Append the alias attributes defined by LLVM_IntrOpBase. - let arguments = !con(args, aliasAttrs); + // Append the arguments defined by LLVM_IntrOpBase. + let arguments = !con(args, baseArgs); let builders = [ OpBuilder<(ins "Value":$dst, "Value":$val, "IntegerAttr":$len, "bool":$isVolatile), [{ @@ -300,7 +303,8 @@ def LLVM_MemsetInlineOp : LLVM_ZeroResultIntrOp<"memset.inline", [0, 2], "IntegerAttr":$isVolatile), [{ build($_builder, $_state, dst, val, len, isVolatile, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, - /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); + /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); }]> ]; } @@ -349,8 +353,8 @@ def LLVM_PtrMaskOp class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName, [1], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; } @@ -370,8 +374,8 @@ def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2], [DeclareOpInterfaceMethods<PromotableOpInterface>], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[1], - /*immArgAttrNames=*/["size"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[1], /*immArgAttrNames=*/["size"]> { let arguments = (ins LLVM_DefaultPointer:$start, I64Attr:$size, LLVM_AnyPointer:$ptr); @@ -542,9 +546,10 @@ def LLVM_AssumeOp : LLVM_ZeroResultIntrOp<"assume", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/1> { dag args = (ins I1:$cond); - let arguments = !con(args, opBundleArgs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $cond @@ -1126,8 +1131,8 @@ def LLVM_DebugTrap : LLVM_ZeroResultIntrOp<"debugtrap">; def LLVM_UBSanTrap : LLVM_ZeroResultIntrOp<"ubsantrap", /*overloadedOperands=*/[], /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - /*requiresOpBundles=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["failureKind"]> { + /*requiresArgAndResultAttrs=*/0, /*requiresOpBundles=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["failureKind"]> { let arguments = (ins I8Attr:$failureKind); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index e845ea9f..a8d7cf2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td" include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" //===----------------------------------------------------------------------===// // LLVM dialect type constraints. @@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> : // intrinsic and "enumName" contains the name of the intrinsic as appears in // `llvm::Intrinsic` enum; one usually wants these to be related. Additionally, // the base class also defines the "mlirBuilder" field to support the inverse -// translation starting from an LLVM IR intrinsic. The "requiresAccessGroup", -// "requiresAliasAnalysis", and "requiresFastmath" flags specify which -// interfaces the intrinsic implements. If the corresponding flags are set, the -// "aliasAttrs" list contains the arguments required by the access group and -// alias analysis interfaces. Derived intrinsics should append the "aliasAttrs" -// to their argument list if they set one of the flags. LLVM `immargs` can be -// represented as MLIR attributes by providing both the `immArgPositions` and -// `immArgAttrNames` lists. These two lists should have equal length, with -// `immArgPositions` containing the argument positions on the LLVM IR attribute -// that are `immargs`, and `immArgAttrNames` mapping these to corresponding -// MLIR attributes. +// translation starting from an LLVM IR intrinsic. +// +// The flags "requiresAccessGroup", "requiresAliasAnalysis", +// "requiresFastmath", and "requiresArgAndResultAttrs" indicate which +// interfaces the intrinsic implements. When a flag is set, the "baseArgs" +// list includes the arguments required by the corresponding interface. +// Derived intrinsics must append "baseArgs" to their argument list if they +// enable any of these flags. +// +// LLVM `immargs` can be represented as MLIR attributes by providing both +// the `immArgPositions` and `immArgAttrNames` lists. These two lists should +// have equal length, with `immArgPositions` containing the argument +// positions on the LLVM IR attribute that are `immargs`, and +// `immArgAttrNames` mapping these to corresponding MLIR attributes. class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, - bit requiresFastmath = 0, bit requiresOpBundles = 0, + bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0, + bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_OpBase<dialect, opName, !listconcat( @@ -311,10 +316,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, [DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []), !if(!gt(requiresFastmath, 0), [DeclareOpInterfaceMethods<FastmathFlagsInterface>], []), + !if(!gt(requiresArgAndResultAttrs, 0), + [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []), traits)>, LLVM_MemOpPatterns, Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> { - dag aliasAttrs = !con( + dag baseArgs = !con( !if(!gt(requiresAccessGroup, 0), (ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups), (ins )), @@ -322,13 +329,17 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, (ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes, OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes, OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa), + (ins )), + !if(!gt(requiresArgAndResultAttrs, 0), + (ins OptionalAttr<DictArrayAttr>:$arg_attrs, + OptionalAttr<DictArrayAttr>:$res_attrs), + (ins )), + !if(!gt(requiresOpBundles, 0), + (ins VariadicOfVariadic<LLVM_Type, + "op_bundle_sizes">:$op_bundle_operands, + DenseI32ArrayAttr:$op_bundle_sizes, + OptionalAttr<ArrayAttr>:$op_bundle_tags), (ins ))); - dag opBundleArgs = !if(!gt(requiresOpBundles, 0), - (ins VariadicOfVariadic<LLVM_Type, - "op_bundle_sizes">:$op_bundle_operands, - DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr<ArrayAttr>:$op_bundle_tags), - (ins )); string llvmEnumName = enumName; string overloadedResultsCpp = "{" # !interleave(overloadedResults, ", ") # "}"; string overloadedOperandsCpp = "{" # !interleave(overloadedOperands, ", ") # "}"; @@ -342,23 +353,35 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{); (void) inst; }]; + string baseLlvmBuilderArgAndResultAttrs = [{ + if (failed(moduleTranslation.convertArgAndResultAttrs( + op, + inst, + }] # immArgPositionsCpp # [{))) { + return failure(); + } + }]; string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", ""); - let llvmBuilder = baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "") - # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "") - # baseLlvmBuilderCoda; + let llvmBuilder = baseLlvmBuilder + # !if(!gt(requiresAccessGroup, 0), + setAccessGroupsMetadataCode, "") + # !if(!gt(requiresAliasAnalysis, 0), + setAliasAnalysisMetadataCode, "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseLlvmBuilderArgAndResultAttrs, "") + # baseLlvmBuilderCoda; string baseMlirBuilder = [{ SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, - llvmOpBundles, - }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, - }] # immArgPositionsCpp # [{, - }] # immArgAttrNamesCpp # [{, - mlirOperands, - mlirAttrs)) - ) { + llvmOperands, + llvmOpBundles, + }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{, + }] # immArgPositionsCpp # [{, + }] # immArgAttrNamesCpp # [{, + mlirOperands, + mlirAttrs))) { return failure(); } SmallVector<Type> resultTypes = @@ -366,9 +389,16 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName, auto op = $_qualCppClassName::create($_builder, $_location, resultTypes, mlirOperands, mlirAttrs); }]; + string baseMlirBuilderArgAndResultAttrs = [{ + moduleImport.convertArgAndResultAttrs( + inst, op, }] # immArgPositionsCpp # [{); + }]; string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); - let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0), + let mlirBuilder = baseMlirBuilder + # !if(!gt(requiresFastmath, 0), "moduleImport.setFastmathFlagsAttr(inst, op);", "") + # !if(!gt(requiresArgAndResultAttrs, 0), + baseMlirBuilderArgAndResultAttrs, "") # baseMlirBuilderCoda; // Code for handling a `range` attribute that holds the constant range of the @@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults, list<int> overloadedOperands, list<Trait> traits, int numResults, bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, bit requiresFastmath = 0, - bit requiresOpBundles = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, requiresAliasAnalysis, - requiresFastmath, requiresOpBundles, immArgPositions, - immArgAttrNames>; + requiresFastmath, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [], list<Trait> traits = [], bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0, + bit requiresArgAndResultAttrs = 0, bit requiresOpBundles = 0, list<int> immArgPositions = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0, requiresAccessGroup, requiresAliasAnalysis, - /*requiresFastMath=*/0, requiresOpBundles, immArgPositions, - immArgAttrNames>; + /*requiresFastMath=*/0, requiresArgAndResultAttrs, + requiresOpBundles, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning one result. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is @@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; // Base class for LLVM intrinsic operations returning two results. Places the @@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [], list<string> immArgAttrNames = []> : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, - requiresFastmath, /*requiresOpBundles=*/0, immArgPositions, + requiresFastmath, /*requiresArgAndResultAttrs=*/0, + /*requiresOpBundles=*/0, immArgPositions, immArgAttrNames>; def LLVM_OneResultOpBuilder : diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 51004f5..3f27f6d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2405,7 +2405,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", - [AttrSizedOperandSegments, + [ArgAndResultAttrsOpInterface, + AttrSizedOperandSegments, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> { let summary = "Call to an LLVM intrinsic function."; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904..30df3b7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,10 +1990,30 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, - Arguments<(ins LLVM_PointerShared:$ptr, - Variadic<I32>:$sources, - MMALayoutAttr:$layout)> { +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; + +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elt_type"> { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, + Arguments<(ins LLVM_PointerShared: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout, + LdStMatrixShapeAttr:$shape, LdStMatrixEltTypeAttr:$eltType)> { let summary = "cooperative matrix store"; let description = [{ Collectively store one or more matrices across all threads in a warp to the @@ -2001,21 +2021,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) }]; - - let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - int d = getSources().size(); - std::string ptx = "stmatrix.sync.aligned"; - ptx += ".x" + std::to_string(d); - if (getLayout() == NVVM::MMALayout::col) - ptx += ".trans"; - if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};"; - if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};"; - if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};"; - return ptx; - } + string llvmBuilder = [{ + auto operands = moduleTranslation.lookupValues(opInst.getOperands()); + auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType); + createIntrinsicCall(builder, intId, operands, operands[0]->getType()); }]; + let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 04a0b58..a2354e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,7 +98,7 @@ class ROCDL_IntrOp<string mnemonic, list<int> overloadedResults, LLVM_IntrOpBase<ROCDL_Dialect, mnemonic, "amdgcn_" # !subst(".", "_", mnemonic), overloadedResults, overloadedOperands, traits, numResults, requiresAccessGroup, - requiresAliasAnalysis, 0, 0, immArgPositions, immArgAttrNames>; + requiresAliasAnalysis, 0, 0, 0, immArgPositions, immArgAttrNames>; // Subclass to save typing and ease readibility when there aren't overloaded // operands or memory accesses. @@ -482,7 +482,7 @@ def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>; class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> : ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> { dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -507,7 +507,7 @@ def ROCDL_LoadToLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict `:` type($globalPtr) @@ -526,7 +526,7 @@ def ROCDL_GlobalLoadLDSOp : I32Attr:$size, I32Attr:$offset, I32Attr:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = [{ $globalPtr `,` $ldsPtr `,` $size `,` $offset `,` $aux attr-dict @@ -561,7 +561,7 @@ def ROCDL_RawPtrBufferLoadOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -579,7 +579,7 @@ def ROCDL_RawPtrBufferLoadLdsOp : I32:$soffset, I32:$offset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -595,7 +595,7 @@ def ROCDL_RawPtrBufferStoreOp : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -614,7 +614,7 @@ def ROCDL_RawPtrBufferAtomicCmpSwap : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($res)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { @@ -630,7 +630,7 @@ class ROCDL_RawPtrBufferAtomicNoRet<string op> : I32:$offset, I32:$soffset, I32:$aux); - let arguments = !con(args, aliasAttrs); + let arguments = !con(args, baseArgs); let assemblyFormat = "operands attr-dict `:` type($vdata)"; let extraClassDefinition = [{ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() { diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8d45c40..61ce23f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1191,6 +1191,7 @@ def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interfac iteration domain induces a padding of the operands that is consistent across the op semantics and, unlike for simple elementwise ops, may not be trivially deducible or specifiable on operands only (e.g. convolutions). + Currently, only a limited set of projected permutation maps are supported. The specification of `padding_sizes` follows that of `tile_sizes` during tiling: the value "0" on a particular iterator encode "no padding". Like in diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index e625eef..d4ffe0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -611,6 +611,13 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// affine.apply operations. /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and /// provides a gentle portability path for Linalg-like ops with affine maps. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permuation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 96b9adc..18d5f2d 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -134,6 +134,24 @@ def OpenACC_VariableTypeCategory : I32BitEnumAttr< let printBitEnumPrimaryGroups = 1; } +// These are parallelism determination modes for `acc loop`. +// In the enum names, we use the "loop_" prefix because "auto" is +// a language keyword - and thus for consistency all other cases +// do the same. +def OpenACC_LoopSeq : I32EnumAttrCase<"loop_seq", 0>; +def OpenACC_LoopAuto : I32EnumAttrCase<"loop_auto", 1>; +def OpenACC_LoopIndependent : I32EnumAttrCase<"loop_independent", 2>; + +def OpenACC_LoopParMode : I32EnumAttr< + "LoopParMode", + "Encodes the options for loop parallelism determination mode", + [ + OpenACC_LoopAuto, OpenACC_LoopIndependent, + OpenACC_LoopSeq]> { + let cppNamespace = "::mlir::acc"; + let genSpecializedAttr = 0; +} + // Type used in operation below. def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>; @@ -1514,6 +1532,10 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", /// types. void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, llvm::ArrayRef<DeviceType>); + + /// Adds a private clause variable to this operation, including its recipe. + void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe); }]; let assemblyFormat = [{ @@ -1656,6 +1678,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", /// types. void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, llvm::ArrayRef<DeviceType>); + /// Adds a private clause variable to this operation, including its recipe. + void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe); }]; let assemblyFormat = [{ @@ -2373,6 +2398,15 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", // Return whether this LoopOp has a gang, worker, or vector applying to the // 'default'/None device-type. bool hasDefaultGangWorkerVector(); + + // Used to obtain the parallelism mode for the requested device type. + // This first checks if the mode is set for the device_type requested. + // And if not, it returns the non-device_type mode. + LoopParMode getDefaultOrDeviceTypeParallelism(DeviceType); + + /// Adds a private clause variable to this operation, including its recipe. + void addPrivatization(MLIRContext *, mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe); }]; let hasCustomAssemblyFormat = 1; @@ -2404,6 +2438,53 @@ def OpenACC_LoopOp : OpenACC_Op<"loop", }]; let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "::mlir::ValueRange":$lowerbounds, + "::mlir::ValueRange":$upperbounds, + "::mlir::ValueRange":$steps, + "LoopParMode":$parMode), [{ + auto deviceNoneAttr = mlir::acc::DeviceTypeAttr::get( + $_builder.getContext(), mlir::acc::DeviceType::None); + auto arrOfDeviceNone = mlir::ArrayAttr::get( + $_builder.getContext(), deviceNoneAttr); + build($_builder, $_state, + /*results=*/{}, + /*lowerbound=*/lowerbounds, + /*upperbound=*/upperbounds, + /*step=*/steps, + /*inclusiveUpperbound=*/nullptr, + /*collapse=*/nullptr, + /*collapseDeviceType=*/nullptr, + /*gangOperands=*/{}, + /*gangOperandsArgType=*/nullptr, + /*gangOperandsSegments=*/nullptr, + /*gangOperandsDeviceType=*/nullptr, + /*workerNumOperands=*/{}, + /*workerNumOperandsDeviceType=*/nullptr, + /*vectorOperands=*/{}, + /*vectorOperandsDeviceType=*/nullptr, + /*seq=*/parMode == LoopParMode::loop_seq ? + arrOfDeviceNone : nullptr, + /*independent=*/parMode == LoopParMode::loop_independent ? + arrOfDeviceNone : nullptr, + /*auto_=*/parMode == LoopParMode::loop_auto ? + arrOfDeviceNone : nullptr, + /*gang=*/nullptr, + /*worker=*/nullptr, + /*vector=*/nullptr, + /*tileOperands=*/{}, + /*tileOperandsSegments=*/nullptr, + /*tileOperandsDeviceType=*/nullptr, + /*cacheOperands=*/{}, + /*privateOperands=*/{}, + /*privatizationRecipes=*/nullptr, + /*reductionOperands=*/{}, + /*reductionRecipes=*/nullptr, + /*combined=*/nullptr); + }] + > + ]; } // Yield operation for the acc.loop and acc.parallel operations. diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 2d15544..0c1c15b 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -87,6 +87,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ be accessed inside the op. The op's region can have multiple blocks and the blocks can have multiple distinct terminators. Values returned from this op's region define the op's results. + The optional 'no_inline' flag can be set to request the ExecuteRegionOp to be + preserved as much as possible and not being inlined in the parent block until + an explicit lowering step. Example: @@ -98,6 +101,14 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ } } + // the same as above but with no_inline attribute + scf.for %i = 0 to 128 step %c1 { + %y = scf.execute_region -> i32 no_inline { + %x = load %A[%i] : memref<128xi32> + scf.yield %x : i32 + } + } + affine.for %i = 0 to 100 { "foo"() : () -> () %v = scf.execute_region -> i64 { @@ -119,6 +130,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ ``` }]; + let arguments = (ins + UnitAttr:$no_inline + ); + let results = (outs Variadic<AnyType>); let regions = (region AnyRegion:$region); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 9038326..bdfd728 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>; +def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr : SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls, + SPV_INTEL_tensor_float32_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, @@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B ]; } +def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> { + list<Availability> availability = [ + Extension<[SPV_INTEL_tensor_float32_conversion]> + ]; +} + def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> { list<Availability> availability = [ Extension<[SPV_INTEL_cache_controls]> @@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, - SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR + SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, + SPIRV_C_TensorFloat32RoundingINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4277,7 +4286,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> : "Matrix">; class SPIRV_VectorOf<Type type> : - VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; + FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; class SPIRV_ScalarOrVectorOf<Type type> : AnyTypeOf<[type, SPIRV_VectorOf<type>]>; @@ -4448,6 +4457,7 @@ def SPIRV_OC_OpUMulExtended : I32EnumAttrCase<"OpUMulExtended" def SPIRV_OC_OpSMulExtended : I32EnumAttrCase<"OpSMulExtended", 152>; def SPIRV_OC_OpIsNan : I32EnumAttrCase<"OpIsNan", 156>; def SPIRV_OC_OpIsInf : I32EnumAttrCase<"OpIsInf", 157>; +def SPIRV_OC_OpIsFinite : I32EnumAttrCase<"OpIsFinite", 158>; def SPIRV_OC_OpOrdered : I32EnumAttrCase<"OpOrdered", 162>; def SPIRV_OC_OpUnordered : I32EnumAttrCase<"OpUnordered", 163>; def SPIRV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; @@ -4586,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>; def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>; def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>; +def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>; def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4630,7 +4641,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector, SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry, SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended, - SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, + SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite, + SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered, SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr, SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect, SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan, @@ -4690,7 +4702,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL, - SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR + SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, + SPIRV_OC_OpRoundFToTF32INTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td index 82d26e3..2a7fa53 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -11,6 +11,7 @@ // at (https://github.com/intel/llvm) // Supported extensions // * SPV_INTEL_bfloat16_conversion +// * SPV_INTEL_tensor_float32_conversion //===----------------------------------------------------------------------===// @@ -19,7 +20,7 @@ // ----- -def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { +def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> { let summary = "See extension SPV_INTEL_bfloat16_conversion"; let description = [{ @@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { let results = (outs SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result ); + let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; - let hasVerifier = 1; + let hasVerifier = 0; } // ----- -def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { +def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> { let summary = "See extension SPV_INTEL_bfloat16_conversion"; let description = [{ @@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; - let hasVerifier = 1; + + let hasVerifier = 0; } +// ----- + +def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> { + let summary = "See extension SPV_INTEL_tensor_float32_conversion"; + + let description = [{ + Convert value numerically from a 32-bit floating point type to tensor float32, + with rounding to the nearest even. + + Result Type must be a scalar or vector of 32-bit floating-point type. + The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value. + + Float Value must be a scalar or vector of floating-point type. + It must have the same number of components as Result Type. The component width must be 32 bits. + + Results are computed per component. + + #### Example: + + ```mlir + %1 = spirv.RoundFToTF32 %0 : f32 to f32 + %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32> + ``` + + }]; + + let availability = [ + MinVersion<SPIRV_V_1_0>, + MaxVersion<SPIRV_V_1_6>, + Extension<[SPV_INTEL_tensor_float32_conversion]>, + Capability<[SPIRV_C_TensorFloat32RoundingINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let hasVerifier = 0; +} // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index ab535d7..9331fc5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -403,6 +403,28 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual", // ----- +def SPIRV_IsFiniteOp : SPIRV_LogicalUnaryOp<"IsFinite", SPIRV_Float, []> { + let summary = "Result is true if x is an IEEE Finite, otherwise result is false"; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + x must be a scalar or vector of floating-point type. It must have the + same number of components as Result Type. + + Results are computed per component. + + #### Example: + + ```mlir + %2 = spirv.IsFinite %0: f32 + %3 = spirv.IsFinite %1: vector<4xf32> + ``` + }]; +} + +// ----- + def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { let summary = "Result is true if x is an IEEE Inf, otherwise result is false"; @@ -418,7 +440,7 @@ def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> { ```mlir %2 = spirv.IsInf %0: f32 - %3 = spirv.IsInf %1: vector<4xi32> + %3 = spirv.IsInf %1: vector<4xf32> ``` }]; } @@ -442,7 +464,7 @@ def SPIRV_IsNanOp : SPIRV_LogicalUnaryOp<"IsNan", SPIRV_Float, []> { ```mlir %2 = spirv.IsNan %0: f32 - %3 = spirv.IsNan %1: vector<4xi32> + %3 = spirv.IsNan %1: vector<4xf32> ``` }]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index c691d59..531fecc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -330,10 +330,34 @@ public: bool hasValue() const { return !isa<UnitAttr>(decorationValue); } }; + // Type for specifying the decoration(s) on the struct itself. + struct StructDecorationInfo { + Decoration decoration; + Attribute decorationValue; + + StructDecorationInfo(Decoration decoration, Attribute decorationValue) + : decoration(decoration), decorationValue(decorationValue) {} + + friend bool operator==(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return lhs.decoration == rhs.decoration && + lhs.decorationValue == rhs.decorationValue; + } + + friend bool operator<(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return llvm::to_underlying(lhs.decoration) < + llvm::to_underlying(rhs.decoration); + } + + bool hasValue() const { return !isa<UnitAttr>(decorationValue); } + }; + /// Construct a literal StructType with at least one member. static StructType get(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); /// Construct an identified StructType. This creates a StructType whose body /// (member types, offset info, and decorations) is not set yet. A call to @@ -367,6 +391,9 @@ public: bool hasOffset() const; + /// Returns true if the struct has a specified decoration. + bool hasDecoration(spirv::Decoration decoration) const; + uint64_t getMemberOffset(unsigned) const; // Returns in `memberDecorations` the Decorations (apart from Offset) @@ -380,12 +407,18 @@ public: unsigned i, SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const; + // Returns in `structDecorations` the Decorations associated with the + // StructType. + void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo> + &structDecorations) const; + /// Sets the contents of an incomplete identified StructType. This method must /// be called only for identified StructTypes and it must be called only once /// per instance. Otherwise, failure() is returned. LogicalResult trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {}, - ArrayRef<MemberDecorationInfo> memberDecorations = {}); + ArrayRef<MemberDecorationInfo> memberDecorations = {}, + ArrayRef<StructDecorationInfo> structDecorations = {}); void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional<StorageClass> storage = std::nullopt); @@ -396,6 +429,9 @@ public: llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); +llvm::hash_code +hash_value(const StructType::StructDecorationInfo &structDecorationInfo); + // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase<CooperativeMatrixType, CompositeType, diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 3d22ec9..03ae54a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -39,6 +39,10 @@ struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// Whether to emulate unsupported floats with integer types of same bit + /// width. + bool emulateUnsupportedFloatTypes{true}; + /// How sub-byte values are storaged in memory. SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index a534381b..2513e10 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -380,7 +380,7 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> { After: %3 = memref.load %2[] : memref<f32> - %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32> + %4 = vector.insert %3, %cst [0] : f32 into vector<32xf32> %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) { %8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32> %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32> diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 0a5c1e5..dc55704 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -646,55 +646,6 @@ def Vector_DeinterleaveOp : }]; } -def Vector_ExtractElementOp : - Vector_Op<"extractelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"result type matches element type of vector operand", - "vector", "result", - "::llvm::cast<VectorType>($_self).getElementType()">]>, - Arguments<(ins AnyVectorOfAnyRank:$vector, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyType:$result)> { - let summary = "extractelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.extract insert. - - Takes a 0-D or 1-D vector and a optional dynamic index position and - extracts the scalar at that position. - - Note that this instruction resembles vector.extract, but is restricted to - 0-D and 1-D vectors. - If the vector is 0-D, the position must be std::nullopt. - - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#extractelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %1 = vector.extractelement %0[%c : i32]: vector<16xf32> - %2 = vector.extractelement %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $vector `[` ($position^ `:` type($position))? `]` attr-dict `:` type($vector) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source)>, - ]; - let extraClassDeclaration = [{ - VectorType getSourceVectorType() { - return ::llvm::cast<VectorType>(getVector().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_ExtractOp : Vector_Op<"extract", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, @@ -890,57 +841,6 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [ let hasCanonicalizer = 1; } -def Vector_InsertElementOp : - Vector_Op<"insertelement", [Pure, - DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, - TypesMatchWith<"source operand type matches element type of result", - "result", "source", - "::llvm::cast<VectorType>($_self).getElementType()">, - AllTypesMatch<["dest", "result"]>]>, - Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, - Optional<AnySignlessIntegerOrIndex>:$position)>, - Results<(outs AnyVectorOfAnyRank:$result)> { - let summary = "insertelement operation"; - let description = [{ - Note: This operation is deprecated. Please use vector.insert instead. - - Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index - position and inserts the source into the destination at the proper position. - - Note that this instruction resembles vector.insert, but is restricted to 0-D - and 1-D vectors. - - It is meant to be closer to LLVM's version: - https://llvm.org/docs/LangRef.html#insertelement-instruction - - Example: - - ```mlir - %c = arith.constant 15 : i32 - %f = arith.constant 0.0f : f32 - %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> - %2 = vector.insertelement %f, %z[]: vector<f32> - ``` - }]; - let assemblyFormat = [{ - $source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:` - type($result) - }]; - - let builders = [ - // 0-D builder. - OpBuilder<(ins "Value":$source, "Value":$dest)>, - ]; - let extraClassDeclaration = [{ - Type getSourceType() { return getSource().getType(); } - VectorType getDestVectorType() { - return ::llvm::cast<VectorType>(getDest().getType()); - } - }]; - let hasVerifier = 1; - let hasFolder = 1; -} - def Vector_InsertOp : Vector_Op<"insert", [Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, @@ -2695,6 +2595,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ def Vector_TransposeOp : Vector_Op<"transpose", [Pure, + DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>, DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]> { @@ -2879,6 +2780,10 @@ def Vector_SplatOp : Vector_Op<"splat", [ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; let hasFolder = 1; + + // vector.splat is deprecated, and vector.broadcast should be used instead. + // Canonicalize vector.splat to vector.broadcast. + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -2976,7 +2881,10 @@ def Vector_ScanOp : // VectorStepOp //===----------------------------------------------------------------------===// -def Vector_StepOp : Vector_Op<"step", [Pure]> { +def Vector_StepOp : Vector_Op<"step", [ + Pure, + DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> + ]> { let summary = "A linear sequence of values from 0 to N"; let description = [{ A `step` operation produces an index vector, i.e. a 1-D vector of values of diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 73f6877..38c217f 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -397,8 +397,8 @@ def DotOp : AVX_LowOp<"dot", [Pure, ```mlir %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %0[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %0[%i4 : i32]: vector<8xf32> + %1 = vector.extract %0[%i0] : f32 from vector<8xf32> + %2 = vector.extract %0[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 ``` }]; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 91d6b2a..75b16a87 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead. - Example: + Example 1: ```mlir xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16xf16> ``` + + Example 2: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex> + xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<4xindex> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $source, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]> - ]> { +def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a set of scattered data points from memory."; let description = [{ It (aka. load) load data per each work-item. The output @@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>, vector<16xi1> -> vector<16x8xf32> ``` + Example 3 (SIMT mode): ```mlir %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, @@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>> vector<16xi1> -> vector<8xf32> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". + The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc + for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` }]; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + let arguments = (ins XeGPU_GatherScatterSourceType: $source, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = extraBaseClassDeclaration # [{ + + Type getSourceType() { + return getSource().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getSourceType()); } mlir::Type getElementType() { @@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [ }]; - let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict - `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}]; + let assemblyFormat = [{ + $source + (`[` $offsets^ `]`)? `,` + $mask prop-dict + attr-dict `:` type(operands) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } -def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ - AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]> - ]> { +def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store data to scattered memory locations."; let description = [{ It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be @@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ l3_hint = #xegpu.cache_hint<write_through>}> : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1> ``` + + Example 4: + A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. + It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". + The dest operand could be a raw pointer (uint64_t). + Please refer to create_tdesc for the restriction of memref. + ```mlir + %a = memref.alloc() : memref<1024xf32> + %val = arith.constant dense<0.0> : vector<16xf32> + %offsets = vector.step : vector<16xindex> + %mask = vector.constant_mask [16]: vector<16xi1> + xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>, + l2_hint = #xegpu.cache_hint<cached>, + l3_hint = #xegpu.cache_hint<cached>} + : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32> + ``` + }]; let arguments = (ins XeGPU_ValueType: $value, - XeGPU_TensorDesc: $TensorDesc, + XeGPU_GatherScatterSourceType: $dest, + Optional<XeGPU_OffsetType>: $offsets, XeGPU_MaskType: $mask, + OptionalAttr<I64Attr>: $chunk_size, OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint, OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint); let extraClassDeclaration = extraBaseClassDeclaration # [{ + Type getDestType() { + return getDest().getType(); + } + + TypedValue<xegpu::TensorDescType> getTensorDesc() { + if (auto tdescType = getTensorDescType()) { + return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest()); + } + return TypedValue<xegpu::TensorDescType>(); + } + xegpu::TensorDescType getTensorDescType() { - return getTensorDesc().getType(); + return dyn_cast<xegpu::TensorDescType>(getDestType()); } VectorType getValueType() { @@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}]; + let assemblyFormat = [{ + $value `,` + $dest + (`[` $offsets^ `]`)? `,` + $mask + prop-dict + attr-dict `:` type(operands) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 20916ae..b268cab 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", let genVerifyDecl = 1; } +def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>; def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier."; diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index 4ed0423..7ff718a 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -639,6 +639,10 @@ public: /// verified correctly, failure otherwise. LogicalResult verify(); + /// Register this handler with the given context. This is intended for use + /// with the splitAndProcessBuffer function. + void registerInContext(MLIRContext *ctx); + private: /// Process a single diagnostic. void process(Diagnostic &diag); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index edc8ab4..4f89f8b 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -1125,6 +1125,26 @@ inline raw_ostream &operator<<(raw_ostream &os, return os; } +/// A wrapper class that allows for printing an operation with a custom +/// AsmState, useful to act as a "stream modifier" to customize printing an +/// operation with a stream using the operator<< overload, e.g.: +/// llvm::dbgs() << OpWithState(op, OpPrintingFlags().skipRegions()); +class OpWithState { +public: + OpWithState(Operation *op, AsmState &state) : op(op), theState(state) {} + +private: + Operation *op; + AsmState &theState; + friend raw_ostream &operator<<(raw_ostream &os, const OpWithState &op); +}; + +inline raw_ostream &operator<<(raw_ostream &os, + const OpWithState &opWithState) { + opWithState.op->print(os, const_cast<OpWithState &>(opWithState).theState); + return os; +} + } // namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index 2162a74..8959dab 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -200,7 +200,7 @@ public: // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get<ConcreteT>(ctx, args...); + return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...); } /// Get an instance of the concrete type from a void pointer. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index a8b04d0..bbfa308 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -55,19 +55,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Returns true if this symbol has nested visibility.", "bool", "isNested", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Nested; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Private; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Private; }] >, InterfaceMethod<"Returns true if this symbol has public visibility.", "bool", "isPublic", (ins), [{}], /*defaultImplementation=*/[{ - return getVisibility() == mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() == mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<"Sets the visibility of this symbol.", @@ -79,19 +79,19 @@ def Symbol : OpInterface<"SymbolOpInterface"> { InterfaceMethod<"Sets the visibility of this symbol to be nested.", "void", "setNested", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Nested); + $_op.setVisibility(mlir::SymbolTable::Visibility::Nested); }] >, InterfaceMethod<"Sets the visibility of this symbol to be private.", "void", "setPrivate", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Private); + $_op.setVisibility(mlir::SymbolTable::Visibility::Private); }] >, InterfaceMethod<"Sets the visibility of this symbol to be public.", "void", "setPublic", (ins), [{}], /*defaultImplementation=*/[{ - setVisibility(mlir::SymbolTable::Visibility::Public); + $_op.setVisibility(mlir::SymbolTable::Visibility::Public); }] >, InterfaceMethod<[{ @@ -144,7 +144,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { // By default, base this on the visibility alone. A symbol can be // discarded as long as it is not public. Only public symbols may be // visible from outside of the IR. - return getVisibility() != ::mlir::SymbolTable::Visibility::Public; + return $_op.getVisibility() != ::mlir::SymbolTable::Visibility::Public; }] >, InterfaceMethod<[{ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 856170e..7628171 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,200 +14,15 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ -#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" -#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" -#include "mlir/Dialect/Async/IR/Async.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/DLTI/DLTI.h" -#include "mlir/Dialect/EmitC/IR/EmitC.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/IRDL/IR/IRDL.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/LLVMIR/XeVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/MPI/IR/MPI.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" -#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" -#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/Dialect/OpenACC/OpenACC.h" -#include "mlir/Dialect/OpenMP/OpenMPDialect.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" -#include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" -#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/SMT/IR/SMTDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Shard/IR/ShardDialect.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" -#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/UB/IR/UBOps.h" -#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/IR/Dialect.h" -#include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/Target/LLVM/ROCDL/Target.h" -#include "mlir/Target/SPIRV/Target.h" - namespace mlir { +class DialectRegistry; +class MLIRContext; /// Add all the MLIR dialects to the provided registry. -inline void registerAllDialects(DialectRegistry ®istry) { - // clang-format off - registry.insert<acc::OpenACCDialect, - affine::AffineDialect, - amdgpu::AMDGPUDialect, - amx::AMXDialect, - arith::ArithDialect, - arm_neon::ArmNeonDialect, - arm_sme::ArmSMEDialect, - arm_sve::ArmSVEDialect, - async::AsyncDialect, - bufferization::BufferizationDialect, - cf::ControlFlowDialect, - complex::ComplexDialect, - DLTIDialect, - emitc::EmitCDialect, - func::FuncDialect, - gpu::GPUDialect, - index::IndexDialect, - irdl::IRDLDialect, - linalg::LinalgDialect, - LLVM::LLVMDialect, - math::MathDialect, - memref::MemRefDialect, - shard::ShardDialect, - ml_program::MLProgramDialect, - mpi::MPIDialect, - nvgpu::NVGPUDialect, - NVVM::NVVMDialect, - omp::OpenMPDialect, - pdl::PDLDialect, - pdl_interp::PDLInterpDialect, - ptr::PtrDialect, - quant::QuantDialect, - ROCDL::ROCDLDialect, - scf::SCFDialect, - shape::ShapeDialect, - smt::SMTDialect, - sparse_tensor::SparseTensorDialect, - spirv::SPIRVDialect, - tensor::TensorDialect, - tosa::TosaDialect, - transform::TransformDialect, - ub::UBDialect, - vector::VectorDialect, - x86vector::X86VectorDialect, - xegpu::XeGPUDialect, - xevm::XeVMDialect>(); - // clang-format on - - // Register all external models. - affine::registerValueBoundsOpInterfaceExternalModels(registry); - arith::registerBufferDeallocationOpInterfaceExternalModels(registry); - arith::registerBufferizableOpInterfaceExternalModels(registry); - arith::registerBufferViewFlowOpInterfaceExternalModels(registry); - arith::registerShardingInterfaceExternalModels(registry); - arith::registerValueBoundsOpInterfaceExternalModels(registry); - bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( - registry); - builtin::registerCastOpInterfaceExternalModels(registry); - cf::registerBufferizableOpInterfaceExternalModels(registry); - cf::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); - gpu::registerValueBoundsOpInterfaceExternalModels(registry); - LLVM::registerInlinerInterface(registry); - NVVM::registerInlinerInterface(registry); - linalg::registerAllDialectInterfaceImplementations(registry); - linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerAllocationOpInterfaceExternalModels(registry); - memref::registerBufferViewFlowOpInterfaceExternalModels(registry); - memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - memref::registerValueBoundsOpInterfaceExternalModels(registry); - memref::registerMemorySlotExternalModels(registry); - ml_program::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerBufferDeallocationOpInterfaceExternalModels(registry); - scf::registerBufferizableOpInterfaceExternalModels(registry); - scf::registerValueBoundsOpInterfaceExternalModels(registry); - shape::registerBufferizableOpInterfaceExternalModels(registry); - sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerBufferizableOpInterfaceExternalModels(registry); - tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); - tensor::registerInferTypeOpInterfaceExternalModels(registry); - tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); - tensor::registerSubsetOpInterfaceExternalModels(registry); - tensor::registerTilingInterfaceExternalModels(registry); - tensor::registerValueBoundsOpInterfaceExternalModels(registry); - tosa::registerShardingInterfaceExternalModels(registry); - vector::registerBufferizableOpInterfaceExternalModels(registry); - vector::registerSubsetOpInterfaceExternalModels(registry); - vector::registerValueBoundsOpInterfaceExternalModels(registry); - NVVM::registerNVVMTargetInterfaceExternalModels(registry); - ROCDL::registerROCDLTargetInterfaceExternalModels(registry); - spirv::registerSPIRVTargetInterfaceExternalModels(registry); -} +void registerAllDialects(DialectRegistry ®istry); /// Append all the MLIR dialects to the registry contained in the given context. -inline void registerAllDialects(MLIRContext &context) { - DialectRegistry registry; - registerAllDialects(registry); - context.appendDialectRegistry(registry); -} +void registerAllDialects(MLIRContext &context); } // namespace mlir diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index d5a9a2c..a7f64d9 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -14,110 +14,15 @@ #ifndef MLIR_INITALLEXTENSIONS_H_ #define MLIR_INITALLEXTENSIONS_H_ -#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" -#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" -#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" -#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" -#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" -#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" -#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" -#include "mlir/Dialect/AMX/Transforms.h" -#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" -#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" -#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" -#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" -#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" -#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" -#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" -#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" -#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" -#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" -#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" -#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" -#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" -#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" -#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" -#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" -#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" -#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" -#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" - -#include <cstdlib> - namespace mlir { +class DialectRegistry; /// This function may be called to register all MLIR dialect extensions with the /// provided registry. /// If you're building a compiler, you generally shouldn't use this: you would /// individually register the specific extensions that are useful for the /// pipelines and transformations you are using. -inline void registerAllExtensions(DialectRegistry ®istry) { - // Register all conversions to LLVM extensions. - registerConvertArithToEmitCInterface(registry); - arith::registerConvertArithToLLVMInterface(registry); - registerConvertComplexToLLVMInterface(registry); - cf::registerConvertControlFlowToLLVMInterface(registry); - func::registerAllExtensions(registry); - tensor::registerAllExtensions(registry); - registerConvertFuncToEmitCInterface(registry); - registerConvertFuncToLLVMInterface(registry); - index::registerConvertIndexToLLVMInterface(registry); - registerConvertMathToLLVMInterface(registry); - mpi::registerConvertMPIToLLVMInterface(registry); - registerConvertMemRefToEmitCInterface(registry); - registerConvertMemRefToLLVMInterface(registry); - registerConvertNVVMToLLVMInterface(registry); - registerConvertOpenMPToLLVMInterface(registry); - registerConvertSCFToEmitCInterface(registry); - ub::registerConvertUBToLLVMInterface(registry); - registerConvertAMXToLLVMInterface(registry); - gpu::registerConvertGpuToLLVMInterface(registry); - NVVM::registerConvertGpuToNVVMInterface(registry); - vector::registerConvertVectorToLLVMInterface(registry); - registerConvertXeVMToLLVMInterface(registry); - - // Register all transform dialect extensions. - affine::registerTransformDialectExtension(registry); - bufferization::registerTransformDialectExtension(registry); - dlti::registerTransformDialectExtension(registry); - func::registerTransformDialectExtension(registry); - gpu::registerTransformDialectExtension(registry); - linalg::registerTransformDialectExtension(registry); - memref::registerTransformDialectExtension(registry); - nvgpu::registerTransformDialectExtension(registry); - scf::registerTransformDialectExtension(registry); - sparse_tensor::registerTransformDialectExtension(registry); - tensor::registerTransformDialectExtension(registry); - transform::registerDebugExtension(registry); - transform::registerIRDLExtension(registry); - transform::registerLoopExtension(registry); - transform::registerPDLExtension(registry); - transform::registerTuneExtension(registry); - vector::registerTransformDialectExtension(registry); - arm_neon::registerTransformDialectExtension(registry); - arm_sve::registerTransformDialectExtension(registry); - - // Translation extensions need to be registered by calling - // `registerAllToLLVMIRTranslations` (see All.h). -} +void registerAllExtensions(DialectRegistry ®istry); } // namespace mlir diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index 002ff61..4554290 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -1,4 +1,4 @@ -//===- LinkAllPassesAndDialects.h - MLIR Registration -----------*- C++ -*-===// +//===- InitAllPasses.h - MLIR Registration ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,50 +6,14 @@ // //===----------------------------------------------------------------------===// // -// This file defines a helper to trigger the registration of all dialects and -// passes to the system. +// This file defines a helper to trigger the registration of all passes to the +// system. // //===----------------------------------------------------------------------===// #ifndef MLIR_INITALLPASSES_H_ #define MLIR_INITALLPASSES_H_ -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/Dialect/ArmSME/Transforms/Passes.h" -#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" -#include "mlir/Dialect/Async/Passes.h" -#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/EmitC/Transforms/Passes.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/GPU/Pipelines/Passes.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MLProgram/Transforms/Passes.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/NVGPU/Transforms/Passes.h" -#include "mlir/Dialect/OpenACC/Transforms/Passes.h" -#include "mlir/Dialect/Quant/Transforms/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Dialect/SPIRV/Transforms/Passes.h" -#include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Dialect/Shard/Transforms/Passes.h" -#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" -#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Transform/Transforms/Passes.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" -#include "mlir/Dialect/XeGPU/Transforms/Passes.h" -#include "mlir/Transforms/Passes.h" - -#include <cstdlib> - namespace mlir { // This function may be called to register the MLIR passes with the @@ -59,49 +23,7 @@ namespace mlir { // registry, since it would already be calling the creation routine of the // individual passes. // The global registry is interesting to interact with the command-line tools. -inline void registerAllPasses() { - // General passes - registerTransformsPasses(); - - // Conversion passes - registerConversionPasses(); - - // Dialect passes - acc::registerOpenACCPasses(); - affine::registerAffinePasses(); - amdgpu::registerAMDGPUPasses(); - registerAsyncPasses(); - arith::registerArithPasses(); - bufferization::registerBufferizationPasses(); - func::registerFuncPasses(); - registerGPUPasses(); - registerLinalgPasses(); - registerNVGPUPasses(); - registerSparseTensorPasses(); - LLVM::registerLLVMPasses(); - math::registerMathPasses(); - memref::registerMemRefPasses(); - shard::registerShardPasses(); - ml_program::registerMLProgramPasses(); - quant::registerQuantPasses(); - registerSCFPasses(); - registerShapePasses(); - spirv::registerSPIRVPasses(); - tensor::registerTensorPasses(); - tosa::registerTosaOptPasses(); - transform::registerTransformPasses(); - vector::registerVectorPasses(); - arm_sme::registerArmSMEPasses(); - arm_sve::registerArmSVEPasses(); - emitc::registerEmitCPasses(); - xegpu::registerXeGPUPasses(); - - // Dialect pipelines - bufferization::registerBufferizationPipelines(); - sparse_tensor::registerSparseTensorPipelines(); - tosa::registerTosaToLinalgPipelines(); - gpu::registerGPUToNVVMPipeline(); -} +void registerAllPasses(); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index e3c2aec..19d3afe 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -18,9 +18,15 @@ include "mlir/IR/OpBase.td" -/// Interface for operations with arguments attributes (both call-like -/// and callable operations). -def ArgumentAttributesMethods { +/// Interface for operations with result and argument attributes. +def ArgAndResultAttrsOpInterface : OpInterface<"ArgAndResultAttrsOpInterface"> { + let description = [{ + An operation that has argument and result attributes. This interface + provides functions to access and modify the argument and result + attributes of the operation. + }]; + let cppNamespace = "::mlir"; + list<InterfaceMethod> methods = [ InterfaceMethod<[{ Get the array of argument attribute dictionaries. The method should @@ -64,7 +70,8 @@ def ArgumentAttributesMethods { // a call-like operation. This represents the destination of the call. /// Interface for call-like operations. -def CallOpInterface : OpInterface<"CallOpInterface"> { +def CallOpInterface : OpInterface<"CallOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A call-like operation is one that transfers control from one sub-routine to another. These operations may be traditional direct calls `call @foo`, or @@ -123,11 +130,12 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { return ::mlir::call_interface_impl::resolveCallable($_op); }] > - ] # ArgumentAttributesMethods.methods; + ]; } /// Interface for callable operations. -def CallableOpInterface : OpInterface<"CallableOpInterface"> { +def CallableOpInterface : OpInterface<"CallableOpInterface", + [ArgAndResultAttrsOpInterface]> { let description = [{ A callable operation is one who represents a potential sub-routine, and may be a target for a call-like operation (those providing the CallOpInterface @@ -140,11 +148,11 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { let methods = [ InterfaceMethod<[{ - Returns the region on the current operation that is callable. This may - return null in the case of an external callable object, e.g. an external - function. - }], - "::mlir::Region *", "getCallableRegion">, + Returns the region on the current operation that is callable. This may + return null in the case of an external callable object, e.g. an external + function. + }], + "::mlir::Region *", "getCallableRegion">, InterfaceMethod<[{ Returns the callable's argument types based exclusively on the type (to allow for this method may be called on function declarations). @@ -155,7 +163,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { allow for this method may be called on function declarations). }], "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">, - ] # ArgumentAttributesMethods.methods; + ]; } #endif // MLIR_INTERFACES_CALLINTERFACES diff --git a/mlir/include/mlir/Support/ToolUtilities.h b/mlir/include/mlir/Support/ToolUtilities.h index cb6ba29..657f117 100644 --- a/mlir/include/mlir/Support/ToolUtilities.h +++ b/mlir/include/mlir/Support/ToolUtilities.h @@ -21,10 +21,16 @@ namespace llvm { class MemoryBuffer; +class MemoryBufferRef; } // namespace llvm namespace mlir { +// A function that processes a chunk of a buffer and writes the result to an +// output stream. using ChunkBufferHandler = function_ref<LogicalResult( + std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &sourceBuffer, raw_ostream &os)>; +using NoSourceChunkBufferHandler = function_ref<LogicalResult( std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os)>; extern inline const char *const kDefaultSplitMarker = "// -----"; @@ -45,6 +51,15 @@ splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker = kDefaultSplitMarker, llvm::StringRef outputSplitMarker = ""); + +/// Same as above, but for case where the original buffer is not used while +/// processing the chunk. +LogicalResult +splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, + llvm::StringRef inputSplitMarker = kDefaultSplitMarker, + llvm::StringRef outputSplitMarker = ""); } // namespace mlir #endif // MLIR_SUPPORT_TOOLUTILITIES_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index 60615cf6..e4670cb 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -28,6 +28,7 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -47,6 +48,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); @@ -63,6 +65,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry ®istry) { registerNVVMDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); + registerXeVMDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h new file mode 100644 index 0000000..b4f6750 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//===-- XeVMToLLVMIRTranslation.h - XeVM to LLVM IR -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for XeVM dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the XeVM dialect and the translation from it to the LLVM IR in the +/// given registry; +void registerXeVMDialectTranslation(mlir::DialectRegistry ®istry); + +/// Register the XeVM dialect and the translation from it in the registry +/// associated with the given context. +void registerXeVMDialectTranslation(mlir::MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_XEVM_XEVMTOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 17ef8e4..09d819a 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -83,6 +83,10 @@ public: /// specification. void convertTargetTriple(); + /// Converts the module level asm of the LLVM module to an MLIR module + /// level asm specification. + void convertModuleLevelAsm(); + /// Stores the mapping between an LLVM value and its MLIR counterpart. void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; } @@ -291,10 +295,12 @@ public: SmallVectorImpl<Value> &valuesOut, SmallVectorImpl<NamedAttribute> &attrsOut); - /// Converts the parameter and result attributes in `argsAttr` and `resAttr` - /// and add them to the `callOp`. - void convertParameterAttributes(llvm::CallBase *call, ArrayAttr &argsAttr, - ArrayAttr &resAttr, OpBuilder &builder); + /// Converts the argument and result attributes attached to `call` and adds + /// them to `attrsOp`. For intrinsic calls, filters out attributes + /// corresponding to immediate arguments specified by `immArgPositions`. + void convertArgAndResultAttrs(llvm::CallBase *call, + ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions = {}); /// Whether the importer should try to convert all intrinsics to /// llvm.call_intrinsic instead of dialect supported operations. @@ -378,19 +384,12 @@ private: bool &isIncompatibleCall); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); - /// Converts the parameter and result attributes attached to `func` and adds + /// Converts the argument and result attributes attached to `func` and adds /// them to the `funcOp`. - void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp, - OpBuilder &builder); - /// Converts the AttributeSet of one parameter in LLVM IR to a corresponding - /// DictionaryAttr for the LLVM dialect. - DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder); - /// Converts the parameter and result attributes attached to `call` and adds - /// them to the `callOp`. Implemented in terms of the the public definition of - /// convertParameterAttributes. - void convertParameterAttributes(llvm::CallBase *call, CallOpInterface callOp, - OpBuilder &builder); + void convertArgAndResultAttrs(llvm::Function *func, LLVMFuncOp funcOp); + /// Converts the argument or result attributes in `llvmAttrSet` to a + /// corresponding MLIR LLVM dialect attribute dictionary. + DictionaryAttr convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet); /// Converts the attributes attached to `inst` and adds them to the `op`. LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op); /// Converts the attributes attached to `inst` and adds them to the `op`. diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index f3f73f4..eb7dfa7 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -25,11 +25,13 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "llvm/ADT/SetVector.h" -#include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/FPEnv.h" +#include "llvm/IR/Module.h" namespace llvm { class BasicBlock; +class CallBase; +class CanonicalLoopInfo; class Function; class IRBuilderBase; class OpenMPIRBuilder; @@ -306,10 +308,16 @@ public: /*recordInsertions=*/false); } - /// Translates parameter attributes of a call and adds them to the returned - /// AttrBuilder. Returns failure if any of the translations failed. - FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, - DictionaryAttr paramAttrs); + /// Converts argument and result attributes from `attrsOp` to LLVM IR + /// attributes on the `call` instruction. Returns failure if conversion fails. + /// The `immArgPositions` parameter is only relevant for intrinsics. It + /// specifies the positions of immediate arguments, which do not have + /// associated argument attributes in MLIR and should be skipped during + /// attribute mapping. + LogicalResult + convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp, + llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions = {}); /// Gets the named metadata in the LLVM IR module being constructed, creating /// it if it does not exist. @@ -389,6 +397,11 @@ private: convertDialectAttributes(Operation *op, ArrayRef<llvm::Instruction *> instructions); + /// Translates parameter attributes of a call and adds them to the returned + /// AttrBuilder. Returns failure if any of the translations failed. + FailureOr<llvm::AttrBuilder> convertParameterAttrs(mlir::Location loc, + DictionaryAttr paramAttrs); + /// Translates parameter attributes of a function and adds them to the /// returned AttrBuilder. Returns failure if any of the translations failed. FailureOr<llvm::AttrBuilder> diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 9f4a87a..8b14e71 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, nestedPunctuation.pop_back(); return success(); }; + const char *curBufferEnd = state.lex.getBufferEnd(); do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, break; } + if (curBufferEnd == curPtr) { + if (!nestedPunctuation.empty()) + return emitPunctError(); + return emitError("unexpected nul or EOF in pretty dialect name"); + } + char c = *curPtr++; switch (c) { case '\0': diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp index 751bd63..8f53529 100644 --- a/mlir/lib/AsmParser/Lexer.cpp +++ b/mlir/lib/AsmParser/Lexer.cpp @@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, AsmParserCodeCompleteContext *codeCompleteContext) : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); + + // Check to see if the main buffer contains the last buffer, and if so the + // last buffer should be used as main file for parsing. + if (sourceMgr.getNumBuffers() > 1) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + if (main->getBufferStart() <= last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + bufferID = lastFileID; + } + } curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); @@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) { } Token Lexer::lexToken() { + const char *curBufferEnd = curBuffer.end(); while (true) { const char *tokStart = curPtr; @@ -78,6 +91,9 @@ Token Lexer::lexToken() { if (tokStart == codeCompleteLoc) return formToken(Token::code_complete, tokStart); + if (tokStart == curBufferEnd) + return formToken(Token::eof, tokStart); + // Lex the next token. switch (*curPtr++) { default: @@ -102,7 +118,7 @@ Token Lexer::lexToken() { case 0: // This may either be a nul character in the source file or may be the EOF // marker that llvm::MemoryBuffer guarantees will be there. - if (curPtr - 1 == curBuffer.end()) + if (curPtr - 1 == curBufferEnd) return formToken(Token::eof, tokStart); continue; @@ -259,7 +275,11 @@ void Lexer::skipComment() { assert(*curPtr == '/'); ++curPtr; + const char *curBufferEnd = curBuffer.end(); while (true) { + if (curPtr == curBufferEnd) + return; + switch (*curPtr++) { case '\n': case '\r': @@ -267,7 +287,7 @@ void Lexer::skipComment() { return; case 0: // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { + if (curPtr - 1 == curBufferEnd) { --curPtr; return; } @@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); + const char *curBufferEnd = curBuffer.end(); while (true) { // Check to see if there is a code completion location within the string. In // these cases we generate a completion location and place the currently @@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) { case 0: // If this is a random nul character in the middle of a string, just // include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) + if (curPtr - 1 != curBufferEnd) continue; [[fallthrough]]; case '\n': diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h index 4085a9b..670444e 100644 --- a/mlir/lib/AsmParser/Lexer.h +++ b/mlir/lib/AsmParser/Lexer.h @@ -40,6 +40,9 @@ public: /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Returns the end of the buffer. + const char *getBufferEnd() { return curBuffer.end(); } + /// Return the code completion location of the lexer, or nullptr if there is /// none. const char *getCodeCompleteLoc() const { return codeCompleteLoc; } diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 8b9a395..ccda668 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,19 +1,16 @@ # Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC - ${dialect_libs} ${translation_libs} - ${conversion_libs} - ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index d25c84a..191b5ab6 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -20,3 +20,37 @@ add_subdirectory(Target) add_subdirectory(Tools) add_subdirectory(Transforms) add_subdirectory(ExecutionEngine) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +add_mlir_library(MLIRRegisterAllDialects + RegisterAllDialects.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ) + +add_mlir_library(MLIRRegisterAllPasses + RegisterAllPasses.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} # Some passes are part of the dialect libs + ${conversion_libs} + ) + +add_mlir_library(MLIRRegisterAllExtensions + RegisterAllExtensions.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e681..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 75e6563..3545acb 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite( if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>( spirv::getTargetEnvAttrName())) spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = + dyn_cast<spirv::TargetEnvAttr>(targetAttr)) { + spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr); + break; + } + } rewriter.eraseOp(moduleOp); return success(); @@ -507,25 +515,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index a344f88..5eab057 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -48,9 +48,36 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> { void runOnOperation() override; private: + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp`. + spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp); + + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp` or returns target environment as returned by + /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'. + spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp); bool mapMemorySpace; }; +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) { + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr)) + return spirvTargetEnvAttr; + } + + return {}; +} + +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) { + if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp)) + return targetEnvAttr; + + return spirv::lookupTargetEnvOrDefault(moduleOp); +} + void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -58,9 +85,8 @@ void GPUToSPIRVPass::runOnOperation() { SmallVector<Operation *, 1> gpuModules; OpBuilder builder(context); - auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) { - Operation *gpuModule = moduleOp.getOperation(); - auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule); + auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) { + auto targetAttr = lookupTargetEnvOrDefault(moduleOp); spirv::TargetEnv targetEnv(targetAttr); return targetEnv.allows(spirv::Capability::Kernel); }; @@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() { // TargetEnv attributes. for (Operation *gpuModule : gpuModules) { spirv::TargetEnvAttr targetAttr = - spirv::lookupTargetEnvOrDefault(gpuModule); + lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule)); // Map MemRef memory space to SPIR-V storage class first if requested. if (mapMemorySpace) { diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 855c582..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -22,7 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -32,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index a877ad2..1787e0a 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -488,7 +488,12 @@ namespace mlir { void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns - patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); + patterns + .add<CopySignPattern, + CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>, + CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>, + CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>( + typeConverter, patterns.getContext()); // GLSL patterns patterns diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 6ba5bfe4..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 5d13353..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -26,13 +26,12 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 807be7e..ba448e4 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> { } // namespace +static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) { + // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the + // llvm.loop_annotation attribute. + // LLVM requires the loop metadata to be attached on the "latch" block. Which + // is the back-edge to the header block (conditionBlock) + SmallVector<NamedAttribute> llvmAttrs; + llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs), + [](auto attr) { + return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); + }); + brOp->setDiscardableAttrs(llvmAttrs); +} + LogicalResult ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); @@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, auto branchOp = cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); - // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the - // llvm.loop_annotation attribute. - // LLVM requires the loop metadata to be attached on the "latch" block. Which - // is the back-edge to the header block (conditionBlock) - SmallVector<NamedAttribute> llvmAttrs; - llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs), - [](auto attr) { - return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); - }); - branchOp->setDiscardableAttrs(llvmAttrs); - + propagateLoopAttrs(forOp, branchOp); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. @@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, rewriter.setInsertionPointToEnd(after); auto yieldOp = cast<scf::YieldOp>(after->getTerminator()); - rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, - yieldOp.getResults()); + auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, + yieldOp.getResults()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, args); @@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); - cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(), - before, condOp.getArgs(), continuation, - ValueRange()); + auto latch = cf::CondBranchOp::create( + rewriter, condOp.getLoc(), condOp.getCondition(), before, + condOp.getArgs(), continuation, ValueRange()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, condOp.getArgs()); diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index fd40e7c..fa9e544 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -36,7 +36,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "shard-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { #define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index a425eff..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } @@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } @@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); @@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, } scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 4307bc6..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1070,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1206,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index b1af5f0..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8d7053c..22608a1 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -26,7 +26,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include <numeric> @@ -40,7 +40,6 @@ using llvm::divideFloorSigned; using llvm::mod; #define DEBUG_TYPE "affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" @@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineMap *map, ValueRange dims, ValueRange syms) { + LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`"; AffineMap affineMinMap = minOp.getAffineMap(); - LLVM_DEBUG({ - DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n"; - }); - // Check the value is positive. for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) { // Compare each expression in the minimum against 0. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d1d1062..aa53f94 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,12 +9,13 @@ // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. +// implementations for FuncOp, CallOp and ReturnOp. Although it is named +// Module Bufferization, it may operate on any SymbolTable. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis -// and bufferization: Functions that are called are processed before their -// respective callers. +// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp, +// ...)`. This function analyzes the given op and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. @@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, + Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, +mlir::bufferization::analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, } void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps<func::FuncOp>()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); + Operation *moduleOp) { + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(moduleOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index f999c93..a6159ee 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics))) + if (failed(analyzeModuleOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4c09022..e6a3154 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { //===----------------------------------------------------------------------===// // FieldOp //===----------------------------------------------------------------------===// +static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, + TypeAttr type, + Attribute initialValue) { + p << type; + if (initialValue) { + p << " = "; + p.printAttributeWithoutType(initialValue); + } +} + +static Type getInitializerTypeForField(Type type) { + if (auto array = llvm::dyn_cast<ArrayType>(type)) + return RankedTensorType::get(array.getShape(), array.getElementType()); + return type; +} + +static ParseResult +parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &initialValue) { + Type type; + if (parser.parseType(type)) + return failure(); + + typeAttr = TypeAttr::get(type); + + if (parser.parseOptionalEqual()) + return success(); + + if (parser.parseAttribute(initialValue, getInitializerTypeForField(type))) + return failure(); + + if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>( + initialValue)) + return parser.emitError(parser.getNameLoc()) + << "initial value should be a integer, float, elements or opaque " + "attribute"; + return success(); +} + LogicalResult FieldOp::verify() { if (!isSupportedEmitCType(getType())) return emitOpError("expected valid emitc type"); @@ -1410,9 +1449,6 @@ LogicalResult FieldOp::verify() { if (!symName || symName.getValue().empty()) return emitOpError("field must have a non-empty symbol name"); - if (!getAttrs()) - return success(); - return success(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index fa05ad8..c55e26e 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -58,17 +58,18 @@ public: auto argAttrs = funcOp.getArgAttrs(); for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { - StringAttr fieldName; - Attribute argAttr = nullptr; - - fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); - if (argAttrs && idx < argAttrs->size()) - argAttr = (*argAttrs)[idx]; + StringAttr fieldName = + rewriter.getStringAttr("fieldName" + std::to_string(idx)); TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr, - argAttr); + + FieldOp fieldop = rewriter.create<emitc::FieldOp>( + funcOp->getLoc(), fieldName, typeAttr, nullptr); + + if (argAttrs && idx < argAttrs->size()) { + fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx)); + } } rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d186a48..5a72ef1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, // RotateOp //===----------------------------------------------------------------------===// -void RotateOp::build(OpBuilder &builder, OperationState &result, Value value, - int32_t offset, int32_t width) { - build(builder, result, value, - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(offset)), - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(width))); -} - LogicalResult RotateOp::verify() { - auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>(); - if (!offsetConstOp) - return emitOpError() << "offset is not a constant value"; - - auto offsetIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue()); - - auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>(); - if (!widthConstOp) - return emitOpError() << "width is not a constant value"; - - auto widthIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue()); - - llvm::APInt offsetValue = offsetIntAttr.getValue(); - llvm::APInt widthValue = widthIntAttr.getValue(); - - if (!widthValue.isPowerOf2()) - return emitOpError() << "width must be a power of two"; + uint32_t offset = getOffset(); + uint32_t width = getWidth(); - if (offsetValue.sge(widthValue) || offsetValue.slt(0)) { - int64_t widthValueInt = widthValue.getSExtValue(); - return emitOpError() << "offset must be in the range [0, " << widthValueInt - << ")"; + if (offset >= width) { + return emitOpError() << "offset must be in the range [0, " << width << ")"; } return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index cffe310..e0977f5 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" @@ -812,15 +813,26 @@ LogicalResult NVVM::LdMatrixOp::verify() { } LogicalResult NVVM::StMatrixOp::verify() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - int numMatrix = getSources().size(); if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4) return emitOpError("expected num attribute to be 1, 2 or 4"); + int m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (getEltType() != NVVM::LdStMatrixEltType::B16) { + return emitOpError("expected element type to be B16 for 8x8 matrix"); + } + } else if (m == 16 && n == 8) { + if (getEltType() != NVVM::LdStMatrixEltType::B8) { + return emitOpError("expected element type to be B8 for 16x8 matrix"); + } + if (getLayout() != NVVM::MMALayout::col) { + return emitOpError("expected layout to be col for 16x8 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8 or 16x8"); + } + return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 935aa3c..b951df8 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -22,6 +22,8 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + #define DEBUG_TYPE "llvm-inliner" using namespace mlir; @@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { bool wouldBeCloned) const final { auto callOp = dyn_cast<LLVM::CallOp>(call); if (!callOp) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '" - << LLVM::CallOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: call is not an '" + << LLVM::CallOp::getOperationName() << "' op"; return false; } if (callOp.getNoInline()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n"); + LDBG() << "Cannot inline: call is marked no_inline"; return false; } auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable); if (!funcOp) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: callable is not an '" - << LLVM::LLVMFuncOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: callable is not an '" + << LLVM::LLVMFuncOp::getOperationName() << "' op"; return false; } if (funcOp.isNoInline()) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: function is marked no_inline\n"); + LDBG() << "Cannot inline: function is marked no_inline"; return false; } if (funcOp.isVarArg()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n"); + LDBG() << "Cannot inline: callable is variadic"; return false; } // TODO: Generate aliasing metadata from noalias result attributes. if (auto attrs = funcOp.getArgAttrs()) { for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) { if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": inalloca arguments not supported\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": inalloca arguments not supported"; return false; } } } // TODO: Handle exceptions. if (funcOp.getPersonality()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled function personality\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": unhandled function personality"; return false; } if (funcOp.getPassthrough()) { @@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { if (!stringAttr) return false; if (disallowedFunctionAttrs.contains(stringAttr)) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline " << funcOp.getSymName() - << ": found disallowed function attribute " - << stringAttr << "\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": found disallowed function attribute " << stringAttr; return true; } return false; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f49d9a1..73ae029 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, SmallVector<unsigned, 2>(ac.begin(), ac.end()), SmallVector<unsigned, 2>(bc.begin(), bc.end()), SmallVector<unsigned, 2>(ra.begin(), ra.end())}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.m.begin(), dimensions.m.end()); - llvm::sort(dimensions.n.begin(), dimensions.n.end()); - llvm::sort(dimensions.k.begin(), dimensions.k.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.m); + llvm::sort(dimensions.n); + llvm::sort(dimensions.k); return dimensions; } @@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp, SmallVector<unsigned, 2>(depth.begin(), depth.end()), /*strides=*/SmallVector<int64_t, 2>{}, /*dilations=*/SmallVector<int64_t, 2>{}}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end()); - llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end()); - llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end()); - llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end()); - llvm::sort(dimensions.depth.begin(), dimensions.depth.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.outputImage); + llvm::sort(dimensions.outputChannel); + llvm::sort(dimensions.filterLoop); + llvm::sort(dimensions.inputChannel); + llvm::sort(dimensions.depth); // Use the op carried strides/dilations attribute if present. auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b56a212..34c63d3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +/// Fold back-to-back broadcasts together. +struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> { + using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>(); + if (!defBroadcastOp) + return failure(); + ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions(); + ArrayRef<int64_t> dimensions = broadcastOp.getDimensions(); + SmallVector<int64_t> foldedDims(dimensions); + Value init = broadcastOp.getInit(); + int64_t initRank = cast<ShapedType>(init.getType()).getRank(); + // Mapping from input dims to init dims. + SmallVector<int64_t> dimMap; + for (auto dim : llvm::seq<int64_t>(0, initRank)) { + if (!llvm::is_contained(dimensions, dim)) + dimMap.push_back(dim); + } + for (auto dim : defDimensions) + foldedDims.push_back(dimMap[dim]); + + llvm::sort(foldedDims); + rewriter.replaceOpWithNewOp<BroadcastOp>( + broadcastOp, defBroadcastOp.getInput(), init, foldedDims); + return success(); + } +}; + void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); + results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 7f9ba1b..bf66ed0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { } ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); + ArrayRef<int64_t> resultShape = padOp.getResultType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { @@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { allowedUnitDims.end()); llvm::SmallDenseSet<unsigned> unitDims; SmallVector<int64_t> newShape; + SmallVector<int64_t> newResultShape; SmallVector<OpFoldResult> newLowPad; SmallVector<OpFoldResult> newHighPad; - for (const auto [dim, size, low, high] : - zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, - padOp.getMixedLowPad(), padOp.getMixedHighPad())) { + for (const auto [dim, size, outSize, low, high] : zip_equal( + llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, + resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); + newResultShape.push_back(outSize); newLowPad.push_back(low); newHighPad.push_back(high); } @@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, + auto newResultType = RankedTensorType::get( + newResultShape, padOp.getResultType().getElementType()); + auto newPadOp = rewriter.create<tensor::PadOp>( + padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2c62cb6..2e62523 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, return paddingSizes; } +/// Extracts the constant multiplier from an affine expression of the form +/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an +/// AffineConstantExpr. Returns 1 if the expression is not a simple +/// multiplication of a dimension and a constant. +static int64_t extractConstantMultiplier(AffineExpr expr) { + if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) { + if (binOp.getKind() == AffineExprKind::Mul) { + auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS()); + auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS()); + if (lhsD && rhsC) { + return rhsC.getValue(); + } + auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS()); + auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS()); + if (lhsC && rhsD) { + return lhsC.getValue(); + } + } + } + return 1; +} + /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases @@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permutation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> linalg::computePaddedShape( @@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( /*compressDims=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. + OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); - terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); - terms.push_back(paddingDimOfr); } + // Adjust for the maximum accessed index, which is (paddingSize - 1) * + // multiplier. + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); + AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); + OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( + rewriter, loc, subtractMap, {paddingDimOfr}); + terms.push_back(maxAccessIdx); + LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } @@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; - OpFoldResult paddedDimOfr = - affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms); + // Add 1 to the maximum accessed index and get the final padded size. + OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index dad3526..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 793eec7..0860cea 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, return success(); } +/// Given the re-associations, "collapses" the input Vector type +/// +/// This is similar to CollapseShapeOp::inferCollapsedType with two notable +/// differences: +/// * We can safely assume that there are no dynamic sizes. +/// * Scalable flags are updated alongside regular dims. +/// +/// When collapsing scalable flags, conservatively avoids cases with two +/// scalable dims. We could re-visit this in the future. +/// +/// EXAMPLE: +/// type = vector<4x16x[8]x16xf32> +/// reassociation = [(d0, d1, d2, d3) -> (d0, d1), +/// (d0, d1, d2, d3) -> (d2, d3)] +/// Result: +/// vector<64x[128]xf32> +static VectorType getCollapsedVecType(VectorType type, + ArrayRef<AffineMap> reassociation) { + assert(type.getNumScalableDims() < 2 && + "Collapsing more than 1 scalable dim is not supported ATM"); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + + auto shape = type.getShape(); + auto scalableFlags = type.getScalableDims(); + SmallVector<int64_t> newShape; + SmallVector<bool> newScalableFlags; + + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1; + bool flag = false; + for (unsigned d = 0; d < dim; ++d) { + size *= shape[currentDim + d]; + flag |= scalableFlags[currentDim + d]; + } + newShape.push_back(size); + newScalableFlags.push_back(flag); + currentDim += dim; + } + + return VectorType::get(newShape, type.getElementType(), newScalableFlags); +} + /// Vectorize a `linalg::UnPackOp` to these 4 Ops: /// Vector::TransferReadOp - Reads a vector from the source tensor /// vector::TransposeOp - Transpose the Source tensor @@ -1928,30 +1975,18 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, PackingMetadata packMetadata; SmallVector<int64_t> lastDimToInsertPosPerm = getUnPackInverseSrcPerm(unpackOp, packMetadata); - ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType()); - SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); - mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); - applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm); - RankedTensorType stripMineTensorType = - RankedTensorType::get(stripMineShape, stripMineElemType); // Transpose the appropriate rows to match output. vector::TransposeOp transposeOp = vector::TransposeOp::create( rewriter, loc, readResult, lastDimToInsertPosPerm); // Collapse the vector to the size required by result. - RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( - stripMineTensorType, packMetadata.reassociations); - mlir::VectorType vecCollapsedType = - VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); + VectorType collapsedVecType = getCollapsedVecType( + transposeOp.getType(), + getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + rewriter.getContext(), packMetadata.reassociations))); vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( - rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); - - // writeVectorSizes had to match the shapecast shape for dynamic sizes, - // otherwise the validator complains that the mask size is invalid. - SmallVector<int64_t> writeVectorSizes( - unpackOp.getDestType().hasStaticShape() - ? vectorSizes - : shapeCastOp.getResultVectorType().getShape()); + rewriter, loc, collapsedVecType, transposeOp->getResult(0)); + Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e73bdd3..485bb73 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -1375,6 +1375,21 @@ void acc::ParallelOp::addWaitOperands( setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums)); } +void acc::ParallelOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + getPrivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getPrivatizationRecipesAttr()) + llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + static ParseResult parseNumGangs( mlir::OpAsmParser &parser, llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, @@ -2011,6 +2026,21 @@ void acc::SerialOp::addWaitOperands( setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums)); } +void acc::SerialOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + getPrivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getPrivatizationRecipesAttr()) + llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -2957,6 +2987,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() { getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static); } +acc::LoopParMode +acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) { + if (hasSeq(deviceType)) + return LoopParMode::loop_seq; + if (hasAuto(deviceType)) + return LoopParMode::loop_auto; + if (hasIndependent(deviceType)) + return LoopParMode::loop_independent; + if (hasSeq()) + return LoopParMode::loop_seq; + if (hasAuto()) + return LoopParMode::loop_auto; + assert(hasIndependent() && + "loop must have default auto, seq, or independent"); + return LoopParMode::loop_independent; +} + void acc::LoopOp::addGangOperands( MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes, llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) { @@ -2997,6 +3044,21 @@ void acc::LoopOp::addGangOperands( } } +void acc::LoopOp::addPrivatization(MLIRContext *context, + mlir::acc::PrivateOp op, + mlir::acc::PrivateRecipeOp recipe) { + getPrivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getPrivatizationRecipesAttr()) + llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // DataOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 759e58b..0262a1b 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, if (parser.parseOptionalArrowTypeList(result.types)) return failure(); + if (succeeded(parser.parseOptionalKeyword("no_inline"))) + result.addAttribute("no_inline", parser.getBuilder().getUnitAttr()); + // Introduce the body region and parse it. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printOptionalArrowTypeList(getResultTypes()); - p << ' '; + if (getNoInline()) + p << "no_inline "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock()) + if (!op.getRegion().hasOneBlock() || op.getNoInline()) return failure(); replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index e27dc27..fcf4eb6 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -270,48 +270,6 @@ LogicalResult ConvertUToFOp::verify() { } //===----------------------------------------------------------------------===// -// spirv.INTELConvertBF16ToFOp -//===----------------------------------------------------------------------===// - -LogicalResult INTELConvertBF16ToFOp::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast<VectorType>(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.INTELConvertFToBF16Op -//===----------------------------------------------------------------------===// - -LogicalResult INTELConvertFToBF16Op::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast<VectorType>(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// // spirv.FConvertOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 9bee200..fcf1526 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations( // `!spirv.struct<` (id `,`)? // `(` // (spirv-type (`[` struct-member-decoration `]`)?)* -// `)>` +// `)` +// (`,` struct-decoration)? +// `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { // TODO: This function is quite lengthy. Break it down into smaller chunks. @@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect, return Type(); } - if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + if (failed(parser.parseRParen())) + return Type(); + + SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo; + + auto parseStructDecoration = [&]() { + std::optional<spirv::Decoration> decoration = + parseAndVerify<spirv::Decoration>(dialect, parser); + if (!decoration) + return failure(); + + // Parse decoration value if it exists. + if (succeeded(parser.parseOptionalEqual())) { + Attribute decorationValue; + if (failed(parser.parseAttribute(decorationValue))) + return failure(); + + structDecorationInfo.emplace_back(decoration.value(), decorationValue); + } else { + structDecorationInfo.emplace_back(decoration.value(), + UnitAttr::get(dialect.getContext())); + } + return success(); + }; + + while (succeeded(parser.parseOptionalComma())) + if (failed(parseStructDecoration())) + return Type(); + + if (failed(parser.parseGreater())) return Type(); if (!identifier.empty()) { if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, - memberDecorationInfo))) + memberDecorationInfo, + structDecorationInfo))) return Type(); return idStructTy; } - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo, + structDecorationInfo); } // spirv-type ::= array-type @@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) { }; llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, printMember); - os << ")>"; + os << ")"; + + SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations; + type.getStructDecorations(decorations); + if (!decorations.empty()) { + os << ", "; + auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); + } + }; + llvm::interleaveComma(decorations, os, eachFn); + } + + os << ">"; } static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 52c672a..f993398 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -767,19 +767,25 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( // spirv.EXTConstantCompositeReplicate //===----------------------------------------------------------------------===// +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() { - Type valueType; - if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) { - auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]); - if (!typedElemAttr) - return emitError("value attribute is not typed"); - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(getValue()); + if (!valueType) return emitError("unknown value attribute type"); - } auto compositeType = dyn_cast<spirv::CompositeType>(getType()); if (!compositeType) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 46739bc..ddb3426 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -835,12 +835,14 @@ void SampledImageType::getCapabilities( /// - for literal structs: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. /// /// Identified structures only have a mutable component consisting of: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { /// Construct a storage object for an identified struct type. A struct type /// associated with such storage must call StructType::trySetBody(...) later @@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage(StringRef identifier) : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), + numStructDecorations(0), structDecorationsInfo(nullptr), identifier(identifier) {} /// Construct a storage object for a literal struct type. A struct type @@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, - StructType::MemberDecorationInfo const *memberDecorationsInfo) + StructType::MemberDecorationInfo const *memberDecorationsInfo, + unsigned numStructDecorations, + StructType::StructDecorationInfo const *structDecorationsInfo) : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), + numStructDecorations(numStructDecorations), + structDecorationsInfo(structDecorationsInfo) {} /// A storage key is divided into 2 parts: /// - for identified structs: @@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - an ArrayRef<Type> for member types; /// - an ArrayRef<StructType::OffsetInfo> for member offset info; /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration + /// info; + /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration /// info. /// /// An identified struct type is uniqued only by the first part (field 0) /// of the key. /// - /// A literal struct type is uniqued only by the second part (fields 1, 2, and - /// 3) of the key. The identifier field (field 0) must be empty. + /// A literal struct type is uniqued only by the second part (fields 1, 2, 3 + /// and 4) of the key. The identifier field (field 0) must be empty. using KeyTy = std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, - ArrayRef<StructType::MemberDecorationInfo>>; + ArrayRef<StructType::MemberDecorationInfo>, + ArrayRef<StructType::StructDecorationInfo>>; /// For identified structs, return true if the given key contains the same /// identifier. @@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), - getMemberDecorationsInfo()); + getMemberDecorationsInfo(), getStructDecorationsInfo()); } /// If the given key contains a non-empty identifier, this method constructs @@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } - return new (allocator.allocate<StructTypeStorage>()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, - numMemberDecorations, memberDecorationList); + const StructType::StructDecorationInfo *structDecorationList = nullptr; + unsigned numStructDecorations = 0; + if (!std::get<4>(key).empty()) { + auto keyStructDecorations = std::get<4>(key); + numStructDecorations = keyStructDecorations.size(); + structDecorationList = allocator.copyInto(keyStructDecorations).data(); + } + + return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage( + keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, + memberDecorationList, numStructDecorations, structDecorationList); } ArrayRef<Type> getMemberTypes() const { @@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return {}; } + ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const { + if (structDecorationsInfo) + return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo, + numStructDecorations); + return {}; + } + StringRef getIdentifier() const { return identifier; } bool isIdentified() const { return !identifier.empty(); } @@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - If called for an identified struct whose body was set before (through a /// call to this method) but with different contents from the passed /// arguments. - LogicalResult mutate( - TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, - ArrayRef<StructType::OffsetInfo> structOffsetInfo, - ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, + ArrayRef<StructType::OffsetInfo> structOffsetInfo, + ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo, + ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) { if (!isIdentified()) return failure(); if (memberTypesAndIsBodySet.getInt() && (getMemberTypes() != structMemberTypes || getOffsetInfo() != structOffsetInfo || - getMemberDecorationsInfo() != structMemberDecorationInfo)) + getMemberDecorationsInfo() != structMemberDecorationInfo || + getStructDecorationsInfo() != structDecorationInfo)) return failure(); memberTypesAndIsBodySet.setInt(true); @@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { allocator.copyInto(structMemberDecorationInfo).data(); } + if (!structDecorationInfo.empty()) { + numStructDecorations = structDecorationInfo.size(); + structDecorationsInfo = allocator.copyInto(structDecorationInfo).data(); + } + return success(); } @@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + unsigned numStructDecorations; + StructType::StructDecorationInfo const *structDecorationsInfo; StringRef identifier; }; StructType StructType::get(ArrayRef<Type> memberTypes, ArrayRef<StructType::OffsetInfo> offsetInfo, - ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { + ArrayRef<StructType::MemberDecorationInfo> memberDecorations, + ArrayRef<StructType::StructDecorationInfo> structDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. - SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( + SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations( memberDecorations); - llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); + llvm::array_pod_sort(sortedMemberDecorations.begin(), + sortedMemberDecorations.end()); + SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations( + structDecorations); + llvm::array_pod_sort(sortedStructDecorations.begin(), + sortedStructDecorations.end()); + return Base::get(memberTypes.vec().front().getContext(), /*identifier=*/StringRef(), memberTypes, offsetInfo, - sortedDecorations); + sortedMemberDecorations, sortedStructDecorations); } StructType StructType::getIdentified(MLIRContext *context, @@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context, return Base::get(context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); } StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { StructType newStructType = Base::get( context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); // Set an empty body in case this is a identified struct. if (newStructType.isIdentified() && failed(newStructType.trySetBody( ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()))) + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()))) return StructType(); return newStructType; @@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const { bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasDecoration(spirv::Decoration decoration) const { + for (StructType::StructDecorationInfo info : + getImpl()->getStructDecorationsInfo()) + if (info.decoration == decoration) + return true; + + return false; +} + uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); return getImpl()->offsetInfo[index]; @@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations( } } +void StructType::getStructDecorations( + SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations) + const { + structDecorations.clear(); + auto implDecorations = getImpl()->getStructDecorationsInfo(); + structDecorations.append(implDecorations.begin(), implDecorations.end()); +} + LogicalResult StructType::trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo, - ArrayRef<MemberDecorationInfo> memberDecorations) { - return Base::mutate(memberTypes, offsetInfo, memberDecorations); + ArrayRef<MemberDecorationInfo> memberDecorations, + ArrayRef<StructDecorationInfo> structDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations, + structDecorations); } void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value( memberDecorationInfo.decoration); } +llvm::hash_code spirv::hash_value( + const StructType::StructDecorationInfo &structDecorationInfo) { + return llvm::hash_value(structDecorationInfo.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 81365b4..3911ec0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } auto varPtrType = cast<spirv::PointerType>(varType); - auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType()); + Type pointeeType = varPtrType.getPointeeType(); + + // Images are an opaque type and so we can just return a pointer to an image. + // Note that currently only sampled images are supported in the SPIR-V + // lowering. + if (isa<spirv::SampledImageType>(pointeeType)) + return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType, + varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); + + auto varPointeeType = cast<spirv::StructType>(pointeeType); // Set the offset information. varPointeeType = diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec019..8f4c4cc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + return std::nullopt; + } + if (auto complexType = dyn_cast<ComplexType>(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); + return nullptr; +} + +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type or emulation flag is set to false. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return type; + Type srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional<spirv::StorageClass> storageClass = {}) { type = cast<VectorType>(convertIndexElementType(type, options)); + type = cast<VectorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast<TensorType>(convertIndexElementType(type, options)); + type = cast<TensorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast<IndexType>(elementType)) { type = cast<MemRefType>(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + // Hnadle 8 bit float types. + type = cast<MemRefType>(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional<Type> { if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951..a53d0a7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + // Update min version requirement for capabilities after deducing them. + for (spirv::Capability cap : deducedCapabilities) { + if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + module.emitError("Capability '") + << spirv::stringifyCapability(cap) << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + return signalPassFailure(); + } + } + } + // TODO: verify that the deduced version is consistent with // SPIR-V ops' maximal version requirements. diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index e5a3b5d..08fccfa 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -38,7 +38,6 @@ #include <utility> #define DEBUG_TYPE "shard-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::shard; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 0e96b59..869d27a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -115,8 +115,7 @@ public: bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, + if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 6d2cbb5..e3cba388 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> { auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType()); auto inputElementType = inputType.getElementType(); - if (!inputType.hasStaticShape()) { - return failure(); - } - if (isa<FloatType>(inputElementType)) { // Unlike integer types, floating point types can represent infinity. - auto minClamp = + const auto minClamp = llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue(); - auto maxClamp = + const auto maxClamp = llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue(); - bool isMin = minClamp.isNegInfinity(); - bool isMax = maxClamp.isInfinity(); + const bool isMin = minClamp.isNegInfinity(); + const bool isMax = maxClamp.isInfinity(); if (isMin && isMax) { rewriter.replaceOp(op, input); @@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> { return failure(); } - if (inputElementType.isUnsignedInteger()) { - int64_t minClamp = - llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt(); - int64_t maxClamp = - llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt(); + // i1 types are boolean in TOSA + const bool isBoolean = inputElementType.isInteger(1); + if (inputElementType.isUnsignedInteger() || isBoolean) { + const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()) + .getValue() + .getZExtValue(); + const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()) + .getValue() + .getZExtValue(); - int64_t intMin = - APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - int64_t intMax = - APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); + const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth(); + const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue(); + const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); @@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> { } if (llvm::isa<IntegerType>(inputElementType)) { - int64_t minClamp = + const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt(); - int64_t maxClamp = + const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt(); - int64_t intMin = - APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - int64_t intMax = - APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); + const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth(); + const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue(); + const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 88b0f36..9543fa1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { CheckCondition condition = CheckCondition::invalid; const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + if (failed(maybeProfDef) && failed(maybeExtDef)) + return success(); - if (!failed(maybeProfDef) && !failed(maybeExtDef) && - !maybeProfDef.value().size() && !maybeExtDef.value().size()) { + const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->empty()); + if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); os << "illegal: operation operand/result data types did not align with any " diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 8ec7765..c7b9534 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1381,7 +1381,7 @@ void TosaValidation::runOnOperation() { // Some uses of TOSA rely on the constant operands of particular // operations. - if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op))) + if (failed(applyConstantOperandCheck(op))) signalPassFailure(); // do level checks diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bce358d..a450056 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -2533,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { return {}; } -/// Rewrite a vector.from_elements into a vector.splat if all elements are the -/// same SSA value. E.g.: -/// -/// %0 = vector.from_elements %a, %a, %a : vector<3xf32> -/// ==> rewrite to vector.splat %a : vector<3xf32> -static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, - PatternRewriter &rewriter) { +/// Rewrite vector.from_elements as vector.broadcast if the elements are the +/// same. Example: +/// %0 = vector.from_elements %a, %a, %a : vector<3xf32> +/// => +/// %0 = vector.broadcast %a : f32 to vector<3xf32> +static LogicalResult +rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) { if (!llvm::all_equal(fromElementsOp.getElements())) return failure(); - rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(), - fromElementsOp.getElements().front()); + rewriter.replaceOpWithNewOp<BroadcastOp>( + fromElementsOp, fromElementsOp.getType(), + fromElementsOp.getElements().front()); return success(); } @@ -2574,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - // Handled by `rewriteFromElementsAsSplat` + // Handled by `rewriteFromElementsAsBroadcast`. if (fromElements.getType().getNumElements() == 1) return failure(); @@ -2667,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(rewriteFromElementsAsSplat); + results.add(rewriteFromElementsAsBroadcast); results.add<FromElementsToShapeCast>(context); } @@ -3115,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { } }; -/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. +/// Consider the defining operation `defOp` of `value`. If `defOp` is a +/// vector.splat or a vector.broadcast with a scalar operand, return the scalar +/// value that is splatted. Otherwise return null. +/// +/// Examples: +/// +/// scalar_source --> vector.splat --> value - return scalar_source +/// scalar_source --> vector.broadcast --> value - return scalar_source +static Value getScalarSplatSource(Value value) { + // Block argument: + Operation *defOp = value.getDefiningOp(); + if (!defOp) + return {}; + + // Splat: + if (auto splat = dyn_cast<vector::SplatOp>(defOp)) + return splat.getInput(); + + auto broadcast = dyn_cast<vector::BroadcastOp>(defOp); + + // Not broadcast (and not splat): + if (!broadcast) + return {}; + + // Broadcast of a vector: + if (isa<VectorType>(broadcast.getSourceType())) + return {}; + + // Broadcast of a scalar: + return broadcast.getSource(); +} + +/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { - auto v1Splat = op.getV1().getDefiningOp<SplatOp>(); - auto v2Splat = op.getV2().getDefiningOp<SplatOp>(); - - if (!v1Splat || !v2Splat) - return failure(); - - if (v1Splat.getInput() != v2Splat.getInput()) + Value splat = getScalarSplatSource(op.getV1()); + if (!splat || getScalarSplatSource(op.getV2()) != splat) return failure(); - rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput()); + rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat); return success(); } }; @@ -3184,60 +3156,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// @@ -3341,23 +3259,19 @@ public: } }; -/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. +/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { - auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>(); - auto dstSplat = op.getDest().getDefiningOp<SplatOp>(); - - if (!srcSplat || !dstSplat) - return failure(); - if (srcSplat.getInput() != dstSplat.getInput()) + Value splat = getScalarSplatSource(op.getValueToStore()); + if (!splat || getScalarSplatSource(op.getDest()) != splat) return failure(); - rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput()); + rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat); return success(); } }; @@ -3625,8 +3539,7 @@ LogicalResult InsertStridedSliceOp::verify() { } namespace { -/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, -/// SplatOp(X):dst_type) to SplatOp(X):dst_type. +/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v. class FoldInsertStridedSliceSplat final : public OpRewritePattern<InsertStridedSliceOp> { public: @@ -3634,18 +3547,13 @@ public: LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { - auto srcSplatOp = - insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>(); - auto destSplatOp = - insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>(); - if (!srcSplatOp || !destSplatOp) + auto dst = insertStridedSliceOp.getDest(); + auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore()); + if (!splat || getScalarSplatSource(dst) != splat) return failure(); - if (srcSplatOp.getInput() != destSplatOp.getInput()) - return failure(); - - rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); + rewriter.replaceOp(insertStridedSliceOp, dst); return success(); } }; @@ -4300,17 +4208,18 @@ public: } }; -/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. +/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto splat = op.getVector().getDefiningOp<SplatOp>(); + + Value splat = getScalarSplatSource(op.getVector()); if (!splat) return failure(); - rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput()); + rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat); return success(); } }; @@ -6027,14 +5936,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { } // shape_cast(constant) -> constant - if (auto splatAttr = - llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) - return splatAttr.reshape(getType()); + if (auto denseAttr = + dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource())) + return denseAttr.reshape(getType()); // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) { + if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) return ub::PoisonAttr::get(getContext()); - } return {}; } @@ -6427,6 +6335,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { return llvm::to_vector<4>(getResultVectorType().getShape()); } +void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. @@ -6461,19 +6374,19 @@ public: } }; -// Folds transpose(splat x : src_type) : res_type into splat x : res_type. +/// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { - auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>(); - if (!splatOp) + Value splat = getScalarSplatSource(transposeOp.getVector()); + if (!splat) return failure(); - rewriter.replaceOpWithNewOp<vector::SplatOp>( - transposeOp, transposeOp.getResultVectorType(), splatOp.getInput()); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>( + transposeOp, transposeOp.getResultVectorType(), splat); return success(); } }; @@ -7224,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { return SplatElementsAttr::get(getType(), {constOperand}); } +// Canonicalizer for vector.splat. It always gets canonicalized to a +// vector.broadcast. +class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { +public: + using OpRewritePattern<SplatOp>::OpRewritePattern; + LogicalResult matchAndRewrite(SplatOp splatOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), + splatOp.getOperand()); + return success(); + } +}; +void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<SplatToBroadcastPattern>(context); +} + void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), argRanges.front()); @@ -7309,6 +7239,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast<VectorType>(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + +//===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index cb8e566..dedc3b3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -28,7 +28,10 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Progressive lowering of BroadcastOp. + +/// Convert a vector.broadcast with a vector operand to a lower rank +/// vector.broadcast. vector.broadcast with a scalar operand is expected to be +/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: using OpRewritePattern::OpRewritePattern; @@ -40,20 +43,23 @@ public: VectorType srcType = dyn_cast<VectorType>(op.getSourceType()); Type eltType = dstType.getElementType(); - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); - return success(); - } + // A broadcast from a scalar is considered to be in the lowered form. + if (!srcType) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar already in lowered form"); // Determine rank of source and destination. int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + // Here we are broadcasting to a rank-1 vector. Ensure that the source is a + // scalar. if (srcRank <= 1 && dstRank == 1) { - Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); + SmallVector<int64_t> fullRankPosition(srcRank, 0); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), + fullRankPosition); + assert(!isa<VectorType>(ext.getType()) && "expected scalar"); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4baeb11..2cf8f0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = vector::SplatOp::create( + Value fill = vector::BroadcastOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); res = vector::MaskedLoadOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 72352d7..cbb9d4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -303,7 +303,7 @@ public: // Extract/insert on a lower ranked extract strided slice op. Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, dstType, zero); + Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 48d680c..c707f38 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -25,12 +25,10 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-transfer-opt" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") - using namespace mlir; /// Return the ancestor op in the region or nullptr if the region is not @@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) { /// transfer_write is dead if all reads that can be reached from the potentially /// dead transfer_write are dominated by the overwriting transfer_write. void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { - LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() - << "\n"); + LDBG() << "Candidate for dead store: " << *write.getOperation(); llvm::SmallVector<Operation *, 8> blockingAccesses; Operation *firstOverwriteCandidate = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase())); @@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { !isReachable(writeAncestor, accessAncestor)) continue; if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { - LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " - << *accessAncestor << "\n"); + LDBG() << "Store may not be dead due to op: " << *accessAncestor; return; } } - LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() - << " overwritten by: " << *firstOverwriteCandidate << "\n"); + LDBG() << "Found dead store: " << *write.getOperation() + << " overwritten by: " << *firstOverwriteCandidate; opToErase.push_back(write.getOperation()); } @@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (read.hasOutOfBoundsDim()) return; - LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() - << "\n"); + LDBG() << "Candidate for Forwarding: " << *read.getOperation(); SmallVector<Operation *, 8> blockingWrites; vector::TransferWriteOp lastwrite = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase())); @@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) continue; if (!postDominators.postDominates(lastwrite, write)) { - LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " - << *write << "\n"); + LDBG() << "Fail to do write to read forwarding due to op: " << *write; return; } } - LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() - << " to: " << *read.getOperation() << "\n"); + LDBG() << "Forward value from " << *lastwrite.getOperation() + << " to: " << *read.getOperation(); read.replaceAllUsesWith(lastwrite.getVector()); opToErase.push_back(read.getOperation()); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 8de87fe..2269a40 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -939,7 +939,7 @@ public: Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, castDstType, zero); + Value res = BroadcastOp::create(rewriter, loc, castDstType, zero); SmallVector<int64_t> sliceShape = {castDstLastDim}; SmallVector<int64_t> strides = {1}; @@ -965,6 +965,45 @@ private: std::function<bool(BitCastOp)> controlFn; }; +static bool haveSameShapeAndScaling(Type t, Type u) { + auto tVec = dyn_cast<VectorType>(t); + auto uVec = dyn_cast<VectorType>(u); + if (!tVec) { + return !uVec; + } + if (!uVec) { + return false; + } + return tVec.getShape() == uVec.getShape() && + tVec.getScalableDims() == uVec.getScalableDims(); +} + +/// If `type` is shaped, clone it with `newElementType`. Otherwise, +/// return `newElementType`. +static Type cloneOrReplace(Type type, Type newElementType) { + if (auto shapedType = dyn_cast<ShapedType>(type)) { + return shapedType.clone(newElementType); + } + return newElementType; +} + +/// If `value` is the result of a splat or broadcast operation, return the input +/// of the splat/broadcast operation. +static Value getBroadcastLikeSource(Value value) { + + Operation *op = value.getDefiningOp(); + if (!op) + return {}; + + if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) + return broadcast.getSource(); + + if (auto splat = dyn_cast<vector::SplatOp>(op)) + return splat.getInput(); + + return {}; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -988,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); - if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) + auto resultType = dyn_cast<VectorType>(op->getResult(0).getType()); + if (!resultType) return failure(); if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure( op, "Op doesn't have ElementwiseMappableTraits"); if (op->getNumOperands() == 0) return failure(); - if (op->getResults()[0].getType() != op->getOperand(0).getType()) - return rewriter.notifyMatchFailure(op, - "result and operand type mismatch"); if (isa<vector::FMAOp>(op)) { return rewriter.notifyMatchFailure( op, @@ -1005,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } - // Get the type of the lhs operand - auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); - if (!lhsBcastOrSplat || - !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) + Type resultElemType = resultType.getElementType(); + + // Get the type of the first non-constant operand + Value splatSource; + for (Value operand : op->getOperands()) { + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + if (definingOp->hasTrait<OpTrait::ConstantLike>()) + continue; + splatSource = getBroadcastLikeSource(operand); + break; + } + if (!splatSource) return failure(); - auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); + Type unbroadcastResultType = + cloneOrReplace(splatSource.getType(), resultElemType); - // Make sure that all operands are broadcast from identical types: + // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { - auto bcast = val.getDefiningOp<vector::BroadcastOp>(); - if (bcast) - return (bcast.getOperand().getType() == lhsBcastOrSplatType); - auto splat = val.getDefiningOp<vector::SplatOp>(); - if (splat) - return (splat.getOperand().getType() == lhsBcastOrSplatType); - return false; + if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (auto source = getBroadcastLikeSource(val)) + return haveSameShapeAndScaling(source.getType(), + splatSource.getType()); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); })) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "not all operands are constants or broadcasts from the same type"); } // Collect the source values before broadcasting SmallVector<Value> srcValues; srcValues.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + SplatElementsAttr splatConst; + if (matchPattern(operand, m_Constant(&splatConst))) { + Attribute newConst; + Type elementType = getElementTypeOrSelf(operand.getType()); + Type newType = cloneOrReplace(unbroadcastResultType, elementType); + if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) { + newConst = splatConst.resizeSplat(newTypeShaped); + } else { + newConst = splatConst.getSplatValue<Attribute>(); + } + Operation *newConstOp = + operand.getDefiningOp()->getDialect()->materializeConstant( + rewriter, newConst, newType, operand.getLoc()); + srcValues.push_back(newConstOp->getResult(0)); + } else { + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + } } // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - lhsBcastOrSplatType, op->getAttrs()); + unbroadcastResultType, op->getAttrs()); // Replace the original Op with the elementwise Op - auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - op, vectorType, elementwiseOp->getResults()); + op, resultType, elementwiseOp->getResults()); return success(); } @@ -1239,15 +1302,17 @@ public: return rewriter.notifyMatchFailure( op, "only 1-element vectors are supported"); - Operation *splat = op.getValueToStore().getDefiningOp(); - if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) - return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + Value toStore = op.getValueToStore(); + Value source = getBroadcastLikeSource(toStore); + if (!source) + return rewriter.notifyMatchFailure( + op, "value to store is not from a broadcast"); // Checking for single use so we can remove splat. + Operation *splat = toStore.getDefiningOp(); if (!splat->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); - Value source = splat->getOperand(0); Value base = op.getBase(); ValueRange indices = op.getIndices(); @@ -1297,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o); indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deea..33450f3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref<InFlightDiagnostic()> emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(srcTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(destTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index ec8fad4..c793b71 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad29..de52fbd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" @@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, return failure(); }); if (failed(verify(op))) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << ": '" << op->getName() - << "' failed to verify and will be printed in generic form\n"); + LDBG() << op->getName() + << "' failed to verify and will be printed in generic form"; printerFlags.printGenericOpForm(); } diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 3e33795..776b5c6 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i) (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1)); - // Register a handler to verify the diagnostics. - setHandler([&](Diagnostic &diag) { - // Process the main diagnostics. - process(diag); - - // Process each of the notes. - for (auto ¬e : diag.getNotes()) - process(note); - }); + registerInContext(ctx); } SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( @@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() { return impl->status; } +void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) { + ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) { + // Process the main diagnostics. + process(diag); + + // Process each of the notes. + for (auto ¬e : diag.getNotes()) + process(note); + }); +} + /// Process a single diagnostic. void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) { return process(diag.getLocation(), diag.str(), diag.getSeverity()); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e9b5e92..310680b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -17,14 +17,32 @@ using namespace mlir; +static std::pair<int64_t, int64_t> +getLineAndColStart(const llvm::SourceMgr &sourceMgr) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + if (lastFileID == 1) + return {0, 0}; + + auto bufferID = sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + // Exclude same start. + if (main->getBufferStart() < last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + return sourceMgr.getLineAndColumn( + llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID); + } + return {0, 0}; +} + LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(*sourceBuf, block, config); @@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, const auto *sourceBuf = sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(*sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(sourceMgr, block, config); diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp new file mode 100644 index 0000000..7a345ed --- /dev/null +++ b/mlir/lib/RegisterAllDialects.cpp @@ -0,0 +1,207 @@ +//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialects and +// passes to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllDialects.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVM/ROCDL/Target.h" +#include "mlir/Target/SPIRV/Target.h" + +/// Add all the MLIR dialects to the provided registry. +void mlir::registerAllDialects(DialectRegistry ®istry) { + // clang-format off + registry.insert<acc::OpenACCDialect, + affine::AffineDialect, + amdgpu::AMDGPUDialect, + amx::AMXDialect, + arith::ArithDialect, + arm_neon::ArmNeonDialect, + arm_sme::ArmSMEDialect, + arm_sve::ArmSVEDialect, + async::AsyncDialect, + bufferization::BufferizationDialect, + cf::ControlFlowDialect, + complex::ComplexDialect, + DLTIDialect, + emitc::EmitCDialect, + func::FuncDialect, + gpu::GPUDialect, + index::IndexDialect, + irdl::IRDLDialect, + linalg::LinalgDialect, + LLVM::LLVMDialect, + math::MathDialect, + memref::MemRefDialect, + shard::ShardDialect, + ml_program::MLProgramDialect, + mpi::MPIDialect, + nvgpu::NVGPUDialect, + NVVM::NVVMDialect, + omp::OpenMPDialect, + pdl::PDLDialect, + pdl_interp::PDLInterpDialect, + ptr::PtrDialect, + quant::QuantDialect, + ROCDL::ROCDLDialect, + scf::SCFDialect, + shape::ShapeDialect, + smt::SMTDialect, + sparse_tensor::SparseTensorDialect, + spirv::SPIRVDialect, + tensor::TensorDialect, + tosa::TosaDialect, + transform::TransformDialect, + ub::UBDialect, + vector::VectorDialect, + x86vector::X86VectorDialect, + xegpu::XeGPUDialect, + xevm::XeVMDialect>(); + // clang-format on + + // Register all external models. + affine::registerValueBoundsOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + arith::registerShardingInterfaceExternalModels(registry); + arith::registerValueBoundsOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + builtin::registerCastOpInterfaceExternalModels(registry); + cf::registerBufferizableOpInterfaceExternalModels(registry); + cf::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerValueBoundsOpInterfaceExternalModels(registry); + LLVM::registerInlinerInterface(registry); + NVVM::registerInlinerInterface(registry); + linalg::registerAllDialectInterfaceImplementations(registry); + linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerAllocationOpInterfaceExternalModels(registry); + memref::registerBufferViewFlowOpInterfaceExternalModels(registry); + memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerValueBoundsOpInterfaceExternalModels(registry); + memref::registerMemorySlotExternalModels(registry); + ml_program::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferDeallocationOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerValueBoundsOpInterfaceExternalModels(registry); + shape::registerBufferizableOpInterfaceExternalModels(registry); + sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); + tosa::registerShardingInterfaceExternalModels(registry); + vector::registerBufferizableOpInterfaceExternalModels(registry); + vector::registerSubsetOpInterfaceExternalModels(registry); + vector::registerValueBoundsOpInterfaceExternalModels(registry); + NVVM::registerNVVMTargetInterfaceExternalModels(registry); + ROCDL::registerROCDLTargetInterfaceExternalModels(registry); + spirv::registerSPIRVTargetInterfaceExternalModels(registry); +} + +/// Append all the MLIR dialects to the registry contained in the given context. +void mlir::registerAllDialects(MLIRContext &context) { + DialectRegistry registry; + registerAllDialects(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp new file mode 100644 index 0000000..8f7c67c --- /dev/null +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -0,0 +1,115 @@ +//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialect +// extensions to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllExtensions.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" +#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" +#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" + +/// This function may be called to register all MLIR dialect extensions with the +/// provided registry. +/// If you're building a compiler, you generally shouldn't use this: you would +/// individually register the specific extensions that are useful for the +/// pipelines and transformations you are using. +void mlir::registerAllExtensions(DialectRegistry ®istry) { + // Register all conversions to LLVM extensions. + registerConvertArithToEmitCInterface(registry); + arith::registerConvertArithToLLVMInterface(registry); + registerConvertComplexToLLVMInterface(registry); + cf::registerConvertControlFlowToLLVMInterface(registry); + func::registerAllExtensions(registry); + tensor::registerAllExtensions(registry); + registerConvertFuncToEmitCInterface(registry); + registerConvertFuncToLLVMInterface(registry); + index::registerConvertIndexToLLVMInterface(registry); + registerConvertMathToLLVMInterface(registry); + mpi::registerConvertMPIToLLVMInterface(registry); + registerConvertMemRefToEmitCInterface(registry); + registerConvertMemRefToLLVMInterface(registry); + registerConvertNVVMToLLVMInterface(registry); + registerConvertOpenMPToLLVMInterface(registry); + registerConvertSCFToEmitCInterface(registry); + ub::registerConvertUBToLLVMInterface(registry); + registerConvertAMXToLLVMInterface(registry); + gpu::registerConvertGpuToLLVMInterface(registry); + NVVM::registerConvertGpuToNVVMInterface(registry); + vector::registerConvertVectorToLLVMInterface(registry); + registerConvertXeVMToLLVMInterface(registry); + + // Register all transform dialect extensions. + affine::registerTransformDialectExtension(registry); + bufferization::registerTransformDialectExtension(registry); + dlti::registerTransformDialectExtension(registry); + func::registerTransformDialectExtension(registry); + gpu::registerTransformDialectExtension(registry); + linalg::registerTransformDialectExtension(registry); + memref::registerTransformDialectExtension(registry); + nvgpu::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); + sparse_tensor::registerTransformDialectExtension(registry); + tensor::registerTransformDialectExtension(registry); + transform::registerDebugExtension(registry); + transform::registerIRDLExtension(registry); + transform::registerLoopExtension(registry); + transform::registerPDLExtension(registry); + transform::registerTuneExtension(registry); + vector::registerTransformDialectExtension(registry); + arm_neon::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); + + // Translation extensions need to be registered by calling + // `registerAllToLLVMIRTranslations` (see All.h). +} diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp new file mode 100644 index 0000000..1ed3a37 --- /dev/null +++ b/mlir/lib/RegisterAllPasses.cpp @@ -0,0 +1,99 @@ +//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all passes to the +// system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllPasses.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" + +// This function may be called to register the MLIR passes with the +// global registry. +// If you're building a compiler, you likely don't need this: you would build a +// pipeline programmatically without the need to register with the global +// registry, since it would already be calling the creation routine of the +// individual passes. +// The global registry is interesting to interact with the command-line tools. +void mlir::registerAllPasses() { + // General passes + registerTransformsPasses(); + + // Conversion passes + registerConversionPasses(); + + // Dialect passes + acc::registerOpenACCPasses(); + affine::registerAffinePasses(); + amdgpu::registerAMDGPUPasses(); + registerAsyncPasses(); + arith::registerArithPasses(); + bufferization::registerBufferizationPasses(); + func::registerFuncPasses(); + registerGPUPasses(); + registerLinalgPasses(); + registerNVGPUPasses(); + registerSparseTensorPasses(); + LLVM::registerLLVMPasses(); + math::registerMathPasses(); + memref::registerMemRefPasses(); + shard::registerShardPasses(); + ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); + registerSCFPasses(); + registerShapePasses(); + spirv::registerSPIRVPasses(); + tensor::registerTensorPasses(); + tosa::registerTosaOptPasses(); + transform::registerTransformPasses(); + vector::registerVectorPasses(); + arm_sme::registerArmSMEPasses(); + arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); + xegpu::registerXeGPUPasses(); + + // Dialect pipelines + bufferization::registerBufferizationPipelines(); + sparse_tensor::registerSparseTensorPipelines(); + tosa::registerTosaToLinalgPipelines(); + gpu::registerGPUToNVVMPipeline(); +} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index b2b372b..e13bcff 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -13,7 +13,7 @@ #include "mlir/Rewrite/PatternApplicator.h" #include "ByteCode.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #ifndef NDEBUG #include "llvm/ADT/ScopeExit.h" @@ -51,9 +51,7 @@ static Operation *getDumpRootOp(Operation *op) { return op; } static void logSucessfulPatternApplication(Operation *op) { - llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n"; - op->dump(); - llvm::dbgs() << "\n\n"; + LDBG(2) << "// *** IR Dump After Pattern Application ***\n" << *op << "\n"; } #endif @@ -208,8 +206,8 @@ LogicalResult PatternApplicator::matchAndRewrite( result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); } else { - LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" - << bestPattern->getDebugName() << "\"\n"); + LDBG() << "Trying to match \"" << bestPattern->getDebugName() + << "\""; const auto *pattern = static_cast<const RewritePattern *>(bestPattern); @@ -223,9 +221,8 @@ LogicalResult PatternApplicator::matchAndRewrite( [&] { rewriter.setListener(oldListener); }); #endif result = pattern->matchAndRewrite(op, rewriter); - LLVM_DEBUG(llvm::dbgs() - << "\"" << bestPattern->getDebugName() << "\" result " - << succeeded(result) << "\n"); + LDBG() << " -> matchAndRewrite " + << (succeeded(result) ? "successful" : "failed"); } // Process the result of the pattern application. diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp index 748f928..2cf33eb 100644 --- a/mlir/lib/Support/ToolUtilities.cpp +++ b/mlir/lib/Support/ToolUtilities.cpp @@ -14,6 +14,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include <string> +#include <utility> using namespace mlir; @@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker, llvm::StringRef outputSplitMarker) { + llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef(); // If splitting is disabled, we process the full input buffer. if (inputSplitMarker.empty()) - return processChunkBuffer(std::move(originalBuffer), os); + return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os); const int inputSplitMarkerLen = inputSplitMarker.size(); - auto *origMemBuffer = originalBuffer.get(); SmallVector<StringRef, 8> rawSourceBuffers; const int checkLen = 2; // Split dropping the last checkLen chars to enable flagging near misses. - origMemBuffer->getBuffer().split(rawSourceBuffers, - inputSplitMarker.drop_back(checkLen)); + originalBufferRef.getBuffer().split(rawSourceBuffers, + inputSplitMarker.drop_back(checkLen)); if (rawSourceBuffers.empty()) return success(); @@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, auto interleaveFn = [&](StringRef subBuffer) { auto splitLoc = SMLoc::getFromPointer(subBuffer.data()); unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first; - auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy( - subBuffer, Twine("within split at ") + - origMemBuffer->getBufferIdentifier() + ":" + - Twine(splitLine) + " offset "); - if (failed(processChunkBuffer(std::move(subMemBuffer), os))) + std::string name((Twine("within split at ") + + originalBufferRef.getBufferIdentifier() + ":" + + Twine(splitLine) + " offset ") + .str()); + // Use MemoryBufferRef to avoid copying the buffer & keep at same location + // relative to the original buffer. + auto subMemBuffer = + llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name), + /*RequiresNullTerminator=*/false); + if (failed( + processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os))) hadFailure = true; }; llvm::interleave(sourceBuffers, os, interleaveFn, @@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, // If any fails, then return a failure of the tool. return failure(hadFailure); } + +LogicalResult +mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, llvm::StringRef inputSplitMarker, + llvm::StringRef outputSplitMarker) { + auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &, raw_ostream &os) { + return processChunkBuffer(std::move(chunkBuffer), os); + }; + return splitAndProcessBuffer(std::move(originalBuffer), process, os, + inputSplitMarker, outputSplitMarker); +} diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index dcd2e11..8e83e45 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -333,7 +333,8 @@ private: /// Determine whether expression \p op should be emitted in a deferred way. static bool hasDeferredEmission(Operation *op) { return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp, - emitc::MemberOfPtrOp, emitc::SubscriptOp>(op); + emitc::MemberOfPtrOp, emitc::SubscriptOp, + emitc::GetFieldOp>(op); } /// Determine whether expression \p expressionOp should be emitted inline, i.e. @@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) { static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) { raw_ostream &os = emitter.ostream(); - if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType()))) + if (failed(emitter.emitVariableDeclaration( + fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName()))) return failure(); - os << " " << fieldOp.getSymName() << ";"; - return success(); -} - -static LogicalResult printOperation(CppEmitter &emitter, - GetFieldOp getFieldOp) { - raw_indented_ostream &os = emitter.ostream(); - - Value result = getFieldOp.getResult(); - if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType()))) - return failure(); - os << " "; - if (failed(emitter.emitOperand(result))) - return failure(); - os << " = "; + std::optional<Attribute> initialValue = fieldOp.getInitialValue(); + if (initialValue) { + os << " = "; + if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue))) + return failure(); + } - os << getFieldOp.getFieldName().str(); + os << ";"; return success(); } @@ -1204,7 +1197,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ") {\n"; if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) return failure(); - os << "}\n"; + os << "}"; return success(); } @@ -1245,7 +1238,7 @@ static LogicalResult printOperation(CppEmitter &emitter, os << ") {\n"; if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) return failure(); - os << "}\n"; + os << "}"; return success(); } @@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp, - emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp, - emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp, - emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, - emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, + emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, + emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, + emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. @@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { cacheDeferredOpResult(op.getResult(), op.getName()); return success(); }) + .Case<emitc::GetFieldOp>([&](auto op) { + cacheDeferredOpResult(op.getResult(), op.getFieldName()); + return success(); + }) .Case<emitc::LiteralOp>([&](auto op) { cacheDeferredOpResult(op.getResult(), op.getValue()); return success(); diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index af22a7f..9ea5c683 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRROCDLToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation + MLIRXeVMToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index f030fa7..86c731a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) +add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index ff34a08..0f675a0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands, return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); } -static LogicalResult -convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, - ArrayAttr resAttrsArray, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - if (argAttrsArray) { - for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr); - !argAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, argAttrs); - if (failed(attrBuilder)) - return failure(); - call->addParamAttrs(argIdx, *attrBuilder); - } - } - } - - if (resAttrsArray && resAttrsArray.size() > 0) { - if (resAttrsArray.size() != 1) - return mlir::emitError(loc, "llvm.func cannot have multiple results"); - if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); - !resAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, resAttrs); - if (failed(attrBuilder)) - return failure(); - call->addRetAttrs(*attrBuilder); - } - } - return success(); -} - -static LogicalResult -convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - return convertParameterAndResultAttrs( - callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call, - moduleTranslation); -} - /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, @@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), moduleTranslation)); - if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(), - op.getResAttrsAttr(), inst, - moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst))) return failure(); if (op.getNumResults() == 1) @@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getInlineHintAttr()) call->addFnAttr(llvm::Attribute::InlineHint); - if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call))) return failure(); if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { @@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); - if (failed( - convertParameterAndResultAttrs(invOp, result, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result))) return failure(); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index 1c9e226..55e73e8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" +#include "llvm/IR/ConstantRange.h" using namespace mlir; using namespace mlir::NVVM; diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index b3577c6..90462d1 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, } } +/// Return the intrinsic ID associated with stmatrix for the given paramters. +static llvm::Intrinsic::ID +getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { + switch (num) { + case 1: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + } + } else if (shape.getM() == 16 && shape.getN() == 8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + case 4: + return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + } + } + llvm_unreachable("unknown stmatrix kind"); +} + /// Return the intrinsic ID associated with st.bulk for the given address type. static llvm::Intrinsic::ID getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9f18199..49e1e55 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, llvm::SmallVector<size_t> indices(indexAttr.size()); std::iota(indices.begin(), indices.end(), 0); - llvm::sort(indices.begin(), indices.end(), - [&](const size_t a, const size_t b) { - auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); - auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); - for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { - int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); - int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); - - if (aIndex == bIndex) - continue; - - if (aIndex < bIndex) - return first; - - if (aIndex > bIndex) - return !first; - } - - // Iterated the up until the end of the smallest member and - // they were found to be equal up to that point, so select - // the member with the lowest index count, so the "parent" - return memberIndicesA.size() < memberIndicesB.size(); - }); + llvm::sort(indices, [&](const size_t a, const size_t b) { + auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); + auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); + for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { + int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); + int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); + + if (aIndex == bIndex) + continue; + + if (aIndex < bIndex) + return first; + + if (aIndex > bIndex) + return !first; + } + + // Iterated the up until the end of the smallest member and + // they were found to be equal up to that point, so select + // the member with the lowest index count, so the "parent" + return memberIndicesA.size() < memberIndicesB.size(); + }); return llvm::cast<omp::MapInfoOp>( mapInfo.getMembers()[indices.front()].getDefiningOp()); diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt new file mode 100644 index 0000000..6308d7e --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + XeVMToLLVMIRTranslation.cpp +) + +add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation + XeVMToLLVMIRTranslation.cpp + + DEPENDS + MLIRXeVMConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRLLVMDialect + MLIRXeVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..73b166d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -0,0 +1,103 @@ +//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR XeVM dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the XeVM dialect to LLVM IR. +class XeVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(instructions.front(), + cacheControlsArray.getValue()); + } + auto func = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!func) + return failure(); + + return success(); + } + +private: + static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst, + ArrayRef<Attribute> attrs) { + SmallVector<llvm::Metadata *> decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue(); + std::array<llvm::Metadata *, 4> metadata; + llvm::transform( + valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + i32Ty, cast<IntegerAttr>(valueAttr).getValue())); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } +}; +} // namespace + +void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry ®istry) { + registry.insert<xevm::XeVMDialect>(); + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) { + DialectRegistry registry; + registerXeVMDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp index 580afdd..cb1f234 100644 --- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp @@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs))) + llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false, + /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands, + mlirAttrs))) return failure(); Type resultType = moduleImport.convertType(inst->getType()); @@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( ValueRange{mlirOperands}, FastmathFlagsAttr{}); moduleImport.setFastmathFlagsAttr(inst, op); - - ArrayAttr argsAttr, resAttr; - moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder); - op.setArgAttrsAttr(argsAttr); - op.setResAttrsAttr(resAttr); + moduleImport.convertArgAndResultAttrs(inst, op); // Update importer tracking of results. unsigned numRes = op.getNumResults(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 58e3c44..6325480 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" @@ -1063,6 +1064,18 @@ void ModuleImport::convertTargetTriple() { builder.getStringAttr(llvmModule->getTargetTriple().str())); } +void ModuleImport::convertModuleLevelAsm() { + llvm::StringRef asmStr = llvmModule->getModuleInlineAsm(); + llvm::SmallVector<mlir::Attribute> asmArrayAttr; + + for (llvm::StringRef line : llvm::split(asmStr, '\n')) + if (!line.empty()) + asmArrayAttr.push_back(builder.getStringAttr(line)); + + mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(), + builder.getArrayAttr(asmArrayAttr)); +} + LogicalResult ModuleImport::convertFunctions() { for (llvm::Function &func : llvmModule->functions()) if (failed(processFunction(&func))) @@ -2267,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // call. if (!isIncompatibleCall) - convertParameterAttributes(callInst, callOp, builder); + convertArgAndResultAttrs(callInst, callOp); return callOp.getOperation(); }(); @@ -2364,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // invoke. if (!isIncompatibleInvoke) - convertParameterAttributes(invokeInst, invokeOp, builder); + convertArgAndResultAttrs(invokeInst, invokeOp); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); @@ -2730,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, } DictionaryAttr -ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder) { +ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) { SmallVector<NamedAttribute> paramAttrs; for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { - auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind); + auto llvmAttr = llvmAttrSet.getAttribute(llvmKind); // Skip attributes that are not attached. if (!llvmAttr.isValid()) continue; @@ -2769,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, return builder.getDictionaryAttr(paramAttrs); } -void ModuleImport::convertParameterAttributes(llvm::Function *func, - LLVMFuncOp funcOp, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs(llvm::Function *func, + LLVMFuncOp funcOp) { auto llvmAttrs = func->getAttributes(); for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) { llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i); - funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder)); + funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs)); } // Convert the result attributes and attach them wrapped in an ArrayAttribute // to the funcOp. @@ -2783,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, if (!llvmResAttr.hasAttributes()) return; funcOp.setResAttrsAttr( - builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); + builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)})); } -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - ArrayAttr &argsAttr, - ArrayAttr &resAttr, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs( + llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions) { + // Compute the set of immediate argument positions. + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + // Convert the argument attributes and filter out immediate arguments. llvm::AttributeList llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + // Skip immediate arguments. + if (immArgPositionsSet.contains(i)) + continue; llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); if (llvmArgAttrsSet.back().hasAttributes()) anyArgAttrs = true; @@ -2807,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, if (anyArgAttrs) { SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) - argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - argsAttr = getArrayAttr(argAttrs); + argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs)); + attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs)); } + // Convert the result attributes. llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); if (!llvmResAttr.hasAttributes()) return; - DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder); - resAttr = getArrayAttr({resAttrs}); -} - -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - CallOpInterface callOp, - OpBuilder &builder) { - ArrayAttr argsAttr, resAttr; - convertParameterAttributes(call, argsAttr, resAttr, builder); - callOp.setArgAttrsAttr(argsAttr); - callOp.setResAttrsAttr(resAttr); + DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr); + attrsOp.setResAttrsAttr(getArrayAttr({resAttrs})); } template <typename Op> @@ -2892,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { builder, loc, func->getName(), functionType, convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); - convertParameterAttributes(func, funcOp, builder); + convertArgAndResultAttrs(func, funcOp); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func)) funcOp.setPersonalityAttr(personality); @@ -3199,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule( if (failed(moduleImport.convertIFuncs())) return {}; moduleImport.convertTargetTriple(); + moduleImport.convertModuleLevelAsm(); return module; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b997e55..b3a06e2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +LogicalResult ModuleTranslation::convertArgAndResultAttrs( + ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions) { + // Convert the argument attributes. + if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) { + unsigned argAttrIdx = 0; + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) { + if (argAttrIdx >= argAttrsArray.size()) + break; + // Skip immediate arguments (they have no entries in argAttrsArray). + if (immArgPositionsSet.contains(argIdx)) + continue; + // Skip empty argument attributes. + auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]); + if (argAttrs.empty()) + continue; + // Convert and add attributes to the call instruction. + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + // Convert the result attributes. + if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) { + if (!resAttrsArray.empty()) { + auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + } + + return success(); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(Location loc, DictionaryAttr paramAttrs) { @@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, llvmModule->setTargetTriple( llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue())); + if (auto asmAttr = m->getDiscardableAttr( + LLVM::LLVMDialect::getModuleLevelAsmAttrName())) { + auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr); + if (!asmArrayAttr) { + m->emitError("expected an array attribute for a module level asm"); + return nullptr; + } + + for (Attribute elt : asmArrayAttr) { + auto asmStrAttr = dyn_cast<StringAttr>(elt); + if (!asmStrAttr) { + m->emitError( + "expected a string attribute for each entry of a module level asm"); + return nullptr; + } + llvmModule->appendModuleInlineAsm(asmStrAttr.getValue()); + } + } + return llvmModule; } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index e5934bb..d0ae513 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; } - // Block decoration does not affect spirv.struct type, but is still stored - // for verification. - // TODO: Update StructType to contain this information since - // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: @@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, - deferredStructIt->memberDecorationsInfo))) + deferredStructIt->memberDecorationsInfo, + deferredStructIt->structDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); @@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { } } + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; + if (decorations.count(operands[0])) { + NamedAttrList &allDecorations = decorations[operands[0]]; + for (NamedAttribute &decorationAttr : allDecorations) { + std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration( + llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true)); + assert(decoration.has_value()); + structDecorationsInfo.emplace_back(decoration.value(), + decorationAttr.getValue()); + } + } + uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); - typeMap[structID] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + typeMap[structID] = spirv::StructType::get( + memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) - deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, - memberTypes, offsetInfo, - memberDecorationsInfo}); + deferredStructTypesInfos.push_back( + {structTy, unresolvedMemberTypes, memberTypes, offsetInfo, + memberDecorationsInfo, structDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, - memberDecorationsInfo))) + memberDecorationsInfo, + structDecorationsInfo))) return failure(); } @@ -1769,7 +1779,7 @@ LogicalResult spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { if (operands.size() != 2) { return emitError(unknownLoc, - "OpConstantNull must have type <id> and result <id>"); + "OpConstantNull must only have type <id> and result <id>"); } Type resultType = getType(operands[0]); @@ -1779,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { } auto resultID = operands[1]; + Attribute attr; if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) { - auto attr = opBuilder.getZeroAttr(resultType); + attr = opBuilder.getZeroAttr(resultType); + } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) { + if (auto element = opBuilder.getZeroAttr(tensorType.getElementType())) + attr = DenseElementsAttr::get(tensorType, element); + } + + if (attr) { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 20482bd..db1cc3f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -95,6 +95,7 @@ struct DeferredStructTypeInfo { SmallVector<Type, 4> memberTypes; SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; }; /// A struct that collects the info needed to materialize/emit a diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index a8a2b2e..3053663 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) { return block; } +static bool isZeroValue(Attribute attr) { + if (auto floatAttr = dyn_cast<FloatAttr>(attr)) { + return floatAttr.getValue().isZero(); + } + if (auto boolAttr = dyn_cast<BoolAttr>(attr)) { + return !boolAttr.getValue(); + } + if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { + return intAttr.getValue().isZero(); + } + if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) { + return isZeroValue(splatElemAttr.getSplatValue<Attribute>()); + } + if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) { + return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue); + } + return false; +} + namespace mlir { namespace spirv { @@ -318,6 +337,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: case spirv::Decoration::Constant: + case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -630,11 +650,16 @@ LogicalResult Serializer::prepareBasicType( operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + // TODO: Now struct decorations are supported this code may not be + // necessary. However, it is left to support backwards compatibility. + // Ideally, Block decorations should be inserted when converting to SPIR-V. if (isInterfaceStructPtrType(ptrType)) { - if (failed(emitDecoration(getTypeID(pointeeStruct), - spirv::Decoration::Block))) - return emitError(loc, "cannot decorate ") - << pointeeStruct << " with Block decoration"; + auto structType = cast<spirv::StructType>(ptrType.getPointeeType()); + if (!structType.hasDecoration(spirv::Decoration::Block)) + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; } return success(); @@ -704,6 +729,20 @@ LogicalResult Serializer::prepareBasicType( } } + SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations; + structType.getStructDecorations(structDecorations); + + for (spirv::StructType::StructDecorationInfo &structDecoration : + structDecorations) { + if (failed(processDecorationAttr(loc, resultID, + structDecoration.decoration, + structDecoration.decorationValue))) { + return emitError(loc, "cannot decorate struct ") + << structType << " with " + << stringifyDecoration(structDecoration.decoration); + } + } + typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) @@ -938,6 +977,30 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } + } else if (isa<spirv::TensorArmType>(constType)) { + if (isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + return resultID; + } + numberOfConstituents = shapedType.getNumElements(); + operands.reserve(numberOfConstituents + 2); + for (int i = 0; i < numberOfConstituents; ++i) { + uint32_t elementID = 0; + if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { + elementID = + elementType.isInteger(1) + ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) + : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); + } + if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { + elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); + } + if (!elementID) { + return 0; + } + operands.push_back(elementID); + } } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { @@ -1124,6 +1187,21 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, return resultID; } +// Returns type of attribute. In case of a TypedAttr this will simply return +// the type. But for an ArrayAttr which is untyped and can be multidimensional +// it creates the ArrayType recursively. +static Type getValueType(Attribute attr) { + if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { + return typedAttr.getType(); + } + + if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { + return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size()); + } + + return nullptr; +} + uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, Type resultType, Attribute valueAttr) { @@ -1137,18 +1215,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, return 0; } - Type valueType; - if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) { - valueType = typedAttr.getType(); - } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) { - auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]); - if (!typedElemAttr) - return 0; - valueType = - spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size()); - } else { + Type valueType = getValueType(valueAttr); + if (!valueAttr) return 0; - } auto compositeType = dyn_cast<CompositeType>(resultType); if (!compositeType) @@ -1163,11 +1232,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, } uint32_t resultID = getNextID(); - uint32_t operands[] = {typeID, resultID, constandID}; - - encodeInstructionInto(typesGlobalValues, - spirv::Opcode::OpConstantCompositeReplicateEXT, - operands); + if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + } else { + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpConstantCompositeReplicateEXT, + {typeID, resultID, constandID}); + } constCompositeReplicateIDMap[valueTypePair] = resultID; return resultID; diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 8f78590..de714d8b 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -501,20 +501,26 @@ performActions(raw_ostream &os, << "bytecode version while not emitting bytecode"; AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); - op.get()->print(os, asmState); - os << '\n'; + os << OpWithState(op.get(), asmState) << '\n'; return success(); } /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult processBuffer(raw_ostream &os, - std::unique_ptr<MemoryBuffer> ownedBuffer, - const MlirOptMainConfig &config, - DialectRegistry ®istry, - llvm::ThreadPoolInterface *threadPool) { +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, + llvm::MemoryBufferRef sourceBuffer, + const MlirOptMainConfig &config, DialectRegistry ®istry, + SourceMgrDiagnosticVerifierHandler *verifyHandler, + llvm::ThreadPoolInterface *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared<SourceMgr>(); + // Add the original buffer to the source manager to use for determining + // locations. + sourceMgr->AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(sourceBuffer, + /*RequiresNullTerminator=*/false), + SMLoc()); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation @@ -522,6 +528,8 @@ static LogicalResult processBuffer(raw_ostream &os, MLIRContext context(registry, MLIRContext::Threading::DISABLED); if (threadPool) context.setThreadPool(*threadPool); + if (verifyHandler) + verifyHandler->registerInContext(&context); StringRef irdlFile = config.getIrdlFile(); if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context))) @@ -545,17 +553,12 @@ static LogicalResult processBuffer(raw_ostream &os, return performActions(os, sourceMgr, &context, config); } - SourceMgrDiagnosticVerifierHandler sourceMgrHandler( - *sourceMgr, &context, config.verifyDiagnosticsLevel()); - // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, sourceMgr, &context, config); - // Verify the diagnostic handler to make sure that each of the diagnostics - // matched. - return sourceMgrHandler.verify(); + return success(); } std::pair<std::string, std::string> @@ -624,14 +627,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, if (threadPoolCtx.isMultithreadingEnabled()) threadPool = &threadPoolCtx.getThreadPool(); + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + SMLoc()); + // Note: this creates a verifier handler independent of the the flag set, as + // internally if the flag is not set, a new scoped diagnostic handler is + // created which would intercept the diagnostics and verify them. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler( + sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel()); auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer, - raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), config, registry, - threadPool); + llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) { + return processBuffer( + os, std::move(chunkBuffer), sourceBuffer, config, registry, + config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr, + threadPool); }; - return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, - config.inputSplitMarker(), - config.outputSplitMarker()); + LogicalResult status = splitAndProcessBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + chunkFn, outputStream, config.inputSplitMarker(), + config.outputSplitMarker()); + if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify())) + status = failure(); + return status; } LogicalResult mlir::MlirOptMain(int argc, char **argv, diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index c11cb8d..e1c8afb 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -135,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer, raw_ostream &os) { + // Many of the translations expect a null-terminated buffer while splitting + // the buffer does not guarantee null-termination. Make a copy of the buffer + // to ensure null-termination. + if (!ownedBuffer->getBuffer().ends_with('\0')) { + ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy( + ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier()); + } // Temporary buffers for chained translation processing. std::string dataIn; std::string dataOut; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 08803e0..f23c619 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" @@ -1129,8 +1130,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// verification. SmallPtrSet<Operation *, 1> pendingRootUpdates; + /// A raw output stream used to prefix the debug log. + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(), + llvm::dbgs(), /*HasPendingNewline=*/false}; + /// A logger used to emit diagnostics during the conversion process. - llvm::ScopedPrinter logger{llvm::dbgs()}; + llvm::ScopedPrinter logger{os}; + std::string logPrefix; #endif }; } // namespace detail diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 1abe0fd..6e2352e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -559,6 +559,23 @@ func.func @constant() { return } +// CHECK-LABEL: @constant_8bit_float +func.func @constant_8bit_float() { + // CHECK: spirv.Constant 56 : i8 + %cst = arith.constant 1.0 : f8E4M3 + // CHECK: spirv.Constant 56 : i8 + %cst_i8 = arith.bitcast %cst : f8E4M3 to i8 + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8> + return +} + // CHECK-LABEL: @constant_16bit func.func @constant_16bit() { // CHECK: spirv.Constant 4 : i16 diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index bae7c59..ae59f28 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -2,8 +2,26 @@ // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32 +// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64 +// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64> // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32> // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64> +// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32> +// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64> //CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { @@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { return %rf, %rd : f32, f64 } +//CHECK-LABEL: @angle_caller +func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) { + // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}}) + %af = complex.angle %f : complex<f32> + // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}}) + %ad = complex.angle %d : complex<f64> + // CHECK: return %[[AF]], %[[AD]] + return %af, %ad : f32, f64 +} + +//CHECK-LABEL: @cos_caller +func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}}) + %cf = complex.cos %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}}) + %cd = complex.cos %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf, %cd : complex<f32>, complex<f64> +} + //CHECK-LABEL: @exp_caller func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) @@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp // CHECK: return %[[EF]], %[[ED]] return %ef, %ed : complex<f32>, complex<f64> } + +//CHECK-LABEL: @log_caller +func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}}) + %lf = complex.log %f : complex<f32> + // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}}) + %ld = complex.log %d : complex<f64> + // CHECK: return %[[LF]], %[[LD]] + return %lf, %ld : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @conj_caller +func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}}) + %cf2 = complex.conj %f : complex<f32> + // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}}) + %cd2 = complex.conj %d : complex<f64> + // CHECK: return %[[CF]], %[[CD]] + return %cf2, %cd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @pow_caller +func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}}) + %pf = complex.pow %f, %f : complex<f32> + // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}}) + %pd = complex.pow %d, %d : complex<f64> + // CHECK: return %[[PF]], %[[PD]] + return %pf, %pd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sin_caller +func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) + %sf2 = complex.sin %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}}) + %sd2 = complex.sin %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf2, %sd2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @sqrt_caller +func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}}) + %sf = complex.sqrt %f : complex<f32> + // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}}) + %sd = complex.sqrt %d : complex<f64> + // CHECK: return %[[SF]], %[[SD]] + return %sf, %sd : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tan_caller +func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}}) + %tf2 = complex.tan %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}}) + %td2 = complex.tan %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf2, %td2 : complex<f32>, complex<f64> +} + +//CHECK-LABEL: @tanh_caller +func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { + // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}}) + %tf = complex.tanh %f : complex<f32> + // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}}) + %td = complex.tanh %d : complex<f64> + // CHECK: return %[[TF]], %[[TD]] + return %tf, %td : complex<f32>, complex<f64> +} diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir index 00bbd1c..96ad107 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir @@ -85,11 +85,10 @@ module attributes { // CHECK: spirv.Load "StorageBuffer" %val = memref.load %arg0[%idx0] : memref<2xi32> // CHECK: spirv.CompositeInsert - %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32> + %vec = vector.insert %val, %vec0[%idx0] : i32 into vector<2xi32> // CHECK: spirv.VectorShuffle %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32> - // CHECK: spirv.CompositeExtract - %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> + %res = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> // CHECK: spirv.AccessChain // CHECK: spirv.Store "StorageBuffer" memref.store %res, %arg1[%idx0]: memref<4xi32> @@ -102,9 +101,9 @@ module attributes { // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32> // CHECK: arith.constant // CHECK: memref.load - // CHECK: vector.insertelement + // CHECK: vector.insert // CHECK: vector.shuffle - // CHECK: vector.extractelement + // CHECK: vector.extract // CHECK: memref.store // CHECK: gpu.return } diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir index fb14feb..eb9feaa 100644 --- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir +++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir @@ -51,108 +51,6 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME:(%[[S:.+]]: f32, -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: spirv.ReturnValue %[[S]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_size1_vector // CHECK-SAME: %[[SUB:.*]]: f32, %[[FULL:.*]]: vector<3xf32> // CHECK: %[[RET:.*]] = spirv.CompositeInsert %[[SUB]], %[[FULL]][2 : i32] : f32 into vector<3xf32> diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 1737f4a..0c77c88 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \ +// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT //===----------------------------------------------------------------------===// // Integer types @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } } // end module + + +// ----- + +// Check that 8-bit float types are emulated as i8. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>> +} { + + // CHECK: spirv.func @float8_to_integer8 + // CHECK-SAME: (%arg0: i8 + // CHECK-SAME: %arg1: i8 + // CHECK-SAME: %arg2: i8 + // CHECK-SAME: %arg3: i8 + // CHECK-SAME: %arg4: i8 + // CHECK-SAME: %arg5: i8 + // CHECK-SAME: %arg6: i8 + // CHECK-SAME: %arg7: i8 + // CHECK-SAME: %arg8: vector<4xi8> + // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer> + // CHECK-SAME: %arg10: !spirv.array<4 x i8> + // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8 + // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2 + // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3 + // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN + // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4 + // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU + // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ> + // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>> + // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2> + // UNSUPPORTED_FLOAT-SAME: ) { + + func.func @float8_to_integer8( + %arg0: f8E5M2, // CHECK-NOT: f8E5M2 + %arg1: f8E4M3, // CHECK-NOT: f8E4M3 + %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN + %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ + %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ + %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ + %arg6: f8E3M4, // CHECK-NOT: f8E3M4 + %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU + %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ> + %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref + %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor + ) { + // CHECK: spirv.Return + return + } +} diff --git a/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir new file mode 100644 index 0000000..983747b --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + gpu.module @kernels [#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- +// Checks that the `-convert-gpu-to-spirv` pass selects the first +// `spirv.target_env` from the `targets` array attribute attached to `gpu.module`. +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + // CHECK-SAME: #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]> + gpu.module @kernels [ + #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +} diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir index b96dd37..c71d220 100644 --- a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir @@ -10,16 +10,14 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate() gpu.func @rotate() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %{{.+}} = spirv.Constant true - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 16 : f32 gpu.return } } @@ -38,18 +36,16 @@ gpu.module @kernels { // CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size() gpu.func @rotate_width_less_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 8 : i32 %val = arith.constant 42.0 : f32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32 // CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32 - // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 // CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32 // CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ // CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] // CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]] - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 8 : f32 gpu.return } } @@ -67,34 +63,10 @@ module attributes { gpu.module @kernels { gpu.func @rotate_with_bigger_than_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %width = arith.constant 32 : i32 %val = arith.constant 42.0 : f32 // expected-error @+1 {{failed to legalize operation 'gpu.rotate'}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 - gpu.return - } -} - -} - -// ----- - -module attributes { - gpu.container_module, - spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, - #spirv.resource_limits<subgroup_size = 16>> -} { - -gpu.module @kernels { - gpu.func @rotate_non_const_width(%width: i32) kernel - attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} { - %offset = arith.constant 4 : i32 - %val = arith.constant 42.0 : f32 - - // expected-error @+1 {{'gpu.rotate' op width is not a constant value}} - %result, %valid = gpu.rotate %val, %offset, %width : f32 + %result, %valid = gpu.rotate %val, 4, 32 : f32 gpu.return } } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir new file mode 100644 index 0000000..3e5f592 --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>> +} { + + // CHECK-LABEL: @fpclassify + func.func @fpclassify(%x: f32, %v: vector<4xf32>) { + // CHECK: spirv.IsFinite %{{.*}} : f32 + %0 = math.isfinite %x : f32 + // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32> + %1 = math.isfinite %v : vector<4xf32> + + // CHECK: spirv.IsNan %{{.*}} : f32 + %2 = math.isnan %x : f32 + // CHECK: spirv.IsNan %{{.*}} : vector<4xf32> + %3 = math.isnan %v : vector<4xf32> + + // CHECK: spirv.IsInf %{{.*}} : f32 + %4 = math.isinf %x : f32 + // CHECK: spirv.IsInf %{{.*}} : vector<4xf32> + %5 = math.isinf %v : vector<4xf32> + + return + } + +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir new file mode 100644 index 0000000..e391a89 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @alloc() { + %alloc = memref.alloc() : memref<999xi32> + return +} + +// CPP: module { +// CPP-NEXT: emitc.include <"cstdlib"> +// CPP-LABEL: alloc() +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"stdlib.h"> +// NOCPP-LABEL: alloc() +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + +func.func @alloc_aligned() { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32> + return +} + +// CPP-LABEL: alloc_aligned +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// CPP-NEXT: return + +// NOCPP-LABEL: alloc_aligned +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32> +// NOCPP-NEXT: return + +func.func @allocating_multi() { + %alloc_5 = memref.alloc() : memref<7x999xi32> + return +} + +// CPP-LABEL: allocating_multi +// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void"> +// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// CPP-NEXT: return + +// NOCPP-LABEL: allocating_multi +// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index +// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> +// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> +// NOCPP-NEXT: return + diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 8d720ce..580b09d 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() { // ----- -// CHECK-LABEL: @stmatrix( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, -// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32, -// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32) -llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) { -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> () -// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> () - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32 - nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32 - nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32 - nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32 - llvm.return -} - -// ----- - // CHECK-LABEL: @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) { //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l" diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index e6fdb7a..ef0fa08 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -708,4 +708,45 @@ func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 : } {llvm.loop_annotation = #full_unroll} } {llvm.loop_annotation = #no_unroll} return -}
\ No newline at end of file +} + +// ----- + +// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true> +// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]> +// CHECK: func @simple_while_loops_annotation +// CHECK: cf.br +// CHECK: cf.cond_br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]} +// CHECK: return +#no_unroll = #llvm.loop_annotation<unroll = <disable = true>> +func.func @simple_while_loops_annotation(%arg0 : i1) { + scf.while : () -> () { + scf.condition(%arg0) + } do { + scf.yield + } attributes {llvm.loop_annotation = #no_unroll} + return +} + +// ----- + +// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true> +// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]> +// CHECK: func @do_while_loops_annotation +// CHECK: cf.br +// CHECK: cf.cond_br +// CHECK: cf.br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]} +// CHECK: return +#no_unroll = #llvm.loop_annotation<unroll = <disable = true>> +func.func @do_while_loops_annotation() { + %c0_i32 = arith.constant 0 : i32 + scf.while (%arg2 = %c0_i32) : (i32) -> (i32) { + %0 = "test.make_condition"() : () -> i1 + scf.condition(%0) %c0_i32 : i32 + } do { + ^bb0(%arg2: i32): + scf.yield %c0_i32: i32 + } attributes {llvm.loop_annotation = #no_unroll} + return +} + diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 8c135d5..31e17fb 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -274,73 +274,6 @@ func.func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf3 // ----- //===----------------------------------------------------------------------===// -// vector.extractelement -//===----------------------------------------------------------------------===// - -func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 { - %1 = vector.extractelement %arg0[] : vector<f32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_0d_f32 -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: llvm.extractelement %{{.*}}[%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_i32_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : i32 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[C]] : i32] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- -func.func @extractelement_from_vec_1d_f32_idx_as_index(%arg0: vector<16xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index( -// CHECK-SAME: %[[A:.*]]: vector<16xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<16xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -func.func @extractelement_from_vec_1d_f32_idx_as_index_scalable(%arg0: vector<[16]xf32>) -> f32 { - %0 = arith.constant 15 : index - %1 = vector.extractelement %arg0[%0 : index]: vector<[16]xf32> - return %1 : f32 -} -// CHECK-LABEL: @extractelement_from_vec_1d_f32_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: vector<[16]xf32>) -// CHECK: %[[C:.*]] = arith.constant 15 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.extractelement %[[A]][%[[I]] : i64] : vector<[16]xf32> -// CHECK: return %[[X]] : f32 - -// ----- - -//===----------------------------------------------------------------------===// // vector.extract //===----------------------------------------------------------------------===// @@ -592,81 +525,6 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : // ----- //===----------------------------------------------------------------------===// -// vector.insertelement -//===----------------------------------------------------------------------===// - -func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> { - %1 = vector.insertelement %arg0, %arg1[] : vector<f32> - return %1 : vector<f32> -} -// CHECK-LABEL: @insertelement_into_vec_0d_f32 -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : -// CHECK: vector<f32> to vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_idx_as_i32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_idx_as_i32_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : i32 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C]] : i32] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<4xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<4xf32> -// CHECK: return %[[X]] : vector<4xf32> - -// ----- - -func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { - %0 = arith.constant 3 : index - %1 = vector.insertelement %arg0, %arg1[%0 : index] : vector<[4]xf32> - return %1 : vector<[4]xf32> -} -// CHECK-LABEL: @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable( -// CHECK-SAME: %[[A:.*]]: f32, -// CHECK-SAME: %[[B:.*]]: vector<[4]xf32>) -// CHECK: %[[C:.*]] = arith.constant 3 : index -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[C]] : index to i64 -// CHECK: %[[X:.*]] = llvm.insertelement %[[A]], %[[B]][%[[I]] : i64] : vector<[4]xf32> -// CHECK: return %[[X]] : vector<[4]xf32> - -// ----- - -//===----------------------------------------------------------------------===// // vector.insert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index f43a41a..8918f91 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -400,67 +400,6 @@ func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> // ----- -// CHECK-LABEL: @extract_element -// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 { - %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_cst -// CHECK-SAME: %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> -func.func @extract_element_cst(%arg0 : vector<4xf32>) -> f32 { - %idx = arith.constant 1 : i32 - %0 = vector.extractelement %arg0[%idx : i32] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_index -func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: spirv.VectorExtractDynamic - %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size5_vector -func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { - // CHECK: vector.extractelement - %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_size1_vector -// CHECK-SAME: (%[[S:.+]]: f32 -func.func @extract_element_size1_vector(%arg0 : f32, %i: index) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<1xf32> - %0 = vector.extractelement %bcast[%i : index] : vector<1xf32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - -// CHECK-LABEL: @extract_element_0d_vector -// CHECK-SAME: (%[[S:.+]]: f32) -func.func @extract_element_0d_vector(%arg0 : f32) -> f32 { - %bcast = vector.broadcast %arg0 : f32 to vector<f32> - %0 = vector.extractelement %bcast[] : vector<f32> - // CHECK: return %[[S]] - return %0: f32 -} - -// ----- - // CHECK-LABEL: @extract_strided_slice // CHECK-SAME: %[[ARG:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]], %[[ARG]] : vector<4xf32>, vector<4xf32> -> vector<2xf32> @@ -473,67 +412,6 @@ func.func @extract_strided_slice(%arg0: vector<4xf32>) -> (vector<2xf32>, vector // ----- -// CHECK-LABEL: @insert_element -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 -// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 -func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_cst -// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> -// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> -func.func @insert_element_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { - %idx = arith.constant 2 : i32 - %0 = vector.insertelement %val, %arg0[%idx : i32] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_index -func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: spirv.VectorInsertDynamic - %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> - return %0: vector<4xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size5_vector -func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { - // CHECK: vector.insertelement - %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - return %0 : vector<5xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_size1_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_size1_vector(%scalar: f32, %vector : vector<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.insertelement %scalar, %vector[%i : index] : vector<1xf32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<1xf32> - // CHECK: return %[[V]] - return %0: vector<1xf32> -} - -// ----- - -// CHECK-LABEL: @insert_element_0d_vector -// CHECK-SAME: (%[[S:[a-z0-9]+]]: f32 -func.func @insert_element_0d_vector(%scalar: f32, %vector : vector<f32>) -> vector<f32> { - %0 = vector.insertelement %scalar, %vector[] : vector<f32> - // CHECK: %[[V:.+]] = builtin.unrealized_conversion_cast %arg0 : f32 to vector<f32> - // CHECK: return %[[V]] - return %0: vector<f32> -} - -// ----- - // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> // CHECK: spirv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]], %[[PART]] : vector<4xf32>, vector<2xf32> -> vector<4xf32> diff --git a/mlir/test/Dialect/Async/canonicalize.mlir b/mlir/test/Dialect/Async/canonicalize.mlir new file mode 100644 index 0000000..1a74eaa --- /dev/null +++ b/mlir/test/Dialect/Async/canonicalize.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// CHECK-NOT: async.execute + +func.func @empty_execute() { + %token = async.execute { + async.yield + } + return +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir new file mode 100644 index 0000000..e2ab876 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-module-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor<?xf32> + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor<?xf32> + // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32 + return %0, %1 : tensor<?xf32>, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref<?xf32 + func.func @call_func_with_non_tensor_return( + %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}> + return %1, %0 : f32, tensor<?xf32> + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 162ff06..35381da 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -479,20 +479,16 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a // ----- func.func @rotate_mismatching_type(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 16 : i32 } : (f32) -> (i32, i1) return } // ----- func.func @rotate_unsupported_type(%arg0 : index) { - %offset = arith.constant 4 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : index + %rotate, %valid = gpu.rotate %arg0, 4, 16 : index return } @@ -502,55 +498,31 @@ func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) { %offset = arith.constant 4 : i32 %width = arith.constant 16 : i32 // expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}} - %rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32> + %rotate, %valid = gpu.rotate %arg0, 4, 16 : vector<[4]xf32> return } // ----- func.func @rotate_unsupported_width(%arg0 : f32) { - %offset = arith.constant 4 : i32 - %width = arith.constant 15 : i32 - // expected-error@+1 {{op width must be a power of two}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'width' failed to satisfy constraint: 32-bit signless integer attribute whose value is a power of two > 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 4 : i32, width = 15 : i32 } : (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset(%arg0 : f32) { - %offset = arith.constant 16 : i32 - %width = arith.constant 16 : i32 // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + %rotate, %valid = "gpu.rotate"(%arg0) { offset = 16 : i32, width = 16 : i32 }: (f32) -> (f32, i1) return } // ----- func.func @rotate_unsupported_offset_minus(%arg0 : f32) { - %offset = arith.constant -1 : i32 - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset must be in the range [0, 16)}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) { - %width = arith.constant 16 : i32 - // expected-error@+1 {{op offset is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) - return -} - -// ----- - -func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) { - %offset = arith.constant 0 : i32 - // expected-error@+1 {{op width is not a constant value}} - %rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1) + // expected-error@+1 {{'gpu.rotate' op attribute 'offset' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 0}} + %rotate, %valid = "gpu.rotate"(%arg0) { offset = -1 : i32, width = 16 : i32 } : (f32) -> (f32, i1) return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 2aef80f..ee1fdfa 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -140,9 +140,8 @@ module attributes {gpu.container_module} { // CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32 %shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32 - // CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32 - %rotate_width = arith.constant 16 : i32 - %rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32 + // CHECK: gpu.rotate %{{.*}}, 3, 16 : f32 + %rotate, %pred4 = gpu.rotate %arg0, 3, 16 : f32 "gpu.barrier"() : () -> () diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 39a7b1b..5c5f7e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index a00c798..5f42938 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -1076,6 +1076,44 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // ----- +func.func @drop_unit_dim_mixed_static_dynamic(%arg0: tensor<1x?xf32>) -> tensor<1x16xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[%c0, %c1] high[%c0, %c0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %cst : f32 + } : tensor<1x?xf32> to tensor<1x16xf32> + return %padded : tensor<1x16xf32> +} +// CHECK-LABEL: func @drop_unit_dim_mixed_static_dynamic +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARGS:.*]] : tensor<1x?xf32> into tensor<?xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSE]] low[1] high[0] { +// CHECK: ^bb0(%[[IDX:.*]]: index): +// CHECK: tensor.yield %[[CST]] : f32 +// CHECK: } : tensor<?xf32> to tensor<16xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, 16] : tensor<16xf32> into tensor<1x16xf32> +// CHECK: return %[[EXPANDED]] : tensor<1x16xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +module { + func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { + %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> + %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.mulf %in, %in_0 : f32 + %3 = arith.addf %out, %2 : f32 + linalg.yield %3 : f32 + } -> tensor<?x1x61x1xf32> + return %1 : tensor<?x1x61x1xf32> + } +} // CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)> // CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()> @@ -1097,23 +1135,6 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te // CHECK: return %[[VAL_14]] : tensor<?x1x61x1xf32> // CHECK: } -#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> -module { - func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> { - %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32> - %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %2 = arith.mulf %in, %in_0 : f32 - %3 = arith.addf %out, %2 : f32 - linalg.yield %3 : f32 - } -> tensor<?x1x61x1xf32> - return %1 : tensor<?x1x61x1xf32> - } -} - // ----- func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> { diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir index 78619b6..981f5dc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface-multiple-of.mlir @@ -52,22 +52,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[2, 0] // CHECK: : tensor<7x5xf32> to tensor<9x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[2, 4, 2] { - // CHECK: : tensor<7x11x12xf32> to tensor<9x15x14xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<9x15x13xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<9x15x14xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<9x15x13xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -83,7 +83,7 @@ module { // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 5)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 3) * 3 + 4)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -272,3 +272,136 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: pad_conv +func.func @pad_conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x16x16x4xf32> to tensor<1x16x18x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16 + 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 16) * 16)> + +// CHECK-LABEL: pad_conv_dynamic +func.func @pad_conv_dynamic(%arg0: tensor<1x16x?x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> { + + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[D0_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[D0_1:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x16x?x4xf32> + // CHECK: %[[H0:.*]] = affine.apply #[[$MAP0]]()[%[[D0_0]], %[[D0_1]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H0]], 12] + // CHECK: : tensor<1x16x?x4xf32> to tensor<1x16x?x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: %[[D1_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK: %[[H1:.*]] = affine.apply #[[$MAP1]]()[%[[D0_0]], %[[D1_0]]] + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, %[[H1]], 0] + // CHECK: : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x14x?x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, %[[D2_0]], 16] [1, 1, 1, 1] : tensor<1x14x?x16xf32> to tensor<1x14x?x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x?x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x?x16xf32>) -> tensor<1x14x?x16xf32> + return %0 : tensor<1x14x?x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_strided +func.func @pad_conv_strided(%arg0: tensor<1x42x42x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 6, 12] + // CHECK: : tensor<1x42x42x4xf32> to tensor<1x42x48x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<1> : tensor<2xi64>, strides = dense<3> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x42x42x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: pad_conv_dilated +func.func @pad_conv_dilated(%arg0: tensor<1x18x18x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { + + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 12] + // CHECK: : tensor<1x18x18x4xf32> to tensor<1x18x20x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 0, 12] + // CHECK: : tensor<16x3x3x4xf32> to tensor<16x3x3x16xf32> + // CHECK: tensor.pad %{{.*}} low[0, 0, 0, 0] high[0, 0, 2, 0] + // CHECK: : tensor<1x14x14x16xf32> to tensor<1x14x16x16xf32> + // CHECK-NEXT: linalg.conv_2d_nhwc_fhwc + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0, 0] [1, 14, 14, 16] [1, 1, 1, 1] : tensor<1x14x16x16xf32> to tensor<1x14x14x16xf32> + + %0 = linalg.conv_2d_nhwc_fhwc + {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x18x18x4xf32>, tensor<16x3x3x4xf32>) + outs(%arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [0, 0, 16, 0, 0, 0, 16] pad_to_multiple_of { + padding_values = [0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32, 0.0 : f32] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index 26c03ed..f741876 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -69,22 +69,22 @@ module { // CHECK-LABEL: @generic // CHECK-SAME: %[[T0:.*]]: tensor<7x5xf32>, -// CHECK-SAME: %[[T1:.*]]: tensor<7x11x12xf32>) - func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> { +// CHECK-SAME: %[[T1:.*]]: tensor<7x11x11xf32>) + func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x11xf32>) -> tensor<7x11x11xf32> { // CHECK-DAG: %[[CST:.*]] = arith.constant 0. // CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0] // CHECK: : tensor<7x5xf32> to tensor<8x5xf32> // CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] { - // CHECK: : tensor<7x11x12xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<7x11x11xf32> to tensor<8x14x12xf32> // CHECK-NEXT: linalg.generic - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32> - %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) { + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 11] [1, 1, 1] : tensor<8x14x12xf32> to tensor<7x11x11xf32> + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x11xf32>) { ^bb0(%in: f32, %out: f32): linalg.yield %in : f32 - } -> tensor<7x11x12xf32> - return %0 : tensor<7x11x12xf32> + } -> tensor<7x11x11xf32> + return %0 : tensor<7x11x11xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { @@ -102,7 +102,7 @@ module { // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (-s0 + 8)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 13)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (-s0 + 12)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 5)> #map = affine_map<(d0, d1, d2) -> (d0, d1)> @@ -127,13 +127,13 @@ module { // CHECK: %[[D2_0:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<?x11x?xf32> // CHECK: %[[H2:.*]] = affine.apply #[[$MAP1]]()[%[[D2_0]]] // CHECK: tensor.pad %{{.*}} low[0, 0, 0] high[%[[H1]], 3, %[[H2]]] { - // CHECK: : tensor<?x11x?xf32> to tensor<8x14x13xf32> + // CHECK: : tensor<?x11x?xf32> to tensor<8x14x12xf32> // // CHECK: %[[D0_2:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x5xf32> // CHECK: %[[D2_1:.*]] = affine.apply #[[$MAP2]]()[%[[D0_2]]] - // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x13xf32>) { - // CHECK: } -> tensor<8x14x13xf32> - // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x13xf32> to tensor<?x11x?xf32> + // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<8x5xf32>) outs(%{{.*}} : tensor<8x14x12xf32>) { + // CHECK: } -> tensor<8x14x12xf32> + // CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [%[[D0_2]], 11, %[[D2_1]]] [1, 1, 1] : tensor<8x14x12xf32> to tensor<?x11x?xf32> // %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<?x5xf32>) outs(%arg1 : tensor<?x11x?xf32>) { ^bb0(%in: f32, %out: f32): diff --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir index c3ee892..d7722ea 100644 --- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir @@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[PV:.*]] = ub.poison : i32 -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> // CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> // CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> -// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> -// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> -// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex> +// CHECK: %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> @@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(% // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32> // CHECK-SAME: %[[ARG1:.*]]: index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> -// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex> // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32> -// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex> // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index -// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex> -// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex> -// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // ----- @@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) // CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( // CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { -// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> -// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> // CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> -// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> -// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> // CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> // CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> // CHECK: return %[[RES]] : tensor<8x1xf32> @@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16 // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex> // CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex> -// CHECK: %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex> // CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_14]] : tensor<1x4xf32> @@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex> // CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1> // CHECK-DAG: %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> // CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index // CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex> -// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> +// CHECK: %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> // CHECK: %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> // CHECK: return %[[VAL_10]] : tensor<1x4xf32> // CHECK: } @@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]] // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]] -// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex> +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1> // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> // CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex> -// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex> -// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex> +// CHECK: %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index +// CHECK: %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex> // CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]] // CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]] // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]] diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 12d30e17..308cf150 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1440,8 +1440,8 @@ func.func @propagate_into_execute_region() { // ----- -// CHECK-LABEL: func @execute_region_elim -func.func @execute_region_elim() { +// CHECK-LABEL: func @execute_region_inline +func.func @execute_region_inline() { affine.for %i = 0 to 100 { "test.foo"() : () -> () %v = scf.execute_region -> i64 { @@ -1461,8 +1461,30 @@ func.func @execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim -func.func @func_execute_region_elim() { +// CHECK-LABEL: func @execute_region_no_inline +func.func @execute_region_no_inline() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 no_inline { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: scf.execute_region +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: scf.yield %[[VAL]] : i64 +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_inline +func.func @func_execute_region_inline() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 @@ -1496,8 +1518,8 @@ func.func @func_execute_region_elim() { // ----- -// CHECK-LABEL: func @func_execute_region_elim_multi_yield -func.func @func_execute_region_elim_multi_yield() { +// CHECK-LABEL: func @func_execute_region_inline_multi_yield +func.func @func_execute_region_inline_multi_yield() { "test.foo"() : () -> () %v = scf.execute_region -> i64 { %c = "test.cmp"() : () -> i1 diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index 3adafc1..c703274 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -13,7 +13,7 @@ func.func @fadd_scalar(%arg: f32) -> f32 { // ----- func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FAdd %arg, %arg : bf16 return %0 : bf16 } @@ -33,7 +33,7 @@ func.func @fdiv_scalar(%arg: f32) -> f32 { // ----- func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FDiv %arg, %arg : bf16 return %0 : bf16 } @@ -53,7 +53,7 @@ func.func @fmod_scalar(%arg: f32) -> f32 { // ----- func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMod %arg, %arg : bf16 return %0 : bf16 } @@ -79,7 +79,7 @@ func.func @fmul_vector(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func.func @fmul_i32(%arg: i32) -> i32 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : i32 return %0 : i32 } @@ -87,7 +87,7 @@ func.func @fmul_i32(%arg: i32) -> i32 { // ----- func.func @fmul_bf16(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : bf16 return %0 : bf16 } @@ -95,7 +95,7 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 { // ----- func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : vector<4xbf16> return %0 : vector<4xbf16> } @@ -103,7 +103,7 @@ func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> { // ----- func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : tensor<4xf32> return %0 : tensor<4xf32> } @@ -123,7 +123,7 @@ func.func @fnegate_scalar(%arg: f32) -> f32 { // ----- func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FNegate %arg : bf16 return %0 : bf16 } @@ -143,7 +143,7 @@ func.func @frem_scalar(%arg: f32) -> f32 { // ----- func.func @frem_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FRem %arg, %arg : bf16 return %0 : bf16 } @@ -163,7 +163,7 @@ func.func @fsub_scalar(%arg: f32) -> f32 { // ----- func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.FSub %arg, %arg : bf16 return %0 : bf16 } @@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { // ----- func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32 return %0 : i32 } @@ -558,7 +558,7 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3 // ----- func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> { - // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}} %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16> return %0 : vector<4xbf16> } diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir index f3f0ebf..4bdac19 100644 --- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir @@ -137,7 +137,7 @@ func.func @bitwise_or_all_ones_vector(%arg: vector<3xi8>) -> vector<3xi8> { // ----- func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseOr %arg0, %arg1 : f16 return %0 : f16 } @@ -165,7 +165,7 @@ func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> { // ----- func.func @bitwise_xor_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseXor %arg0, %arg1 : f16 return %0 : f16 } @@ -274,7 +274,7 @@ func.func @bitwise_and_zext_vector(%arg: vector<2xi8>) -> vector<2xi32> { // ----- func.func @bitwise_and_float(%arg0: f16, %arg1: f16) -> f16 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %0 = spirv.BitwiseAnd %arg0, %arg1 : f16 return %0 : f16 } diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 5c5d94c..fd8a2ff 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %2 = spirv.GL.Exp %arg0 : i32 return } @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} %2 = spirv.GL.Exp %arg0 : vector<5xf32> return } @@ -51,7 +51,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp_bf16(%arg0 : bf16) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values of length 2/3/4}} %2 = spirv.GL.Exp %arg0 : bf16 return } @@ -101,7 +101,7 @@ func.func @iminmax(%arg0: i32, %arg1: i32) { // ----- func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) { - // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16> %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16> return @@ -499,7 +499,7 @@ func.func @frexp_struct_mismatch_type(%arg0 : f32) -> () { // ----- func.func @frexp_struct_wrong_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.GL.FrexpStruct %arg0 : i32 -> !spirv.struct<(i32, i32)> return } @@ -614,7 +614,7 @@ func.func @findimsb_vector_i64(%arg0 : vector<3xi64>) -> () { // ----- func.func @findimsb_error_scalar_float(%arg0 : f32) -> () { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/1}} %2 = spirv.GL.FindILsb %arg0 : f32 return } @@ -640,7 +640,7 @@ func.func @findsmsb_vector(%arg0 : vector<3xi32>) -> () { // ----- func.func @findsmsb_error_scalar_i64(%arg0 : i64) -> () { - // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}} + // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}} %2 = spirv.GL.FindSMsb %arg0 : i64 return } @@ -666,7 +666,7 @@ func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () { // ----- func.func @findumsb(%arg0 : i64) -> () { - // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}} + // expected-error @+1 {{operand #0 must be Int32 or fixed-length vector of Int32}} %2 = spirv.GL.FindUMsb %arg0 : i64 return } @@ -692,7 +692,7 @@ func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) { - // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32 return } @@ -708,7 +708,7 @@ func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) { // ----- func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) { - // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32 return } @@ -736,7 +736,7 @@ func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) { - // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32> return } @@ -762,7 +762,7 @@ func.func @normalize_vector(%arg0 : vector<3xf32>) { // ----- func.func @normalize_invalid_type(%arg0 : i32) { - // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Normalize %arg0 : i32 return } @@ -788,7 +788,7 @@ func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) { // ----- func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) { - // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Reflect %arg0, %arg1 : i32 return } @@ -814,7 +814,7 @@ func.func @fractvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @fract_invalid_type(%arg0 : i32) { - // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{'spirv.GL.Fract' op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %0 = spirv.GL.Fract %arg0 : i32 return } @@ -840,7 +840,7 @@ func.func @log2vec(%arg0 : vector<3xf16>) -> () { // ----- func.func @log2_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Log2 %arg0 : i32 return } @@ -866,7 +866,7 @@ func.func @tanhvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @tanh_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Tanh %arg0 : i32 return } @@ -892,7 +892,7 @@ func.func @exp2vec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp2_invalid_type(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32-bit float or fixed-length vector of 16/32-bit float values}} %0 = spirv.GL.Exp2 %arg0 : i32 return } @@ -1022,7 +1022,7 @@ func.func @lengthvec(%arg0 : vector<3xf32>) -> () { // ----- func.func @length_i32_in(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GL.Length %arg0 : i32 -> f32 return } @@ -1038,7 +1038,7 @@ func.func @length_f16_in(%arg0 : f16) -> () { // ----- func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}} %0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir index d9957ad8..d7a4a6d 100644 --- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir @@ -49,7 +49,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) // ----- func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 { - // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}} %0 = spirv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32> return %0: f32 } diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir index d3aaef7..320a8fa 100644 --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -349,7 +349,7 @@ func.func @image_fetch_2d_result(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArr // ----- func.func @image_fetch_float_coords(%arg0: !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, %arg1: vector<2xf32>) -> () { - // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NeedSampler, Rgba8>, vector<2xf32> -> vector<2xf32> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index bb15d01..2e2fb1a 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 spirv.Return } @@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { // ----- spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16> spirv.Return } @@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { // ----- spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { - // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + // expected-error @+1 {{result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16 spirv.Return } @@ -65,7 +65,7 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { // ----- spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> spirv.Return } @@ -73,6 +73,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { // ----- //===----------------------------------------------------------------------===// +// spirv.INTEL.RoundFToTF32 +//===----------------------------------------------------------------------===// + +spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32> + spirv.Return +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 61a35b7..491c7a7 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -583,7 +583,7 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix // These binary arithmetic instructions do not support coop matrix operands. spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> spirv.Return } @@ -591,14 +591,14 @@ spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16}} %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> spirv.Return } // ----- spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } @@ -606,7 +606,7 @@ spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } @@ -614,7 +614,7 @@ spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.c // ----- spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}} %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> spirv.Return } diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index d6c3464..d7f4ed0 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -33,6 +33,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto // ----- //===----------------------------------------------------------------------===// +// spirv.IsFinite +//===----------------------------------------------------------------------===// + +func.func @isfinite_scalar(%arg0: f32) -> i1 { + // CHECK: spirv.IsFinite {{.*}} : f32 + %0 = spirv.IsFinite %arg0 : f32 + return %0 : i1 +} + +func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> { + // CHECK: spirv.IsFinite {{.*}} : vector<2xf32> + %0 = spirv.IsFinite %arg0 : vector<2xf32> + return %0 : vector<2xi1> +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.IsInf //===----------------------------------------------------------------------===// @@ -166,7 +184,7 @@ func.func @logicalUnary(%arg0 : i1) func.func @logicalUnary(%arg0 : i32) { - // expected-error @+1 {{'operand' must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{'operand' must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.LogicalNot %arg0 : i32 return } diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 7ab94f1..bdb2abd 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -185,7 +185,7 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto // ----- func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} %0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16 return %0: bf16 } @@ -206,7 +206,7 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 { // ----- func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} %0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16 return %0: bf16 } @@ -511,7 +511,7 @@ func.func @group_non_uniform_bitwise_and(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_and(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseAnd <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -532,7 +532,7 @@ func.func @group_non_uniform_bitwise_or(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_or(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseOr <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -553,7 +553,7 @@ func.func @group_non_uniform_bitwise_xor(%val: i32) -> i32 { // ----- func.func @group_non_uniform_bitwise_xor(%val: i1) -> i1 { - // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} + // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'i1'}} %0 = spirv.GroupNonUniformBitwiseXor <Workgroup> <Reduce> %val : i1 -> i1 return %0: i1 } @@ -574,7 +574,7 @@ func.func @group_non_uniform_logical_and(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_and(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalAnd <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } @@ -595,7 +595,7 @@ func.func @group_non_uniform_logical_or(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_or(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalOr <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } @@ -616,7 +616,7 @@ func.func @group_non_uniform_logical_xor(%val: i1) -> i1 { // ----- func.func @group_non_uniform_logical_xor(%val: i32) -> i32 { - // expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4/8/16, but got 'i32'}} + // expected-error @+1 {{operand #0 must be bool or fixed-length vector of bool values of length 2/3/4/8/16, but got 'i32'}} %0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32 return %0: i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 8f021ed..6aaaa60 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -19,7 +19,7 @@ func.func @expvec(%arg0 : vector<3xf16>) -> () { // ----- func.func @exp(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.CL.exp %arg0 : i32 return } @@ -27,7 +27,7 @@ func.func @exp(%arg0 : i32) -> () { // ----- func.func @exp(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} %2 = spirv.CL.exp %arg0 : vector<5xf32> return } @@ -75,7 +75,7 @@ func.func @fabsf64(%arg0 : f64) -> () { // ----- func.func @fabs(%arg0 : i32) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values}} %2 = spirv.CL.fabs %arg0 : i32 return } @@ -83,7 +83,7 @@ func.func @fabs(%arg0 : i32) -> () { // ----- func.func @fabs(%arg0 : vector<5xf32>) -> () { - // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or fixed-length vector of 16/32/64-bit float values of length 2/3/4}} %2 = spirv.CL.fabs %arg0 : vector<5xf32> return } @@ -137,7 +137,7 @@ func.func @sabsi8(%arg0 : i8) -> () { // ----- func.func @sabs(%arg0 : f32) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values}} %2 = spirv.CL.s_abs %arg0 : f32 return } @@ -145,7 +145,7 @@ func.func @sabs(%arg0 : f32) -> () { // ----- func.func @sabs(%arg0 : vector<5xi32>) -> () { - // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4}} + // expected-error @+1 {{op operand #0 must be 8/16/32/64-bit integer or fixed-length vector of 8/16/32/64-bit integer values of length 2/3/4}} %2 = spirv.CL.s_abs %arg0 : vector<5xi32> return } diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 5d05a654..6d321af 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve // CHECK: func private @struct_empty(!spirv.struct<()>) func.func private @struct_empty(!spirv.struct<()>) +// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) +func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) + +// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) +func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) + // ----- // expected-error @+1 {{offset specification must be given for all members}} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir index bd51a07..f3a3218 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]] // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 } // end spirv.module + +// ----- + +module { + spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} { + // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + // CHECK: spirv.func @read_image + spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} { + // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> + %cst0_i32 = spirv.Constant 0 : i32 + // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]] + %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>> + %2 = spirv.ImageFetch %1, %cst0_i32 : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32> + %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32> + %cst0_i32_0 = spirv.Constant 0 : i32 + %cst0_i32_1 = spirv.Constant 0 : i32 + %cst1_i32 = spirv.Constant 1 : i32 + %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + spirv.Store "StorageBuffer" %4, %3 : f32 + spirv.Return + } + } +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 2b23766..8d7f3da 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -178,7 +178,7 @@ spirv.module Logical GLSL450 attributes { // Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled // implicitly by v1.5. -// CHECK: requires #spirv.vce<v1.0, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> +// CHECK: requires #spirv.vce<v1.5, [VulkanMemoryModel], [SPV_KHR_vulkan_memory_model]> spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< #spirv.vce<v1.5, [Shader, VulkanMemoryModel], []>, #spirv.resource_limits<>> diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 6b55442..5150ee3 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // ----- +// CHECK-LABEL: @clamp_boolean_is_noop +func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @clamp_boolean_dynamic_is_noop +func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1> + return %0 : tensor<?xi1> +} + +// ----- + // CHECK-LABEL: @clamp_int8_is_noop func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index 8739f97..e23ce430 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic strict-op-spec-alignment allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index b90d6f5..3bccb32 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -2036,3 +2036,19 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- + +func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { + // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} + %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> + return %0 : tensor<1x12x11xf32> +} + +// ----- + +func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) { + // expected-error@+1 {{'tosa.rfft2d' op illegal: operation operand/result data types did not align with any profile or extension, got (bf16,bf16,bf16), did you mean (f32,f32,f32)?}} + %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) + return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index cbe0056..0184d2b 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} @@ -48,10 +48,10 @@ func.func @test_add_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tens // ----- -func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>, %arg1: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { +func.func @test_arithmetic_right_shift_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>, %arg1: tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> { // expected-error@+1 {{'tosa.arithmetic_right_shift' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xf32>, tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> - return %0 : tensor<1x1x1x1x13x21x3xf32> + %0 = tosa.arithmetic_right_shift %arg0, %arg1 {round = false} : (tensor<1x1x1x1x13x21x3xi32>, tensor<1x1x1x1x13x21x3xi32>) -> tensor<1x1x1x1x13x21x3xi32> + return %0 : tensor<1x1x1x1x13x21x3xi32> } // ----- diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30..f86fb38 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_scalar_from_splat +// CHECK-LABEL: fold_extract_splatlike // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { - %b = vector.splat %a : vector<1x2x4xf32> +func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } @@ -1330,11 +1330,11 @@ func.func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { // ----- -// CHECK-LABEL: shape_cast_constant +// CHECK-LABEL: shape_cast_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<20x2xf32> // CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32> -func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { +func.func @shape_cast_splat_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { %cst = arith.constant dense<2.000000e+00> : vector<5x4x2xf32> %cst_1 = arith.constant dense<1> : vector<12x2xi32> %0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32> @@ -1344,6 +1344,36 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) { // ----- +// Test of shape_cast's fold method: +// shape_cast(constant) -> constant. +// +// CHECK-LABEL: @shape_cast_dense_int_constant +// CHECK: %[[CST:.*]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[2, 3, 5], [7, 11, 13]]> +// CHECK: return %[[CST]] : vector<2x3xi8> +func.func @shape_cast_dense_int_constant() -> vector<2x3xi8> { + %cst = arith.constant dense<[2, 3, 5, 7, 11, 13]> : vector<6xi8> + %0 = vector.shape_cast %cst : vector<6xi8> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + +// Test of shape_cast fold's method: +// (shape_cast(const_x), const_x) -> (const_x_folded, const_x) +// +// CHECK-LABEL: @shape_cast_dense_float_constant +// CHECK-DAG: %[[CST0:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<1x2xf32> +// CHECK-DAG: %[[CST1:.*]] = {{.*}}1.000000e+00, 2.000000e+00{{.*}} vector<2xf32> +// CHECK: return %[[CST1]], %[[CST0]] : vector<2xf32>, vector<1x2xf32> +func.func @shape_cast_dense_float_constant() -> (vector<2xf32>, vector<1x2xf32>){ + %cst = arith.constant dense<[[1.0, 2.0]]> : vector<1x2xf32> + %0 = vector.shape_cast %cst : vector<1x2xf32> to vector<2xf32> + return %0, %cst : vector<2xf32>, vector<1x2xf32> +} + +// ----- + // CHECK-LABEL: shape_cast_poison // CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32> @@ -2033,11 +2063,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve // ----- -// CHECK-LABEL: extract_strided_splat -// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16> +// CHECK-LABEL: extract_strided_splatlike +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> // CHECK-NEXT: return %[[B]] : vector<2x4xf16> -func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { - %0 = vector.splat %arg0 : vector<16x4xf16> +func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16> %1 = vector.extract_strided_slice %0 {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : vector<16x4xf16> to vector<2x4xf16> @@ -2323,14 +2353,14 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>) // ----- -// CHECK-LABEL: func @splat_fold -func.func @splat_fold() -> vector<4xf32> { +// CHECK-LABEL: func @splatlike_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> +func.func @splatlike_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 - %v = vector.splat %c : vector<4xf32> + %v = vector.broadcast %c : f32 to vector<4xf32> return %v : vector<4xf32> - // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> - // CHECK-NEXT: return [[V]] : vector<4xf32> } // ----- @@ -2469,10 +2499,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5 // ----- -// CHECK-LABEL: func @transpose_splat_constant +// CHECK-LABEL: func @transpose_splatlike_constant // CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> // CHECK: return %[[CST]] -func.func @transpose_splat_constant() -> vector<8x4xf32> { +func.func @transpose_splatlike_constant() -> vector<8x4xf32> { %cst = arith.constant dense<5.0> : vector<4x8xf32> %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> @@ -2480,13 +2510,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> { // ----- -// CHECK-LABEL: func @transpose_splat2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32> -// CHECK: return %[[VAL_1]] : vector<3x4xf32> -// CHECK: } -func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { - %splat = vector.splat %arg : vector<4x3xf32> +// CHECK-LABEL: func @transpose_splatlike2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +// CHECK: } +func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.broadcast %arg : f32 to vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> return %0 : vector<3x4xf32> } @@ -2562,118 +2592,6 @@ func.func @insert_2d_splat_constant() // ----- -// CHECK-LABEL: func @insert_element_fold -// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> -// CHECK: return %[[V]] -func.func @insert_element_fold() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @insert_element_invalid_fold -func.func @insert_element_invalid_fold() -> vector<1xf32> { - // Out-of-bound index here. - %c26 = arith.constant 26 : index - %cst_2 = arith.constant 1.60215309E+9 : f32 - %cst_20 = arith.constant dense<1.60215309E+9> : vector<1xf32> -// CHECK: vector.insertelement - %46 = vector.insertelement %cst_2, %cst_20[%c26 : index] : vector<1xf32> - return %46 : vector<1xf32> -} - - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold1 -// CHECK: vector.insertelement -func.func @insert_poison_fold1() -> vector<4xi32> { - %v = ub.poison : vector<4xi32> - %s = arith.constant 7 : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold2 -// CHECK: vector.insertelement -func.func @insert_poison_fold2() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = ub.poison : i32 - %i = arith.constant 2 : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @insert_poison_fold3 -// CHECK: vector.insertelement -func.func @insert_poison_fold3() -> vector<4xi32> { - %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> - %s = arith.constant 7 : i32 - %i = ub.poison : i32 - %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> - return %1 : vector<4xi32> -} - -// ----- - -// CHECK-LABEL: func @extract_element_fold -// CHECK: %[[C:.+]] = arith.constant 5 : i32 -// CHECK: return %[[C]] -func.func @extract_element_fold() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold1 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold1() -> i32 { - %v = ub.poison : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - -// Do not crash on poison -// CHECK-LABEL: func @extract_element_poison_fold2 -// CHECK: vector.extractelement -func.func @extract_element_poison_fold2() -> i32 { - %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> - %i = ub.poison : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: func @reduce_one_element_vector_extract // CHECK-SAME: (%[[V:.+]]: vector<1xf32>) // CHECK: %[[S:.+]] = vector.extract %[[V]][0] : f32 from vector<1xf32> @@ -2781,13 +2699,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> { // ----- -// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-LABEL: @insert_strided_slice_splatlike // CHECK-SAME: (%[[ARG:.*]]: f32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> // CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> -func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { - %splat0 = vector.splat %x : vector<4x4xf32> - %splat1 = vector.splat %x : vector<8x16xf32> +func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.broadcast %x : f32 to vector<4x4xf32> + %splat1 = vector.broadcast %x : f32 to vector<8x16xf32> %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<8x16xf32> return %0 : vector<8x16xf32> @@ -2860,13 +2778,13 @@ func.func @insert_strided_2d_constant() -> // ----- -// CHECK-LABEL: func @shuffle_splat +// CHECK-LABEL: func @shuffle_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> -func.func @shuffle_splat(%x : i32) -> vector<4xi32> { - %v0 = vector.splat %x : vector<4xi32> - %v1 = vector.splat %x : vector<2xi32> +func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> { + %v0 = vector.broadcast %x : i32 to vector<4xi32> + %v1 = vector.broadcast %x : i32 to vector<2xi32> %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> return %shuffle : vector<4xi32> } @@ -2874,13 +2792,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> { // ----- -// CHECK-LABEL: func @insert_splat +// CHECK-LABEL: func @insert_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> -func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { - %v0 = vector.splat %x : vector<4x3xi32> - %v1 = vector.splat %x : vector<2x4x3xi32> +func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.broadcast %x : i32 to vector<4x3xi32> + %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32> %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> return %insert : vector<2x4x3xi32> } @@ -2933,18 +2851,6 @@ func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // ----- -// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( -// CHECK-SAME: %[[f:.*]]: f32 -// CHECK: return %[[f]] -func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { - %0 = vector.broadcast %f : f32 to vector<15xf32> - %c5 = arith.constant 5 : index - %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> - return %1 : f32 -} - -// ----- - // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 { // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32> @@ -3124,11 +3030,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3 // ----- -// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>) -func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { - // Splat scalar to 0D and extract scalar. - %0 = vector.splat %a : vector<f32> +// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: vector<f32>, %[[C:.*]]: vector<2xf32>) +func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat/broadcast scalar to 0D and extract scalar. + %0 = vector.broadcast %a : f32 to vector<f32> %1 = vector.extract %0[] : f32 from vector<f32> // Broadcast scalar to 0D and extract scalar. @@ -3136,12 +3042,12 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %3 = vector.extract %2[] : f32 from vector<f32> // Broadcast 0D to 3D and extract scalar. - // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32> + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector<f32> %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32> %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> - // Splat scalar to 2D and extract scalar. - %6 = vector.splat %a : vector<2x3xf32> + // Splat/broadcast scalar to 2D and extract scalar. + %6 = vector.broadcast %a : f32 to vector<2x3xf32> %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> // Broadcast scalar to 3D and extract scalar. @@ -3149,14 +3055,14 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> // Extract 2D from 3D that was broadcasted from a scalar. - // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> // Extract 1D from 2D that was splat'ed from a scalar. - // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] + // CHECK: return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> } @@ -3598,7 +3504,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index %v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32> %v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32> %v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32> - return %v_2 : vector<4x4xf32> + return %v_2 : vector<4x4xf32> } // ----- @@ -3612,5 +3518,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> { %v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32> %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32> - return %v_1 : vector<4xf32> + return %v_1 : vector<4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index fdab2a8..f43328f 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32 // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> + // CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32> %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> // CHECK: return %[[SPLAT1]], %[[SPLAT2]] return %1, %2 : vector<3xf32>, vector<3xf32> @@ -63,11 +63,11 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, // CHECK-LABEL: func @from_elements_to_splat( // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) { - // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> + // CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector<f32> + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector<f32> %2 = vector.from_elements %a : vector<f32> // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]] return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32> @@ -170,7 +170,7 @@ func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> v // Could match, but handled by `rewriteFromElementsAsSplat`. // CHECK-LABEL: func @extract_single_elm( // CHECK-NEXT: vector.extract -// CHECK-NEXT: vector.splat +// CHECK-NEXT: vector.broadcast // CHECK-NEXT: return func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir new file mode 100644 index 0000000..e4a9391 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file should be removed when vector.splat is removed. +// This file tests canonicalization/folding with vector.splat. +// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir + + +// CHECK-LABEL: fold_extract_splat +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: extract_strided_splat +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.splat %arg0 : vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +} + +// ----- + +// CHECK-LABEL: func @splat_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> +func.func @splat_fold() -> vector<4xf32> { + %c = arith.constant 1.0 : f32 + %v = vector.splat %c : vector<4xf32> + return %v : vector<4xf32> + +} + +// ----- + +// CHECK-LABEL: func @transpose_splat2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.splat %arg : vector<4x3xf32> + %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> +// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> +func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.splat %x : vector<4x4xf32> + %splat1 = vector.splat %x : vector<8x16xf32> + %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> +func.func @shuffle_splat(%x : i32) -> vector<4xi32> { + %v0 = vector.splat %x : vector<4xi32> + %v1 = vector.splat %x : vector<2xi32> + %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} + + +// ----- + +// CHECK-LABEL: func @insert_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> +func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.splat %x : vector<4x3xi32> + %v1 = vector.splat %x : vector<2x4x3xi32> + %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> + return %insert : vector<2x4x3xi32> +} + +// ----- + +// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression +// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>) +func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat scalar to 0D and extract scalar. + %0 = vector.splat %a : vector<f32> + %1 = vector.extract %0[] : f32 from vector<f32> + + // Broadcast scalar to 0D and extract scalar. + %2 = vector.splat %a : vector<f32> + %3 = vector.extract %2[] : f32 from vector<f32> + + // Splat scalar to 2D and extract scalar. + %6 = vector.splat %a : vector<2x3xf32> + %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + + // Broadcast scalar to 3D and extract scalar. + %8 = vector.splat %a : vector<5x6x7xf32> + %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> + + // Extract 2D from 3D that was broadcasted from a scalar. + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> + %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> + + // Extract 1D from 2D that was splat'ed from a scalar. + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> + %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> + + // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] + return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> +} diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 0263193..b2f16bb 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> { func.return %2 : vector<4x4xindex> } +// CHECK-LABEL: func @vector_transpose +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index} +func.func @vector_transpose() -> vector<2x4xindex> { + %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex> + %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %2 = test.reflect_bounds %1 : vector<2x4xindex> + func.return %2 : vector<2x4xindex> +} + // CHECK-LABEL: func @vector_extract // CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} func.func @vector_extract() -> index { @@ -60,16 +69,6 @@ func.func @vector_extract() -> index { func.return %2 : index } -// CHECK-LABEL: func @vector_extractelement -// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} -func.func @vector_extractelement() -> index { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> - %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> - %2 = test.reflect_bounds %1 : index - func.return %2 : index -} - // CHECK-LABEL: func @vector_add // CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} func.func @vector_add() -> vector<4xindex> { @@ -90,17 +89,6 @@ func.func @vector_insert() -> vector<4xindex> { func.return %3 : vector<4xindex> } -// CHECK-LABEL: func @vector_insertelement -// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} -func.func @vector_insertelement() -> vector<4xindex> { - %c0 = arith.constant 0 : index - %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> - %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index - %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> - %3 = test.reflect_bounds %2 : vector<4xindex> - func.return %3 : vector<4xindex> -} - // CHECK-LABEL: func @test_loaded_vector_extract // No bounds // CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32 @@ -120,3 +108,11 @@ func.func @test_vector_extsi() -> vector<2xi32> { %2 = test.reflect_bounds %1 : vector<2xi32> func.return %2 : vector<2xi32> } + +// CHECK-LABEL: func @vector_step +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @vector_step() -> vector<8xindex> { + %0 = vector.step : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ca837d3..c21de56 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -119,30 +119,6 @@ func.func @shuffle_empty_mask(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { // ----- -func.func @extract_element(%arg0: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %1 = vector.extractelement %arg0[%c : i32] : vector<f32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %1 = vector.extractelement %arg0[] : vector<4xf32> -} - -// ----- - -func.func @extract_element(%arg0: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %1 = vector.extractelement %arg0[%c : i32] : vector<4x4xf32> -} - -// ----- - func.func @extract_vector_type(%arg0: index) { // expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'index'}} %1 = vector.extract %arg0[] : index from index @@ -192,38 +168,6 @@ func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @insert_element(%arg0: f32, %arg1: vector<f32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position to be empty with 0-D vector}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<f32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{expected position for 1-D vector}} - %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> -} - -// ----- - -func.func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{unexpected >1 vector rank}} - %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> -} - -// ----- - -func.func @insert_element_wrong_type(%arg0: i32, %arg1: vector<4xf32>) { - %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op failed to verify that source operand type matches element type of result}} - %0 = "vector.insertelement" (%arg0, %arg1, %c) : (i32, vector<4xf32>, i32) -> (vector<4xf32>) -} - -// ----- - func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 6a56116..625ffc1 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -199,22 +199,6 @@ func.func @shuffle_poison_mask(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<4 return %1 : vector<4xf32> } -// CHECK-LABEL: @extract_element_0d -func.func @extract_element_0d(%a: vector<f32>) -> f32 { - // CHECK-NEXT: vector.extractelement %{{.*}}[] : vector<f32> - %1 = vector.extractelement %a[] : vector<f32> - return %1 : f32 -} - -// CHECK-LABEL: @extract_element -func.func @extract_element(%a: vector<16xf32>) -> f32 { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.extractelement %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.extractelement %a[%c : i32] : vector<16xf32> - return %1 : f32 -} - // CHECK-LABEL: @extract_const_idx func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { @@ -256,22 +240,6 @@ func.func @extract_poison_idx(%a: vector<4x5xf32>) -> f32 { return %0 : f32 } -// CHECK-LABEL: @insert_element_0d -func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> { - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector<f32> - %1 = vector.insertelement %a, %b[] : vector<f32> - return %1 : vector<f32> -} - -// CHECK-LABEL: @insert_element -func.func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { - // CHECK: %[[C15:.*]] = arith.constant 15 : i32 - %c = arith.constant 15 : i32 - // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[%[[C15]] : i32] : vector<16xf32> - %1 = vector.insertelement %a, %b[%c : i32] : vector<16xf32> - return %1 : vector<16xf32> -} - // CHECK-LABEL: @insert_const_idx func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 8e167a5..d5e3443 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -12,7 +12,7 @@ func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -22,7 +22,7 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -87,7 +87,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3 // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -113,16 +113,16 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[U0:.*]] = ub.poison : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[U0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.broadcast %[[T8]] : f32 to vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : f32 from vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.broadcast %[[T12]] : f32 to vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 059d955..5a8125e 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -5,11 +5,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -26,12 +26,12 @@ func.func @outerproduct_noacc(%arg0: vector<2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<3xf32> from vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -49,11 +49,11 @@ func.func @outerproduct_acc(%arg0: vector<2xf32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : i32 to vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -69,13 +69,13 @@ func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : i32 from vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : i32 to vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : i32 from vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.broadcast %[[T6]] : i32 to vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<3xi32> from vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -91,7 +91,7 @@ func.func @outerproduct_acc_int(%arg0: vector<2xi32>, // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -103,7 +103,7 @@ func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : f32 to vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -114,7 +114,7 @@ func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32> // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -126,7 +126,7 @@ func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.broadcast %[[B]] : i32 to vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32> diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir index b826cdc..ef881ba 100644 --- a/mlir/test/Dialect/Vector/vector-sink.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -180,13 +180,14 @@ func.func @negative_not_elementwise() -> vector<2x2xf32> { // ----- -// The source and the result for arith.cmp have different types - not supported - -// CHECK-LABEL: func.func @negative_source_and_result_mismatch -// CHECK: %[[BROADCAST:.+]] = vector.broadcast -// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]] -// CHECK: return %[[RETURN]] -func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> { +// The source and the result for arith.cmp have different types + +// CHECK-LABEL: func.func @source_and_result_mismatch( +// CHECK-SAME: %[[ARG0:.+]]: f32) +// CHECK: %[[COMPARE:.+]] = arith.cmpf uno, %[[ARG0]], %[[ARG0]] +// CHECK: %[[BROADCAST:.+]] = vector.broadcast %[[COMPARE]] : i1 to vector<1xi1> +// CHECK: return %[[BROADCAST]] +func.func @source_and_result_mismatch(%arg0 : f32) -> vector<1xi1> { %0 = vector.broadcast %arg0 : f32 to vector<1xf32> %1 = arith.cmpf uno, %0, %0 : vector<1xf32> return %1 : vector<1xi1> @@ -210,6 +211,130 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { return %1 : vector<1xf32> } +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_const_first( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 2 : index +// CHECK: %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> + +func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<2> : vector<1x4xindex> + %2 = arith.subi %cst, %0 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32> +// CHECK: %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32> + %2 = arith.mulf %0, %cst : vector<3x4xf32> + return %2 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_broadcast_with_non_splat_const( +// CHECK-SAME: %[[ARG_0:.*]]: index) -> vector<1x4xindex> { +// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex> +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex> +// CHECK: return %[[ADD]] : vector<1x4xindex> + +func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg0 : index to vector<1x4xindex> + %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex> + %2 = arith.addi %0, %cst : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f16) -> vector<1x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : f16 to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_mixed_type(%arg0: f16) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f16 to vector<1x4xf16> + %1 = arith.extf %0 : vector<1x4xf16> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf16>) -> vector<3x4xf32> { +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG_0]] : vector<4xf16> to vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_mixed_type(%arg0: vector<4xf16>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<3x4xf16> + %1 = arith.extf %0 : vector<3x4xf16> to vector<3x4xf32> + return %1 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_scalar_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: f32) -> vector<1x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant 3 : i32 +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : f32, i32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : f32 to vector<1x4xf32> +// CHECK: return %[[BCAST]] : vector<1x4xf32> + +func.func @broadcast_scalar_and_splat_const_mixed_type(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x4xf32> + %cst = arith.constant dense<3> : vector<1x4xi32> + %2 = math.fpowi %0, %cst : vector<1x4xf32>, vector<1x4xi32> + return %2 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_splat_const_mixed_type( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[NEW_CST:.*]] = arith.constant dense<3> : vector<4xi32> +// CHECK: %[[POW:.*]] = math.fpowi %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>, vector<4xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[POW]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> + +func.func @broadcast_vector_and_splat_const_mixed_type(%arg0: vector<4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32> + %cst = arith.constant dense<3> : vector<3x4xi32> + %2 = math.fpowi %0, %cst : vector<3x4xf32>, vector<3x4xi32> + return %2 : vector<3x4xf32> +} + //===----------------------------------------------------------------------===// // [Pattern: ReorderCastOpsOnBroadcast] // diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 511ab70..1b54d54 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -284,19 +284,19 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref< %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index -// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask0 = vector.splat %m : vector<14x7xi1> %0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1> +// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1> %mask1 = vector.splat %m : vector<16x14xi1> %1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref<?x?x?x?xf32>, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask2 = vector.splat %m : vector<14x7xi1> %2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref<?x?x?x?xf32>, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref<?x?x?x?xf32>, vector<14x16x7xf32> @@ -336,7 +336,7 @@ func.func @transfer_write_permutations_tensor_masked( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + // CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1> %mask0 = vector.splat %m : vector<16x14x7x8xi1> %res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32> // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 0160bfe..dff3ffa 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -385,6 +385,74 @@ func.func @load_gather_vc_3(%src: ui64) { } // ----- +func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) { + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex> + return +} + +// ----- +func.func @load_gather_offset_sg(%src: memref<?xf16>) { + %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<8xi1> + // expected-error@+1 {{Mask should match value except the chunk size dim}} + %2 = xegpu.load %src[%offsets], %mask + : memref<?xf16>, vector<4xindex>, vector<8xi1> + -> vector<4x2xf16> + return +} + +// ----- +func.func @load_gather_offset_wi(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf32> + return +} + +// ----- +func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{value elements must match chunk size}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) { + %val = arith.constant dense<2.9>: vector<4xf16> + %offsets = arith.constant dense<[0]> : vector<1xindex> + %mask = arith.constant dense<1>: vector<1xi1> + // expected-error@+1 {{Expecting the dest is a 1D memref or pointer}} + xegpu.store %val, %src[%offsets], %mask + : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1> + return +} + +// ----- +func.func @load_gather_offset_wi_2(%src: ui64) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{value elements must match chunk size}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64, vector<1xindex>, vector<1xi1> -> vector<4xf16> + return +} + +// ----- +func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) { + %mask = arith.constant dense<1>: vector<1xi1> + %offsets = arith.constant dense<[0]> : vector<1xindex> + // expected-error@+1 {{Expecting the source is a 1D memref or pointer}} + %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>, vector<1xindex>, vector<1xi1> -> vector<2xf32> + return +} + +// ----- func.func @store_scatter_vc_1(%src: memref<24x32xf32>) { %0 = arith.constant dense<1>: vector<4xi1> %1 = arith.constant dense<2.9>: vector<4x2xf32> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 3ebb1b969a..6be2371 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -521,6 +521,16 @@ gpu.func @subgroup_load_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) { + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16> + gpu.return +} + // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) { gpu.func @subgroup_store(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -626,6 +636,17 @@ gpu.func @subgroup_store_4(%src: ui64) { gpu.return } +// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) { +gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) { + %val = arith.constant dense<2.9>: vector<4x2xf16> + %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %mask = arith.constant dense<1>: vector<4xi1> + //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}> + : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1> + gpu.return +} + // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) { gpu.func @prefetch(%src: ui64) { //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> @@ -637,6 +658,14 @@ gpu.func @prefetch(%src: ui64) { gpu.return } +// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) { +gpu.func @prefetch_offset(%src: ui64) { + //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> + // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex> + xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex> + gpu.return +} // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) { gpu.func @create_update_tdesc(%src: ui64) { diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir index d67bdb4..628a485 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -2,122 +2,117 @@ gpu.module @test_round_robin_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.load_nd %{{.*}} - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-SAME-COUNT-12: -> vector<2x2xf32> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.load_nd %{{.*}} + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME-COUNT-4: -> vector<16x16xf32> // CHECK-NOT: xegpu.load_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT : xegpu.store_nd %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @update_nd(%src: memref<24x32xf32>){ - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16] - // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @update_nd(%src: memref<256x128xf32>){ + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] + // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>> // CHECK-NOT: xegpu.update_nd_offset %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas - // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>) - gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) { - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>) + gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> - // CHECK-NOT: xegpu.create_nd_tdesc - // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32> - // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>> // CHECK-NOT: xegpu.create_nd_tdesc // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} - // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32> + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> // CHECK-NOT: xegpu.dpas - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16> + -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16> + -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> - -> vector<8x8xf32> - %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> - -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x256xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} - // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} + // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK-NOT: xegpu.prefetch_nd - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: broadcast - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32> + gpu.func @broadcast(%src: memref<128x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32> + -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK-COUNT-3: vector.broadcast {{.*}} - // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32> + : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<128x1xf32> + // CHECK-COUNT-2: vector.broadcast {{.*}} + // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32> // CHECK-NOT: vector.broadcast %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<128x1xf32> to vector<128x64xf32> gpu.return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index d511224..d4b0037 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -4,201 +4,181 @@ //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> gpu.module @test_1_1_assignment { // CHECK-LABEL: create_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) { // CHECK: %[[SGID:.*]] = gpu.subgroup_id - // CHECK: %[[C12:.*]] = arith.constant 12 : index - // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C32_0:.*]] = arith.constant 32 : index + // CHECK: %[[C4_1:.*]] = arith.constant 4 : index // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]] // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]] - // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]] - // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]] - // CHECK: %[[C24:.*]] = arith.constant 24 : index - // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]] + // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]] + // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]] // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]] - // CHECK: %[[C32:.*]] = arith.constant 32 : index - // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]] - // CHECK: %[[C0_1:.*]] = arith.constant 0 : index - // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]] - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK: %[[C256:.*]] = arith.constant 256 : index + // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]] + // CHECK: %[[C0_2:.*]] = arith.constant 0 : index + // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]] + // CHECK: %[[C0_3:.*]] = arith.constant 0 : index + // CHECK: %[[C128:.*]] = arith.constant 128 : index + // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]] + // CHECK: %[[C0_4:.*]] = arith.constant 0 : index + // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]] + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: gpu.return - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: load_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> gpu.return } // CHECK-LABEL: store_nd - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @store_nd(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @store_nd(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + // CHECK-SAME: -> vector<32x32xf32> // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] - // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<256x128xf32> xegpu.store_nd %load, %tdesc - : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: update_nd -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -gpu.func @update_nd(%src: memref<24x32xf32>){ - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> +// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> +gpu.func @update_nd(%src: memref<256x128xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %update = xegpu.update_nd_offset %tdesc, [0, 16] - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: dpas_no_sg_data -// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> -// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> -gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { - // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<12x8xf32> - // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] - // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> - // CHECK-SAME: -> vector<8x12xf32> - // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] - // CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> - %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> +gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { + // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> %load_a = xegpu.load_nd %tdesc_a - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>> - -> vector<24x32xf32> - %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> - -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], + order = [1, 0]>> + -> vector<128x128xf16> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16> + -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> %load_b = xegpu.load_nd %tdesc_b - : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>> - -> vector<32x24xf32> + : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], + order = [1, 0]>> + -> vector<128x128xf16> %dpas = xegpu.dpas %load_a, %load_b - {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} + : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32> gpu.return } // CHECK-LABEL: prefetch_nd_tdesc - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> - gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { - // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> - // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> // CHECK: xegpu.prefetch_nd %[[TDESC]] - // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> - -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> xegpu.prefetch_nd %tdesc - : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>> gpu.return } // CHECK-LABEL: dpas_with_no_create_nd_desc - gpu.func @dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) { - // CHECK-NOT: vector<12x12xf32> + gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) { + // CHECK-NOT: vector<32x32xf32> %dpas = xegpu.dpas %a, %b {layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} - : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32> gpu.return } // CHECK-LABEL: broadcast_dim1 - // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32> - gpu.func @broadcast_dim1(%src: memref<24x1xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32> - -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32> + gpu.func @broadcast_dim1(%src: memref<256x1xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32> + -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>> - -> vector<24x1xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>} - // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32> - %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>} - : vector<24x1xf32> to vector<24x8xf32> + : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>> + -> vector<256x1xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} + // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32> + %broadcast = vector.broadcast %load + {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>} + : vector<256x1xf32> to vector<256x32xf32> gpu.return } // CHECK-LABEL: broadcast_dim0 - // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32> - gpu.func @broadcast_dim0(%src: memref<1x32xf32>) { - %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32> - -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> + // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32> + gpu.func @broadcast_dim0(%src: memref<1x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32> + -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> %load = xegpu.load_nd %tdesc - : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>> - -> vector<1x32xf32> - // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>} - // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32> + : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>> + -> vector<1x128xf32> + // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} + // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32> %broadcast = vector.broadcast %load - {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>} - : vector<1x32xf32> to vector<12x32xf32> + {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>} + : vector<1x128xf32> to vector<32x128xf32> gpu.return } diff --git a/mlir/test/IR/diagnostic-nosplit.mlir b/mlir/test/IR/diagnostic-nosplit.mlir new file mode 100644 index 0000000..ecfb9c6 --- /dev/null +++ b/mlir/test/IR/diagnostic-nosplit.mlir @@ -0,0 +1,13 @@ +// RUN: not mlir-opt %s -o - --split-input-file 2>&1 | FileCheck %s +// This test verifies that diagnostic handler doesn't emit splits. + + +// ----- + + + +func.func @constant_out_of_range() { + // CHECK: mlir:11:8: error: 'arith.constant' + %x = "arith.constant"() {value = 100} : () -> i1 + return +} diff --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir index b571d94..5389691 100644 --- a/mlir/test/IR/top-level.mlir +++ b/mlir/test/IR/top-level.mlir @@ -6,10 +6,10 @@ func.func private @foo() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 2}} +// expected-error@-2 {{source must contain a single top-level operation, found: 2}} func.func private @bar() func.func private @baz() // ----- -// expected-error@-3 {{source must contain a single top-level operation, found: 0}} +// expected-error@-2 {{source must contain a single top-level operation, found: 0}} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir index 6e2a82b..6ec1031 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/0-d-vectors.mlir @@ -4,14 +4,14 @@ // RUN: FileCheck %s func.func @extract_element_0d(%a: vector<f32>) { - %1 = vector.extractelement %a[] : vector<f32> + %1 = vector.extract %a[] : f32 from vector<f32> // CHECK: 42 vector.print %1: f32 return } func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> (vector<f32>) { - %1 = vector.insertelement %a, %b[] : vector<f32> + %1 = vector.insert %a, %b[] : f32 into vector<f32> return %1: vector<f32> } @@ -58,9 +58,9 @@ func.func @broadcast_0d(%a: f32) { func.func @bitcast_0d() { %0 = arith.constant 42 : i32 %1 = arith.constant dense<0> : vector<i32> - %2 = vector.insertelement %0, %1[] : vector<i32> + %2 = vector.insert %0, %1[] : i32 into vector<i32> %3 = vector.bitcast %2 : vector<i32> to vector<f32> - %4 = vector.extractelement %3[] : vector<f32> + %4 = vector.extract %3[] : f32 from vector<f32> %5 = arith.bitcast %4 : f32 to i32 // CHECK: 42 vector.print %5: i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index b69a200..eb99886 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -72,7 +72,7 @@ func.func @za0_d_f64() -> i32 { %row = vector.load %mem2[%vnum, %c0] : memref<?x?xf64>, vector<[2]xf64> %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { - %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %t = vector.extract %row[%offset] : f64 from vector<[2]xf64> %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 scf.yield %inner_add_reduce_next : f64 } @@ -102,7 +102,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -125,7 +125,7 @@ func.func @za0_d_f64() -> i32 { %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t = vector.extract %cmp[%i] : i1 from vector<[2]xi1> %t_i64 = arith.extui %t : i1 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index 697fb90..ad8e321 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -36,7 +36,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 @@ -64,7 +64,7 @@ func.func @entry() -> i32 { %row = vector.load %za_b[%vnum, %c0] : memref<?x?xi8>, vector<[16]xi8> %inner_mul_reduce = scf.for %offset = %c0 to %svl_b step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { - %t = vector.extractelement %row[%offset : index] : vector<[16]xi8> + %t = vector.extract %row[%offset] : i8 from vector<[16]xi8> %t_i64 = arith.extui %t : i8 to i64 %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 scf.yield %inner_mul_reduce_next : i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir index 53a7282..aff272c2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot.mlir @@ -11,8 +11,8 @@ func.func @entry() -> i32 { %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %r = x86vector.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extractelement %r[%i0 : i32]: vector<8xf32> - %2 = vector.extractelement %r[%i4 : i32]: vector<8xf32> + %1 = vector.extract %r[%i0] : f32 from vector<8xf32> + %2 = vector.extract %r[%i4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir index bf1caaa..1c56990 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/sparse-dot-product.mlir @@ -196,13 +196,13 @@ func.func @memref_dot_optimized(%m_A : memref<?xi64>, %m_B : memref<?xf64>, iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { %v_A = vector.transfer_read %m_A[%a], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { %v_C = vector.transfer_read %m_C[%b], %index_padding : memref<?xi64>, vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { @@ -273,10 +273,10 @@ func.func @memref_dot_while(%m_A : memref<?xi64>, %m_B : memref<?xf64>, %v_C = vector.transfer_read %m_C[%b1], %index_padding : memref<?xi64>, vector<8xi64> - %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { @@ -370,8 +370,8 @@ func.func @memref_dot_while_branchless(%m_A : memref<?xi64>, %m_B : memref<?xf64 -> f64 %r2 = arith.addf %r1, %subresult : f64 - %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> - %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 %cond_a_i64 = arith.extui %cond_a : i1 to i64 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir index e9a66cc..1683fa5 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/compress.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -49,7 +48,7 @@ func.func @entry() { memref.store %z, %A[%i] : memref<?xf32> %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir index 2dc00df..826da53 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir @@ -28,8 +28,7 @@ func.func @printmem16(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c16 step %c1 iter_args(%m_iter = %m) -> (vector<16xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<16xf32> scf.yield %m_new : vector<16xf32> } vector.print %mem : vector<16xf32> @@ -53,7 +52,7 @@ func.func @entry() { iter_args(%v_iter = %v) -> (vector<16xf32>) { %i32 = arith.index_cast %i : index to i32 %fi = arith.sitofp %i32 : i32 to f32 - %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32> + %v_new = vector.insert %fi, %v_iter[%i] : f32 into vector<16xf32> scf.yield %v_new : vector<16xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir index 54b6e69..22b5eef 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/scatter.mlir @@ -21,8 +21,7 @@ func.func @printmem8(%A: memref<?xf32>) { %mem = scf.for %i = %c0 to %c8 step %c1 iter_args(%m_iter = %m) -> (vector<8xf32>) { %c = memref.load %A[%i] : memref<?xf32> - %i32 = arith.index_cast %i : index to i32 - %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<8xf32> + %m_new = vector.insert %c, %m_iter[%i] : f32 into vector<8xf32> scf.yield %m_new : vector<8xf32> } vector.print %mem : vector<8xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir index 2393bd1..639eed4 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/transfer-read-1d.mlir @@ -200,7 +200,7 @@ func.func @entry() { // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. + // Generates a loop with vector.insert. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> () // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) diff --git a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir index e665653..731bd5a 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-interleave.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %interleave = vector.interleave %lhs1, %rhs1 : vector<2xi32> -> vector<4xi32> - %res0 = vector.extractelement %interleave[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %interleave[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %interleave[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %interleave[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %interleave[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %interleave[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %interleave[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %interleave[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir index dc53fe3..c1b7dba 100644 --- a/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir +++ b/mlir/test/Integration/GPU/Vulkan/vector-shuffle.mlir @@ -26,17 +26,17 @@ module attributes { %val2 = memref.load %arg1[%idx0] : memref<2xi32> %val3 = memref.load %arg1[%idx1] : memref<2xi32> - %lhs0 = vector.insertelement %val0, %lhs[%idx0 : index] : vector<2xi32> - %lhs1 = vector.insertelement %val1, %lhs0[%idx1 : index] : vector<2xi32> - %rhs0 = vector.insertelement %val2, %rhs[%idx0 : index] : vector<2xi32> - %rhs1 = vector.insertelement %val3, %rhs0[%idx1 : index] : vector<2xi32> + %lhs0 = vector.insert %val0, %lhs[%idx0] : i32 into vector<2xi32> + %lhs1 = vector.insert %val1, %lhs0[%idx1] : i32 into vector<2xi32> + %rhs0 = vector.insert %val2, %rhs[%idx0] : i32 into vector<2xi32> + %rhs1 = vector.insert %val3, %rhs0[%idx1] : i32 into vector<2xi32> %shuffle = vector.shuffle %lhs1, %rhs1[2, 1, 3, 3] : vector<2xi32>, vector<2xi32> - %res0 = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32> - %res1 = vector.extractelement %shuffle[%idx1 : index] : vector<4xi32> - %res2 = vector.extractelement %shuffle[%idx2 : index] : vector<4xi32> - %res3 = vector.extractelement %shuffle[%idx3 : index] : vector<4xi32> + %res0 = vector.extract %shuffle[%idx0] : i32 from vector<4xi32> + %res1 = vector.extract %shuffle[%idx1] : i32 from vector<4xi32> + %res2 = vector.extract %shuffle[%idx2] : i32 from vector<4xi32> + %res3 = vector.extract %shuffle[%idx3] : i32 from vector<4xi32> memref.store %res0, %arg2[%idx0]: memref<4xi32> memref.store %res1, %arg2[%idx1]: memref<4xi32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index e48e5c6..7888462 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -595,16 +595,17 @@ module attributes {transform.with_named_sequence} { // ----- -// It is valid to fuse the pack op with padding semantics if the tiled -// dimensions do not need padding. +// It is valid to fuse the pack op with padding semantics if it is a perfect +// tiling case. func.func @fuse_pack_consumer_with_padding_semantics(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<22x2x3x16xf32> { - %0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) { - %src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32> - %2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32> + %0 = scf.forall (%arg2, %arg3) = (0, 0) to (64, 32) step (15, 16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) { + %size = affine.min affine_map<(d0) -> (-d0 + 64, 15)>(%arg2) + %src = tensor.extract_slice %arg0[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %dest = tensor.extract_slice %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<64x32xf32> to tensor<?x16xf32> + %2 = linalg.exp ins(%src : tensor<?x16xf32>) outs(%dest : tensor<?x16xf32>) -> tensor<?x16xf32> scf.forall.in_parallel { - tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32> + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [%size, 16] [1, 1] : tensor<?x16xf32> into tensor<64x32xf32> } } %1 = tensor.empty() : tensor<22x2x3x16xf32> @@ -621,28 +622,39 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 64, 15)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0) -> (d0 ceildiv 3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0) -> (d0 floordiv 16)> // CHECK: func.func @fuse_pack_consumer_with_padding_semantics( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<22x2x3x16xf32> // CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16) -// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) -// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] +// CHECK: %{{.*}}:2 = scf.forall (%[[I:.*]], %[[J:.*]]) = (0, 0) to (64, 32) step (15, 16) +// CHECK-SAME: shared_outs(%[[ELEM_OUT:.*]] = %[[ARG1]], %[[PACK_OUT:.*]] = %[[OUT_INIT]]) +// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]]) +// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] // CHECK: %[[ELEM:.*]] = linalg.exp // CHECK-SAME: ins(%[[ELEM_SRC]] // CHECK-SAME: outs(%[[ELEM_DEST]] -// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]]) -// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] -// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]] +// CHECK-DAG: %[[D0_OFFSET:.*]] = affine.apply #[[MAP1]](%[[I]]) +// CHECK-DAG: %[[D0_SIZE:.*]] = affine.apply #[[MAP2]](%[[SIZE]]) +// CHECK-DAG: %[[D1_OFFSET:.*]] = affine.apply #[[MAP3]](%[[J]]) +// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.extract_slice %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] +// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]] // CHECK-SAME: padding_value(%[[PAD_VAL]] : f32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16] // CHECK-SAME: into %[[TILED_PACK_DEST]] // CHECK: scf.forall.in_parallel { -// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1] -// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [22, 1, 3, 16] [1, 1, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[SIZE]], 16] [1, 1] +// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT]] +// CHECK-SAME: [%[[D0_OFFSET]], %[[D1_OFFSET]], 0, 0] [%[[D0_SIZE]], 1, 3, 16] [1, 1, 1, 1] // ----- diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 24380b5..a419d75 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -570,10 +570,10 @@ define void @trap_intrinsics() { ; CHECK-LABEL: llvm.func @memcpy_test define void @memcpy_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memcpy.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () - call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memcpy.p0.p0.i32(ptr align 4 %1, ptr align 8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 4 : i64}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + call void @llvm.memcpy.inline.p0.p0.i64(ptr %1, ptr align 4 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () call void @llvm.memcpy.inline.p0.p0.i32(ptr %1, ptr %2, i32 10, i1 false) ret void @@ -581,17 +581,17 @@ define void @memcpy_test(i32 %0, ptr %1, ptr %2) { ; CHECK-LABEL: llvm.func @memmove_test define void @memmove_test(i32 %0, ptr %1, ptr %2) { - ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - call void @llvm.memmove.p0.p0.i32(ptr %1, ptr %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memmove"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 16 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + call void @llvm.memmove.p0.p0.i32(ptr align 16 %1, ptr %2, i32 %0, i1 false) ret void } ; CHECK-LABEL: llvm.func @memset_test define void @memset_test(i32 %0, ptr %1, i8 %2) { - ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - call void @llvm.memset.p0.i32(ptr %1, i8 %2, i32 %0, i1 false) - ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () - call void @llvm.memset.inline.p0.i64(ptr %1, i8 %2, i64 10, i1 false) + ; CHECK: "llvm.intr.memset"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 2 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + call void @llvm.memset.p0.i32(ptr align 2 %1, i8 %2, i32 %0, i1 false) + ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = false, len = 10 : i64}> : (!llvm.ptr, i8) -> () + call void @llvm.memset.inline.p0.i64(ptr align 4 %1, i8 %2, i64 10, i1 false) ; CHECK: "llvm.intr.memset.inline"(%{{.*}}, %{{.*}}) <{isVolatile = false, len = 10 : i32}> : (!llvm.ptr, i8) -> () call void @llvm.memset.inline.p0.i32(ptr %1, i8 %2, i32 10, i1 false) ret void diff --git a/mlir/test/Target/LLVMIR/Import/module-asm.ll b/mlir/test/Target/LLVMIR/Import/module-asm.ll new file mode 100644 index 0000000..38f6ea4 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/module-asm.ll @@ -0,0 +1,5 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s +; CHECK: llvm.module_asm = ["foo", "bar"] + +module asm "foo" +module asm "bar" diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir index 7fd5f26..5ed6244 100644 --- a/mlir/test/Target/LLVMIR/invalid-module.mlir +++ b/mlir/test/Target/LLVMIR/invalid-module.mlir @@ -1,6 +1,16 @@ -// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s +// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module -split-input-file %s // expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}} llvm.func @foo() { llvm.return } + +// ----- + +// expected-error@below {{expected an array attribute for a module level asm}} +module attributes {llvm.module_asm = "foo"} {} + +// ----- + +// expected-error@below {{expected a string attribute for each entry of a module level asm}} +module attributes {llvm.module_asm = [42]} {} diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 44074ce..eb3510c 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -601,29 +601,33 @@ llvm.func @trap_intrinsics() { // CHECK-LABEL: @memcpy_test llvm.func @memcpy_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () - // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 10, i1 true - "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memcpy.inline.p0.p0.i32(ptr align 4 %{{.*}}, ptr %{{.*}}, i32 10, i1 true + "llvm.intr.memcpy.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 4 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, !llvm.ptr) -> () // CHECK: call void @llvm.memcpy.inline.p0.p0.i64(ptr %{{.*}}, ptr %{{.*}}, i64 10, i1 true "llvm.intr.memcpy.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, !llvm.ptr) -> () + + // Verify that trailing empty argument attribute dictionaries can be omitted. + // CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memcpy"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memmove_test llvm.func @memmove_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: !llvm.ptr) { - // CHECK: call void @llvm.memmove.p0.p0.i32(ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () + // CHECK: call void @llvm.memmove.p0.p0.i32(ptr align 4 %{{.*}}, ptr align 8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memmove"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 4 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> () llvm.return } // CHECK-LABEL: @memset_test llvm.func @memset_test(%arg0: i32, %arg2: !llvm.ptr, %arg3: i8) { %i1 = llvm.mlir.constant(false) : i1 - // CHECK: call void @llvm.memset.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false - "llvm.intr.memset"(%arg2, %arg3, %arg0) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () - // CHECK: call void @llvm.memset.inline.p0.i32(ptr %{{.*}}, i8 %{{.*}}, i32 10, i1 true - "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () + // CHECK: call void @llvm.memset.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 %{{.*}}, i1 false + "llvm.intr.memset"(%arg2, %arg3, %arg0) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}], isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK: call void @llvm.memset.inline.p0.i32(ptr align 8 %{{.*}}, i8 %{{.*}}, i32 10, i1 true + "llvm.intr.memset.inline"(%arg2, %arg3) <{arg_attrs = [{llvm.align = 8 : i64}, {}], isVolatile = true, len = 10 : i32}> : (!llvm.ptr, i8) -> () // CHECK: call void @llvm.memset.inline.p0.i64(ptr %{{.*}}, i8 %{{.*}}, i64 10, i1 true "llvm.intr.memset.inline"(%arg2, %arg3) <{isVolatile = true, len = 10 : i64}> : (!llvm.ptr, i8) -> () llvm.return diff --git a/mlir/test/Target/LLVMIR/module-asm.mlir b/mlir/test/Target/LLVMIR/module-asm.mlir new file mode 100644 index 0000000..2afb37c --- /dev/null +++ b/mlir/test/Target/LLVMIR/module-asm.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +module attributes {llvm.module_asm = ["foo", "bar"]} {} + +// CHECK: module asm "foo" +// CHECK: module asm "bar" diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8c4f0aa..85478cc 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -312,3 +312,42 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> llvm.return } + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected num attribute to be 1, 2 or 4}} + nvvm.stmatrix %arg0, %r1, %r2, %r3 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected shape to be 8x8 or 16x8}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B16 for 8x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + llvm.return +} +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected element type to be B8 for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + llvm.return +} + +// ----- + +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // expected-error@+1 {{'nvvm.stmatrix' op expected layout to be col for 16x8 matrix}} + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a041..5c2cfa4 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @st_matrix +llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) { + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b16>} : !llvm.ptr<3>, i32, i32, i32, i32 + // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 8>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : !llvm.ptr<3>, i32, i32, i32, i32 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} { diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir new file mode 100644 index 0000000..a3dd0b6 --- /dev/null +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-translate --split-input-file -mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) + llvm.func @prefetch(%arg0: !llvm.ptr<1>) { + %0 = llvm.mlir.constant(1 : i64) : i64 + // CHECK-LABEL: call spir_func void @_Z8prefetchPU3AS1Kcm + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%arg0, %0) + {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, + no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64, + xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} + : (!llvm.ptr<1>, i64) -> () + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 6aca11e..c81ceac 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -307,6 +307,48 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> } + // CHECK-LABEL: @arm_tensor_of_i32 + spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_i32 + spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @arm_tensor_of_f32 + spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + // CHECK-LABEL: @splat_arm_tensor_of_f32 + spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + + // CHECK-LABEL: @null_arm_tensor_of_i32 + spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @null_arm_tensor_of_f32 + spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + spirv.EntryPoint "GLCompute" @bool_const } @@ -363,6 +405,20 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_i32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_i32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x i32>>> + } + + // CHECK-LABEL: @null_cc_arm_tensor_of_i32 + spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + // CHECK-LABEL: @splat_vector_f32 spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" { // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32> @@ -411,4 +467,18 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos %0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } + + // CHECK-LABEL: @splat_array_of_non_splat_array_of_arrays_of_f32 + spirv.func @splat_array_of_non_splat_array_of_arrays_of_f32() -> !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> "None" { + // CHECK: spirv.EXT.ConstantCompositeReplicate {{\[}}{{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32], [4.000000e+00 : f32, 5.000000e+00 : f32, 6.000000e+00 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + %0 = spirv.EXT.ConstantCompositeReplicate [[[1.0 : f32, 2.0 : f32, 3.0 : f32], [4.0 : f32, 5.0 : f32, 6.0 : f32]]] : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<2 x !spirv.array<3 x f32>>> + } + + // CHECK-LABEL: @null_cc_arm_tensor_of_f32 + spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } } diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir index 6d2fd32..53cf8bf 100644 --- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -33,6 +33,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL] // ----- //===----------------------------------------------------------------------===// +// spirv.INTEL.RoundFToTF32 +//===----------------------------------------------------------------------===// + +spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> { + // CHECK-LABEL: @f32_to_tf32 + spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return + } + + // CHECK-LABEL: @f32_to_tf32_vec + spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return + } +} + +// ----- + +//===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index b200871..05cbddc 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { %15 = spirv.IsNan %arg0 : f32 // CHECK: spirv.IsInf %16 = spirv.IsInf %arg1 : f32 + // CHECK: spirv.IsFinite + %17 = spirv.IsFinite %arg0 : f32 spirv.Return } } diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir index 6b50c39..786d07a2 100644 --- a/mlir/test/Target/SPIRV/memory-ops.mlir +++ b/mlir/test/Target/SPIRV/memory-ops.mlir @@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // ----- spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { - spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : f32 spirv.Return } - spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> %2 = spirv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr<i32, StorageBuffer> spirv.Store "StorageBuffer" %4, %2 : i32 spirv.Return } diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir index 0db0c0b..4984ee7 100644 --- a/mlir/test/Target/SPIRV/struct.mlir +++ b/mlir/test/Target/SPIRV/struct.mlir @@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr<!spirv.struct<(f32 [0], !spirv.struct<(f32 [0], !spirv.array<16 x f32, stride=4> [4])> [4])>, Input> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> - spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> + spirv.GlobalVariable @var2 : !spirv.ptr<!spirv.struct<(f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> - spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> + spirv.GlobalVariable @var3 : !spirv.ptr<!spirv.struct<(!spirv.array<128 x !spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, stride=512> [0]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> - spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> + spirv.GlobalVariable @var4 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var5 : !spirv.ptr<!spirv.struct<(f32 [NonWritable], i32 [NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> - spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> + spirv.GlobalVariable @var6 : !spirv.ptr<!spirv.struct<(f32 [0, NonWritable], i32 [4, NonWritable, NonReadable]), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> + spirv.GlobalVariable @var7 : !spirv.ptr<!spirv.struct<(!spirv.matrix<3 x vector<3xf32>> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> // CHECK: !spirv.ptr<!spirv.struct<()>, StorageBuffer> spirv.GlobalVariable @empty : !spirv.ptr<!spirv.struct<()>, StorageBuffer> @@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> spirv.GlobalVariable @id_var0 : !spirv.ptr<!spirv.struct<test_id, (!spirv.array<128 x f32, stride=4> [0])>, Input> + // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> + spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>), Block>, StorageBuffer> - // CHECK: !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> - spirv.GlobalVariable @recursive_simple : !spirv.ptr<!spirv.struct<rec, (!spirv.ptr<!spirv.struct<rec>, StorageBuffer>)>, StorageBuffer> + // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_2 : !spirv.ptr<!spirv.struct<a, (!spirv.ptr<!spirv.struct<b, (!spirv.ptr<!spirv.struct<a>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_3 : !spirv.ptr<!spirv.struct<axx, (!spirv.ptr<!spirv.struct<bxx, (!spirv.ptr<!spirv.struct<axx>, Uniform>, !spirv.ptr<!spirv.struct<bxx>, Uniform>)>, Uniform>)>, Uniform> + // CHECK: spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> + spirv.GlobalVariable @block : !spirv.ptr<!spirv.struct<vert, (vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> // CHECK: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Input>, // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<128 x f32, stride=4> [0])>, Output> diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir index b9044fe..8889b80 100644 --- a/mlir/test/Target/SPIRV/undef.mlir +++ b/mlir/test/Target/SPIRV/undef.mlir @@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> { // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>> %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>> %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>> - // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> - %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer> + // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> + %7 = spirv.Undef : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer> %8 = spirv.Constant 0 : i32 - %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32)>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> + %9 = spirv.AccessChain %7[%8] : !spirv.ptr<!spirv.struct<(f32), Block>, StorageBuffer>, i32 -> !spirv.ptr<f32, StorageBuffer> spirv.Return } } diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index 226e0bb..2ee3222 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestOneShotModuleBufferize.cpp TestTensorCopyInsertion.cpp TestTensorLikeAndBufferLike.cpp diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp new file mode 100644 index 0000000..1e2d4a7 --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp @@ -0,0 +1,57 @@ +//===- TestOneShotModuleBufferzation.cpp - Bufferization Test -----*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneShotModuleBufferizePass + : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass) + + TestOneShotModuleBufferizePass() = default; + TestOneShotModuleBufferizePass(const TestOneShotModuleBufferizePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<bufferization::BufferizationDialect>(); + } + StringRef getArgument() const final { + return "test-one-shot-module-bufferize"; + } + StringRef getDescription() const final { + return "Pass to test One Shot Module Bufferization"; + } + + void runOnOperation() override { + + llvm::errs() << "Running TestOneShotModuleBufferize on: " + << getOperation()->getName() << "\n"; + bufferization::OneShotBufferizationOptions opt; + + opt.bufferizeFunctionBoundaries = true; + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt, + bufferizationState))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestOneShotModuleBufferizePass() { + PassRegistration<TestOneShotModuleBufferizePass>(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 382da59..5685004 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> { let mnemonic = "copy_count"; let parameters = (ins TestParamCopyCount:$copy_count); let assemblyFormat = "`<` $copy_count `>`"; + let genVerifyDecl = 1; } def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> { diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index b31e90f..5890913 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -214,6 +214,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { } //===----------------------------------------------------------------------===// +// TestCopyCountAttr Implementation +//===----------------------------------------------------------------------===// + +LogicalResult TestCopyCountAttr::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/, + CopyCount /*copy_count*/) { + return success(); +} + +//===----------------------------------------------------------------------===// // CopyCountAttr Implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index f79e2cf..53055fe 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -18,6 +18,32 @@ using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// +// OverridenSymbolVisibilityOp +//===----------------------------------------------------------------------===// + +SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() { + return SymbolTable::Visibility::Private; +} + +static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) { + switch (visibility) { + case SymbolTable::Visibility::Private: + return "private"; + case SymbolTable::Visibility::Nested: + return "nested"; + case SymbolTable::Visibility::Public: + return "public"; + } +} + +void OverriddenSymbolVisibilityOp::setVisibility( + SymbolTable::Visibility visibility) { + + emitOpError("cannot change visibility of symbol to ") + << getVisibilityString(visibility); +} + +//===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a7c6cd6..2eaad55 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -119,12 +119,28 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> { OptionalAttr<StrAttr>:$sym_visibility); } +def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [ + DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>, +]> { + let summary = "operation overridden symbol visibility accessors"; + let arguments = (ins StrAttr:$sym_name); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; let regions = (region SizedRegion<1>:$region); } +def SymbolScopeIsolatedOp + : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable, + SingleBlockImplicitTerminator< + "TerminatorOp">]> { + let summary = + "operation which defines a new symbol table that is IsolatedFromAbove"; + let regions = (region SizedRegion<1>:$region); +} + def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index d47411d..a809611 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate<CompoundAAttrStorage>()) // DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); +// DEF: CompoundAAttr CompoundAAttr::getChecked( +// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner +// DEF-SAME: ) +// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); + // DEF: ::mlir::Type CompoundAAttr::getInner() const { // DEF-NEXT: return getImpl()->inner; } diff --git a/mlir/test/mlir-tblgen/op-properties-predicates.td b/mlir/test/mlir-tblgen/op-properties-predicates.td index 7cd24aa..af09ee7 100644 --- a/mlir/test/mlir-tblgen/op-properties-predicates.td +++ b/mlir/test/mlir-tblgen/op-properties-predicates.td @@ -70,6 +70,12 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> { // CHECK-NEXT: if (!(((!prop.has_value())) || ((::llvm::all_of((*(prop)), [](const int64_t& baseStore) -> bool { return [](int64_t baseIface) -> bool { return ((baseIface >= 0)); }(baseStore); })) && (!(((*(prop)).empty())))))) // CHECK: failed to satisfy constraint: optional non-empty array of non-negative int64_ +// CHECK-LABEL: ::llvm::LogicalResult OpWithPredicatesAdaptor::verify +// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below +// CHECK: int64_t tblgen_scalar = this->getScalar(); +// CHECK: if (!((tblgen_scalar >= 0))) +// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t"); + // CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl() // Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused // CHECK: [[maybe_unused:\[\[maybe_unused\]\]]] int64_t tblgen_scalar = this->getScalar(); diff --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir index 4b7ddf4..d880f9b 100644 --- a/mlir/test/mlir-translate/emitc_classops.mlir +++ b/mlir/test/mlir-translate/emitc_classops.mlir @@ -14,15 +14,12 @@ emitc.class @modelClass { // CHECK-LABEL: class modelClass { // CHECK-NEXT: public: -// CHECK-NEXT: float[1] fieldName0; -// CHECK-NEXT: float[1] fieldName1; +// CHECK-NEXT: float fieldName0[1]; +// CHECK-NEXT: float fieldName1[1]; // CHECK-NEXT: void execute() { // CHECK-NEXT: size_t v1 = 0; -// CHECK-NEXT: float[1] v2 = fieldName0; -// CHECK-NEXT: float[1] v3 = fieldName1; // CHECK-NEXT: return; // CHECK-NEXT: } -// CHECK-EMPTY: // CHECK-NEXT: }; emitc.class final @finalClass { @@ -39,13 +36,43 @@ emitc.class final @finalClass { // CHECK-LABEL: class finalClass final { // CHECK-NEXT: public: -// CHECK-NEXT: float[1] fieldName0; -// CHECK-NEXT: float[1] fieldName1; +// CHECK-NEXT: float fieldName0[1]; +// CHECK-NEXT: float fieldName1[1]; // CHECK-NEXT: void execute() { // CHECK-NEXT: size_t v1 = 0; -// CHECK-NEXT: float[1] v2 = fieldName0; -// CHECK-NEXT: float[1] v3 = fieldName1; // CHECK-NEXT: return; // CHECK-NEXT: } -// CHECK-EMPTY: // CHECK-NEXT: }; + +emitc.class @mainClass { + emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}} + emitc.func @get_fieldName0() { + %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32> + return + } +} + +// CHECK-LABEL: class mainClass { +// CHECK-NEXT: public: +// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f}; +// CHECK-NEXT: void get_fieldName0() { +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: }; + +emitc.class @reflectionClass { + emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }"> + emitc.func @get_reflectionMap() { + %0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> + return + } +} + +// CHECK-LABEL: class reflectionClass { +// CHECK-NEXT: public: +// CHECK-NEXT: const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } }; +// CHECK-NEXT: void get_reflectionMap() { +// CHECK-NEXT: return; +// CHECK-NEXT: } +// CHECK-NEXT: }; + diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt index 6932e0f..0518620 100644 --- a/mlir/tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt @@ -2,8 +2,6 @@ set(LLVM_OPTIONAL_SOURCES null.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) set(LLVM_LINK_COMPONENTS Core Support @@ -35,22 +33,11 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${conversion_libs} - ${dialect_libs} - ${extension_libs} - - MLIRAffineAnalysis - MLIRAnalysis - MLIRDialect - MLIRFuncAllExtensions MLIRLspServerLib - MLIRParser - MLIRPass - MLIRTensorAllExtensions - MLIRTransforms - MLIRTransformUtils - MLIRSupport - MLIRIR + + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) add_mlir_tool(mlir-lsp-server diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index 6a759d9..10d602f 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 6958fe3..7cc6e78 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -2,9 +2,6 @@ set(LLVM_OPTIONAL_SOURCES null.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(LLVM_LINK_COMPONENTS Core Support @@ -65,21 +62,11 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${dialect_libs} - ${conversion_libs} - ${extension_libs} - MLIRAffineAnalysis - MLIRAnalysis - MLIRCastInterfaces - MLIRDialect MLIROptLib - MLIRParser - MLIRPass - MLIRTransforms - MLIRTransformUtils - MLIRSupport - MLIRIR + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses # TODO: Remove when registerAllGPUToLLVMIRTranslations is no longer # registered directly in mlir-opt.cpp. diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 2c09753..14714c45 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -135,6 +135,7 @@ void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); +void registerTestOneShotModuleBufferizePass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); @@ -281,6 +282,7 @@ void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); + mlir::test::registerTestOneShotModuleBufferizePass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion(); diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp index 88a5f36..f99dcdb 100644 --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -201,6 +201,12 @@ int main(int argc, char **argv) { llvm::raw_string_ostream outputStrOS(outputStr); auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, raw_ostream &os) { + // Split does not guarantee null-termination. Make a copy of the buffer to + // ensure null-termination. + if (!chunkBuffer->getBuffer().ends_with('\0')) { + chunkBuffer = llvm::MemoryBuffer::getMemBufferCopy( + chunkBuffer->getBuffer(), chunkBuffer->getBufferIdentifier()); + } return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs, dumpODS, includedFiles); }; diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt index 1826397..1668bba 100644 --- a/mlir/tools/mlir-query/CMakeLists.txt +++ b/mlir/tools/mlir-query/CMakeLists.txt @@ -1,5 +1,3 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) - if(MLIR_INCLUDE_TESTS) set(test_libs MLIRTestDialect @@ -12,8 +10,8 @@ add_mlir_tool(mlir-query llvm_update_compile_flags(mlir-query) mlir_target_link_libraries(mlir-query PRIVATE - ${dialect_libs} MLIRQueryLib + MLIRRegisterAllDialects ) target_link_libraries(mlir-query PRIVATE ${test_libs}) diff --git a/mlir/tools/mlir-reduce/CMakeLists.txt b/mlir/tools/mlir-reduce/CMakeLists.txt index d71ac86..349d75b 100644 --- a/mlir/tools/mlir-reduce/CMakeLists.txt +++ b/mlir/tools/mlir-reduce/CMakeLists.txt @@ -1,6 +1,3 @@ -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) - if(MLIR_INCLUDE_TESTS) set(test_libs MLIRTestDialect @@ -8,12 +5,9 @@ if(MLIR_INCLUDE_TESTS) endif() set(LIBS - ${conversion_libs} - ${dialect_libs} - MLIRDialect - MLIRIR - MLIRPass MLIRReduceLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses ) add_mlir_tool(mlir-reduce diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt index 216491e..4120b175 100644 --- a/mlir/tools/mlir-rewrite/CMakeLists.txt +++ b/mlir/tools/mlir-rewrite/CMakeLists.txt @@ -1,21 +1,19 @@ -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) set(LLVM_LINK_COMPONENTS Support ) set(LIBS - ${dialect_libs} - MLIRAffineAnalysis MLIRAnalysis MLIRCastInterfaces MLIRDialect + MLIRIR MLIRParser MLIRPass - MLIRTransforms - MLIRTransformUtils + MLIRRegisterAllDialects MLIRSupport - MLIRIR + MLIRTransformUtils + MLIRTransforms ) include_directories(../../../clang/include) diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp index 87df9e1..fd8ae7e 100644 --- a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp +++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp @@ -24,6 +24,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/LineIterator.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index dbae2143..3140f12 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -495,7 +495,7 @@ void DefGen::emitCheckedBuilder() { MethodBody &body = m->body().indent(); auto scope = body.scope("return Base::getChecked(emitError, context", ");"); for (const auto ¶m : params) - body << ", " << param.getName(); + body << ", std::move(" << param.getName() << ")"; } static SmallVector<MethodParameter> diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index f35cfa6..8ea4eb7 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1127,7 +1127,7 @@ static void genPropertyVerifier( body << formatv(fetchProperty, varName, getterName, prop.prop.getInterfaceType()); auto uniquedFn = staticVerifierEmitter.getPropConstraintFn(prop.prop); - if (uniquedFn.has_value()) + if (uniquedFn.has_value() && emitHelper.isEmittingForOp()) body << formatv(verifyPropertyUniqued, *uniquedFn, varName, prop.name); else body << formatv( @@ -4764,6 +4764,7 @@ void OpOperandAdaptorEmitter::addVerification() { FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); + genPropertyVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter, useProperties); diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt index 4ef69a8..b83163e 100644 --- a/mlir/unittests/ExecutionEngine/CMakeLists.txt +++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt @@ -10,14 +10,13 @@ add_mlir_unittest(MLIRExecutionEngineTests StridedMemRef.cpp Invoke.cpp ) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) mlir_target_link_libraries(MLIRExecutionEngineTests PRIVATE MLIRArithToLLVM MLIRMemRefToLLVM MLIRReconcileUnrealizedCasts - ${dialect_libs} + MLIRRegisterAllDialects ) target_link_libraries(MLIRExecutionEngineTests PRIVATE diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index a55592d..fd40404 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -477,8 +477,9 @@ TEST(SubElementTest, Nested) { {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); } -// Test how many times we call copy-ctor when building an attribute. -TEST(CopyCountAttr, CopyCount) { +// Test how many times we call copy-ctor when building an attribute with the +// 'get' method. +TEST(CopyCountAttr, CopyCountGet) { MLIRContext context; context.loadDialect<test::TestDialect>(); @@ -489,15 +490,35 @@ TEST(CopyCountAttr, CopyCount) { test::CopyCount::counter = 0; test::TestCopyCountAttr::get(&context, std::move(copyCount)); #ifndef NDEBUG - // One verification enabled only in assert-mode requires a copy. - EXPECT_EQ(counter1, 1); - EXPECT_EQ(test::CopyCount::counter, 1); + // One verification enabled only in assert-mode requires two copies: one for + // calling 'verifyInvariants' and one for calling 'verify' inside + // 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); #else EXPECT_EQ(counter1, 0); EXPECT_EQ(test::CopyCount::counter, 0); #endif } +// Test how many times we call copy-ctor when building an attribute with the +// 'getChecked' method. +TEST(CopyCountAttr, CopyCountGetChecked) { + MLIRContext context; + context.loadDialect<test::TestDialect>(); + test::CopyCount::counter = 0; + test::CopyCount copyCount("hello"); + auto loc = UnknownLoc::get(&context); + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + int counter1 = test::CopyCount::counter; + test::CopyCount::counter = 0; + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + // The verifiers require two copies: one for calling 'verifyInvariants' and + // one for calling 'verify' inside 'verifyInvariants'. + EXPECT_EQ(counter1, 2); + EXPECT_EQ(test::CopyCount::counter, 2); +} + // Test stripped printing using test dialect attribute. TEST(CopyCountAttr, PrintStripped) { MLIRContext context; diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp index cfc3fe0..4b3545b 100644 --- a/mlir/unittests/IR/SymbolTableTest.cpp +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -132,4 +132,38 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { }); } +TEST(SymbolOpInterface, Visibility) { + DialectRegistry registry; + ::test::registerTestDialect(registry); + MLIRContext context(registry); + + constexpr static StringLiteral kInput = R"MLIR( + "test.overridden_symbol_visibility"() {sym_name = "symbol_name"} : () -> () + )MLIR"; + OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(kInput, &context); + auto symOp = cast<SymbolOpInterface>(module->getBody()->front()); + + ASSERT_TRUE(symOp.isPrivate()); + ASSERT_FALSE(symOp.isPublic()); + ASSERT_FALSE(symOp.isNested()); + ASSERT_TRUE(symOp.canDiscardOnUseEmpty()); + + std::string diagStr; + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { diagStr += diag.str(); }); + + std::string expectedDiag; + symOp.setPublic(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to public"; + symOp.setNested(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to nested"; + symOp.setPrivate(); + expectedDiag += "'test.overridden_symbol_visibility' op cannot change " + "visibility of symbol to private"; + + ASSERT_EQ(diagStr, expectedDiag); +} + } // namespace diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt index 0daac11..0a77deb 100644 --- a/mlir/unittests/Target/LLVM/CMakeLists.txt +++ b/mlir/unittests/Target/LLVM/CMakeLists.txt @@ -1,13 +1,11 @@ set(LLVM_LINK_COMPONENTS nativecodegen BitReader) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) - add_mlir_unittest(MLIRTargetLLVMTests SerializeNVVMTarget.cpp SerializeROCDLTarget.cpp SerializeToLLVMBitcode.cpp DEPENDS - ${dialect_libs} + MLIRRegisterAllDialects ) mlir_target_link_libraries(MLIRTargetLLVMTests |