diff options
Diffstat (limited to 'mlir')
150 files changed, 5848 insertions, 4860 deletions
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 85c8027..2b88355 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -99,62 +99,6 @@ function(declare_mlir_python_sources name) endif() endfunction() -# Function: generate_type_stubs -# Turns on automatic type stub generation (via nanobind's stubgen) for extension modules. -# Arguments: -# MODULE_NAME: The name of the extension module as specified in declare_mlir_python_extension. -# DEPENDS_TARGET: The dso target corresponding to the extension module -# (e.g., something like StandalonePythonModules.extension._standaloneDialectsNanobind.dso) -# MLIR_DEPENDS_TARGET: The dso target corresponding to the main/core extension module -# (e.g., something like StandalonePythonModules.extension._mlir.dso) -# OUTPUT_DIR: The root output directory to emit the type stubs into. -# Outputs: -# NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on. -function(generate_type_stubs MODULE_NAME DEPENDS_TARGET MLIR_DEPENDS_TARGET OUTPUT_DIR) - cmake_parse_arguments(ARG - "" - "" - "OUTPUTS" - ${ARGN}) - if(EXISTS ${nanobind_DIR}/../src/stubgen.py) - set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py") - elseif(EXISTS ${nanobind_DIR}/../stubgen.py) - set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py") - else() - message(FATAL_ERROR "generate_type_stubs(): could not locate 'stubgen.py'!") - endif() - file(REAL_PATH "${NB_STUBGEN}" NB_STUBGEN) - - set(_module "${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs.${MODULE_NAME}") - file(REAL_PATH "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/.." _import_path) - - set(NB_STUBGEN_CMD - "${Python_EXECUTABLE}" - "${NB_STUBGEN}" - --module - "${_module}" - -i - "${_import_path}" - --recursive - --include-private - --output-dir - "${OUTPUT_DIR}") - - list(TRANSFORM ARG_OUTPUTS PREPEND "${OUTPUT_DIR}/" OUTPUT_VARIABLE _generated_type_stubs) - add_custom_command( - OUTPUT ${_generated_type_stubs} - COMMAND ${NB_STUBGEN_CMD} - WORKING_DIRECTORY "${CMAKE_CURRENT_FUNCTION_LIST_DIR}" - DEPENDS - "${MLIR_DEPENDS_TARGET}.extension._mlir.dso" - "${MLIR_DEPENDS_TARGET}.sources.MLIRPythonSources.Core.Python" - "${DEPENDS_TARGET}" - ) - set(_name "MLIRPythonModuleStubs_${_module}") - add_custom_target("${_name}" ALL DEPENDS ${_generated_type_stubs}) - set(NB_STUBGEN_CUSTOM_TARGET "${_name}" PARENT_SCOPE) -endfunction() - # Function: declare_mlir_python_extension # Declares a buildable python extension from C++ source files. The built # module is considered a python source file and included as everything else. @@ -171,12 +115,11 @@ endfunction() # on. These will be collected for all extensions and put into an # aggregate dylib that is linked against. # PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind. -# GENERATE_TYPE_STUBS: List of generated type stubs expected from stubgen relative to _mlir_libs. function(declare_mlir_python_extension name) cmake_parse_arguments(ARG "" "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY" - "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS;GENERATE_TYPE_STUBS" + "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS" ${ARGN}) if(NOT ARG_ROOT_DIR) @@ -192,13 +135,12 @@ function(declare_mlir_python_extension name) set_target_properties(${name} PROPERTIES # Yes: Leading-lowercase property names are load bearing and the recommended # way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261 - EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY;mlir_python_GENERATE_TYPE_STUBS" + EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY" mlir_python_SOURCES_TYPE extension mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}" mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}" mlir_python_DEPENDS "" mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}" - mlir_python_GENERATE_TYPE_STUBS "${ARG_GENERATE_TYPE_STUBS}" ) # Set the interface source and link_libs properties of the target @@ -301,32 +243,6 @@ function(add_mlir_python_modules name) ) add_dependencies(${modules_target} ${_extension_target}) mlir_python_setup_extension_rpath(${_extension_target}) - get_target_property(_generate_type_stubs ${sources_target} mlir_python_GENERATE_TYPE_STUBS) - if(_generate_type_stubs) - generate_type_stubs( - ${_module_name} - ${_extension_target} - ${name} - "${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs" - OUTPUTS "${_generate_type_stubs}" - ) - add_dependencies("${modules_target}" "${NB_STUBGEN_CUSTOM_TARGET}") - set(_stubgen_target "${MLIR_PYTHON_PACKAGE_PREFIX}.${_module_name}_type_stub_gen") - declare_mlir_python_sources( - ${_stubgen_target} - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs" - ADD_TO_PARENT "${sources_target}" - SOURCES "${_generate_type_stubs}" - ) - set(_pure_sources_target "${modules_target}.sources.${sources_target}_type_stub_gen") - add_mlir_python_sources_target(${_pure_sources_target} - INSTALL_COMPONENT ${modules_target} - INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs" - OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs" - SOURCES_TARGETS ${_stubgen_target} - ) - add_dependencies(${modules_target} ${_pure_sources_target}) - endif() else() message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}") return() @@ -762,28 +678,26 @@ function(add_mlir_python_extension libname extname) # the super project handle compile options as it wishes. get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES) target_compile_options(${NB_LIBRARY_TARGET_NAME} - PRIVATE - -Wall -Wextra -Wpedantic - -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - ${eh_rtti_enable}) + PRIVATE + -Wall -Wextra -Wpedantic + -Wno-c++98-compat-extra-semi + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + ${eh_rtti_enable}) target_compile_options(${libname} - PRIVATE - -Wall -Wextra -Wpedantic - -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - ${eh_rtti_enable}) + PRIVATE + -Wall -Wextra -Wpedantic + -Wno-c++98-compat-extra-semi + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + ${eh_rtti_enable}) endif() if(APPLE) diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md index 9846df8..f50ddeb 100644 --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -125,7 +125,9 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit indicates if the section has alignment requirements, a length (which allows for skipping over the section), and an optional alignment. When an alignment is present, a variable number of padding bytes (0xCB) may appear before the section -data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section. +data. The alignment of a section must be a power of 2. +The input bytecode buffer must satisfy the same alignment requirements as +those of every section. ## MLIR Encoding diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 7070351..5ae3515 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a `Type`. A context-aware conversion function converts a `Value` into a type. The latter allows users to customize type conversion rules based on the IR. -Note: When there is at least one context-aware type conversion function, the -result of type conversions can no longer be cached, which can increase -compilation time. Use this feature with caution! +Note: context-aware type conversion functions impact the ability of the +framework to cache the conversion result. In the absence of a context-aware +conversion, all context-free type conversions can be cached. Otherwise only the +context-free conversions added after a context-aware type conversion can be +cached (conversions are applied in reverse order). +As such it is advised to add context-aware conversions as early as possible in +the sequence of `addConversion` calls (so that they apply last). A `materialization` describes how a list of values should be converted to a list of values with specific types. An important distinction from a diff --git a/mlir/docs/Dialects/IRDL.md b/mlir/docs/Dialects/IRDL.md new file mode 100644 index 0000000..c09457a --- /dev/null +++ b/mlir/docs/Dialects/IRDL.md @@ -0,0 +1,123 @@ +# 'irdl' Dialect + +[TOC] + +## Basics + +The IRDL (*Intermediate Representation Definition Language*) dialect allows +defining MLIR dialects as MLIR programs. Nested operations are used to +represent dialect structure: dialects contain operations, types and +attributes, themselves containing type parameters, operands, results, etc. +Each of those concepts are mapped to MLIR operations in the IRDL dialect, as +shown in the example dialect below: + +```mlir +irdl.dialect @cmath { + irdl.type @complex { + %0 = irdl.is f32 + %1 = irdl.is f64 + %2 = irdl.any_of(%0, %1) + irdl.parameters(%2) + } + + irdl.operation @mul { + %0 = irdl.is f32 + %1 = irdl.is f64 + %2 = irdl.any_of(%0, %1) + %3 = irdl.parametric @cmath::@complex<%2> + irdl.operands(%3, %3) + irdl.results(%3) + } +} +``` + +This program defines a `cmath` dialect that defines a `complex` type, and +a `mul` operation. Both express constraints over their parameters using +SSA constraint operations. Informally, one can see those SSA values as +constraint variables that evaluate to a single type at constraint +evaluation. For example, the result of the `irdl.any_of` stored in `%2` +in the `mul` operation will collapse into either `f32` or `f64` for the +entirety of this instance of `mul` constraint evaluation. As such, +both operands and the result of `mul` must be of equal type (and not just +satisfy the same constraint). For more information, see +[constraints and combinators](#constraints-and-combinators). + +In order to simplify the dialect, IRDL variables are handles over +`mlir::Attribute`. In order to support manipulating `mlir::Type`, +IRDL wraps all types in an `mlir::TypeAttr` attribute. + +## Principles + +The core principles of IRDL are the following, in no particular order: + +- **Portability.** IRDL dialects should be self-contained, such that dialects + can be easily distributed with minimal assumptions on which compiler + infrastructure (or which commit of MLIR) is used. +- **Introspection.** The IRDL dialect definition mechanism should strive + towards offering as much introspection abilities as possible. Dialects + should be as easy to manipulate, generate, and analyze as possible. +- **Runtime declaration support**. The specification of IRDL dialects should + offer the ability to have them be loaded at runtime, via dynamic registration + or JIT compilation. Compatibility with dynamic workflows should not hinder + the ability to compile IRDL dialects into ahead-of-time declarations. +- **Reliability.** Concepts in IRDL should be consistent and predictable, with + as much focus on high-level simplicity as possible. Consequently, IRDL + definitions that verify should work out of the box, and those that do not + verify should provide clear and understandable errors in all circumstances. + +While IRDL simplifies IR definition, it remains an IR itself and thus does not +require to be comfortably user-writeable. + +## Constraints and combinators + +Attribute, type and operation verifiers are expressed in terms of constraint +variables. Constraint variables are defined as the results of constraint +operations (like `irdl.is` or constraint combinators). + +Constraint variables act as variables: as such, matching against the same +constraint variable multiple times can only succeed if the matching type or +attribute is the same as the one that previously matched. In the following +example: + +```mlir +irdl.type @foo { + %ty = irdl.any_type + irdl.parameters(param1: %ty, param2: %ty) +} +``` + +only types with two equal parameters will successfully match (`foo<i32, i32>` +would match while `foo<i32, i64>` would fail, even though both i32 and i64 +individually satisfy the `irdl.any_type` constraint). This constraint variable +mechanism allows to easily express a requirement on type or attribute equality. + +To declare more complex verifiers, IRDL provides constraint-combinator +operations such as `irdl.any_of`, `irdl.all_of` or `irdl.parametric`. These +combinators can be used to combine constraint variables into new constraint +variables. Like all uses of constraint variables, their constraint variable +operands enforce equality of matched types of attributes as explained in the +previous paragraph. + +## Motivating use cases + +To illustrate the rationale behind IRDL, the following list describes examples +of intended use cases for IRDL, in no particular order: + +- **Fuzzer generation.** With declarative verifier definitions, it is possible + to compile IRDL dialects into compiler fuzzers that generate only programs + passing verifiers. +- **Portable dialects between compiler infrastructures.** Some compiler + infrastructures are independent from MLIR but are otherwise IR-compatible. + Portable IRDL dialects allow to share the dialect definitions between MLIR + and other compiler infrastructures without needing to maintain multiple + potentially out-of-sync definitions. +- **Dialect simplification.** Because IRDL definitions can easily be + mechanically modified, it is possible to simplify the definition of dialects + based on which operations are actually used, leading to smaller compilers. +- **SMT analysis.** Because IRDL dialect definitions are declarative, their + definition can be lowered to alternative representations like SMT, allowing + analysis of the behavior of transforms taking verifiers into account. + +## Operations + +[include "Dialects/IRDLOps.md"] diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index cb10518..a0eca9c 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -39,7 +39,6 @@ declare_mlir_python_extension(StandalonePythonSources.NanobindExtension EMBED_CAPI_LINK_LIBS StandaloneCAPI PYTHON_BINDINGS_LIBRARY nanobind - GENERATE_TYPE_STUBS ) diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index 2250db8..c7c405e1 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -229,6 +229,13 @@ private: /// considered an external callable. Operation *analysisScope; + /// Whether the analysis scope has a symbol table. This is used to avoid + /// resolving callables outside the analysis scope. + /// It is updated when recursing into a region in case where the top-level + /// operation does not have a symbol table, but one is encountered in a nested + /// region. + bool hasSymbolTable = false; + /// A symbol table used for O(1) symbol lookups during simplification. SymbolTableCollection symbolTable; }; diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h index 291b809..220da0a 100644 --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -13,6 +13,7 @@ #include <memory> namespace mlir { +class Pass; class LLVMTypeConverter; class ConversionTarget; class RewritePatternSet; @@ -42,16 +43,6 @@ void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, /// Configure target to convert from the GPU dialect to ROCDL. void configureGpuToROCDLConversionLegality(ConversionTarget &target); -/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The -/// index bitwidth used for the lowering of the device side index computations -/// is configurable. -std::unique_ptr<OperationPass<gpu::GPUModuleOp>> -createLowerGpuOpsToROCDLOpsPass( - const std::string &chipset = "gfx900", - unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout, - bool useBarePtrCallConv = false, - gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown); - } // namespace mlir #endif // MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 44dc1bc..1a37d05 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -624,7 +624,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> { let summary = "Generate ROCDL operations for gpu operations"; - let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()"; let dependentDialects = [ "ROCDL::ROCDLDialect", "amdgpu::AMDGPUDialect", diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h index 7ffdbd4..f591407 100644 --- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h +++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -11,6 +11,7 @@ #include <memory> namespace mlir { +class ConversionTarget; class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; @@ -19,7 +20,8 @@ class Pass; #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" -void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); +void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, + RewritePatternSet &patterns); void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index 0dd8de4..df4145d 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -153,9 +153,17 @@ public: MemRefDependenceGraph(Block &block) : block(block) {} - // Initializes the dependence graph based on operations in `block'. - // Returns true on success, false otherwise. - bool init(); + // Initializes the data dependence graph by iterating over the operations of + // the MDG's `block`. A `Node` is created for every top-level op except for + // side-effect-free operations with zero results and no regions. Assigns each + // node in the graph a node id based on the order in block. Fails if certain + // kinds of operations, for which `Node` creation isn't supported, are + // encountered (unknown region holding ops). If `fullAffineDependences` is + // set, affine memory dependence analysis is performed before concluding that + // conflicting affine memory accesses lead to a dependence check; otherwise, a + // pair of conflicting affine memory accesses (where one of them is a store + // and they are to the same memref) always leads to an edge (conservatively). + bool init(bool fullAffineDependences = true); // Returns the graph node for 'id'. const Node *getNode(unsigned id) const; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h index 1ef5370..e735651 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -12,7 +12,6 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SubsetOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 271b420..6724d4c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -18,7 +18,6 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SubsetOpInterface.td" -include "mlir/Interfaces/CopyOpInterface.td" class Bufferization_Op<string mnemonic, list<Trait> traits = []> : Op<Bufferization_Dialect, mnemonic, traits>; @@ -171,7 +170,6 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", //===----------------------------------------------------------------------===// def Bufferization_CloneOp : Bufferization_Op<"clone", [ - CopyOpInterface, MemoryEffectsOpInterface, DeclareOpInterfaceMethods<AllocationOpInterface, ["buildDealloc", "buildClone"]> ]> { diff --git a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt index 9dcadec..1c08f0f 100644 --- a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect(IRDL irdl) -add_mlir_doc(IRDLOps IRDL Dialects/ -gen-dialect-doc -dialect=irdl) +add_mlir_doc(IRDLOps IRDLOps Dialects/ -gen-op-doc -dialect=irdl) # Add IRDL interfaces set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index ac99b8a..a8c9ef7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1233,6 +1233,47 @@ def LLVM_TBAATagArrayAttr } //===----------------------------------------------------------------------===// +// MMRATagAttr +//===----------------------------------------------------------------------===// + +def LLVM_MMRATagAttr : LLVM_Attr<"MMRATag", "mmra_tag"> { + let parameters = (ins + StringRefParameter<>:$prefix, + StringRefParameter<>:$suffix + ); + + let summary = "MLIR wrapper around a prefix:suffix MMRA tag"; + + let description = [{ + Defines a single memory model relaxation annotation (MMRA) entry + with prefix `$prefix` and suffix `$suffix`. This corresponds directly + to a LLVM `!{prefix, suffix}` metadata tuple, which is often written + `prefix:shuffix` as shorthand. + + Example: + ```mlir + #mmra_tag = #llvm.mmmra_tag<"amdgpu-synchronize-as":"local"> + #mmra_tag1 = #llvm.mmra_tag<"foo":"bar"> + ``` + + Either one MMRA tag or an array of them may be added to any LLVM + operation that operates on memory. + + ```mlir + %v = llvm.load %ptr {llvm.mmra = #mmra_tag} : !llvm.ptr -> i8 + llvm.store %v, %ptr2 {llvm.mmra [#mmra_tag, #mmra_tag1]} : i8, !llvm.ptr + ``` + + See the following link for more details: + https://llvm.org/docs/MemoryModelRelaxationAnnotations.html + }]; + + let assemblyFormat = "`<` $prefix `` `:` `` $suffix `>`"; + + let genMnemonicAlias = 1; +} + +//===----------------------------------------------------------------------===// // ConstantRangeAttr //===----------------------------------------------------------------------===// def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> { diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td index ab0462f..d2d7131 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td @@ -36,6 +36,7 @@ def LLVM_Dialect : Dialect { static StringRef getIdentAttrName() { return "llvm.ident"; } static StringRef getModuleFlags() { return "llvm.module.flags"; } static StringRef getCommandlineAttrName() { return "llvm.commandline"; } + static StringRef getMmraAttrName() { return "llvm.mmra"; } /// Names of llvm parameter attributes. static StringRef getAlignAttrName() { return "llvm.align"; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index eb4e381..9de6d8f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -22,7 +22,6 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index a19cce4..8f3232f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp : DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes, OptionalAttr<DeviceMappingArrayAttr>:$mapping); let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op, - TransformHandleTypeInterface:$split_linalg_op, - TransformHandleTypeInterface:$combining_linalg_op, + TransformHandleTypeInterface:$split_op, + TransformHandleTypeInterface:$combining_op, TransformHandleTypeInterface:$forall_op); let builders = [ @@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp : let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, - ::mlir::linalg::LinalgOp target, + Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index ac383ab4..bdec699 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -16,7 +16,6 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index d6b7a97..513a9a1 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -13,7 +13,6 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -530,7 +529,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [ // CopyOp //===----------------------------------------------------------------------===// -def CopyOp : MemRef_Op<"copy", [CopyOpInterface, SameOperandsElementType, +def CopyOp : MemRef_Op<"copy", [SameOperandsElementType, SameOperandsShape]> { let description = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index faf820d..6a92b13 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -40,7 +40,7 @@ struct DeviceTypeClauseOps { /// Clauses that correspond to operations other than omp.target, but might have /// to be evaluated outside of a parent target region. using HostEvaluatedOperands = - detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps, + detail::Clauses<CollapseClauseOps, LoopRelatedClauseOps, NumTeamsClauseOps, NumThreadsClauseOps, ThreadLimitClauseOps>; // TODO: Add `indirect` clause. diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 311c57f..5f40abe 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -210,6 +210,23 @@ class OpenMP_BindClauseSkip< def OpenMP_BindClause : OpenMP_BindClauseSkip<>; //===----------------------------------------------------------------------===// +// V5.2: [4.4.3] `collapse` clause +//===----------------------------------------------------------------------===// + +class OpenMP_CollapseClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause<traits, arguments, assemblyFormat, description, + extraClassDeclaration> { + let arguments = (ins + ConfinedAttr<DefaultValuedOptionalAttr<I64Attr, "1">, [IntMinValue<1>]> + :$collapse_num_loops + ); +} + +def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>; + +//===----------------------------------------------------------------------===// // V5.2: [5.7.2] `copyprivate` clause //===----------------------------------------------------------------------===// @@ -1386,6 +1403,22 @@ class OpenMP_ThreadLimitClauseSkip< def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>; //===----------------------------------------------------------------------===// +// V5.2: [9.1.1] `sizes` clause +//===----------------------------------------------------------------------===// + +class OpenMP_TileSizesClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause<traits, arguments, assemblyFormat, description, + extraClassDeclaration> { + let arguments = (ins + OptionalAttr<DenseI64ArrayAttr>:$tile_sizes + ); +} + +def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>; + +//===----------------------------------------------------------------------===// // V5.2: [12.1] `untied` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 2548a8a..830b36f 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -614,13 +614,18 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [ RecursiveMemoryEffects, SameVariadicOperandSize ], clauses = [ - OpenMP_LoopRelatedClause + OpenMP_CollapseClause, + OpenMP_LoopRelatedClause, + OpenMP_TileSizesClause ], singleRegion = true> { let summary = "rectangular loop nest"; let description = [{ - This operation represents a collapsed rectangular loop nest. For each - rectangular loop of the nest represented by an instance of this operation, - lower and upper bounds, as well as a step variable, must be defined. + This operation represents a rectangular loop nest which may be collapsed + and/or tiled. For each rectangular loop of the nest represented by an + instance of this operation, lower and upper bounds, as well as a step + variable, must be defined. The collapse clause specifies how many loops + that should be collapsed (1 if no collapse is done) after any tiling is + performed. The tiling sizes is represented by the tile sizes clause. The lower and upper bounds specify a half-open range: the range includes the lower bound but does not include the upper bound. If the `loop_inclusive` @@ -633,7 +638,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [ `loop_steps` arguments. ```mlir - omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) { + omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) collapse(2) tiles(5,5) { %a = load %arrA[%i1, %i2] : memref<?x?xf32> %b = load %arrB[%i1, %i2] : memref<?x?xf32> %sum = arith.addf %a, %b : f32 diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 77e26cc..65ba7e0 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -972,10 +972,6 @@ def Vector_ScalableInsertOp : VectorType getDestVectorType() { return ::llvm::cast<VectorType>(getDest().getType()); } - /// Wrapper for getResult, which replaced getRes. - [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() { - return getResult(); - } }]; } @@ -1027,10 +1023,6 @@ def Vector_ScalableExtractOp : VectorType getResultVectorType() { return ::llvm::cast<VectorType>(getResult().getType()); } - /// Wrapper for getResult, which replaced getRes. - [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() { - return getResult(); - } }]; } @@ -1085,10 +1077,6 @@ def Vector_InsertStridedSliceOp : return ::llvm::cast<IntegerAttr>(attr).getInt() != 1; }); } - /// Wrapper for getResult, which replaced getRes. - [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() { - return getResult(); - } }]; let hasFolder = 1; diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 07a4117..72a69a0 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect, let assemblyFormat = "attr-dict"; } +def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect, + "apply_patterns.vector.unroll_to_elements", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Indicates that vector to_elements operations should be unrolled + along the outermost dimension. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerScanPatternsOp : Op<Transform_Dialect, "apply_patterns.vector.lower_scan", [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 47f9611..f56124c 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -313,6 +313,12 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, /// Populate the pattern set with the following patterns: /// +/// [UnrollToElements] +void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// /// [ContractionOpToMatmulOpLowering] /// Lowers `vector.contract` to `llvm.intr.matrix.multiply`. /// diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index ace2699..97163c4 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -255,6 +255,12 @@ using UnrollVectorOpFn = LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn); +/// Generic utility for unrolling values of type vector<NxAxBx...> +/// to N values of type vector<AxBx...> using vector.extract. If the input +/// is rank-1 or has leading scalable dimension, failure is returned. +FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>, + RewriterBase &); + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 20cc267..2add220 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_interface(CallInterfaces) add_mlir_interface(CastInterfaces) add_mlir_interface(ControlFlowInterfaces) -add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(FunctionInterfaces) diff --git a/mlir/include/mlir/Interfaces/CopyOpInterface.h b/mlir/include/mlir/Interfaces/CopyOpInterface.h deleted file mode 100644 index 2f38eb3..0000000 --- a/mlir/include/mlir/Interfaces/CopyOpInterface.h +++ /dev/null @@ -1,21 +0,0 @@ -//===- CopyOpInterface.h - copy operations interface ----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the operation interface for copy-like operations. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_COPYOPINTERFACE_H_ -#define MLIR_INTERFACES_COPYOPINTERFACE_H_ - -#include "mlir/IR/OpDefinition.h" - -/// Include the generated interface declarations. -#include "mlir/Interfaces/CopyOpInterface.h.inc" - -#endif // MLIR_INTERFACES_COPYOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/CopyOpInterface.td b/mlir/include/mlir/Interfaces/CopyOpInterface.td deleted file mode 100644 index f6c5a6f..0000000 --- a/mlir/include/mlir/Interfaces/CopyOpInterface.td +++ /dev/null @@ -1,38 +0,0 @@ -//===- CopyOpInterface.td - Copy operation interface -------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the interface for copy-like operations. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_INTERFACES_COPYOPINTERFACE -#define MLIR_INTERFACES_COPYOPINTERFACE - -include "mlir/IR/OpBase.td" - -def CopyOpInterface : OpInterface<"CopyOpInterface"> { - let description = [{ - A copy-like operation is one that copies from source value to target value. - }]; - let cppNamespace = "::mlir"; - - let methods = [ - InterfaceMethod< - /*desc=*/"Returns the source value for this copy operation", - /*retTy=*/"::mlir::Value", - /*methodName=*/"getSource" - >, - InterfaceMethod< - /*desc=*/"Returns the target value for this copy operation", - /*retTy=*/"::mlir::Value", - /*methodName=*/"getTarget" - > - ]; -} - -#endif // MLIR_INTERFACES_COPYOPINTERFACE diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index d168735..58852239 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -217,7 +217,7 @@ public: /// `closedUB` is set to "true", upper bounds are also closed. static FailureOr<int64_t> computeConstantBound(presburger::BoundType type, const Variable &var, - StopConditionFn stopCondition = nullptr, + const StopConditionFn &stopCondition = nullptr, bool closedUB = false); /// Compute a constant delta between the given two values. Return "failure" @@ -282,18 +282,18 @@ public: /// /// Slice are non-overlapping if the above constraint is not satisfied for /// at least one dimension. - static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx, - HyperrectangularSlice slice1, - HyperrectangularSlice slice2); + static FailureOr<bool> + areOverlappingSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, + const HyperrectangularSlice &slice2); /// Return "true" if the given slices are guaranteed to be equivalent. /// Return "false" if the given slices are guaranteed to be non-equivalent. /// Return "failure" if unknown. /// /// Slices are equivalent if their offsets, sizes and strices are equal. - static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx, - HyperrectangularSlice slice1, - HyperrectangularSlice slice2); + static FailureOr<bool> + areEquivalentSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1, + const HyperrectangularSlice &slice2); /// Add a bound for the given index-typed value or shaped value. This function /// returns a builder that adds the bound. @@ -326,7 +326,8 @@ protected: /// An index-typed value or the dimension of a shaped-type value. using ValueDim = std::pair<Value, int64_t>; - ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition, + ValueBoundsConstraintSet(MLIRContext *ctx, + const StopConditionFn &stopCondition, bool addConservativeSemiAffineBounds = false); /// Return "true" if, based on the current state of the constraint system, @@ -401,7 +402,8 @@ protected: /// Insert the given affine map and its bound operands as a new column in the /// constraint system. Return the position of the new column. Any operands /// that were not analyzed yet are put on the worklist. - int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + int64_t insert(AffineMap map, const ValueDimList &operands, + bool isSymbol = true); int64_t insert(const Variable &var, bool isSymbol = true); /// Project out the given column in the constraint set. diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index c72ca58..6838c16 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -187,12 +187,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { return inBounds; } - /// Wrapper for getBase, which replaced getSource. - [[deprecated("Use getBase instead!")]] - ::mlir::Value getSource() { - return $_op.getBase(); - } - /// Return the number of leading shaped dimensions (of the "source" operand) /// that do not participate in the permutation map. unsigned getLeadingShapedRank() { diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 9e57037..f0514d8 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -323,21 +323,22 @@ public: /// Requires: all result types are known. const InferredResultType &getInferredResultType(int index) const; - /// Pair consisting kind of argument and index into operands or attributes. - struct OperandOrAttribute { - enum class Kind { Operand, Attribute }; - OperandOrAttribute(Kind kind, int index) { - packed = (index << 1) | (kind == Kind::Attribute); + /// Pair consisting kind of argument and index into operands, attributes, or + /// properties. + struct OperandAttrOrProp { + enum class Kind { Operand = 0x0, Attribute = 0x1, Property = 0x2 }; + OperandAttrOrProp(Kind kind, int index) { + packed = (index << 2) | static_cast<int>(kind); } - int operandOrAttributeIndex() const { return (packed >> 1); } - Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; } + int operandOrAttributeIndex() const { return (packed >> 2); } + Kind kind() const { return static_cast<Kind>(packed & 0x3); } private: int packed; }; - /// Returns the OperandOrAttribute corresponding to the index. - OperandOrAttribute getArgToOperandOrAttribute(int index) const; + /// Returns the OperandAttrOrProp corresponding to the index. + OperandAttrOrProp getArgToOperandAttrOrProp(int index) const; /// Returns the builders of this operation. ArrayRef<Builder> getBuilders() const { return builders; } @@ -405,8 +406,8 @@ private: /// The argument with the same type as the result. SmallVector<InferredResultType> resultTypeMapping; - /// Map from argument to attribute or operand number. - SmallVector<OperandOrAttribute, 4> attrOrOperandMapping; + /// Map from argument to attribute, property, or operand number. + SmallVector<OperandAttrOrProp, 4> attrPropOrOperandMapping; /// The builders of this operator. SmallVector<Builder> builders; diff --git a/mlir/include/mlir/Tools/lsp-server-support/Logging.h b/mlir/include/mlir/Tools/lsp-server-support/Logging.h deleted file mode 100644 index 9b090d0..0000000 --- a/mlir/include/mlir/Tools/lsp-server-support/Logging.h +++ /dev/null @@ -1,65 +0,0 @@ -//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H - -#include "mlir/Support/LLVM.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include <memory> -#include <mutex> - -namespace mlir { -namespace lsp { - -/// This class represents the main interface for logging, and allows for -/// filtering logging based on different levels of severity or significance. -class Logger { -public: - /// The level of significance for a log message. - enum class Level { Debug, Info, Error }; - - /// Set the severity level of the logger. - static void setLogLevel(Level logLevel); - - /// Initiate a log message at various severity levels. These should be called - /// after a call to `initialize`. - template <typename... Ts> - static void debug(const char *fmt, Ts &&...vals) { - log(Level::Debug, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...)); - } - template <typename... Ts> - static void info(const char *fmt, Ts &&...vals) { - log(Level::Info, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...)); - } - template <typename... Ts> - static void error(const char *fmt, Ts &&...vals) { - log(Level::Error, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...)); - } - -private: - Logger() = default; - - /// Return the main logger instance. - static Logger &get(); - - /// Start a log message with the given severity level. - static void log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message); - - /// The minimum logging level. Messages with lower level are ignored. - Level logLevel = Level::Error; - - /// A mutex used to guard logging. - std::mutex mutex; -}; -} // namespace lsp -} // namespace mlir - -#endif // MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H diff --git a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h b/mlir/include/mlir/Tools/lsp-server-support/Protocol.h deleted file mode 100644 index cc06dbf..0000000 --- a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h +++ /dev/null @@ -1,1257 +0,0 @@ -//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains structs based on the LSP specification at -// https://microsoft.github.io/language-server-protocol/specification -// -// This is not meant to be a complete implementation, new interfaces are added -// when they're needed. -// -// Each struct has a toJSON and fromJSON function, that converts between -// the struct and a JSON representation. (See JSON.h) -// -// Some structs also have operator<< serialization. This is for debugging and -// tests, and is not generally machine-readable. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H - -#include "mlir/Support/LLVM.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" -#include <bitset> -#include <optional> -#include <string> -#include <utility> -#include <vector> - -namespace mlir { -namespace lsp { - -enum class ErrorCode { - // Defined by JSON RPC. - ParseError = -32700, - InvalidRequest = -32600, - MethodNotFound = -32601, - InvalidParams = -32602, - InternalError = -32603, - - ServerNotInitialized = -32002, - UnknownErrorCode = -32001, - - // Defined by the protocol. - RequestCancelled = -32800, - ContentModified = -32801, - RequestFailed = -32803, -}; - -/// Defines how the host (editor) should sync document changes to the language -/// server. -enum class TextDocumentSyncKind { - /// Documents should not be synced at all. - None = 0, - - /// Documents are synced by always sending the full content of the document. - Full = 1, - - /// Documents are synced by sending the full content on open. After that only - /// incremental updates to the document are sent. - Incremental = 2, -}; - -//===----------------------------------------------------------------------===// -// LSPError -//===----------------------------------------------------------------------===// - -/// This class models an LSP error as an llvm::Error. -class LSPError : public llvm::ErrorInfo<LSPError> { -public: - std::string message; - ErrorCode code; - static char ID; - - LSPError(std::string message, ErrorCode code) - : message(std::move(message)), code(code) {} - - void log(raw_ostream &os) const override { - os << int(code) << ": " << message; - } - std::error_code convertToErrorCode() const override { - return llvm::inconvertibleErrorCode(); - } -}; - -//===----------------------------------------------------------------------===// -// URIForFile -//===----------------------------------------------------------------------===// - -/// URI in "file" scheme for a file. -class URIForFile { -public: - URIForFile() = default; - - /// Try to build a URIForFile from the given URI string. - static llvm::Expected<URIForFile> fromURI(StringRef uri); - - /// Try to build a URIForFile from the given absolute file path and optional - /// scheme. - static llvm::Expected<URIForFile> fromFile(StringRef absoluteFilepath, - StringRef scheme = "file"); - - /// Returns the absolute path to the file. - StringRef file() const { return filePath; } - - /// Returns the original uri of the file. - StringRef uri() const { return uriStr; } - - /// Return the scheme of the uri. - StringRef scheme() const; - - explicit operator bool() const { return !filePath.empty(); } - - friend bool operator==(const URIForFile &lhs, const URIForFile &rhs) { - return lhs.filePath == rhs.filePath; - } - friend bool operator!=(const URIForFile &lhs, const URIForFile &rhs) { - return !(lhs == rhs); - } - friend bool operator<(const URIForFile &lhs, const URIForFile &rhs) { - return lhs.filePath < rhs.filePath; - } - - /// Register a supported URI scheme. The protocol supports `file` by default, - /// so this is only necessary for any additional schemes that a server wants - /// to support. - static void registerSupportedScheme(StringRef scheme); - -private: - explicit URIForFile(std::string &&filePath, std::string &&uriStr) - : filePath(std::move(filePath)), uriStr(uriStr) {} - - std::string filePath; - std::string uriStr; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const URIForFile &value); -bool fromJSON(const llvm::json::Value &value, URIForFile &result, - llvm::json::Path path); -raw_ostream &operator<<(raw_ostream &os, const URIForFile &value); - -//===----------------------------------------------------------------------===// -// ClientCapabilities -//===----------------------------------------------------------------------===// - -struct ClientCapabilities { - /// Client supports hierarchical document symbols. - /// textDocument.documentSymbol.hierarchicalDocumentSymbolSupport - bool hierarchicalDocumentSymbol = false; - - /// Client supports CodeAction return value for textDocument/codeAction. - /// textDocument.codeAction.codeActionLiteralSupport. - bool codeActionStructure = false; - - /// Client supports server-initiated progress via the - /// window/workDoneProgress/create method. - /// - /// window.workDoneProgress - bool workDoneProgress = false; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, ClientCapabilities &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// ClientInfo -//===----------------------------------------------------------------------===// - -struct ClientInfo { - /// The name of the client as defined by the client. - std::string name; - - /// The client's version as defined by the client. - std::optional<std::string> version; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, ClientInfo &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// InitializeParams -//===----------------------------------------------------------------------===// - -enum class TraceLevel { - Off = 0, - Messages = 1, - Verbose = 2, -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, TraceLevel &result, - llvm::json::Path path); - -struct InitializeParams { - /// The capabilities provided by the client (editor or tool). - ClientCapabilities capabilities; - - /// Information about the client. - std::optional<ClientInfo> clientInfo; - - /// The initial trace setting. If omitted trace is disabled ('off'). - std::optional<TraceLevel> trace; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, InitializeParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// InitializedParams -//===----------------------------------------------------------------------===// - -struct NoParams {}; -inline bool fromJSON(const llvm::json::Value &, NoParams &, llvm::json::Path) { - return true; -} -using InitializedParams = NoParams; - -//===----------------------------------------------------------------------===// -// TextDocumentItem -//===----------------------------------------------------------------------===// - -struct TextDocumentItem { - /// The text document's URI. - URIForFile uri; - - /// The text document's language identifier. - std::string languageId; - - /// The content of the opened text document. - std::string text; - - /// The version number of this document. - int64_t version; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, TextDocumentItem &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// TextDocumentIdentifier -//===----------------------------------------------------------------------===// - -struct TextDocumentIdentifier { - /// The text document's URI. - URIForFile uri; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const TextDocumentIdentifier &value); -bool fromJSON(const llvm::json::Value &value, TextDocumentIdentifier &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// VersionedTextDocumentIdentifier -//===----------------------------------------------------------------------===// - -struct VersionedTextDocumentIdentifier { - /// The text document's URI. - URIForFile uri; - /// The version number of this document. - int64_t version; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const VersionedTextDocumentIdentifier &value); -bool fromJSON(const llvm::json::Value &value, - VersionedTextDocumentIdentifier &result, llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// Position -//===----------------------------------------------------------------------===// - -struct Position { - Position(int line = 0, int character = 0) - : line(line), character(character) {} - - /// Construct a position from the given source location. - Position(llvm::SourceMgr &mgr, SMLoc loc) { - std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc); - line = lineAndCol.first - 1; - character = lineAndCol.second - 1; - } - - /// Line position in a document (zero-based). - int line = 0; - - /// Character offset on a line in a document (zero-based). - int character = 0; - - friend bool operator==(const Position &lhs, const Position &rhs) { - return std::tie(lhs.line, lhs.character) == - std::tie(rhs.line, rhs.character); - } - friend bool operator!=(const Position &lhs, const Position &rhs) { - return !(lhs == rhs); - } - friend bool operator<(const Position &lhs, const Position &rhs) { - return std::tie(lhs.line, lhs.character) < - std::tie(rhs.line, rhs.character); - } - friend bool operator<=(const Position &lhs, const Position &rhs) { - return std::tie(lhs.line, lhs.character) <= - std::tie(rhs.line, rhs.character); - } - - /// Convert this position into a source location in the main file of the given - /// source manager. - SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const { - return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), line + 1, - character + 1); - } -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, Position &result, - llvm::json::Path path); -llvm::json::Value toJSON(const Position &value); -raw_ostream &operator<<(raw_ostream &os, const Position &value); - -//===----------------------------------------------------------------------===// -// Range -//===----------------------------------------------------------------------===// - -struct Range { - Range() = default; - Range(Position start, Position end) : start(start), end(end) {} - Range(Position loc) : Range(loc, loc) {} - - /// Construct a range from the given source range. - Range(llvm::SourceMgr &mgr, SMRange range) - : Range(Position(mgr, range.Start), Position(mgr, range.End)) {} - - /// The range's start position. - Position start; - - /// The range's end position. - Position end; - - friend bool operator==(const Range &lhs, const Range &rhs) { - return std::tie(lhs.start, lhs.end) == std::tie(rhs.start, rhs.end); - } - friend bool operator!=(const Range &lhs, const Range &rhs) { - return !(lhs == rhs); - } - friend bool operator<(const Range &lhs, const Range &rhs) { - return std::tie(lhs.start, lhs.end) < std::tie(rhs.start, rhs.end); - } - - bool contains(Position pos) const { return start <= pos && pos < end; } - bool contains(Range range) const { - return start <= range.start && range.end <= end; - } - - /// Convert this range into a source range in the main file of the given - /// source manager. - SMRange getAsSMRange(llvm::SourceMgr &mgr) const { - SMLoc startLoc = start.getAsSMLoc(mgr); - SMLoc endLoc = end.getAsSMLoc(mgr); - // Check that the start and end locations are valid. - if (!startLoc.isValid() || !endLoc.isValid() || - startLoc.getPointer() > endLoc.getPointer()) - return SMRange(); - return SMRange(startLoc, endLoc); - } -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, Range &result, - llvm::json::Path path); -llvm::json::Value toJSON(const Range &value); -raw_ostream &operator<<(raw_ostream &os, const Range &value); - -//===----------------------------------------------------------------------===// -// Location -//===----------------------------------------------------------------------===// - -struct Location { - Location() = default; - Location(const URIForFile &uri, Range range) : uri(uri), range(range) {} - - /// Construct a Location from the given source range. - Location(const URIForFile &uri, llvm::SourceMgr &mgr, SMRange range) - : Location(uri, Range(mgr, range)) {} - - /// The text document's URI. - URIForFile uri; - Range range; - - friend bool operator==(const Location &lhs, const Location &rhs) { - return lhs.uri == rhs.uri && lhs.range == rhs.range; - } - - friend bool operator!=(const Location &lhs, const Location &rhs) { - return !(lhs == rhs); - } - - friend bool operator<(const Location &lhs, const Location &rhs) { - return std::tie(lhs.uri, lhs.range) < std::tie(rhs.uri, rhs.range); - } -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, Location &result, - llvm::json::Path path); -llvm::json::Value toJSON(const Location &value); -raw_ostream &operator<<(raw_ostream &os, const Location &value); - -//===----------------------------------------------------------------------===// -// TextDocumentPositionParams -//===----------------------------------------------------------------------===// - -struct TextDocumentPositionParams { - /// The text document. - TextDocumentIdentifier textDocument; - - /// The position inside the text document. - Position position; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, - TextDocumentPositionParams &result, llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// ReferenceParams -//===----------------------------------------------------------------------===// - -struct ReferenceContext { - /// Include the declaration of the current symbol. - bool includeDeclaration = false; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, ReferenceContext &result, - llvm::json::Path path); - -struct ReferenceParams : public TextDocumentPositionParams { - ReferenceContext context; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, ReferenceParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// DidOpenTextDocumentParams -//===----------------------------------------------------------------------===// - -struct DidOpenTextDocumentParams { - /// The document that was opened. - TextDocumentItem textDocument; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, DidOpenTextDocumentParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// DidCloseTextDocumentParams -//===----------------------------------------------------------------------===// - -struct DidCloseTextDocumentParams { - /// The document that was closed. - TextDocumentIdentifier textDocument; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, - DidCloseTextDocumentParams &result, llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// DidChangeTextDocumentParams -//===----------------------------------------------------------------------===// - -struct TextDocumentContentChangeEvent { - /// Try to apply this change to the given contents string. - LogicalResult applyTo(std::string &contents) const; - /// Try to apply a set of changes to the given contents string. - static LogicalResult applyTo(ArrayRef<TextDocumentContentChangeEvent> changes, - std::string &contents); - - /// The range of the document that changed. - std::optional<Range> range; - - /// The length of the range that got replaced. - std::optional<int> rangeLength; - - /// The new text of the range/document. - std::string text; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, - TextDocumentContentChangeEvent &result, llvm::json::Path path); - -struct DidChangeTextDocumentParams { - /// The document that changed. - VersionedTextDocumentIdentifier textDocument; - - /// The actual content changes. - std::vector<TextDocumentContentChangeEvent> contentChanges; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, - DidChangeTextDocumentParams &result, llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// MarkupContent -//===----------------------------------------------------------------------===// - -/// Describes the content type that a client supports in various result literals -/// like `Hover`. -enum class MarkupKind { - PlainText, - Markdown, -}; -raw_ostream &operator<<(raw_ostream &os, MarkupKind kind); - -struct MarkupContent { - MarkupKind kind = MarkupKind::PlainText; - std::string value; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const MarkupContent &mc); - -//===----------------------------------------------------------------------===// -// Hover -//===----------------------------------------------------------------------===// - -struct Hover { - /// Construct a default hover with the given range that uses Markdown content. - Hover(Range range) : contents{MarkupKind::Markdown, ""}, range(range) {} - - /// The hover's content. - MarkupContent contents; - - /// An optional range is a range inside a text document that is used to - /// visualize a hover, e.g. by changing the background color. - std::optional<Range> range; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const Hover &hover); - -//===----------------------------------------------------------------------===// -// SymbolKind -//===----------------------------------------------------------------------===// - -enum class SymbolKind { - File = 1, - Module = 2, - Namespace = 3, - Package = 4, - Class = 5, - Method = 6, - Property = 7, - Field = 8, - Constructor = 9, - Enum = 10, - Interface = 11, - Function = 12, - Variable = 13, - Constant = 14, - String = 15, - Number = 16, - Boolean = 17, - Array = 18, - Object = 19, - Key = 20, - Null = 21, - EnumMember = 22, - Struct = 23, - Event = 24, - Operator = 25, - TypeParameter = 26 -}; - -//===----------------------------------------------------------------------===// -// DocumentSymbol -//===----------------------------------------------------------------------===// - -/// Represents programming constructs like variables, classes, interfaces etc. -/// that appear in a document. Document symbols can be hierarchical and they -/// have two ranges: one that encloses its definition and one that points to its -/// most interesting range, e.g. the range of an identifier. -struct DocumentSymbol { - DocumentSymbol() = default; - DocumentSymbol(DocumentSymbol &&) = default; - DocumentSymbol(const Twine &name, SymbolKind kind, Range range, - Range selectionRange) - : name(name.str()), kind(kind), range(range), - selectionRange(selectionRange) {} - - /// The name of this symbol. - std::string name; - - /// More detail for this symbol, e.g the signature of a function. - std::string detail; - - /// The kind of this symbol. - SymbolKind kind; - - /// The range enclosing this symbol not including leading/trailing whitespace - /// but everything else like comments. This information is typically used to - /// determine if the clients cursor is inside the symbol to reveal in the - /// symbol in the UI. - Range range; - - /// The range that should be selected and revealed when this symbol is being - /// picked, e.g the name of a function. Must be contained by the `range`. - Range selectionRange; - - /// Children of this symbol, e.g. properties of a class. - std::vector<DocumentSymbol> children; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const DocumentSymbol &symbol); - -//===----------------------------------------------------------------------===// -// DocumentSymbolParams -//===----------------------------------------------------------------------===// - -struct DocumentSymbolParams { - // The text document to find symbols in. - TextDocumentIdentifier textDocument; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, DocumentSymbolParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// DiagnosticRelatedInformation -//===----------------------------------------------------------------------===// - -/// Represents a related message and source code location for a diagnostic. -/// This should be used to point to code locations that cause or related to a -/// diagnostics, e.g. when duplicating a symbol in a scope. -struct DiagnosticRelatedInformation { - DiagnosticRelatedInformation() = default; - DiagnosticRelatedInformation(Location location, std::string message) - : location(std::move(location)), message(std::move(message)) {} - - /// The location of this related diagnostic information. - Location location; - /// The message of this related diagnostic information. - std::string message; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, - DiagnosticRelatedInformation &result, llvm::json::Path path); -llvm::json::Value toJSON(const DiagnosticRelatedInformation &info); - -//===----------------------------------------------------------------------===// -// Diagnostic -//===----------------------------------------------------------------------===// - -enum class DiagnosticSeverity { - /// It is up to the client to interpret diagnostics as error, warning, info or - /// hint. - Undetermined = 0, - Error = 1, - Warning = 2, - Information = 3, - Hint = 4 -}; - -enum class DiagnosticTag { - Unnecessary = 1, - Deprecated = 2, -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(DiagnosticTag tag); -bool fromJSON(const llvm::json::Value &value, DiagnosticTag &result, - llvm::json::Path path); - -struct Diagnostic { - /// The source range where the message applies. - Range range; - - /// The diagnostic's severity. Can be omitted. If omitted it is up to the - /// client to interpret diagnostics as error, warning, info or hint. - DiagnosticSeverity severity = DiagnosticSeverity::Undetermined; - - /// A human-readable string describing the source of this diagnostic, e.g. - /// 'typescript' or 'super lint'. - std::string source; - - /// The diagnostic's message. - std::string message; - - /// An array of related diagnostic information, e.g. when symbol-names within - /// a scope collide all definitions can be marked via this property. - std::optional<std::vector<DiagnosticRelatedInformation>> relatedInformation; - - /// Additional metadata about the diagnostic. - std::vector<DiagnosticTag> tags; - - /// The diagnostic's category. Can be omitted. - /// An LSP extension that's used to send the name of the category over to the - /// client. The category typically describes the compilation stage during - /// which the issue was produced, e.g. "Semantic Issue" or "Parse Issue". - std::optional<std::string> category; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const Diagnostic &diag); -bool fromJSON(const llvm::json::Value &value, Diagnostic &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// PublishDiagnosticsParams -//===----------------------------------------------------------------------===// - -struct PublishDiagnosticsParams { - PublishDiagnosticsParams(URIForFile uri, int64_t version) - : uri(std::move(uri)), version(version) {} - - /// The URI for which diagnostic information is reported. - URIForFile uri; - /// The list of reported diagnostics. - std::vector<Diagnostic> diagnostics; - /// The version number of the document the diagnostics are published for. - int64_t version; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const PublishDiagnosticsParams ¶ms); - -//===----------------------------------------------------------------------===// -// TextEdit -//===----------------------------------------------------------------------===// - -struct TextEdit { - /// The range of the text document to be manipulated. To insert - /// text into a document create a range where start === end. - Range range; - - /// The string to be inserted. For delete operations use an - /// empty string. - std::string newText; -}; - -inline bool operator==(const TextEdit &lhs, const TextEdit &rhs) { - return std::tie(lhs.newText, lhs.range) == std::tie(rhs.newText, rhs.range); -} - -bool fromJSON(const llvm::json::Value &value, TextEdit &result, - llvm::json::Path path); -llvm::json::Value toJSON(const TextEdit &value); -raw_ostream &operator<<(raw_ostream &os, const TextEdit &value); - -//===----------------------------------------------------------------------===// -// CompletionItemKind -//===----------------------------------------------------------------------===// - -/// The kind of a completion entry. -enum class CompletionItemKind { - Missing = 0, - Text = 1, - Method = 2, - Function = 3, - Constructor = 4, - Field = 5, - Variable = 6, - Class = 7, - Interface = 8, - Module = 9, - Property = 10, - Unit = 11, - Value = 12, - Enum = 13, - Keyword = 14, - Snippet = 15, - Color = 16, - File = 17, - Reference = 18, - Folder = 19, - EnumMember = 20, - Constant = 21, - Struct = 22, - Event = 23, - Operator = 24, - TypeParameter = 25, -}; -bool fromJSON(const llvm::json::Value &value, CompletionItemKind &result, - llvm::json::Path path); - -constexpr auto kCompletionItemKindMin = - static_cast<size_t>(CompletionItemKind::Text); -constexpr auto kCompletionItemKindMax = - static_cast<size_t>(CompletionItemKind::TypeParameter); -using CompletionItemKindBitset = std::bitset<kCompletionItemKindMax + 1>; -bool fromJSON(const llvm::json::Value &value, CompletionItemKindBitset &result, - llvm::json::Path path); - -CompletionItemKind -adjustKindToCapability(CompletionItemKind kind, - CompletionItemKindBitset &supportedCompletionItemKinds); - -//===----------------------------------------------------------------------===// -// CompletionItem -//===----------------------------------------------------------------------===// - -/// Defines whether the insert text in a completion item should be interpreted -/// as plain text or a snippet. -enum class InsertTextFormat { - Missing = 0, - /// The primary text to be inserted is treated as a plain string. - PlainText = 1, - /// The primary text to be inserted is treated as a snippet. - /// - /// A snippet can define tab stops and placeholders with `$1`, `$2` - /// and `${3:foo}`. `$0` defines the final tab stop, it defaults to the end - /// of the snippet. Placeholders with equal identifiers are linked, that is - /// typing in one will update others too. - /// - /// See also: - /// https//github.com/Microsoft/vscode/blob/master/src/vs/editor/contrib/snippet/common/snippet.md - Snippet = 2, -}; - -struct CompletionItem { - CompletionItem() = default; - CompletionItem(const Twine &label, CompletionItemKind kind, - StringRef sortText = "") - : label(label.str()), kind(kind), sortText(sortText.str()), - insertTextFormat(InsertTextFormat::PlainText) {} - - /// The label of this completion item. By default also the text that is - /// inserted when selecting this completion. - std::string label; - - /// The kind of this completion item. Based of the kind an icon is chosen by - /// the editor. - CompletionItemKind kind = CompletionItemKind::Missing; - - /// A human-readable string with additional information about this item, like - /// type or symbol information. - std::string detail; - - /// A human-readable string that represents a doc-comment. - std::optional<MarkupContent> documentation; - - /// A string that should be used when comparing this item with other items. - /// When `falsy` the label is used. - std::string sortText; - - /// A string that should be used when filtering a set of completion items. - /// When `falsy` the label is used. - std::string filterText; - - /// A string that should be inserted to a document when selecting this - /// completion. When `falsy` the label is used. - std::string insertText; - - /// The format of the insert text. The format applies to both the `insertText` - /// property and the `newText` property of a provided `textEdit`. - InsertTextFormat insertTextFormat = InsertTextFormat::Missing; - - /// An edit which is applied to a document when selecting this completion. - /// When an edit is provided `insertText` is ignored. - /// - /// Note: The range of the edit must be a single line range and it must - /// contain the position at which completion has been requested. - std::optional<TextEdit> textEdit; - - /// An optional array of additional text edits that are applied when selecting - /// this completion. Edits must not overlap with the main edit nor with - /// themselves. - std::vector<TextEdit> additionalTextEdits; - - /// Indicates if this item is deprecated. - bool deprecated = false; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const CompletionItem &value); -raw_ostream &operator<<(raw_ostream &os, const CompletionItem &value); -bool operator<(const CompletionItem &lhs, const CompletionItem &rhs); - -//===----------------------------------------------------------------------===// -// CompletionList -//===----------------------------------------------------------------------===// - -/// Represents a collection of completion items to be presented in the editor. -struct CompletionList { - /// The list is not complete. Further typing should result in recomputing the - /// list. - bool isIncomplete = false; - - /// The completion items. - std::vector<CompletionItem> items; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const CompletionList &value); - -//===----------------------------------------------------------------------===// -// CompletionContext -//===----------------------------------------------------------------------===// - -enum class CompletionTriggerKind { - /// Completion was triggered by typing an identifier (24x7 code - /// complete), manual invocation (e.g Ctrl+Space) or via API. - Invoked = 1, - - /// Completion was triggered by a trigger character specified by - /// the `triggerCharacters` properties of the `CompletionRegistrationOptions`. - TriggerCharacter = 2, - - /// Completion was re-triggered as the current completion list is incomplete. - TriggerTriggerForIncompleteCompletions = 3 -}; - -struct CompletionContext { - /// How the completion was triggered. - CompletionTriggerKind triggerKind = CompletionTriggerKind::Invoked; - - /// The trigger character (a single character) that has trigger code complete. - /// Is undefined if `triggerKind !== CompletionTriggerKind.TriggerCharacter` - std::string triggerCharacter; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, CompletionContext &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// CompletionParams -//===----------------------------------------------------------------------===// - -struct CompletionParams : TextDocumentPositionParams { - CompletionContext context; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, CompletionParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// ParameterInformation -//===----------------------------------------------------------------------===// - -/// A single parameter of a particular signature. -struct ParameterInformation { - /// The label of this parameter. Ignored when labelOffsets is set. - std::string labelString; - - /// Inclusive start and exclusive end offsets withing the containing signature - /// label. - std::optional<std::pair<unsigned, unsigned>> labelOffsets; - - /// The documentation of this parameter. Optional. - std::string documentation; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const ParameterInformation &value); - -//===----------------------------------------------------------------------===// -// SignatureInformation -//===----------------------------------------------------------------------===// - -/// Represents the signature of something callable. -struct SignatureInformation { - /// The label of this signature. Mandatory. - std::string label; - - /// The documentation of this signature. Optional. - std::string documentation; - - /// The parameters of this signature. - std::vector<ParameterInformation> parameters; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const SignatureInformation &value); -raw_ostream &operator<<(raw_ostream &os, const SignatureInformation &value); - -//===----------------------------------------------------------------------===// -// SignatureHelp -//===----------------------------------------------------------------------===// - -/// Represents the signature of a callable. -struct SignatureHelp { - /// The resulting signatures. - std::vector<SignatureInformation> signatures; - - /// The active signature. - int activeSignature = 0; - - /// The active parameter of the active signature. - int activeParameter = 0; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const SignatureHelp &value); - -//===----------------------------------------------------------------------===// -// DocumentLinkParams -//===----------------------------------------------------------------------===// - -/// Parameters for the document link request. -struct DocumentLinkParams { - /// The document to provide document links for. - TextDocumentIdentifier textDocument; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, DocumentLinkParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// DocumentLink -//===----------------------------------------------------------------------===// - -/// A range in a text document that links to an internal or external resource, -/// like another text document or a web site. -struct DocumentLink { - DocumentLink() = default; - DocumentLink(Range range, URIForFile target) - : range(range), target(std::move(target)) {} - - /// The range this link applies to. - Range range; - - /// The uri this link points to. If missing a resolve request is sent later. - URIForFile target; - - // TODO: The following optional fields defined by the language server protocol - // are unsupported: - // - // data?: any - A data entry field that is preserved on a document link - // between a DocumentLinkRequest and a - // DocumentLinkResolveRequest. - - friend bool operator==(const DocumentLink &lhs, const DocumentLink &rhs) { - return lhs.range == rhs.range && lhs.target == rhs.target; - } - - friend bool operator!=(const DocumentLink &lhs, const DocumentLink &rhs) { - return !(lhs == rhs); - } -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const DocumentLink &value); - -//===----------------------------------------------------------------------===// -// InlayHintsParams -//===----------------------------------------------------------------------===// - -/// A parameter literal used in inlay hint requests. -struct InlayHintsParams { - /// The text document. - TextDocumentIdentifier textDocument; - - /// The visible document range for which inlay hints should be computed. - Range range; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, InlayHintsParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// InlayHintKind -//===----------------------------------------------------------------------===// - -/// Inlay hint kinds. -enum class InlayHintKind { - /// An inlay hint that for a type annotation. - /// - /// An example of a type hint is a hint in this position: - /// auto var ^ = expr; - /// which shows the deduced type of the variable. - Type = 1, - - /// An inlay hint that is for a parameter. - /// - /// An example of a parameter hint is a hint in this position: - /// func(^arg); - /// which shows the name of the corresponding parameter. - Parameter = 2, -}; - -//===----------------------------------------------------------------------===// -// InlayHint -//===----------------------------------------------------------------------===// - -/// Inlay hint information. -struct InlayHint { - InlayHint(InlayHintKind kind, Position pos) : position(pos), kind(kind) {} - - /// The position of this hint. - Position position; - - /// The label of this hint. A human readable string or an array of - /// InlayHintLabelPart label parts. - /// - /// *Note* that neither the string nor the label part can be empty. - std::string label; - - /// The kind of this hint. Can be omitted in which case the client should fall - /// back to a reasonable default. - InlayHintKind kind; - - /// Render padding before the hint. - /// - /// Note: Padding should use the editor's background color, not the - /// background color of the hint itself. That means padding can be used - /// to visually align/separate an inlay hint. - bool paddingLeft = false; - - /// Render padding after the hint. - /// - /// Note: Padding should use the editor's background color, not the - /// background color of the hint itself. That means padding can be used - /// to visually align/separate an inlay hint. - bool paddingRight = false; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const InlayHint &); -bool operator==(const InlayHint &lhs, const InlayHint &rhs); -bool operator<(const InlayHint &lhs, const InlayHint &rhs); -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, InlayHintKind value); - -//===----------------------------------------------------------------------===// -// CodeActionContext -//===----------------------------------------------------------------------===// - -struct CodeActionContext { - /// An array of diagnostics known on the client side overlapping the range - /// provided to the `textDocument/codeAction` request. They are provided so - /// that the server knows which errors are currently presented to the user for - /// the given range. There is no guarantee that these accurately reflect the - /// error state of the resource. The primary parameter to compute code actions - /// is the provided range. - std::vector<Diagnostic> diagnostics; - - /// Requested kind of actions to return. - /// - /// Actions not of this kind are filtered out by the client before being - /// shown. So servers can omit computing them. - std::vector<std::string> only; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, CodeActionContext &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// CodeActionParams -//===----------------------------------------------------------------------===// - -struct CodeActionParams { - /// The document in which the command was invoked. - TextDocumentIdentifier textDocument; - - /// The range for which the command was invoked. - Range range; - - /// Context carrying additional information. - CodeActionContext context; -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, CodeActionParams &result, - llvm::json::Path path); - -//===----------------------------------------------------------------------===// -// WorkspaceEdit -//===----------------------------------------------------------------------===// - -struct WorkspaceEdit { - /// Holds changes to existing resources. - std::map<std::string, std::vector<TextEdit>> changes; - - /// Note: "documentChanges" is not currently used because currently there is - /// no support for versioned edits. -}; - -/// Add support for JSON serialization. -bool fromJSON(const llvm::json::Value &value, WorkspaceEdit &result, - llvm::json::Path path); -llvm::json::Value toJSON(const WorkspaceEdit &value); - -//===----------------------------------------------------------------------===// -// CodeAction -//===----------------------------------------------------------------------===// - -/// A code action represents a change that can be performed in code, e.g. to fix -/// a problem or to refactor code. -/// -/// A CodeAction must set either `edit` and/or a `command`. If both are -/// supplied, the `edit` is applied first, then the `command` is executed. -struct CodeAction { - /// A short, human-readable, title for this code action. - std::string title; - - /// The kind of the code action. - /// Used to filter code actions. - std::optional<std::string> kind; - const static llvm::StringLiteral kQuickFix; - const static llvm::StringLiteral kRefactor; - const static llvm::StringLiteral kInfo; - - /// The diagnostics that this code action resolves. - std::optional<std::vector<Diagnostic>> diagnostics; - - /// Marks this as a preferred action. Preferred actions are used by the - /// `auto fix` command and can be targeted by keybindings. - /// A quick fix should be marked preferred if it properly addresses the - /// underlying error. A refactoring should be marked preferred if it is the - /// most reasonable choice of actions to take. - bool isPreferred = false; - - /// The workspace edit this code action performs. - std::optional<WorkspaceEdit> edit; -}; - -/// Add support for JSON serialization. -llvm::json::Value toJSON(const CodeAction &); - -} // namespace lsp -} // namespace mlir - -namespace llvm { -template <> -struct format_provider<mlir::lsp::Position> { - static void format(const mlir::lsp::Position &pos, raw_ostream &os, - StringRef style) { - assert(style.empty() && "style modifiers for this type are not supported"); - os << pos; - } -}; -} // namespace llvm - -#endif diff --git a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h index 9ed8326..920ce83 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h +++ b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h @@ -14,7 +14,8 @@ #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H #define MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/SourceMgr.h" #include <optional> @@ -45,17 +46,18 @@ bool contains(SMRange range, SMLoc loc); /// This class represents a single include within a root file. struct SourceMgrInclude { - SourceMgrInclude(const lsp::URIForFile &uri, const lsp::Range &range) + SourceMgrInclude(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range) : uri(uri), range(range) {} /// Build a hover for the current include file. - Hover buildHover() const; + llvm::lsp::Hover buildHover() const; /// The URI of the file that is included. - lsp::URIForFile uri; + llvm::lsp::URIForFile uri; /// The range of the include directive. - lsp::Range range; + llvm::lsp::Range range; }; /// Given a source manager, gather all of the processed include files. These are diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h deleted file mode 100644 index 0010a47..0000000 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ /dev/null @@ -1,283 +0,0 @@ -//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The language server protocol is usually implemented by writing messages as -// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON -// transport interface that handles this communication. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H - -#include "mlir/Support/DebugStringHelper.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/FunctionExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatAdapters.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/raw_ostream.h" -#include <atomic> - -namespace mlir { -namespace lsp { -class MessageHandler; - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// The encoding style of the JSON-RPC messages (both input and output). -enum JSONStreamStyle { - /// Encoding per the LSP specification, with mandatory Content-Length header. - Standard, - /// Messages are delimited by a '// -----' line. Comment lines start with //. - Delimited -}; - -/// An abstract class used by the JSONTransport to read JSON message. -class JSONTransportInput { -public: - explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard) - : style(style) {} - virtual ~JSONTransportInput() = default; - - virtual bool hasError() const = 0; - virtual bool isEndOfInput() const = 0; - - /// Read in a message from the input stream. - LogicalResult readMessage(std::string &json) { - return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) - : readStandardMessage(json); - } - virtual LogicalResult readDelimitedMessage(std::string &json) = 0; - virtual LogicalResult readStandardMessage(std::string &json) = 0; - -private: - /// The JSON stream style to use. - JSONStreamStyle style; -}; - -/// Concrete implementation of the JSONTransportInput that reads from a file. -class JSONTransportInputOverFile : public JSONTransportInput { -public: - explicit JSONTransportInputOverFile( - std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard) - : JSONTransportInput(style), in(in) {} - - bool hasError() const final { return ferror(in); } - bool isEndOfInput() const final { return feof(in); } - - LogicalResult readDelimitedMessage(std::string &json) final; - LogicalResult readStandardMessage(std::string &json) final; - -private: - std::FILE *in; -}; - -/// A transport class that performs the JSON-RPC communication with the LSP -/// client. -class JSONTransport { -public: - JSONTransport(std::unique_ptr<JSONTransportInput> in, raw_ostream &out, - bool prettyOutput = false) - : in(std::move(in)), out(out), prettyOutput(prettyOutput) {} - - JSONTransport(std::FILE *in, raw_ostream &out, - JSONStreamStyle style = JSONStreamStyle::Standard, - bool prettyOutput = false) - : in(std::make_unique<JSONTransportInputOverFile>(in, style)), out(out), - prettyOutput(prettyOutput) {} - - /// The following methods are used to send a message to the LSP client. - void notify(StringRef method, llvm::json::Value params); - void call(StringRef method, llvm::json::Value params, llvm::json::Value id); - void reply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result); - - /// Start executing the JSON-RPC transport. - llvm::Error run(MessageHandler &handler); - -private: - /// Dispatches the given incoming json message to the message handler. - bool handleMessage(llvm::json::Value msg, MessageHandler &handler); - /// Writes the given message to the output stream. - void sendMessage(llvm::json::Value msg); - -private: - /// The input to read a message from. - std::unique_ptr<JSONTransportInput> in; - SmallVector<char, 0> outputBuffer; - /// The output file stream. - raw_ostream &out; - /// If the output JSON should be formatted for easier readability. - bool prettyOutput; -}; - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -/// A Callback<T> is a void function that accepts Expected<T>. This is -/// accepted by functions that logically return T. -template <typename T> -using Callback = llvm::unique_function<void(llvm::Expected<T>)>; - -/// An OutgoingNotification<T> is a function used for outgoing notifications -/// send to the client. -template <typename T> -using OutgoingNotification = llvm::unique_function<void(const T &)>; - -/// An OutgoingRequest<T> is a function used for outgoing requests to send to -/// the client. -template <typename T> -using OutgoingRequest = - llvm::unique_function<void(const T &, llvm::json::Value id)>; - -/// An `OutgoingRequestCallback` is invoked when an outgoing request to the -/// client receives a response in turn. It is passed the original request's ID, -/// as well as the response result. -template <typename T> -using OutgoingRequestCallback = - std::function<void(llvm::json::Value, llvm::Expected<T>)>; - -/// A handler used to process the incoming transport messages. -class MessageHandler { -public: - MessageHandler(JSONTransport &transport) : transport(transport) {} - - bool onNotify(StringRef method, llvm::json::Value value); - bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id); - bool onReply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result); - - template <typename T> - static llvm::Expected<T> parse(const llvm::json::Value &raw, - StringRef payloadName, StringRef payloadKind) { - T result; - llvm::json::Path::Root root; - if (fromJSON(raw, result, root)) - return std::move(result); - - // Dump the relevant parts of the broken message. - std::string context; - llvm::raw_string_ostream os(context); - root.printErrorContext(raw, os); - - // Report the error (e.g. to the client). - return llvm::make_error<LSPError>( - llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind, - fmt_consume(root.getError())), - ErrorCode::InvalidParams); - } - - template <typename Param, typename Result, typename ThisT> - void method(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &, Callback<Result>)) { - methodHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams, - Callback<llvm::json::Value> reply) { - llvm::Expected<Param> param = parse<Param>(rawParams, method, "request"); - if (!param) - return reply(param.takeError()); - (thisPtr->*handler)(*param, std::move(reply)); - }; - } - - template <typename Param, typename ThisT> - void notification(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &)) { - notificationHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams) { - llvm::Expected<Param> param = - parse<Param>(rawParams, method, "notification"); - if (!param) { - return llvm::consumeError( - llvm::handleErrors(param.takeError(), [](const LSPError &lspError) { - Logger::error("JSON parsing error: {0}", - lspError.message.c_str()); - })); - } - (thisPtr->*handler)(*param); - }; - } - - /// Create an OutgoingNotification object used for the given method. - template <typename T> - OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) { - return [&, method](const T ¶ms) { - std::lock_guard<std::mutex> transportLock(transportOutputMutex); - Logger::info("--> {0}", method); - transport.notify(method, llvm::json::Value(params)); - }; - } - - /// Create an OutgoingRequest function that, when called, sends a request with - /// the given method via the transport. Should the outgoing request be - /// met with a response, the result JSON is parsed and the response callback - /// is invoked. - template <typename Param, typename Result> - OutgoingRequest<Param> - outgoingRequest(llvm::StringLiteral method, - OutgoingRequestCallback<Result> callback) { - return [&, method, callback](const Param ¶m, llvm::json::Value id) { - auto callbackWrapper = [method, callback = std::move(callback)]( - llvm::json::Value id, - llvm::Expected<llvm::json::Value> value) { - if (!value) - return callback(std::move(id), value.takeError()); - - std::string responseName = llvm::formatv("reply:{0}({1})", method, id); - llvm::Expected<Result> result = - parse<Result>(*value, responseName, "response"); - if (!result) - return callback(std::move(id), result.takeError()); - - return callback(std::move(id), *result); - }; - - { - std::lock_guard<std::mutex> lock(responseHandlersMutex); - responseHandlers.insert( - {debugString(id), std::make_pair(method.str(), callbackWrapper)}); - } - - std::lock_guard<std::mutex> transportLock(transportOutputMutex); - Logger::info("--> {0}({1})", method, id); - transport.call(method, llvm::json::Value(param), id); - }; - } - -private: - template <typename HandlerT> - using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>; - - HandlerMap<void(llvm::json::Value)> notificationHandlers; - HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)> - methodHandlers; - - /// A pair of (1) the original request's method name, and (2) the callback - /// function to be invoked for responses. - using ResponseHandlerTy = - std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>; - /// A mapping from request/response ID to response handler. - llvm::StringMap<ResponseHandlerTy> responseHandlers; - /// Mutex to guard insertion into the response handler map. - std::mutex responseHandlersMutex; - - JSONTransport &transport; - - /// Mutex to guard sending output messages to the transport. - std::mutex transportOutputMutex; -}; - -} // namespace lsp -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h index 4811ecb..0d9ba2a 100644 --- a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h @@ -16,14 +16,16 @@ namespace llvm { template <typename Fn> class function_ref; +namespace lsp { +class URIForFile; +} // namespace lsp } // namespace llvm namespace mlir { class DialectRegistry; namespace lsp { -class URIForFile; using DialectRegistryFn = - llvm::function_ref<DialectRegistry &(const URIForFile &uri)>; + llvm::function_ref<DialectRegistry &(const llvm::lsp::URIForFile &uri)>; } // namespace lsp } // namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6949f4a..a096f82 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -433,7 +433,7 @@ private: std::is_same_v<T, Value>, ConversionCallbackFn> wrapCallback(FnT &&callback) { - hasContextAwareTypeConversions = true; + contextAwareTypeConversionsIndex = conversions.size(); return [callback = std::forward<FnT>(callback)]( PointerUnion<Type, Value> typeOrValue, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -555,6 +555,10 @@ private: cachedMultiConversions.clear(); } + /// Internal implementation of the type conversion. + LogicalResult convertTypeImpl(PointerUnion<Type, Value> t, + SmallVectorImpl<Type> &results) const; + /// The set of registered conversion functions. SmallVector<ConversionCallbackFn, 4> conversions; @@ -575,10 +579,13 @@ private: mutable llvm::sys::SmartRWMutex<true> cacheMutex; /// Whether the type converter has context-aware type conversions. I.e., /// conversion rules that depend on the SSA value instead of just the type. - /// Type conversion caching is deactivated when there are context-aware - /// conversions because the type converter may return different results for - /// the same input type. - bool hasContextAwareTypeConversions = false; + /// We store here the index in the `conversions` vector of the last added + /// context-aware conversion, if any. This is useful because we can't cache + /// the result of type conversion happening after context-aware conversions, + /// because the type converter may return different results for the same input + /// type. This is why it is recommened to add context-aware conversions first, + /// any context-free conversions after will benefit from caching. + int contextAwareTypeConversionsIndex = -1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 9424eff..131c49c 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -22,6 +22,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" @@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " << OpWithFlags(top, OpPrintingFlags().skipRegions()); analysisScope = top; + hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>(); auto walkFn = [&](Operation *symTable, bool allUsesVisible) { LDBG() << "[init] Processing symbol table op: " << OpWithFlags(symTable, OpPrintingFlags().skipRegions()); @@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { return failure(); } // Recurse on nested operations. - for (Region ®ion : op->getRegions()) { - LDBG() << "[init] Recursing into region of op: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - for (Operation &nestedOp : region.getOps()) { - LDBG() << "[init] Recursing into nested op: " - << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); - if (failed(initializeRecursively(&nestedOp))) - return failure(); + if (op->getNumRegions()) { + // If we haven't seen a symbol table yet, check if the current operation + // has one. If so, update the flag to allow for resolving callables in + // nested regions. + bool savedHasSymbolTable = hasSymbolTable; + auto restoreHasSymbolTable = + llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; }); + if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>()) + hasSymbolTable = true; + + for (Region ®ion : op->getRegions()) { + LDBG() << "[init] Recursing into region of op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + for (Operation &nestedOp : region.getOps()) { + LDBG() << "[init] Recursing into nested op: " + << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); + if (failed(initializeRecursively(&nestedOp))) + return failure(); + } } } LDBG() << "[init] Finished initializeRecursively for op: " @@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { LDBG() << "visitCallOperation: " << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); - Operation *callableOp = call.resolveCallableInTable(&symbolTable); + + Operation *callableOp = nullptr; + if (hasSymbolTable) + callableOp = call.resolveCallableInTable(&symbolTable); + else + LDBG() + << "No symbol table present in analysis scope, can't resolve callable"; // A call to a externally-defined callable has unknown predecessors. const auto isExternalCallable = [this](Operation *op) { diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index d05374f..b51465b 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( AbstractDenseLattice *after) { // Allow for customizing the behavior of calls to external symbols, including // when the analysis is explicitly marked as non-interprocedural. - auto callable = - dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { + auto isExternalCallable = [&]() { + auto callable = + dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); + return callable && !callable.getCallableRegion(); + }; + if (!getSolverConfig().isInterprocedural() || isExternalCallable()) { return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, before, after); } @@ -290,6 +292,12 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &after, AbstractDenseLattice *before) { + // If the solver is not interprocedural, let the hook handle it as an external + // callee. + if (!getSolverConfig().isInterprocedural()) + return visitCallControlFlowTransfer( + call, CallControlFlowAction::ExternalCallee, after, before); + // Find the callee. Operation *callee = call.resolveCallableInTable(&symbolTable); @@ -297,12 +305,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( // No region means the callee is only declared in this module. // If that is the case or if the solver is not interprocedural, // let the hook handle it. - if (!getSolverConfig().isInterprocedural() || - (callable && (!callable.getCallableRegion() || - callable.getCallableRegion()->empty()))) { + if (callable && + (!callable.getCallableRegion() || callable.getCallableRegion()->empty())) return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, after, before); - } if (!callable) return setToExitState(before); diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 13a3e14..0d2e2ed 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation( ArrayRef<AbstractSparseLattice *> resultLattices) { // If the call operation is to an external function, attempt to infer the // results from the call arguments. - auto callable = - dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { + auto isExternalCallable = [&]() { + auto callable = + dyn_cast_if_present<CallableOpInterface>(call.resolveCallable()); + return callable && !callable.getCallableRegion(); + }; + if (!getSolverConfig().isInterprocedural() || isExternalCallable()) { visitExternalCallImpl(call, operandLattices, resultLattices); return success(); } diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 7e1b405..9352ab0 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -9,6 +9,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/iterator.h" @@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { isRunning = true; auto guard = llvm::make_scope_exit([&]() { isRunning = false; }); + bool isInterprocedural = config.isInterprocedural(); + auto restoreInterprocedural = llvm::make_scope_exit( + [&]() { config.setInterprocedural(isInterprocedural); }); + if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>()) + config.setInterprocedural(false); + // Initialize equivalent lattice anchors. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { analysis.initializeEquivalentLatticeAnchor(top); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index d7282b3..a14f09f 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -52,9 +52,14 @@ NB_MODULE(_mlir, m) { [](PyGlobals &self, bool enabled) { self.getTracebackLoc().setLocTracebacksEnabled(enabled); }) + .def("loc_tracebacks_frame_limit", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebackFramesLimit(); + }) .def("set_loc_tracebacks_frame_limit", - [](PyGlobals &self, int n) { - self.getTracebackLoc().setLocTracebackFramesLimit(n); + [](PyGlobals &self, std::optional<int> n) { + self.getTracebackLoc().setLocTracebackFramesLimit( + n.value_or(PyGlobals::TracebackLoc::kMaxFrames)); }) .def("register_traceback_file_inclusion", [](PyGlobals &self, const std::string &filename) { diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 6ee85e8..47ef5d8 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -57,6 +57,13 @@ private: /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- + // Mapping of MlirExternalPass + //---------------------------------------------------------------------------- + nb::class_<MlirExternalPass>(m, "ExternalPass") + .def("signal_pass_failure", + [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); + + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- nb::class_<PyPassManager>(m, "PassManager") @@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { callbacks.clone = [](void *) -> void * { throw std::runtime_error("Cloning Python passes not supported"); }; - callbacks.run = [](MlirOperation op, MlirExternalPass, + callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { - nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op); + nb::handle(static_cast<PyObject *>(userData))(op, pass); }; auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index d29053a..1659437 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -22,8 +22,6 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" -#include "llvm/Support/Format.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SourceMgr.h" @@ -296,12 +294,38 @@ public: if (failed(parseVarInt(alignment))) return failure(); - // Check that the requested alignment is less than or equal to the - // alignment of the root buffer. If it is not, we cannot safely guarantee - // that the specified alignment is globally correct. + // Check that the requested alignment must not exceed the alignment of + // the root buffer itself. Otherwise we cannot guarantee that pointers + // derived from this buffer will actually satisfy the requested alignment + // globally. // - // E.g. if the buffer is 8k aligned and the section is 16k aligned, - // we could end up at an offset of 24k, which is not globally 16k aligned. + // Consider a bytecode buffer that is guaranteed to be 8k aligned, but not + // 16k aligned (e.g. absolute address 40960. If a section inside this + // buffer declares a 16k alignment requirement, two problems can arise: + // + // (a) If we "align forward" the current pointer to the next + // 16k boundary, the amount of padding we skip depends on the + // buffer's starting address. For example: + // + // buffer_start = 40960 + // next 16k boundary = 49152 + // bytes skipped = 49152 - 40960 = 8192 + // + // This leaves behind variable padding that could be misinterpreted + // as part of the next section. + // + // (b) If we align relative to the buffer start, we may + // obtain addresses that are multiples of "buffer_start + + // section_alignment" rather than truly globally aligned + // addresses. For example: + // + // buffer_start = 40960 (5×8k, 8k aligned but not 16k) + // offset = 16384 (first multiple of 16k) + // section_ptr = 40960 + 16384 = 57344 + // + // 57344 is 8k aligned but not 16k aligned. + // Any consumer expecting true 16k alignment would see this as a + // violation. if (failed(alignmentValidator(alignment))) return emitError("failed to align section ID: ", unsigned(sectionID)); diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 807d1f5..bbfa3d1 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -36,7 +36,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -287,19 +286,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { // code. struct LowerGpuOpsToROCDLOpsPass final : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { - LowerGpuOpsToROCDLOpsPass() = default; - LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, - bool useBarePtrCallConv, - gpu::amd::Runtime runtime) { - if (this->chipset.getNumOccurrences() == 0) - this->chipset = chipset; - if (this->indexBitwidth.getNumOccurrences() == 0) - this->indexBitwidth = indexBitwidth; - if (this->useBarePtrCallConv.getNumOccurrences() == 0) - this->useBarePtrCallConv = useBarePtrCallConv; - if (this->runtime.getNumOccurrences() == 0) - this->runtime = runtime; - } + using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { Base::getDependentDialects(registry); @@ -499,12 +486,3 @@ void mlir::populateGpuToROCDLConversionPatterns( populateMathToROCDLConversionPatterns(converter, patterns); } - -std::unique_ptr<OperationPass<gpu::GPUModuleOp>> -mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, - unsigned indexBitwidth, - bool useBarePtrCallConv, - gpu::amd::Runtime runtime) { - return std::make_unique<LowerGpuOpsToROCDLOpsPass>( - chipset, indexBitwidth, useBarePtrCallConv, runtime); -} diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index e5496e5..aa47e39 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -405,7 +405,8 @@ std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) { return std::make_unique<OMPIImplTraits>(moduleOp); if (!strAttr || strAttr.getValue() != "MPICH") moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI (" - << strAttr.getValue() << "), defaulting to MPICH"; + << (strAttr ? strAttr.getValue() : "<NULL>") + << "), defaulting to MPICH"; return std::make_unique<MPICHImplTraits>(moduleOp); } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index c4a9fc2..460595b 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -492,8 +492,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { // Create loop nest and populate region with contents of scf.parallel. auto loopOp = omp::LoopNestOp::create( - rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(), - parallelOp.getUpperBound(), parallelOp.getStep()); + rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(), + parallelOp.getLowerBound(), parallelOp.getUpperBound(), + parallelOp.getStep(), /*loop_inclusive=*/false, + /*tile_sizes=*/nullptr); rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), loopOp.getRegion().begin()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 9852df6..0b44ca7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); populateVectorFromElementsLoweringPatterns(patterns); + populateVectorToElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index a7f2dc2..9ead1d8 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern matchAndRewrite(xegpu::CreateNdDescOp op, xegpu::CreateNdDescOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); + if (mixedOffsets.size() != 0) + return rewriter.notifyMatchFailure(op, "Offsets not supported."); auto loc = op.getLoc(); auto source = op.getSource(); // Op is lowered to a code sequence that populates payload. @@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); - SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); if (rank != 2) @@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets can be either 2D or not provided (0 is used). - if (mixedOffsets.size() == 2) { - offsetW = createOffset(mixedOffsets, 1); - offsetH = createOffset(mixedOffsets, 0); - } else if (mixedOffsets.size() == 0) { - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - } else { - return rewriter.notifyMatchFailure(op, - "Expected 2D offsets or no offsets."); - } + // Offsets are not supported (0 is used). + offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); @@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern } }; -class UpdateNdOffsetToXeVMPattern - : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::UpdateNdOffsetOp op, - xegpu::UpdateNdOffsetOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto mixedOffsets = op.getMixedOffsets(); - // Only 2D offsets are supported for now. - if (mixedOffsets.size() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); - auto payload = adaptor.getTensorDesc(); - // Utility for updating payload offset values from op fold result. - auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { - Value offset = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); - offset = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offset); - Value oldOffset = - vector::ExtractOp::create(rewriter, loc, payload, payloadPos); - Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); - return vector::InsertOp::create(rewriter, loc, newOffset, payload, - payloadPos); - }; - // Update offsets in the payload. - payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH)); - payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW)); - rewriter.replaceOp(op, payload); - return success(); - } -}; - template < typename OpType, typename = std::enable_if_t<llvm::is_one_of< @@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto mixedOffsets = op.getMixedOffsets(); + int64_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); @@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); Value baseShapeH = vector::ExtractOp::create( rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); - // Offsets provided in two ways: - // 1. Offsets are extracted from the tensor descriptor. - // 2. (Mixed) offsets which are provided by the op. - Value offsetW; - Value offsetH; - auto mixedOffsets = op.getMixedOffsets(); - int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 0 && opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, - "Expected 2D offsets or no offsets."); - if (opOffsetsSize) { - // If mixed offsets are provided by the op convert them to i32. - offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); - } else { - // If offsets are not available, we need to extract them from the tensor - // descriptor. - offsetW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW)); - offsetH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH)); - } + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); @@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, return newAddr; } -class CreateDescToXeVMPattern - : public OpConversionPattern<xegpu::CreateDescOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto eTy = op.getTensorDescType().getElementType(); - auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) - return rewriter.notifyMatchFailure( - op, "Expected element type bit width to be multiple of 8."); - auto loc = op.getLoc(); - // Offsets are provided as scalar i64 by type converter. - auto offsets = adaptor.getOffsets(); - // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). - // But type converter will convert them to integer types. - Value addr = adaptor.getSource(); - // ui32 or i32 are passed as i32 so they need to be casted to i64. - if (addr.getType() != rewriter.getI64Type()) - addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); - auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); - rewriter.replaceOp(op, laneAddr); - return success(); - } -}; - -class UpdateOffsetToXeVMPattern - : public OpConversionPattern<xegpu::UpdateOffsetOp> { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::UpdateOffsetOp op, - xegpu::UpdateOffsetOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto eTy = op.getTensorDescType().getElementType(); - auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) - return rewriter.notifyMatchFailure( - op, "Expected element type bit width to be multiple of 8."); - auto loc = op.getLoc(); - // Scatter descriptor is provided as scalar i64 by type converter. - // Offsets are provided as scalar i64 by type converter. - Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), - adaptor.getOffsets(), eBw / 8); - rewriter.replaceOp(op, newOffset); - return success(); - } -}; - template <typename OpType, typename = std::enable_if_t<llvm::is_one_of< OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>> @@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Value offset = adaptor.getOffsets(); + if (!offset) + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); @@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrI64); } - Value offsets = adaptor.getOffsets(); Value mask = adaptor.getMask(); - if (offsets) { - if (dyn_cast<VectorType>(offsets.getType())) { - // Offset needs be scalar. Single element vector is converted to scalar - // by type converter. - return rewriter.notifyMatchFailure(op, - "Expected offsets to be a scalar."); - } else { - // If offsets are provided, we add them to the base pointer. - // Offsets are in number of elements, we need to multiply by - // element byte size. - basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); - } + if (dyn_cast<VectorType>(offset.getType())) { + // Offset needs be scalar. Single element vector is converted to scalar + // by type converter. + return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar."); + } else { + // If offset is provided, we add them to the base pointer. + // Offset is in number of elements, we need to multiply by + // element byte size. + basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass //===----------------------------------------------------------------------===// void mlir::populateXeGPUToXeVMConversionPatterns( const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern, + patterns.add<CreateNdDescToXeVMPattern, LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>, LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>, LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>( typeConverter, patterns.getContext()); - patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern, - AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, + patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( typeConverter, patterns.getContext()); diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 4dfcb2b..0f90acf 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types, return os.str(); } -template <bool isLoad, typename OpType> -int32_t getL1CacheControl(OpType op) { +static int32_t getL1CacheControl(LoadCacheControl cc) { int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1UC_L2C_L3C: - control = 1; - break; - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1C_L2C_L3C: - control = 2; - break; - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3UC: - case LoadCacheControl::L1S_L2C_L3C: - control = 3; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1UC_L2WB_L3WB: - control = 1; - break; - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1WT_L2WB_L3WB: - control = 2; - break; - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1S_L2WB_L3WB: - control = 3; - break; - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 4; - break; - } + switch (cc) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1UC_L2C_L3C: + control = 1; + break; + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1C_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3UC: + case LoadCacheControl::L1S_L2C_L3C: + control = 3; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; } return control; } -template <bool isLoad, typename OpType> -int32_t getL3CacheControl(OpType op) { +static int32_t getL1CacheControl(StoreCacheControl cc) { int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2C_L3UC: - control = 1; - break; - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3C: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3C: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3C: - control = 2; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3WB: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3WB: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3WB: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 2; - break; - } + switch (cc) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1UC_L2WB_L3WB: + control = 1; + break; + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1WT_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1S_L2WB_L3WB: + control = 3; + break; + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 4; + break; } return control; } -template <bool isLoad, typename OpType> +static int32_t getL3CacheControl(LoadCacheControl cc) { + int32_t control = 0; + switch (cc) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3C: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3C: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3C: + control = 2; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + } + return control; +} + +static int32_t getL3CacheControl(StoreCacheControl cc) { + int32_t control = 0; + switch (cc) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3WB: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3WB: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3WB: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 2; + break; + } + return control; +} + +static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) { + return op.getCacheControl(); +} + +static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) { + return op.getCacheControl(); +} + +static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) { + return op.getCacheControl(); +} + +static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) { + return op.getCacheControl(); +} + +static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) { + if (op->hasAttr("cache_control")) { + auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control"); + if (!attr) + return std::nullopt; + return std::optional<LoadCacheControl>(attr.getValue()); + } + return std::nullopt; +} + +static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) { + if (op->hasAttr("cache_control")) { + auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control"); + if (!attr) + return std::nullopt; + return std::optional<StoreCacheControl>(attr.getValue()); + } + return std::nullopt; +} + +template <typename OpType> +int32_t getL1CacheControl(OpType op) { + return getL1CacheControl(*getCacheControl(op)); +} + +template <typename OpType> +int32_t getL3CacheControl(OpType op) { + return getL3CacheControl(*getCacheControl(op)); +} + +template <typename OpType> static std::optional<ArrayAttr> getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { - if (!op.getCacheControl()) + if (!getCacheControl(op)) return {}; constexpr int32_t decorationCacheControlArity{4}; constexpr int32_t loadCacheControlKey{6442}; constexpr int32_t storeCacheControlKey{6443}; + constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> || + std::is_same_v<OpType, BlockPrefetch2dOp> || + std::is_same_v<OpType, LLVM::LoadOp> || + std::is_same_v<OpType, PrefetchOp>; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector<int32_t, decorationCacheControlArity> decorationsL1{ - controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0}; + controlKey, 0, getL1CacheControl<OpType>(op), 0}; SmallVector<int32_t, decorationCacheControlArity> decorationsL3{ - controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0}; + controlKey, 1, getL3CacheControl<OpType>(op), 0}; auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); @@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, {}, funcAttr, op.getOperation()); if (std::optional<ArrayAttr> optCacheControls = - getCacheControlMetadata<true>(rewriter, op)) + getCacheControlMetadata(rewriter, op)) call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); rewriter.eraseOp(op); return success(); @@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, paramAttrs, funcAttr, op.getOperation()); if (std::optional<ArrayAttr> optCacheControls = - getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { + getCacheControlMetadata(rewriter, op)) { call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); } if constexpr (isLoad) @@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { return success(); } }; +template <typename OpType> +class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op->hasAttr("cache_control")) + return failure(); + std::optional<ArrayAttr> optCacheControls = + getCacheControlMetadata(rewriter, op); + op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + op->removeAttr("cache_control"); + return success(); + } +}; //===----------------------------------------------------------------------===// // Pass Definition @@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass void runOnOperation() override { ConversionTarget target(getContext()); - target.addLegalDialect<LLVM::LLVMDialect>(); - target.addIllegalDialect<XeVMDialect>(); RewritePatternSet patterns(&getContext()); - populateXeVMToLLVMConversionPatterns(patterns); + populateXeVMToLLVMConversionPatterns(target, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { - populateXeVMToLLVMConversionPatterns(patterns); + populateXeVMToLLVMConversionPatterns(target, patterns); } }; } // namespace @@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { // Pattern Population //===----------------------------------------------------------------------===// -void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { +void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, + RewritePatternSet &patterns) { + target.addDynamicallyLegalDialect<LLVM::LLVMDialect>( + [](Operation *op) { return !op->hasAttr("cache_control"); }); + target.addIllegalDialect<XeVMDialect>(); patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>, LoadStorePrefetchToOCLPattern<BlockStore2dOp>, LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, - MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( - patterns.getContext()); + MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, + LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, + LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 99ea20b..f38493b 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SetVector.h" @@ -241,7 +242,98 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, return &node; } -bool MemRefDependenceGraph::init() { +/// Returns the memref being read/written by a memref/affine load/store op. +static Value getMemRef(Operation *memOp) { + if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp)) + return memrefLoad.getMemRef(); + if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp)) + return affineLoad.getMemRef(); + if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp)) + return memrefStore.getMemRef(); + if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp)) + return affineStore.getMemRef(); + llvm_unreachable("unexpected op"); +} + +/// Returns true if there may be a dependence on `memref` from srcNode's +/// memory ops to dstNode's memory ops, while using the affine memory +/// dependence analysis checks. The method assumes that there is at least one +/// memory op in srcNode's loads and stores on `memref`, and similarly for +/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the +/// same block and so the dependences are tested at the depth of that block. +static bool mayDependence(const Node &srcNode, const Node &dstNode, + Value memref) { + assert(srcNode.op->getBlock() == dstNode.op->getBlock()); + if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op)) + return true; + + // Conservatively handle dependences involving non-affine load/stores. Return + // true if there exists a conflicting read/write access involving such. + + // Check whether there is a dependence from a source read/write op to a + // destination read/write one; all expected to be memref/affine load/store. + auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps, + ArrayRef<Operation *> dstMemOps) { + return llvm::any_of(srcMemOps, [&](Operation *srcOp) { + Value srcMemref = getMemRef(srcOp); + if (srcMemref != memref) + return false; + return llvm::find_if(dstMemOps, [&](Operation *dstOp) { + return srcMemref == getMemRef(dstOp); + }) != dstMemOps.end(); + }); + }; + + SmallVector<Operation *> dstOps; + // Between non-affine src stores and dst load/store. + llvm::append_range(dstOps, llvm::concat<Operation *const>( + dstNode.loads, dstNode.stores, + dstNode.memrefLoads, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.memrefStores, dstOps)) + return true; + // Between non-affine loads and dst stores. + dstOps.clear(); + llvm::append_range(dstOps, llvm::concat<Operation *const>( + dstNode.stores, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.memrefLoads, dstOps)) + return true; + // Between affine stores and memref load/stores. + dstOps.clear(); + llvm::append_range(dstOps, llvm::concat<Operation *const>( + dstNode.memrefLoads, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.stores, dstOps)) + return true; + // Between affine loads and memref stores. + dstOps.clear(); + llvm::append_range(dstOps, dstNode.memrefStores); + if (hasNonAffineDep(srcNode.loads, dstOps)) + return true; + + // Affine load/store pairs. We don't need to check for locally allocated + // memrefs since the dependence analysis here is between mem ops from + // srcNode's for op to dstNode's for op at the depth at which those + // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where + // `d` is the number of common surrounding loops. + for (auto *srcMemOp : + llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) { + MemRefAccess srcAcc(srcMemOp); + if (srcAcc.memref != memref) + continue; + for (auto *destMemOp : + llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) { + MemRefAccess destAcc(destMemOp); + if (destAcc.memref != memref) + continue; + // Check for a top-level dependence between srcNode and destNode's ops. + if (!noDependence(checkMemrefAccessDependence( + srcAcc, destAcc, getNestingDepth(srcNode.op) + 1))) + return true; + } + } + return false; +} + +bool MemRefDependenceGraph::init(bool fullAffineDependences) { LDBG() << "--- Initializing MDG ---"; // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. @@ -344,8 +436,12 @@ bool MemRefDependenceGraph::init() { Node *dstNode = getNode(dstId); bool dstHasStoreOrFree = dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef); - if (srcHasStoreOrFree || dstHasStoreOrFree) - addEdge(srcId, dstId, srcMemRef); + if ((srcHasStoreOrFree || dstHasStoreOrFree)) { + // Check precise affine deps if asked for; otherwise, conservative. + if (!fullAffineDependences || + mayDependence(*srcNode, *dstNode, srcMemRef)) + addEdge(srcId, dstId, srcMemRef); + } } } } @@ -562,13 +658,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, } // Build set of insts in range (srcId, dstId) which depend on 'srcId'. - SmallPtrSet<Operation *, 2> srcDepInsts; + llvm::SmallPtrSet<Operation *, 2> srcDepInsts; for (auto &outEdge : outEdges.lookup(srcId)) if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->op); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. - SmallPtrSet<Operation *, 2> dstDepInsts; + llvm::SmallPtrSet<Operation *, 2> dstDepInsts; for (auto &inEdge : inEdges.lookup(dstId)) if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->op); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f0c1f44..f3db8f7c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build( } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - transform::TransformRewriter &rewriter, LinalgOp target, + transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); + + auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target); + if (!partialReductionOp) { + return emitSilenceableFailure( + target->getLoc(), + "Operation should implement PartialReductionOpInterface"); + } SmallVector<OpFoldResult> numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector<OpFoldResult> tileSizes = @@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( extractFromIntegerArrayAttr<unsigned>(getReductionDims()); if (reductionDims.empty()) { for (auto [idx, iteratorType] : - llvm::enumerate(target.getIteratorTypesArray())) { + llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) { if (iteratorType == utils::IteratorType::reduction) reductionDims.push_back(idx); } } options.setReductionDims(reductionDims); - FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF( - rewriter, cast<TilingInterface>(target.getOperation()), options); + FailureOr<scf::SCFTilingResult> result = + scf::tileUsingSCF(rewriter, partialReductionOp, options); if (failed(result)) { auto diag = emitSilenceableError() << "could not tile reduction"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index b7da20c..9015cbb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -208,7 +208,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( // Does it require broadcast? if (!broadcastedDims.empty()) { - assert(broadcastedDims.size() && "should have non size broadcast"); + assert(!broadcastedDims.empty() && "should have non size broadcast"); Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape, inputRTType.getElementType()); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 6e43f28..3d70e28 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -56,6 +56,11 @@ makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) { return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray); } +static DenseI64ArrayAttr +makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef<int64_t> intArray) { + return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray); +} + namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, @@ -2956,10 +2961,10 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) { for (auto &iv : ivs) iv.type = loopVarType; + auto *ctx = parser.getBuilder().getContext(); // Parse "inclusive" flag. if (succeeded(parser.parseOptionalKeyword("inclusive"))) - result.addAttribute("loop_inclusive", - UnitAttr::get(parser.getBuilder().getContext())); + result.addAttribute("loop_inclusive", UnitAttr::get(ctx)); // Parse step values. SmallVector<OpAsmParser::UnresolvedOperand> steps; @@ -2967,6 +2972,35 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) return failure(); + // Parse collapse + int64_t value = 0; + if (!parser.parseOptionalKeyword("collapse") && + (parser.parseLParen() || parser.parseInteger(value) || + parser.parseRParen())) + return failure(); + if (value > 1) + result.addAttribute( + "collapse_num_loops", + IntegerAttr::get(parser.getBuilder().getI64Type(), value)); + + // Parse tiles + SmallVector<int64_t> tiles; + auto parseTiles = [&]() -> ParseResult { + int64_t tile; + if (parser.parseInteger(tile)) + return failure(); + tiles.push_back(tile); + return success(); + }; + + if (!parser.parseOptionalKeyword("tiles") && + (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) || + parser.parseRParen())) + return failure(); + + if (tiles.size() > 0) + result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles)); + // Parse the body. Region *region = result.addRegion(); if (parser.parseRegion(*region, ivs)) @@ -2990,14 +3024,23 @@ void LoopNestOp::print(OpAsmPrinter &p) { if (getLoopInclusive()) p << "inclusive "; p << "step (" << getLoopSteps() << ") "; + if (int64_t numCollapse = getCollapseNumLoops()) + if (numCollapse > 1) + p << "collapse(" << numCollapse << ") "; + + if (const auto tiles = getTileSizes()) + p << "tiles(" << tiles.value() << ") "; + p.printRegion(region, /*printEntryBlockArgs=*/false); } void LoopNestOp::build(OpBuilder &builder, OperationState &state, const LoopNestOperands &clauses) { - LoopNestOp::build(builder, state, clauses.loopLowerBounds, - clauses.loopUpperBounds, clauses.loopSteps, - clauses.loopInclusive); + MLIRContext *ctx = builder.getContext(); + LoopNestOp::build(builder, state, clauses.collapseNumLoops, + clauses.loopLowerBounds, clauses.loopUpperBounds, + clauses.loopSteps, clauses.loopInclusive, + makeDenseI64ArrayAttr(ctx, clauses.tileSizes)); } LogicalResult LoopNestOp::verify() { @@ -3013,6 +3056,17 @@ LogicalResult LoopNestOp::verify() { << "range argument type does not match corresponding IV type"; } + uint64_t numIVs = getIVs().size(); + + if (const auto &numCollapse = getCollapseNumLoops()) + if (numCollapse > numIVs) + return emitOpError() + << "collapse value is larger than the number of loops"; + + if (const auto &tiles = getTileSizes()) + if (tiles.value().size() > numIVs) + return emitOpError() << "too few canonical loops for tile dimensions"; + if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp())) return emitOpError() << "expects parent op to be a loop wrapper"; @@ -3142,7 +3196,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { } SmallString<64> Name("canonloop"); - for (std::string s : reverse(components)) { + for (const std::string &s : reverse(components)) { Name += '_'; Name += s; } diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp index 1c3971d2..1940e6d 100644 --- a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp @@ -88,10 +88,10 @@ class NormalizedQuantTypesConverter : public TypeConverter { llvm::find_if(shape, [](int64_t dim) { return dim != 1; }); auto scales = llvm::to_vector(llvm::map_range( subChannelType.getScales().getValues<APFloat>(), - [](APFloat scale) { return scale.convertToDouble(); })); + [](const APFloat &scale) { return scale.convertToDouble(); })); auto zeroPoints = llvm::to_vector(llvm::map_range( subChannelType.getZeroPoints().getValues<APInt>(), - [](APInt zeroPoint) { return zeroPoint.getSExtValue(); })); + [](const APInt &zeroPoint) { return zeroPoint.getSExtValue(); })); auto perAxisType = UniformQuantizedPerAxisType::get( subChannelType.getFlags(), subChannelType.getStorageType(), subChannelType.getExpressedType(), scales, zeroPoints, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index f993398..5511998 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -723,7 +723,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames( IntegerType intTy = llvm::dyn_cast<IntegerType>(type); if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) { - if (intTy && intTy.getWidth() == 1) { + assert(intTy); + + if (intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 87bed81..b4c87a3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -512,20 +512,20 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { // Tosa utilities. //===----------------------------------------------------------------------===// -std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { +static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { if (lhs % rhs != 0) return std::nullopt; return lhs / rhs; } -Type getStorageElementTypeOrSelf(Type type) { +static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType)) srcType = quantType.getStorageType(); return srcType; } -Type getStorageElementTypeOrSelf(Value value) { +static Type getStorageElementTypeOrSelf(Value value) { return getStorageElementTypeOrSelf(value.getType()); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 706076c..aba6178 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1639,8 +1639,8 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter, return a.size() < b.size(); })->size(); - for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++) - payloads[argIdx].resize(numIterations); + for (auto &payload : payloads) + payload.resize(numIterations); } // As we will be "zipping" over them, check all payloads have the same size. diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index fe066dc..6bb390a 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( vector::populateVectorFromElementsLoweringPatterns(patterns); } +void transform::ApplyUnrollToElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorToElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index acbf2b7..d74007f1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp + LowerVectorToElements.cpp LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp new file mode 100644 index 0000000..a53a183 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp @@ -0,0 +1,53 @@ +//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.to_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-to-elements" + +using namespace mlir; + +namespace { + +struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ToElementsOp op, + PatternRewriter &rewriter) const override { + + TypedValue<VectorType> source = op.getSource(); + FailureOr<SmallVector<Value>> result = + vector::unrollVectorValue(source, rewriter); + if (failed(result)) { + return failure(); + } + SmallVector<Value> vectors = *result; + + SmallVector<Value> results; + for (const Value &vector : vectors) { + auto subElements = + vector::ToElementsOp::create(rewriter, op.getLoc(), vector); + llvm::append_range(results, subElements.getResults()); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + +} // namespace + +void mlir::vector::populateVectorToElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<UnrollToElements>(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index c84eb2c..995a259 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map, return targetType; } +/// Given a warpOp that contains ops with regions, the corresponding op's +/// "inner" region and the distributionMapFn, get all values used by the op's +/// region that are defined within the warpOp, but outside the inner region. +/// Return the set of values, their types and their distributed types. +std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>, + SmallVector<Type>> +getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion, + DistributionMapFn distributionMapFn) { + llvm::SmallSetVector<Value, 32> escapingValues; + SmallVector<Type> escapingValueTypes; + SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp + if (innerRegion.empty()) + return {std::move(escapingValues), std::move(escapingValueTypes), + std::move(escapingValueDistTypes)}; + mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) { + Operation *parent = operand->get().getParentRegion()->getParentOp(); + if (warpOp->isAncestor(parent)) { + if (!escapingValues.insert(operand->get())) + return; + Type distType = operand->get().getType(); + if (auto vecType = dyn_cast<VectorType>(distType)) { + AffineMap map = distributionMapFn(operand->get()); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + escapingValueTypes.push_back(operand->get().getType()); + escapingValueDistTypes.push_back(distType); + } + }); + return {std::move(escapingValues), std::move(escapingValueTypes), + std::move(escapingValueDistTypes)}; +} + /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` /// will not be distributed (it should be less than the warp size). @@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern { } }; +/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if +/// the scf.if is the last operation in the region so that it doesn't +/// change the order of execution. This creates a new scf.if after the +/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in +/// the "inner" WarpExecuteOnLane0Op. Example: +/// ``` +/// gpu.warp_execute_on_lane_0(%laneid)[32] { +/// %payload = ... : vector<32xindex> +/// scf.if %pred { +/// vector.store %payload, %buffer[%idx] : memref<128xindex>, +/// vector<32xindex> +/// } +/// gpu.yield +/// } +/// ``` +/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] { +/// %payload = ... : vector<32xindex> +/// gpu.yield %payload : vector<32xindex> +/// } +/// scf.if %pred { +/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) { +/// ^bb0(%arg1: vector<32xindex>): +/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex> +/// } +/// } +/// ``` +struct WarpOpScfIfOp : public WarpDistributionPattern { + WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + gpu::YieldOp warpOpYield = warpOp.getTerminator(); + // Only pick up `IfOp` if it is the last op in the region. + Operation *lastNode = warpOpYield->getPrevNode(); + auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode); + if (!ifOp) + return failure(); + + // The current `WarpOp` can yield two types of values: + // 1. Not results of `IfOp`: + // Preserve them in the new `WarpOp`. + // Collect their yield index to remap the usages. + // 2. Results of `IfOp`: + // They are not part of the new `WarpOp` results. + // Map current warp's yield operand index to `IfOp` result idx. + SmallVector<Value> nonIfYieldValues; + SmallVector<unsigned> nonIfYieldIndices; + llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping; + llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes; + for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { + const unsigned yieldOperandIdx = yieldOperand.getOperandNumber(); + if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) { + nonIfYieldValues.push_back(yieldOperand.get()); + nonIfYieldIndices.push_back(yieldOperandIdx); + continue; + } + OpResult ifResult = cast<OpResult>(yieldOperand.get()); + const unsigned ifResultIdx = ifResult.getResultNumber(); + ifResultMapping[yieldOperandIdx] = ifResultIdx; + // If this `ifOp` result is vector type and it is yielded by the + // `WarpOp`, we keep track the distributed type for this result. + if (!isa<VectorType>(ifResult.getType())) + continue; + VectorType distType = + cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType()); + ifResultDistTypes[ifResultIdx] = distType; + } + + // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns + // them + auto [escapingValuesThen, escapingValueInputTypesThen, + escapingValueDistTypesThen] = + getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(), + distributionMapFn); + auto [escapingValuesElse, escapingValueInputTypesElse, + escapingValueDistTypesElse] = + getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(), + distributionMapFn); + if (llvm::is_contained(escapingValueDistTypesThen, Type{}) || + llvm::is_contained(escapingValueDistTypesElse, Type{})) + return failure(); + + // The new `WarpOp` groups yields values in following order: + // 1. Branch condition + // 2. Escaping values then branch + // 3. Escaping values else branch + // 4. All non-`ifOp` yielded values. + SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()}; + newWarpOpYieldValues.append(escapingValuesThen.begin(), + escapingValuesThen.end()); + newWarpOpYieldValues.append(escapingValuesElse.begin(), + escapingValuesElse.end()); + SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()}; + newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(), + escapingValueDistTypesThen.end()); + newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), + escapingValueDistTypesElse.end()); + + llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; + for (auto [idx, val] : + llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { + origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); + newWarpOpYieldValues.push_back(val); + newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); + } + // Create the new `WarpOp` with the updated yield values and types. + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // `ifOp` returns the result of the inner warp op. + SmallVector<Type> newIfOpDistResTypes; + for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { + Type distType = cast<Value>(res).getType(); + if (auto vecType = dyn_cast<VectorType>(distType)) { + AffineMap map = distributionMapFn(cast<Value>(res)); + // Fallback to affine map if the dist result was not previously recorded + distType = ifResultDistTypes.count(i) + ? ifResultDistTypes[i] + : getDistributedType(vecType, map, warpOp.getWarpSize()); + } + newIfOpDistResTypes.push_back(distType); + } + // Create a new `IfOp` outside the new `WarpOp` region. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newWarpOp); + auto newIfOp = scf::IfOp::create( + rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), + static_cast<bool>(ifOp.thenBlock()), + static_cast<bool>(ifOp.elseBlock())); + auto encloseRegionInWarpOp = + [&](Block *oldIfBranch, Block *newIfBranch, + llvm::SmallSetVector<Value, 32> &escapingValues, + SmallVector<Type> &escapingValueInputTypes, + size_t warpResRangeStart) { + OpBuilder::InsertionGuard g(rewriter); + if (!newIfBranch) + return; + rewriter.setInsertionPointToStart(newIfBranch); + llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex; + SmallVector<Value> innerWarpInputVals; + SmallVector<Type> innerWarpInputTypes; + for (size_t i = 0; i < escapingValues.size(); + ++i, ++warpResRangeStart) { + innerWarpInputVals.push_back( + newWarpOp.getResult(warpResRangeStart)); + escapeValToBlockArgIndex[escapingValues[i]] = + innerWarpInputTypes.size(); + innerWarpInputTypes.push_back(escapingValueInputTypes[i]); + } + auto innerWarp = WarpExecuteOnLane0Op::create( + rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(), + newWarpOp.getLaneid(), newWarpOp.getWarpSize(), + innerWarpInputVals, innerWarpInputTypes); + + innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent()); + innerWarp.getWarpRegion().addArguments( + innerWarpInputTypes, + SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc())); + + SmallVector<Value> yieldOperands; + for (Value operand : oldIfBranch->getTerminator()->getOperands()) + yieldOperands.push_back(operand); + rewriter.eraseOp(oldIfBranch->getTerminator()); + + rewriter.setInsertionPointToEnd(innerWarp.getBody()); + gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands); + rewriter.setInsertionPointAfter(innerWarp); + scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults()); + + // Update any users of escaping values that were forwarded to the + // inner `WarpOp`. These values are arguments of the inner `WarpOp`. + innerWarp.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + auto it = escapeValToBlockArgIndex.find(operand.get()); + if (it == escapeValToBlockArgIndex.end()) + continue; + operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + } + }); + mlir::vector::moveScalarUniformCode(innerWarp); + }; + encloseRegionInWarpOp(&ifOp.getThenRegion().front(), + &newIfOp.getThenRegion().front(), escapingValuesThen, + escapingValueInputTypesThen, 1); + if (!ifOp.getElseRegion().empty()) + encloseRegionInWarpOp(&ifOp.getElseRegion().front(), + &newIfOp.getElseRegion().front(), + escapingValuesElse, escapingValueInputTypesElse, + 1 + escapingValuesThen.size()); + // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` + // result. + for (auto [origIdx, newIdx] : ifResultMapping) + rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + newIfOp.getResult(newIdx), newIfOp); + // Similarly, update any users of the `WarpOp` results that were not + // results of the `IfOp`. + for (auto [origIdx, newIdx] : origToNewYieldIdx) + rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), + newWarpOp.getResult(newIdx)); + // Remove the original `WarpOp` and `IfOp`, they should not have any uses + // at this point. + rewriter.eraseOp(ifOp); + rewriter.eraseOp(warpOp); + return success(); + } + +private: + DistributionMapFn distributionMapFn; +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't /// change the order of execution. This creates a new scf.for region after the @@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern { return failure(); // Collect Values that come from the `WarpOp` but are outside the `ForOp`. // Those Values need to be returned by the new warp op. - llvm::SmallSetVector<Value, 32> escapingValues; - SmallVector<Type> escapingValueInputTypes; - SmallVector<Type> escapingValueDistTypes; - mlir::visitUsedValuesDefinedAbove( - forOp.getBodyRegion(), [&](OpOperand *operand) { - Operation *parent = operand->get().getParentRegion()->getParentOp(); - if (warpOp->isAncestor(parent)) { - if (!escapingValues.insert(operand->get())) - return; - Type distType = operand->get().getType(); - if (auto vecType = dyn_cast<VectorType>(distType)) { - AffineMap map = distributionMapFn(operand->get()); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - } - escapingValueInputTypes.push_back(operand->get().getType()); - escapingValueDistTypes.push_back(distType); - } - }); - + auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] = + getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(), + distributionMapFn); if (llvm::is_contained(escapingValueDistTypes, Type{})) return failure(); // `WarpOp` can yield two types of values: @@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( benefit); patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, benefit); + patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 7dde631..12acf4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final } }; +/// This pattern linearizes the operand in `vector.to_elements` operations +/// by converting the source type to a 1-D vector while preserving all element +/// values. The transformation creates a linearized `vector.shape_cast` +/// followed by a `vector.to_elements`. +/// +/// Example: +/// +/// %0:4 = vector.to_elements %v : vector<2x2xf32> +/// +/// is converted to: +/// +/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32> +/// %0:4 = vector.to_elements %vector_cast : vector<4xf32> +/// +struct LinearizeVectorToElements final + : public OpConversionPattern<vector::ToElementsOp> { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorToElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType vecType = toElementsOp.getSource().getType(); + if (vecType.getRank() <= 1) + return rewriter.notifyMatchFailure( + toElementsOp, "the rank is already less than or equal to 1"); + + assert(vecType.getNumScalableDims() == 0 && + "to_elements does not support scalable vectors"); + auto vec1DType = + VectorType::get({vecType.getNumElements()}, vecType.getElementType()); + Value shapeCast = vector::ShapeCastOp::create( + rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource()); + auto newToElementsOp = + vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(), + toElementsOp.getResultTypes(), shapeCast); + rewriter.replaceOp(toElementsOp, newToElementsOp); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, - LinearizeVectorStore, LinearizeVectorFromElements>( - typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 841e138..39dc7a4f 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, return success(); } +/// Takes a 2+ dimensional vector as an input +/// returns n vector values produced by n vector.extract operations. +/// I.e. calling unrollVectorValue([[%v]], rewriter) such that +/// +/// %v : vector<nxaxb...> +/// +/// will produce the following IR changes +/// +/// %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...> +/// %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...> +/// ... +/// %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ... +/// +/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]} +FailureOr<SmallVector<Value>> +vector::unrollVectorValue(TypedValue<VectorType> vector, + RewriterBase &rewriter) { + SmallVector<Value> subvectors; + VectorType ty = cast<VectorType>(vector.getType()); + Location loc = vector.getLoc(); + if (ty.getRank() < 2) + return rewriter.notifyMatchFailure(loc, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (ty.getScalableDims().front()) + return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim"); + + for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) { + subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i)); + } + + return subvectors; +} + LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, vector::UnrollVectorOpFn unrollFn) { assert(op->getNumResults() == 1 && "expected single result"); diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index af923d9..fdc1984 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -2,7 +2,6 @@ set(LLVM_OPTIONAL_SOURCES CallInterfaces.cpp CastInterfaces.cpp ControlFlowInterfaces.cpp - CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp DestinationStyleOpInterface.cpp @@ -43,7 +42,6 @@ endfunction(add_mlir_interface_library) add_mlir_interface_library(CallInterfaces) add_mlir_interface_library(CastInterfaces) add_mlir_interface_library(ControlFlowInterfaces) -add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(DestinationStyleOpInterface) diff --git a/mlir/lib/Interfaces/CopyOpInterface.cpp b/mlir/lib/Interfaces/CopyOpInterface.cpp deleted file mode 100644 index 8e6132c..0000000 --- a/mlir/lib/Interfaces/CopyOpInterface.cpp +++ /dev/null @@ -1,18 +0,0 @@ -//===- CopyOpInterface.cpp - Copy operations interface in MLIR ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Interfaces/CopyOpInterface.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// CopyOp Interface -//===----------------------------------------------------------------------===// - -/// Include the definitions of the copy operation interface. -#include "mlir/Interfaces/CopyOpInterface.cpp.inc" diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index caa9091..d2bafb7 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include <utility> + #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/IR/BuiltinTypes.h" @@ -151,7 +153,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map, [](Value v) { return Variable(v); })) {} ValueBoundsConstraintSet::ValueBoundsConstraintSet( - MLIRContext *ctx, StopConditionFn stopCondition, + MLIRContext *ctx, const StopConditionFn &stopCondition, bool addConservativeSemiAffineBounds) : builder(ctx), stopCondition(stopCondition), addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) { @@ -302,7 +304,8 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { return pos; } -int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, +int64_t ValueBoundsConstraintSet::insert(AffineMap map, + const ValueDimList &operands, bool isSymbol) { assert(map.getNumResults() == 1 && "expected affine map with one result"); int64_t pos = insert(isSymbol); @@ -629,7 +632,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, const Variable &var, - StopConditionFn stopCondition, bool closedUB) { + const StopConditionFn &stopCondition, bool closedUB) { // Default stop condition if none was specified: Keep adding constraints until // a bound could be computed. int64_t pos = 0; @@ -666,7 +669,7 @@ void ValueBoundsConstraintSet::populateConstraints(Value value, int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, ValueDimList operands) { - int64_t pos = insert(map, operands, /*isSymbol=*/false); + int64_t pos = insert(map, std::move(operands), /*isSymbol=*/false); // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. processWorklist(); @@ -826,10 +829,9 @@ FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1, return strongCompare(var1, ComparisonOperator::EQ, var2); } -FailureOr<bool> -ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, - HyperrectangularSlice slice1, - HyperrectangularSlice slice2) { +FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices( + MLIRContext *ctx, const HyperrectangularSlice &slice1, + const HyperrectangularSlice &slice2) { assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && "expected slices of same rank"); assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && @@ -891,10 +893,9 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, return true; } -FailureOr<bool> -ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx, - HyperrectangularSlice slice1, - HyperrectangularSlice slice2) { +FailureOr<bool> ValueBoundsConstraintSet::areEquivalentSlices( + MLIRContext *ctx, const HyperrectangularSlice &slice1, + const HyperrectangularSlice &slice2) { assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() && "expected slices of same rank"); assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() && diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index da86b00..926ffd0 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -385,7 +385,8 @@ void Operator::populateTypeInferenceInfo( if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { // Check for a non-variable length operand to use as the type anchor. auto *operandI = llvm::find_if(arguments, [](const Argument &arg) { - NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg); + NamedTypeConstraint *operand = + llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg); return operand && !operand->isVariableLength(); }); if (operandI == arguments.end()) @@ -663,15 +664,17 @@ void Operator::populateOpStructure() { argDef = argDef->getValueAsDef("constraint"); if (argDef->isSubClassOf(typeConstraintClass)) { - attrOrOperandMapping.push_back( - {OperandOrAttribute::Kind::Operand, operandIndex}); + attrPropOrOperandMapping.push_back( + {OperandAttrOrProp::Kind::Operand, operandIndex}); arguments.emplace_back(&operands[operandIndex++]); } else if (argDef->isSubClassOf(attrClass)) { - attrOrOperandMapping.push_back( - {OperandOrAttribute::Kind::Attribute, attrIndex}); + attrPropOrOperandMapping.push_back( + {OperandAttrOrProp::Kind::Attribute, attrIndex}); arguments.emplace_back(&attributes[attrIndex++]); } else { assert(argDef->isSubClassOf(propertyClass)); + attrPropOrOperandMapping.push_back( + {OperandAttrOrProp::Kind::Property, propIndex}); arguments.emplace_back(&properties[propIndex++]); } } @@ -867,9 +870,8 @@ auto Operator::VariableDecoratorIterator::unwrap(const Init *init) return VariableDecorator(cast<DefInit>(init)->getDef()); } -auto Operator::getArgToOperandOrAttribute(int index) const - -> OperandOrAttribute { - return attrOrOperandMapping[index]; +auto Operator::getArgToOperandAttrOrProp(int index) const -> OperandAttrOrProp { + return attrPropOrOperandMapping[index]; } std::string Operator::getGetterName(StringRef name) const { diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index 8d5d7f9..44732d5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Support/LLVM.h" @@ -21,6 +22,7 @@ #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MemoryModelRelaxationAnnotations.h" using namespace mlir; using namespace mlir::LLVM; @@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) { llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_dereferenceable, llvm::LLVMContext::MD_dereferenceable_or_null, + llvm::LLVMContext::MD_mmra, context.getMDKindID(vecTypeHintMDName), context.getMDKindID(workGroupSizeHintMDName), context.getMDKindID(reqdWorkGroupSizeMDName), @@ -212,6 +215,39 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node, return success(); } +/// Convert the given MMRA metadata (either an MMRA tag or an array of them) +/// into corresponding MLIR attributes and set them on the given operation as a +/// discardable `llvm.mmra` attribute. +static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op, + LLVM::ModuleImport &moduleImport) { + if (!node) + return success(); + + // We don't use the LLVM wrappers here becasue we care about the order + // of the metadata for deterministic roundtripping. + MLIRContext *ctx = op->getContext(); + auto toAttribute = [&](llvm::MDNode *tag) -> Attribute { + return LLVM::MMRATagAttr::get( + ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(), + cast<llvm::MDString>(tag->getOperand(1))->getString()); + }; + Attribute mlirMmra; + if (llvm::MMRAMetadata::isTagMD(node)) { + mlirMmra = toAttribute(node); + } else { + SmallVector<Attribute> tags; + for (const llvm::MDOperand &operand : node->operands()) { + auto *tagNode = dyn_cast<llvm::MDNode>(operand.get()); + if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode)) + return failure(); + tags.push_back(toAttribute(tagNode)); + } + mlirMmra = ArrayAttr::get(ctx, tags); + } + op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra); + return success(); +} + /// Converts the given loop metadata node to an MLIR loop annotation attribute /// and attaches it to the imported operation if the translation succeeds. /// Returns failure otherwise. @@ -432,7 +468,8 @@ public: return setDereferenceableAttr( node, llvm::LLVMContext::MD_dereferenceable_or_null, op, moduleImport); - + if (kind == llvm::LLVMContext::MD_mmra) + return setMmraAttr(node, op, moduleImport); llvm::LLVMContext &context = node->getContext(); if (kind == context.getMDKindID(vecTypeHintMDName)) return setVecTypeHintAttr(builder, node, op, moduleImport); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index fd8463a..eaf1d20 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -24,6 +24,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/MatrixBuilder.h" +#include "llvm/IR/MemoryModelRelaxationAnnotations.h" +#include "llvm/Support/LogicalResult.h" using namespace mlir; using namespace mlir::LLVM; @@ -723,6 +725,40 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return failure(); } +static LogicalResult +amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) { + StringRef name = attribute.getName(); + if (name == LLVMDialect::getMmraAttrName()) { + SmallVector<llvm::MMRAMetadata::TagT> tags; + if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) { + tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix()); + } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) { + for (Attribute attr : manyTags) { + auto tag = dyn_cast<MMRATagAttr>(attr); + if (!tag) + return op.emitOpError( + "MMRA annotations array contains value that isn't an MMRA tag"); + tags.emplace_back(tag.getPrefix(), tag.getSuffix()); + } + } else { + return op.emitOpError( + "llvm.mmra is something other than an MMRA tag or an array of them"); + } + llvm::MDTuple *mmraMd = + llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags); + if (!mmraMd) { + // Empty list, canonicalizes to nothing + return success(); + } + for (llvm::Instruction *inst : instructions) + inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd); + return success(); + } + return success(); +} + namespace { /// Implementation of the dialect interface that converts operations belonging /// to the LLVM dialect to LLVM IR. @@ -738,6 +774,14 @@ public: LLVM::ModuleTranslation &moduleTranslation) const final { return convertOperationImpl(*op, builder, moduleTranslation); } + + /// Handle some metadata that is represented as a discardable attribute. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + return amendOperationImpl(*op, instructions, attribute, moduleTranslation); + } }; } // namespace diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 4e26e65..5e194dc 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -3041,16 +3042,46 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, loopInfos.push_back(*loopResult); } - // Collapse loops. Store the insertion point because LoopInfos may get - // invalidated. llvm::OpenMPIRBuilder::InsertPointTy afterIP = loopInfos.front()->getAfterIP(); - // Update the stack frame created for this loop to point to the resulting loop - // after applying transformations. + // Do tiling. + if (const auto &tiles = loopOp.getTileSizes()) { + llvm::Type *ivType = loopInfos.front()->getIndVarType(); + SmallVector<llvm::Value *> tileSizes; + + for (auto tile : tiles.value()) { + llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile); + tileSizes.push_back(tileVal); + } + + std::vector<llvm::CanonicalLoopInfo *> newLoops = + ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes); + + // Update afterIP to get the correct insertion point after + // tiling. + llvm::BasicBlock *afterBB = newLoops.front()->getAfter(); + llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor(); + afterIP = {afterAfterBB, afterAfterBB->begin()}; + + // Update the loop infos. + loopInfos.clear(); + for (const auto &newLoop : newLoops) + loopInfos.push_back(newLoop); + } // Tiling done. + + // Do collapse. + const auto &numCollapse = loopOp.getCollapseNumLoops(); + SmallVector<llvm::CanonicalLoopInfo *> collapseLoopInfos( + loopInfos.begin(), loopInfos.begin() + (numCollapse)); + + auto newTopLoopInfo = + ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {}); + + assert(newTopLoopInfo && "New top loop information is missing"); moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>( [&](OpenMPLoopInfoStackFrame &frame) { - frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); + frame.loopInfo = newTopLoopInfo; return WalkResult::interrupt(); }); @@ -6304,7 +6335,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( if (auto filepathAttr = dyn_cast<StringAttr>(attr)) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); + auto VFS = llvm::vfs::getRealFileSystem(); + ompBuilder->loadOffloadInfoMetadata(*VFS, + filepathAttr.getValue()); return success(); } return failure(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp index 73b166d..7e9318a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -55,10 +55,6 @@ public: return handleDecorationCacheControl(instructions.front(), cacheControlsArray.getValue()); } - auto func = dyn_cast<LLVM::LLVMFuncOp>(op); - if (!func) - return failure(); - return success(); } diff --git a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt index 48a9601..2fe29f1 100644 --- a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt +++ b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt @@ -1,13 +1,13 @@ add_mlir_library(MLIRLspServerSupportLib CompilationDatabase.cpp - Logging.cpp - Protocol.cpp SourceMgrUtils.cpp - Transport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/lsp-server-support + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRSupport - ) +) diff --git a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp index 9ae0674..67b8ef6 100644 --- a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp +++ b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp @@ -8,14 +8,15 @@ #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" #include "mlir/Support/FileUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/YAMLTraits.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Logger; //===----------------------------------------------------------------------===// // YamlFileInfo diff --git a/mlir/lib/Tools/lsp-server-support/Logging.cpp b/mlir/lib/Tools/lsp-server-support/Logging.cpp deleted file mode 100644 index 373e216..0000000 --- a/mlir/lib/Tools/lsp-server-support/Logging.cpp +++ /dev/null @@ -1,41 +0,0 @@ -//===- Logging.cpp --------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "llvm/Support/Chrono.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::lsp; - -void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; } - -Logger &Logger::get() { - static Logger logger; - return logger; -} - -void Logger::log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message) { - Logger &logger = get(); - - // Ignore messages with log levels below the current setting in the logger. - if (logLevel < logger.logLevel) - return; - - // An indicator character for each log level. - const char *logLevelIndicators = "DIE"; - - // Format the message and print to errs. - llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now(); - std::lock_guard<std::mutex> logGuard(logger.mutex); - llvm::errs() << llvm::formatv( - "{0}[{1:%H:%M:%S.%L}] {2}\n", - logLevelIndicators[static_cast<unsigned>(logLevel)], timestamp, message); - llvm::errs().flush(); -} diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp deleted file mode 100644 index 9828704..0000000 --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ /dev/null @@ -1,1043 +0,0 @@ -//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains the serialization code for the LSP structs. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::lsp; - -// Helper that doesn't treat `null` and absent fields as failures. -template <typename T> -static bool mapOptOrNull(const llvm::json::Value ¶ms, - llvm::StringLiteral prop, T &out, - llvm::json::Path path) { - const llvm::json::Object *o = params.getAsObject(); - assert(o); - - // Field is missing or null. - auto *v = o->get(prop); - if (!v || v->getAsNull()) - return true; - return fromJSON(*v, out, path.field(prop)); -} - -//===----------------------------------------------------------------------===// -// LSPError -//===----------------------------------------------------------------------===// - -char LSPError::ID; - -//===----------------------------------------------------------------------===// -// URIForFile -//===----------------------------------------------------------------------===// - -static bool isWindowsPath(StringRef path) { - return path.size() > 1 && llvm::isAlpha(path[0]) && path[1] == ':'; -} - -static bool isNetworkPath(StringRef path) { - return path.size() > 2 && path[0] == path[1] && - llvm::sys::path::is_separator(path[0]); -} - -static bool shouldEscapeInURI(unsigned char c) { - // Unreserved characters. - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9')) - return false; - - switch (c) { - case '-': - case '_': - case '.': - case '~': - // '/' is only reserved when parsing. - case '/': - // ':' is only reserved for relative URI paths, which we doesn't produce. - case ':': - return false; - } - return true; -} - -/// Encodes a string according to percent-encoding. -/// - Unreserved characters are not escaped. -/// - Reserved characters always escaped with exceptions like '/'. -/// - All other characters are escaped. -static void percentEncode(StringRef content, std::string &out) { - for (unsigned char c : content) { - if (shouldEscapeInURI(c)) { - out.push_back('%'); - out.push_back(llvm::hexdigit(c / 16)); - out.push_back(llvm::hexdigit(c % 16)); - } else { - out.push_back(c); - } - } -} - -/// Decodes a string according to percent-encoding. -static std::string percentDecode(StringRef content) { - std::string result; - for (auto i = content.begin(), e = content.end(); i != e; ++i) { - if (*i != '%') { - result += *i; - continue; - } - if (*i == '%' && i + 2 < content.end() && llvm::isHexDigit(*(i + 1)) && - llvm::isHexDigit(*(i + 2))) { - result.push_back(llvm::hexFromNibbles(*(i + 1), *(i + 2))); - i += 2; - } else { - result.push_back(*i); - } - } - return result; -} - -/// Return the set containing the supported URI schemes. -static StringSet<> &getSupportedSchemes() { - static StringSet<> schemes({"file", "test"}); - return schemes; -} - -/// Returns true if the given scheme is structurally valid, i.e. it does not -/// contain any invalid scheme characters. This does not check that the scheme -/// is actually supported. -static bool isStructurallyValidScheme(StringRef scheme) { - if (scheme.empty()) - return false; - if (!llvm::isAlpha(scheme[0])) - return false; - return llvm::all_of(llvm::drop_begin(scheme), [](char c) { - return llvm::isAlnum(c) || c == '+' || c == '.' || c == '-'; - }); -} - -static llvm::Expected<std::string> uriFromAbsolutePath(StringRef absolutePath, - StringRef scheme) { - std::string body; - StringRef authority; - StringRef root = llvm::sys::path::root_name(absolutePath); - if (isNetworkPath(root)) { - // Windows UNC paths e.g. \\server\share => file://server/share - authority = root.drop_front(2); - absolutePath.consume_front(root); - } else if (isWindowsPath(root)) { - // Windows paths e.g. X:\path => file:///X:/path - body = "/"; - } - body += llvm::sys::path::convert_to_slash(absolutePath); - - std::string uri = scheme.str() + ":"; - if (authority.empty() && body.empty()) - return uri; - - // If authority if empty, we only print body if it starts with "/"; otherwise, - // the URI is invalid. - if (!authority.empty() || StringRef(body).starts_with("/")) { - uri.append("//"); - percentEncode(authority, uri); - } - percentEncode(body, uri); - return uri; -} - -static llvm::Expected<std::string> getAbsolutePath(StringRef authority, - StringRef body) { - if (!body.starts_with("/")) - return llvm::createStringError( - llvm::inconvertibleErrorCode(), - "File scheme: expect body to be an absolute path starting " - "with '/': " + - body); - SmallString<128> path; - if (!authority.empty()) { - // Windows UNC paths e.g. file://server/share => \\server\share - ("//" + authority).toVector(path); - } else if (isWindowsPath(body.substr(1))) { - // Windows paths e.g. file:///X:/path => X:\path - body.consume_front("/"); - } - path.append(body); - llvm::sys::path::native(path); - return std::string(path); -} - -static llvm::Expected<std::string> parseFilePathFromURI(StringRef origUri) { - StringRef uri = origUri; - - // Decode the scheme of the URI. - size_t pos = uri.find(':'); - if (pos == StringRef::npos) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Scheme must be provided in URI: " + - origUri); - StringRef schemeStr = uri.substr(0, pos); - std::string uriScheme = percentDecode(schemeStr); - if (!isStructurallyValidScheme(uriScheme)) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Invalid scheme: " + schemeStr + - " (decoded: " + uriScheme + ")"); - uri = uri.substr(pos + 1); - - // Decode the authority of the URI. - std::string uriAuthority; - if (uri.consume_front("//")) { - pos = uri.find('/'); - uriAuthority = percentDecode(uri.substr(0, pos)); - uri = uri.substr(pos); - } - - // Decode the body of the URI. - std::string uriBody = percentDecode(uri); - - // Compute the absolute path for this uri. - if (!getSupportedSchemes().contains(uriScheme)) { - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "unsupported URI scheme `" + uriScheme + - "' for workspace files"); - } - return getAbsolutePath(uriAuthority, uriBody); -} - -llvm::Expected<URIForFile> URIForFile::fromURI(StringRef uri) { - llvm::Expected<std::string> filePath = parseFilePathFromURI(uri); - if (!filePath) - return filePath.takeError(); - return URIForFile(std::move(*filePath), uri.str()); -} - -llvm::Expected<URIForFile> URIForFile::fromFile(StringRef absoluteFilepath, - StringRef scheme) { - llvm::Expected<std::string> uri = - uriFromAbsolutePath(absoluteFilepath, scheme); - if (!uri) - return uri.takeError(); - return fromURI(*uri); -} - -StringRef URIForFile::scheme() const { return uri().split(':').first; } - -void URIForFile::registerSupportedScheme(StringRef scheme) { - getSupportedSchemes().insert(scheme); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result, - llvm::json::Path path) { - if (std::optional<StringRef> str = value.getAsString()) { - llvm::Expected<URIForFile> expectedURI = URIForFile::fromURI(*str); - if (!expectedURI) { - path.report("unresolvable URI"); - consumeError(expectedURI.takeError()); - return false; - } - result = std::move(*expectedURI); - return true; - } - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const URIForFile &value) { - return value.uri(); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) { - return os << value.uri(); -} - -//===----------------------------------------------------------------------===// -// ClientCapabilities -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ClientCapabilities &result, llvm::json::Path path) { - const llvm::json::Object *o = value.getAsObject(); - if (!o) { - path.report("expected object"); - return false; - } - if (const llvm::json::Object *textDocument = o->getObject("textDocument")) { - if (const llvm::json::Object *documentSymbol = - textDocument->getObject("documentSymbol")) { - if (std::optional<bool> hierarchicalSupport = - documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) - result.hierarchicalDocumentSymbol = *hierarchicalSupport; - } - if (auto *codeAction = textDocument->getObject("codeAction")) { - if (codeAction->getObject("codeActionLiteralSupport")) - result.codeActionStructure = true; - } - } - if (auto *window = o->getObject("window")) { - if (std::optional<bool> workDoneProgressSupport = - window->getBoolean("workDoneProgress")) - result.workDoneProgress = *workDoneProgressSupport; - } - return true; -} - -//===----------------------------------------------------------------------===// -// ClientInfo -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, ClientInfo &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("name", result.name)) - return false; - - // Don't fail if we can't parse version. - o.map("version", result.version); - return true; -} - -//===----------------------------------------------------------------------===// -// InitializeParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TraceLevel &result, - llvm::json::Path path) { - if (std::optional<StringRef> str = value.getAsString()) { - if (*str == "off") { - result = TraceLevel::Off; - return true; - } - if (*str == "messages") { - result = TraceLevel::Messages; - return true; - } - if (*str == "verbose") { - result = TraceLevel::Verbose; - return true; - } - } - return false; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InitializeParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - // We deliberately don't fail if we can't parse individual fields. - o.map("capabilities", result.capabilities); - o.map("trace", result.trace); - mapOptOrNull(value, "clientInfo", result.clientInfo, path); - - return true; -} - -//===----------------------------------------------------------------------===// -// TextDocumentItem -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentItem &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && - o.map("languageId", result.languageId) && o.map("text", result.text) && - o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// TextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const TextDocumentIdentifier &value) { - return llvm::json::Object{{"uri", value.uri}}; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri); -} - -//===----------------------------------------------------------------------===// -// VersionedTextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value -mlir::lsp::toJSON(const VersionedTextDocumentIdentifier &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"version", value.version}, - }; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - VersionedTextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// Position -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Position &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("line", result.line) && - o.map("character", result.character); -} - -llvm::json::Value mlir::lsp::toJSON(const Position &value) { - return llvm::json::Object{ - {"line", value.line}, - {"character", value.character}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Position &value) { - return os << value.line << ':' << value.character; -} - -//===----------------------------------------------------------------------===// -// Range -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Range &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("start", result.start) && o.map("end", result.end); -} - -llvm::json::Value mlir::lsp::toJSON(const Range &value) { - return llvm::json::Object{ - {"start", value.start}, - {"end", value.end}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Range &value) { - return os << value.start << '-' << value.end; -} - -//===----------------------------------------------------------------------===// -// Location -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Location &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("range", result.range); -} - -llvm::json::Value mlir::lsp::toJSON(const Location &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"range", value.range}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Location &value) { - return os << value.range << '@' << value.uri; -} - -//===----------------------------------------------------------------------===// -// TextDocumentPositionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentPositionParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("position", result.position); -} - -//===----------------------------------------------------------------------===// -// ReferenceParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.mapOptional("includeDeclaration", result.includeDeclaration); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceParams &result, llvm::json::Path path) { - TextDocumentPositionParams &base = result; - llvm::json::ObjectMapper o(value, path); - return fromJSON(value, base, path) && o && - o.mapOptional("context", result.context); -} - -//===----------------------------------------------------------------------===// -// DidOpenTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidOpenTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidCloseTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidCloseTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidChangeTextDocumentParams -//===----------------------------------------------------------------------===// - -LogicalResult -TextDocumentContentChangeEvent::applyTo(std::string &contents) const { - // If there is no range, the full document changed. - if (!range) { - contents = text; - return success(); - } - - // Try to map the replacement range to the content. - llvm::SourceMgr tmpScrMgr; - tmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(contents), - SMLoc()); - SMRange rangeLoc = range->getAsSMRange(tmpScrMgr); - if (!rangeLoc.isValid()) - return failure(); - - contents.replace(rangeLoc.Start.getPointer() - contents.data(), - rangeLoc.End.getPointer() - rangeLoc.Start.getPointer(), - text); - return success(); -} - -LogicalResult TextDocumentContentChangeEvent::applyTo( - ArrayRef<TextDocumentContentChangeEvent> changes, std::string &contents) { - for (const auto &change : changes) - if (failed(change.applyTo(contents))) - return failure(); - return success(); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentContentChangeEvent &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && - o.map("rangeLength", result.rangeLength) && o.map("text", result.text); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidChangeTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("contentChanges", result.contentChanges); -} - -//===----------------------------------------------------------------------===// -// MarkupContent -//===----------------------------------------------------------------------===// - -static llvm::StringRef toTextKind(MarkupKind kind) { - switch (kind) { - case MarkupKind::PlainText: - return "plaintext"; - case MarkupKind::Markdown: - return "markdown"; - } - llvm_unreachable("Invalid MarkupKind"); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, MarkupKind kind) { - return os << toTextKind(kind); -} - -llvm::json::Value mlir::lsp::toJSON(const MarkupContent &mc) { - if (mc.value.empty()) - return nullptr; - - return llvm::json::Object{ - {"kind", toTextKind(mc.kind)}, - {"value", mc.value}, - }; -} - -//===----------------------------------------------------------------------===// -// Hover -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const Hover &hover) { - llvm::json::Object result{{"contents", toJSON(hover.contents)}}; - if (hover.range) - result["range"] = toJSON(*hover.range); - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbol -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) { - llvm::json::Object result{{"name", symbol.name}, - {"kind", static_cast<int>(symbol.kind)}, - {"range", symbol.range}, - {"selectionRange", symbol.selectionRange}}; - - if (!symbol.detail.empty()) - result["detail"] = symbol.detail; - if (!symbol.children.empty()) - result["children"] = symbol.children; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbolParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentSymbolParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DiagnosticRelatedInformation -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DiagnosticRelatedInformation &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("location", result.location) && - o.map("message", result.message); -} - -llvm::json::Value mlir::lsp::toJSON(const DiagnosticRelatedInformation &info) { - return llvm::json::Object{ - {"location", info.location}, - {"message", info.message}, - }; -} - -//===----------------------------------------------------------------------===// -// Diagnostic -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(DiagnosticTag tag) { - return static_cast<int>(tag); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, DiagnosticTag &result, - llvm::json::Path path) { - if (std::optional<int64_t> i = value.getAsInteger()) { - result = (DiagnosticTag)*i; - return true; - } - - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const Diagnostic &diag) { - llvm::json::Object result{ - {"range", diag.range}, - {"severity", (int)diag.severity}, - {"message", diag.message}, - }; - if (diag.category) - result["category"] = *diag.category; - if (!diag.source.empty()) - result["source"] = diag.source; - if (diag.relatedInformation) - result["relatedInformation"] = *diag.relatedInformation; - if (!diag.tags.empty()) - result["tags"] = diag.tags; - return std::move(result); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Diagnostic &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - int severity = 0; - if (!mapOptOrNull(value, "severity", severity, path)) - return false; - result.severity = (DiagnosticSeverity)severity; - - return o.map("range", result.range) && o.map("message", result.message) && - mapOptOrNull(value, "category", result.category, path) && - mapOptOrNull(value, "source", result.source, path) && - mapOptOrNull(value, "relatedInformation", result.relatedInformation, - path) && - mapOptOrNull(value, "tags", result.tags, path); -} - -//===----------------------------------------------------------------------===// -// PublishDiagnosticsParams -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams ¶ms) { - return llvm::json::Object{ - {"uri", params.uri}, - {"diagnostics", params.diagnostics}, - {"version", params.version}, - }; -} - -//===----------------------------------------------------------------------===// -// TextEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && o.map("newText", result.newText); -} - -llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) { - return llvm::json::Object{ - {"range", value.range}, - {"newText", value.newText}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) { - os << value.range << " => \""; - llvm::printEscapedString(value.newText, os); - return os << '"'; -} - -//===----------------------------------------------------------------------===// -// CompletionItemKind -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKind &result, llvm::json::Path path) { - if (std::optional<int64_t> intValue = value.getAsInteger()) { - if (*intValue < static_cast<int>(CompletionItemKind::Text) || - *intValue > static_cast<int>(CompletionItemKind::TypeParameter)) - return false; - result = static_cast<CompletionItemKind>(*intValue); - return true; - } - return false; -} - -CompletionItemKind mlir::lsp::adjustKindToCapability( - CompletionItemKind kind, - CompletionItemKindBitset &supportedCompletionItemKinds) { - size_t kindVal = static_cast<size_t>(kind); - if (kindVal >= kCompletionItemKindMin && - kindVal <= supportedCompletionItemKinds.size() && - supportedCompletionItemKinds[kindVal]) - return kind; - - // Provide some fall backs for common kinds that are close enough. - switch (kind) { - case CompletionItemKind::Folder: - return CompletionItemKind::File; - case CompletionItemKind::EnumMember: - return CompletionItemKind::Enum; - case CompletionItemKind::Struct: - return CompletionItemKind::Class; - default: - return CompletionItemKind::Text; - } -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKindBitset &result, - llvm::json::Path path) { - if (const llvm::json::Array *arrayValue = value.getAsArray()) { - for (size_t i = 0, e = arrayValue->size(); i < e; ++i) { - CompletionItemKind kindOut; - if (fromJSON((*arrayValue)[i], kindOut, path.index(i))) - result.set(size_t(kindOut)); - } - return true; - } - return false; -} - -//===----------------------------------------------------------------------===// -// CompletionItem -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) { - assert(!value.label.empty() && "completion item label is required"); - llvm::json::Object result{{"label", value.label}}; - if (value.kind != CompletionItemKind::Missing) - result["kind"] = static_cast<int>(value.kind); - if (!value.detail.empty()) - result["detail"] = value.detail; - if (value.documentation) - result["documentation"] = value.documentation; - if (!value.sortText.empty()) - result["sortText"] = value.sortText; - if (!value.filterText.empty()) - result["filterText"] = value.filterText; - if (!value.insertText.empty()) - result["insertText"] = value.insertText; - if (value.insertTextFormat != InsertTextFormat::Missing) - result["insertTextFormat"] = static_cast<int>(value.insertTextFormat); - if (value.textEdit) - result["textEdit"] = *value.textEdit; - if (!value.additionalTextEdits.empty()) { - result["additionalTextEdits"] = - llvm::json::Array(value.additionalTextEdits); - } - if (value.deprecated) - result["deprecated"] = value.deprecated; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const CompletionItem &value) { - return os << value.label << " - " << toJSON(value); -} - -bool mlir::lsp::operator<(const CompletionItem &lhs, - const CompletionItem &rhs) { - return (lhs.sortText.empty() ? lhs.label : lhs.sortText) < - (rhs.sortText.empty() ? rhs.label : rhs.sortText); -} - -//===----------------------------------------------------------------------===// -// CompletionList -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) { - return llvm::json::Object{ - {"isIncomplete", value.isIncomplete}, - {"items", llvm::json::Array(value.items)}, - }; -} - -//===----------------------------------------------------------------------===// -// CompletionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - int triggerKind; - if (!o || !o.map("triggerKind", triggerKind) || - !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path)) - return false; - result.triggerKind = static_cast<CompletionTriggerKind>(triggerKind); - return true; -} - -//===----------------------------------------------------------------------===// -// CompletionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionParams &result, llvm::json::Path path) { - if (!fromJSON(value, static_cast<TextDocumentPositionParams &>(result), path)) - return false; - if (const llvm::json::Value *context = value.getAsObject()->get("context")) - return fromJSON(*context, result.context, path.field("context")); - return true; -} - -//===----------------------------------------------------------------------===// -// ParameterInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) { - assert((value.labelOffsets || !value.labelString.empty()) && - "parameter information label is required"); - llvm::json::Object result; - if (value.labelOffsets) - result["label"] = llvm::json::Array( - {value.labelOffsets->first, value.labelOffsets->second}); - else - result["label"] = value.labelString; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// SignatureInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) { - assert(!value.label.empty() && "signature information label is required"); - llvm::json::Object result{ - {"label", value.label}, - {"parameters", llvm::json::Array(value.parameters)}, - }; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const SignatureInformation &value) { - return os << value.label << " - " << toJSON(value); -} - -//===----------------------------------------------------------------------===// -// SignatureHelp -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) { - assert(value.activeSignature >= 0 && - "Unexpected negative value for number of active signatures."); - assert(value.activeParameter >= 0 && - "Unexpected negative value for active parameter index"); - return llvm::json::Object{ - {"activeSignature", value.activeSignature}, - {"activeParameter", value.activeParameter}, - {"signatures", llvm::json::Array(value.signatures)}, - }; -} - -//===----------------------------------------------------------------------===// -// DocumentLinkParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentLinkParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DocumentLink -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentLink &value) { - return llvm::json::Object{ - {"range", value.range}, - {"target", value.target}, - }; -} - -//===----------------------------------------------------------------------===// -// InlayHintsParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InlayHintsParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range); -} - -//===----------------------------------------------------------------------===// -// InlayHint -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const InlayHint &value) { - return llvm::json::Object{{"position", value.position}, - {"kind", (int)value.kind}, - {"label", value.label}, - {"paddingLeft", value.paddingLeft}, - {"paddingRight", value.paddingRight}}; -} -bool mlir::lsp::operator==(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) == - std::tie(rhs.position, rhs.kind, rhs.label); -} -bool mlir::lsp::operator<(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) < - std::tie(rhs.position, rhs.kind, rhs.label); -} - -llvm::raw_ostream &mlir::lsp::operator<<(llvm::raw_ostream &os, - InlayHintKind value) { - switch (value) { - case InlayHintKind::Parameter: - return os << "parameter"; - case InlayHintKind::Type: - return os << "type"; - } - llvm_unreachable("Unknown InlayHintKind"); -} - -//===----------------------------------------------------------------------===// -// CodeActionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("diagnostics", result.diagnostics)) - return false; - o.map("only", result.only); - return true; -} - -//===----------------------------------------------------------------------===// -// CodeActionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range) && o.map("context", result.context); -} - -//===----------------------------------------------------------------------===// -// WorkspaceEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, WorkspaceEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("changes", result.changes); -} - -llvm::json::Value mlir::lsp::toJSON(const WorkspaceEdit &value) { - llvm::json::Object fileChanges; - for (auto &change : value.changes) - fileChanges[change.first] = llvm::json::Array(change.second); - return llvm::json::Object{{"changes", std::move(fileChanges)}}; -} - -//===----------------------------------------------------------------------===// -// CodeAction -//===----------------------------------------------------------------------===// - -const llvm::StringLiteral CodeAction::kQuickFix = "quickfix"; -const llvm::StringLiteral CodeAction::kRefactor = "refactor"; -const llvm::StringLiteral CodeAction::kInfo = "info"; - -llvm::json::Value mlir::lsp::toJSON(const CodeAction &value) { - llvm::json::Object codeAction{{"title", value.title}}; - if (value.kind) - codeAction["kind"] = *value.kind; - if (value.diagnostics) - codeAction["diagnostics"] = llvm::json::Array(*value.diagnostics); - if (value.isPreferred) - codeAction["isPreferred"] = true; - if (value.edit) - codeAction["edit"] = *value.edit; - return std::move(codeAction); -} diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp index f1a3623..5cd1c85 100644 --- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp +++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp @@ -14,6 +14,10 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Hover; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // Utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp deleted file mode 100644 index 5a098b2..0000000 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ /dev/null @@ -1,369 +0,0 @@ -//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/Support/Error.h" -#include <optional> -#include <system_error> -#include <utility> - -using namespace mlir; -using namespace mlir::lsp; - -//===----------------------------------------------------------------------===// -// Reply -//===----------------------------------------------------------------------===// - -namespace { -/// Function object to reply to an LSP call. -/// Each instance must be called exactly once, otherwise: -/// - if there was no reply, an error reply is sent -/// - if there were multiple replies, only the first is sent -class Reply { -public: - Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, - std::mutex &transportOutputMutex); - Reply(Reply &&other); - Reply &operator=(Reply &&) = delete; - Reply(const Reply &) = delete; - Reply &operator=(const Reply &) = delete; - - void operator()(llvm::Expected<llvm::json::Value> reply); - -private: - std::string method; - std::atomic<bool> replied = {false}; - llvm::json::Value id; - JSONTransport *transport; - std::mutex &transportOutputMutex; -}; -} // namespace - -Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, - JSONTransport &transport, std::mutex &transportOutputMutex) - : method(method), id(id), transport(&transport), - transportOutputMutex(transportOutputMutex) {} - -Reply::Reply(Reply &&other) - : method(other.method), replied(other.replied.load()), - id(std::move(other.id)), transport(other.transport), - transportOutputMutex(other.transportOutputMutex) { - other.transport = nullptr; -} - -void Reply::operator()(llvm::Expected<llvm::json::Value> reply) { - if (replied.exchange(true)) { - Logger::error("Replied twice to message {0}({1})", method, id); - assert(false && "must reply to each call only once!"); - return; - } - assert(transport && "expected valid transport to reply to"); - - std::lock_guard<std::mutex> transportLock(transportOutputMutex); - if (reply) { - Logger::info("--> reply:{0}({1})", method, id); - transport->reply(std::move(id), std::move(reply)); - } else { - llvm::Error error = reply.takeError(); - Logger::info("--> reply:{0}({1}): {2}", method, id, error); - transport->reply(std::move(id), std::move(error)); - } -} - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { - Logger::info("--> {0}", method); - - if (method == "exit") - return false; - if (method == "$cancel") { - // TODO: Add support for cancelling requests. - } else { - auto it = notificationHandlers.find(method); - if (it != notificationHandlers.end()) - it->second(std::move(value)); - } - return true; -} - -bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, - llvm::json::Value id) { - Logger::info("--> {0}({1})", method, id); - - Reply reply(id, method, transport, transportOutputMutex); - - auto it = methodHandlers.find(method); - if (it != methodHandlers.end()) { - it->second(std::move(params), std::move(reply)); - } else { - reply(llvm::make_error<LSPError>("method not found: " + method.str(), - ErrorCode::MethodNotFound)); - } - return true; -} - -bool MessageHandler::onReply(llvm::json::Value id, - llvm::Expected<llvm::json::Value> result) { - // Find the response handler in the mapping. If it exists, move it out of the - // mapping and erase it. - ResponseHandlerTy responseHandler; - { - std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex); - auto it = responseHandlers.find(debugString(id)); - if (it != responseHandlers.end()) { - responseHandler = std::move(it->second); - responseHandlers.erase(it); - } - } - - // If we found a response handler, invoke it. Otherwise, log an error. - if (responseHandler.second) { - Logger::info("--> reply:{0}({1})", responseHandler.first, id); - responseHandler.second(std::move(id), std::move(result)); - } else { - Logger::error( - "received a reply with ID {0}, but there was no such outgoing request", - id); - if (!result) - llvm::consumeError(result.takeError()); - } - return true; -} - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// Encode the given error as a JSON object. -static llvm::json::Object encodeError(llvm::Error error) { - std::string message; - ErrorCode code = ErrorCode::UnknownErrorCode; - auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { - message = lspError.message; - code = lspError.code; - return llvm::Error::success(); - }; - if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn)) - message = llvm::toString(std::move(unhandled)); - - return llvm::json::Object{ - {"message", std::move(message)}, - {"code", int64_t(code)}, - }; -} - -/// Decode the given JSON object into an error. -llvm::Error decodeError(const llvm::json::Object &o) { - StringRef msg = o.getString("message").value_or("Unspecified error"); - if (std::optional<int64_t> code = o.getInteger("code")) - return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code)); - return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(), - msg.str()); -} - -void JSONTransport::notify(StringRef method, llvm::json::Value params) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::call(StringRef method, llvm::json::Value params, - llvm::json::Value id) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::reply(llvm::json::Value id, - llvm::Expected<llvm::json::Value> result) { - if (result) { - return sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"result", std::move(*result)}, - }); - } - - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"error", encodeError(result.takeError())}, - }); -} - -llvm::Error JSONTransport::run(MessageHandler &handler) { - std::string json; - while (!in->isEndOfInput()) { - if (in->hasError()) { - return llvm::errorCodeToError( - std::error_code(errno, std::system_category())); - } - - if (succeeded(in->readMessage(json))) { - if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) { - if (!handleMessage(std::move(*doc), handler)) - return llvm::Error::success(); - } else { - Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError())); - } - } - } - return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); -} - -void JSONTransport::sendMessage(llvm::json::Value msg) { - outputBuffer.clear(); - llvm::raw_svector_ostream os(outputBuffer); - os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg); - out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" - << outputBuffer; - out.flush(); - Logger::debug(">>> {0}\n", outputBuffer); -} - -bool JSONTransport::handleMessage(llvm::json::Value msg, - MessageHandler &handler) { - // Message must be an object with "jsonrpc":"2.0". - llvm::json::Object *object = msg.getAsObject(); - if (!object || - object->getString("jsonrpc") != std::optional<StringRef>("2.0")) - return false; - - // `id` may be any JSON value. If absent, this is a notification. - std::optional<llvm::json::Value> id; - if (llvm::json::Value *i = object->get("id")) - id = std::move(*i); - std::optional<StringRef> method = object->getString("method"); - - // This is a response. - if (!method) { - if (!id) - return false; - if (auto *err = object->getObject("error")) - return handler.onReply(std::move(*id), decodeError(*err)); - // result should be given, use null if not. - llvm::json::Value result = nullptr; - if (llvm::json::Value *r = object->get("result")) - result = std::move(*r); - return handler.onReply(std::move(*id), std::move(result)); - } - - // Params should be given, use null if not. - llvm::json::Value params = nullptr; - if (llvm::json::Value *p = object->get("params")) - params = std::move(*p); - - if (id) - return handler.onCall(*method, std::move(params), std::move(*id)); - return handler.onNotify(*method, std::move(params)); -} - -/// Tries to read a line up to and including \n. -/// If failing, feof(), ferror(), or shutdownRequested() will be set. -LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) { - // Big enough to hold any reasonable header line. May not fit content lines - // in delimited mode, but performance doesn't matter for that mode. - static constexpr int bufSize = 128; - size_t size = 0; - out.clear(); - for (;;) { - out.resize_for_overwrite(size + bufSize); - if (!std::fgets(&out[size], bufSize, in)) - return failure(); - - clearerr(in); - - // If the line contained null bytes, anything after it (including \n) will - // be ignored. Fortunately this is not a legal header or JSON. - size_t read = std::strlen(&out[size]); - if (read > 0 && out[size + read - 1] == '\n') { - out.resize(size + read); - return success(); - } - size += read; - } -} - -// Returns std::nullopt when: -// - ferror(), feof(), or shutdownRequested() are set. -// - Content-Length is missing or empty (protocol error) -LogicalResult -JSONTransportInputOverFile::readStandardMessage(std::string &json) { - // A Language Server Protocol message starts with a set of HTTP headers, - // delimited by \r\n, and terminated by an empty line (\r\n). - unsigned long long contentLength = 0; - llvm::SmallString<128> line; - while (true) { - if (feof(in) || hasError() || failed(readLine(in, line))) - return failure(); - - // Content-Length is a mandatory header, and the only one we handle. - StringRef lineRef = line; - if (lineRef.consume_front("Content-Length: ")) { - llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength); - } else if (!lineRef.trim().empty()) { - // It's another header, ignore it. - continue; - } else { - // An empty line indicates the end of headers. Go ahead and read the JSON. - break; - } - } - - // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" - if (contentLength == 0 || contentLength > 1 << 30) - return failure(); - - json.resize(contentLength); - for (size_t pos = 0, read; pos < contentLength; pos += read) { - read = std::fread(&json[pos], 1, contentLength - pos, in); - if (read == 0) - return failure(); - - // If we're done, the error was transient. If we're not done, either it was - // transient or we'll see it again on retry. - clearerr(in); - pos += read; - } - return success(); -} - -/// For lit tests we support a simplified syntax: -/// - messages are delimited by '// -----' on a line by itself -/// - lines starting with // are ignored. -/// This is a testing path, so favor simplicity over performance here. -/// When returning failure: feof(), ferror(), or shutdownRequested() will be -/// set. -LogicalResult -JSONTransportInputOverFile::readDelimitedMessage(std::string &json) { - json.clear(); - llvm::SmallString<128> line; - while (succeeded(readLine(in, line))) { - StringRef lineRef = line.str().trim(); - if (lineRef.starts_with("//")) { - // Found a delimiter for the message. - if (lineRef == kDefaultSplitMarker) - break; - continue; - } - - json += line; - } - - return failure(ferror(in)); -} diff --git a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt index d04d5156f..e2acba5 100644 --- a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ add_mlir_library(MLIRLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRBytecodeWriter MLIRFunctionInterfaces diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp index 9b937db..1bbbcde 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -9,8 +9,8 @@ #include "LSPServer.h" #include "MLIRServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include <optional> #define DEBUG_TYPE "mlir-lsp-server" @@ -18,6 +18,33 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionParams; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::MLIRConvertBytecodeParams; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h index 2c50c6b..d652899 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class MLIRServer; /// Run the main loop of the LSP server using the given MLIR server and /// transport. llvm::LogicalResult runMlirLSPServer(MLIRServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 6198752..47b4328 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -16,10 +16,10 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Base64.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/SourceMgr.h" #include <optional> @@ -39,9 +39,9 @@ static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme, llvm::Expected<lsp::URIForFile> sourceURI = lsp::URIForFile::fromFile(loc.getFilename(), uriScheme); if (!sourceURI) { - lsp::Logger::error("Failed to create URI for file `{0}`: {1}", - loc.getFilename(), - llvm::toString(sourceURI.takeError())); + llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}", + loc.getFilename(), + llvm::toString(sourceURI.takeError())); return std::nullopt; } @@ -217,22 +217,22 @@ static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, // Convert the severity for the diagnostic. switch (diag.getSeverity()) { - case DiagnosticSeverity::Note: + case mlir::DiagnosticSeverity::Note: llvm_unreachable("expected notes to be handled separately"); - case DiagnosticSeverity::Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + case mlir::DiagnosticSeverity::Warning: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; - case DiagnosticSeverity::Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + case mlir::DiagnosticSeverity::Error: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; - case DiagnosticSeverity::Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + case mlir::DiagnosticSeverity::Remark: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.str(); // Attach any notes to the main diagnostic as related information. - std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; + std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags; for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; if (std::optional<lsp::Location> loc = @@ -317,7 +317,7 @@ struct MLIRDocument { void getCodeActionForDiagnostic(const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, StringRef message, - std::vector<lsp::TextEdit> &edits); + std::vector<llvm::lsp::TextEdit> &edits); //===--------------------------------------------------------------------===// // Bytecode @@ -355,7 +355,8 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, // Try to parsed the given IR string. auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -695,8 +696,8 @@ void MLIRDocument::findDocumentSymbols( if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { symbols.emplace_back(symbol.getName(), isa<FunctionOpInterface>(op) - ? lsp::SymbolKind::Function - : lsp::SymbolKind::Class, + ? llvm::lsp::SymbolKind::Function + : llvm::lsp::SymbolKind::Class, lsp::Range(sourceMgr, def->scopeLoc), lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; @@ -704,9 +705,9 @@ void MLIRDocument::findDocumentSymbols( } else if (op->hasTrait<OpTrait::SymbolTable>()) { // Otherwise, if this is a symbol table push an anonymous document symbol. symbols.emplace_back("<" + op->getName().getStringRef() + ">", - lsp::SymbolKind::Namespace, - lsp::Range(sourceMgr, def->scopeLoc), - lsp::Range(sourceMgr, def->loc)); + llvm::lsp::SymbolKind::Namespace, + llvm::lsp::Range(sourceMgr, def->scopeLoc), + llvm::lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; } } @@ -734,9 +735,9 @@ public: /// Signal code completion for a dialect name, with an optional prefix. void completeDialectName(StringRef prefix) final { for (StringRef dialect : ctx->getAvailableDialects()) { - lsp::CompletionItem item(prefix + dialect, - lsp::CompletionItemKind::Module, - /*sortText=*/"3"); + llvm::lsp::CompletionItem item(prefix + dialect, + llvm::lsp::CompletionItemKind::Module, + /*sortText=*/"3"); item.detail = "dialect"; completionList.items.emplace_back(item); } @@ -753,9 +754,9 @@ public: if (&op.getDialect() != dialect) continue; - lsp::CompletionItem item( + llvm::lsp::CompletionItem item( op.getStringRef().drop_front(dialectName.size() + 1), - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); item.detail = "operation"; completionList.items.emplace_back(item); @@ -768,7 +769,8 @@ public: // Check if we need to insert the `%` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); + llvm::lsp::CompletionItem item(name, + llvm::lsp::CompletionItemKind::Variable); if (stripPrefix) item.insertText = name.drop_front(1).str(); item.detail = std::move(typeData); @@ -781,7 +783,7 @@ public: // Check if we need to insert the `^` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); + llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field); if (stripPrefix) item.insertText = name.drop_front(1).str(); completionList.items.emplace_back(item); @@ -790,8 +792,9 @@ public: /// Signal a completion for the given expected token. void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final { for (StringRef token : tokens) { - lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, - /*sortText=*/"0"); + llvm::lsp::CompletionItem item(token, + llvm::lsp::CompletionItemKind::Keyword, + /*sortText=*/"0"); item.detail = optional ? "optional" : ""; completionList.items.emplace_back(item); } @@ -802,7 +805,7 @@ public: appendSimpleCompletions({"affine_set", "affine_map", "dense", "dense_resource", "false", "loc", "sparse", "true", "unit"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); completeDialectName("#"); @@ -820,13 +823,14 @@ public: appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector", "bf16", "f16", "f32", "f64", "f80", "f128", "index", "none"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); // Handle the builtin integer types. for (StringRef type : {"i", "si", "ui"}) { - lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field, - /*sortText=*/"1"); + llvm::lsp::CompletionItem item(type + "<N>", + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"1"); item.insertText = type.str(); completionList.items.emplace_back(item); } @@ -846,9 +850,9 @@ public: void completeAliases(const llvm::StringMap<T> &aliases, StringRef prefix = "") { for (const auto &alias : aliases) { - lsp::CompletionItem item(prefix + alias.getKey(), - lsp::CompletionItemKind::Field, - /*sortText=*/"2"); + llvm::lsp::CompletionItem item(prefix + alias.getKey(), + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"2"); llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); completionList.items.emplace_back(item); } @@ -856,7 +860,7 @@ public: /// Add a set of simple completions that all have the same kind. void appendSimpleCompletions(ArrayRef<StringRef> completions, - lsp::CompletionItemKind kind, + llvm::lsp::CompletionItemKind kind, StringRef sortText = "") { for (StringRef completion : completions) completionList.items.emplace_back(completion, kind, sortText); @@ -897,7 +901,7 @@ MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, void MLIRDocument::getCodeActionForDiagnostic( const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, - StringRef message, std::vector<lsp::TextEdit> &edits) { + StringRef message, std::vector<llvm::lsp::TextEdit> &edits) { // Ignore diagnostics that print the current operation. These are always // enabled for the language server, but not generally during normal // parsing/verification. @@ -913,7 +917,7 @@ void MLIRDocument::getCodeActionForDiagnostic( // Add a text edit for adding an expected-* diagnostic check for this // diagnostic. - lsp::TextEdit edit; + llvm::lsp::TextEdit edit; edit.range = lsp::Range(lsp::Position(pos.line, 0)); // Use the indent of the current line for the expected-* diagnostic. @@ -937,13 +941,14 @@ MLIRDocument::convertToBytecode() { // conceptually be relaxed. if (!llvm::hasSingleElement(parsedIR)) { if (parsedIR.empty()) { - return llvm::make_error<lsp::LSPError>( + return llvm::make_error<llvm::lsp::LSPError>( "expected a single and valid top-level operation, please ensure " "there are no errors", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } - return llvm::make_error<lsp::LSPError>( - "expected a single top-level operation", lsp::ErrorCode::RequestFailed); + return llvm::make_error<llvm::lsp::LSPError>( + "expected a single top-level operation", + llvm::lsp::ErrorCode::RequestFailed); } lsp::MLIRConvertBytecodeResult result; @@ -1134,7 +1139,7 @@ void MLIRTextFile::findDocumentSymbols( lsp::Position endPos((i == e - 1) ? totalNumLines - 1 : chunks[i + 1]->lineOffset); lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", - lsp::SymbolKind::Namespace, + llvm::lsp::SymbolKind::Namespace, /*range=*/lsp::Range(startPos, endPos), /*selectionRange=*/lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); @@ -1167,10 +1172,10 @@ lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, uri, completePos, context.getDialectRegistry()); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; @@ -1194,10 +1199,10 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, StringRef severity; switch (diag.severity) { - case lsp::DiagnosticSeverity::Error: + case llvm::lsp::DiagnosticSeverity::Error: severity = "error"; break; - case lsp::DiagnosticSeverity::Warning: + case llvm::lsp::DiagnosticSeverity::Warning: severity = "warning"; break; default: @@ -1205,7 +1210,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } // Get edits for the diagnostic. - std::vector<lsp::TextEdit> edits; + std::vector<llvm::lsp::TextEdit> edits; chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity, diag.message, edits); @@ -1221,7 +1226,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } } // Fixup the locations for any edits. - for (lsp::TextEdit &edit : edits) + for (llvm::lsp::TextEdit &edit : edits) chunk.adjustLocForChunkOffset(edit.range); action.edit.emplace(); @@ -1236,9 +1241,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult> MLIRTextFile::convertToBytecode() { // Bail out if there is more than one chunk, bytecode wants a single module. if (chunks.size() != 1) { - return llvm::make_error<lsp::LSPError>( + return llvm::make_error<llvm::lsp::LSPError>( "unexpected split file, please remove all `// -----`", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return chunks.front()->document.convertToBytecode(); } @@ -1283,7 +1288,7 @@ lsp::MLIRServer::~MLIRServer() = default; void lsp::MLIRServer::addOrUpdateDocument( const URIForFile &uri, StringRef contents, int64_t version, - std::vector<Diagnostic> &diagnostics) { + std::vector<llvm::lsp::Diagnostic> &diagnostics) { impl->files[uri.file()] = std::make_unique<MLIRTextFile>( uri, contents, version, impl->registry_fn, diagnostics); } @@ -1298,17 +1303,17 @@ std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector<Location> &locations) { +void lsp::MLIRServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector<llvm::lsp::Location> &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector<Location> &references) { +void lsp::MLIRServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector<llvm::lsp::Location> &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1367,17 +1372,17 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { // Try to parse the given source file. Block parsedBlock; if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) { - return llvm::make_error<lsp::LSPError>( + return llvm::make_error<llvm::lsp::LSPError>( "failed to parse bytecode source file: " + errorMsg, - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // TODO: We currently expect a single top-level operation, but this could // conceptually be relaxed. if (!llvm::hasSingleElement(parsedBlock)) { - return llvm::make_error<lsp::LSPError>( + return llvm::make_error<llvm::lsp::LSPError>( "expected bytecode to contain a single top-level operation", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // Print the module to a buffer. @@ -1401,9 +1406,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult> lsp::MLIRServer::convertToBytecode(const URIForFile &uri) { auto fileIt = impl->files.find(uri.file()); if (fileIt == impl->files.end()) { - return llvm::make_error<lsp::LSPError>( + return llvm::make_error<llvm::lsp::LSPError>( "language server does not contain an entry for this source file", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return fileIt->second->convertToBytecode(); } diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h index 85e69e6..31a01fe 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -9,6 +9,7 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ +#include "Protocol.h" #include "mlir/Support/LLVM.h" #include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h" #include "llvm/Support/Error.h" @@ -19,16 +20,17 @@ namespace mlir { class DialectRegistry; namespace lsp { -struct CodeAction; -struct CodeActionContext; -struct CompletionList; -struct Diagnostic; -struct DocumentSymbol; -struct Hover; -struct Location; -struct MLIRConvertBytecodeResult; -struct Position; -struct Range; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionContext; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; /// This class implements all of the MLIR related functionality necessary for a /// language server. This class allows for keeping the MLIR specific logic diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp index f1dc326..d4589b2 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" #include "LSPServer.h" #include "MLIRServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::MlirLspServerMain(int argc, char **argv, DialectRegistryFn registry_fn) { llvm::cl::opt<JSONStreamStyle> inputStyle{ diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp index a56e9a1..28aded3 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp @@ -13,14 +13,11 @@ #include "Protocol.h" #include "llvm/Support/JSON.h" -using namespace mlir; -using namespace mlir::lsp; - //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams //===----------------------------------------------------------------------===// -bool mlir::lsp::fromJSON(const llvm::json::Value &value, +bool llvm::lsp::fromJSON(const llvm::json::Value &value, MLIRConvertBytecodeParams &result, llvm::json::Path path) { llvm::json::ObjectMapper o(value, path); @@ -31,6 +28,6 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value, // MLIRConvertBytecodeResult //===----------------------------------------------------------------------===// -llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) { +llvm::json::Value llvm::lsp::toJSON(const MLIRConvertBytecodeResult &value) { return llvm::json::Object{{"output", value.output}}; } diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/Protocol.h index d910780..ed0db4e 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.h @@ -20,9 +20,9 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" -namespace mlir { +namespace llvm { namespace lsp { //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams @@ -54,6 +54,6 @@ struct MLIRConvertBytecodeResult { llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value); } // namespace lsp -} // namespace mlir +} // namespace llvm #endif diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt index bf25b7e..b41603f 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ llvm_add_library(MLIRPdllLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRPDLLCodeGen MLIRPDLLParser diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp index 82542a1..7b23adc 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -10,8 +10,9 @@ #include "PDLLServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include <optional> #define DEBUG_TYPE "pdll-lsp-server" @@ -19,6 +20,30 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::InlayHintsParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h index 78c4c31..42c0a5d 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class PDLLServer; /// Run the main loop of the LSP server using the given PDLL server and /// transport. llvm::LogicalResult runPdllLSPServer(PDLLServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp index 287a131..5dea130 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp @@ -9,14 +9,17 @@ #include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h" #include "LSPServer.h" #include "PDLLServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::Logger; + LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { llvm::cl::opt<JSONStreamStyle> inputStyle{ "input-style", @@ -72,7 +75,8 @@ LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { // Configure the transport used for communication. llvm::sys::ChangeStdinToBinary(); - JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint); + llvm::lsp::JSONTransport transport(stdin, llvm::outs(), inputStyle, + prettyPrint); // Configure the servers and start the main language server. PDLLServer::Options options(compilationDatabases, extraIncludeDirs); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 84f529a..60b9567 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -23,13 +23,13 @@ #include "mlir/Tools/PDLL/Parser/CodeComplete.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" #include <optional> @@ -38,17 +38,19 @@ using namespace mlir::pdll; /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc.Start); if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID())) return mainFileURI; - llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected<llvm::lsp::URIForFile> fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } @@ -59,16 +61,18 @@ static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) { } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range)); +static llvm::lsp::Location +getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, range, uri), + llvm::lsp::Range(mgr, range)); } /// Convert the given MLIR diagnostic to the LSP form. -static std::optional<lsp::Diagnostic> +static std::optional<llvm::lsp::Diagnostic> getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, - const lsp::URIForFile &uri) { - lsp::Diagnostic lspDiag; + const llvm::lsp::URIForFile &uri) { + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "pdll"; // FIXME: Right now all of the diagnostics are treated as parser issues, but @@ -76,7 +80,8 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri); + llvm::lsp::Location loc = + getLocationFromLoc(sourceMgr, diag.getLocation(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -88,19 +93,19 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, case ast::Diagnostic::Severity::DK_Note: llvm_unreachable("expected notes to be handled separately"); case ast::Diagnostic::Severity::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case ast::Diagnostic::Severity::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case ast::Diagnostic::Severity::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); // Attach any notes to the main diagnostic as related information. - std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; + std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags; for (const ast::Diagnostic ¬e : diag.getNotes()) { relatedDiags.emplace_back( getLocationFromLoc(sourceMgr, note.getLocation(), uri), @@ -259,9 +264,9 @@ namespace { /// This class represents all of the information pertaining to a specific PDL /// document. struct PDLDocument { - PDLDocument(const lsp::URIForFile &uri, StringRef contents, + PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector<std::string> &extraDirs, - std::vector<lsp::Diagnostic> &diagnostics); + std::vector<llvm::lsp::Diagnostic> &diagnostics); PDLDocument(const PDLDocument &) = delete; PDLDocument &operator=(const PDLDocument &) = delete; @@ -269,76 +274,83 @@ struct PDLDocument { // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector<lsp::Location> &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector<lsp::Location> &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector<llvm::lsp::Location> &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector<llvm::lsp::Location> &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - std::optional<lsp::Hover> findHover(const ast::Decl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange); - lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange); - lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + std::optional<llvm::lsp::Hover> + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + std::optional<llvm::lsp::Hover> findHover(const ast::Decl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + const SMRange &hoverRange); template <typename T> - lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName, - const T *decl, - const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl, + const SMRange &hoverRange); //===--------------------------------------------------------------------===// // Document Symbols //===--------------------------------------------------------------------===// - void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); + void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols); //===--------------------------------------------------------------------===// // Code Completion //===--------------------------------------------------------------------===// - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos); + llvm::lsp::CompletionList + getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos); //===--------------------------------------------------------------------===// // Signature Help //===--------------------------------------------------------------------===// - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos); //===--------------------------------------------------------------------===// // Inlay Hints //===--------------------------------------------------------------------===// - void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range, - std::vector<lsp::InlayHint> &inlayHints); + void getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector<llvm::lsp::InlayHint> &inlayHints); void getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints); - void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints); + void getInlayHintsFor(const ast::CallExpr *expr, + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints); void getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints); /// Add a parameter hint for the given expression using `label`. - void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints, + void addParameterHintFor(std::vector<llvm::lsp::InlayHint> &inlayHints, const ast::Expr *expr, StringRef label); //===--------------------------------------------------------------------===// @@ -372,13 +384,14 @@ struct PDLDocument { }; } // namespace -PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, +PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector<std::string> &extraDirs, - std::vector<lsp::Diagnostic> &diagnostics) + std::vector<llvm::lsp::Diagnostic> &diagnostics) : astContext(odsContext) { auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -412,9 +425,9 @@ PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, // PDLDocument: Definitions and References //===----------------------------------------------------------------------===// -void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector<lsp::Location> &locations) { +void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector<llvm::lsp::Location> &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -423,9 +436,9 @@ void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri)); } -void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, - const lsp::Position &pos, - std::vector<lsp::Location> &references) { +void PDLDocument::findReferencesOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector<llvm::lsp::Location> &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -440,8 +453,9 @@ void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, // PDLDocument: Document Links //===--------------------------------------------------------------------===// -void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links) { +void PDLDocument::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -450,9 +464,9 @@ void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, // PDLDocument: Hover //===----------------------------------------------------------------------===// -std::optional<lsp::Hover> -PDLDocument::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional<llvm::lsp::Hover> +PDLDocument::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); // Check for a reference to an include. @@ -474,8 +488,8 @@ PDLDocument::findHover(const lsp::URIForFile &uri, return findHover(decl, hoverRange); } -std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl, - const SMRange &hoverRange) { +std::optional<llvm::lsp::Hover> +PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) { // Add hover for variables. if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl)) return buildHoverForVariable(varDecl, hoverRange); @@ -499,9 +513,9 @@ std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl, return std::nullopt; } -lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**OpName**: `" << op->getName() << "`\n***\n" @@ -511,9 +525,10 @@ lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, return hover; } -lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n" @@ -522,9 +537,9 @@ lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, return hover; } -lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Pattern**"; @@ -545,10 +560,10 @@ lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, return hover; } -lsp::Hover +llvm::lsp::Hover PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Constraint**: `"; @@ -573,9 +588,9 @@ PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, } template <typename T> -lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( +llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( StringRef typeName, const T *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**" << typeName << "**: `" << decl->getName().getName() @@ -617,7 +632,7 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( //===----------------------------------------------------------------------===// void PDLDocument::findDocumentSymbols( - std::vector<lsp::DocumentSymbol> &symbols) { + std::vector<llvm::lsp::DocumentSymbol> &symbols) { if (failed(astModule)) return; @@ -631,25 +646,28 @@ void PDLDocument::findDocumentSymbols( SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc(); SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End); - symbols.emplace_back( - name ? name->getName() : "<pattern>", lsp::SymbolKind::Class, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(name ? name->getName() : "<pattern>", + llvm::lsp::SymbolKind::Class, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } } } @@ -662,7 +680,7 @@ namespace { class LSPCodeCompleteContext : public CodeCompleteContext { public: LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::CompletionList &completionList, + llvm::lsp::CompletionList &completionList, ods::Context &odsContext, ArrayRef<std::string> includeDirs) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), @@ -674,13 +692,13 @@ public: ArrayRef<StringRef> elementNames = tupleType.getElementNames(); for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", i).str(); item.insertText = Twine(i).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]); - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the element has a name, push back a completion item with that name. @@ -705,11 +723,11 @@ public: const ods::TypeConstraint &constraint = result.getConstraint(); // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", it.index()).str(); item.insertText = Twine(it.index()).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; switch (result.getVariableLengthKind()) { case ods::VariableLengthKind::Single: item.detail = llvm::formatv("{0}: Value", it.index()).str(); @@ -721,12 +739,12 @@ public: item.detail = llvm::formatv("{0}: ValueRange", it.index()).str(); break; } - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the result has a name, push back a completion item with the result @@ -750,16 +768,16 @@ public: for (const ods::Attribute &attr : odsOp->getAttributes()) { const ods::AttributeConstraint &constraint = attr.getConstraint(); - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = attr.getName().str(); - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = attr.isOptional() ? "optional" : ""; - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -769,18 +787,18 @@ public: const ast::DeclScope *scope) final { auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = (constraint + " constraint").str(); - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, ("A single entity core constraint of type `" + mlirType + "`").str()}; item.sortText = "0"; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -812,9 +830,9 @@ public: while (scope) { for (const ast::Decl *decl : scope->getDecls()) { if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = cst->getName().getName().str(); - item.kind = lsp::CompletionItemKind::Interface; + item.kind = llvm::lsp::CompletionItemKind::Interface; item.sortText = "2_" + item.label; // Skip constraints that are not single-arg. We currently only @@ -841,8 +859,8 @@ public: // Format the documentation for the constraint. if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, cst)) { - item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)}; + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, std::move(*doc)}; } completionList.items.emplace_back(item); @@ -856,10 +874,10 @@ public: void codeCompleteDialectName() final { // Code complete known dialects. for (const ods::Dialect &dialect : odsContext.getDialects()) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = dialect.getName().str(); - item.kind = lsp::CompletionItemKind::Class; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Class; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -872,10 +890,10 @@ public: for (const auto &it : dialect->getOperations()) { const ods::Operation &op = *it.second; - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = op.getName().drop_front(dialectName.size() + 1).str(); - item.kind = lsp::CompletionItemKind::Field; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Field; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -883,16 +901,16 @@ public: void codeCompletePatternMetadata() final { auto addSimpleConstraint = [&](StringRef constraint, StringRef desc, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = "pattern metadata"; item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()}; + llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()}; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -913,10 +931,10 @@ public: // Functor used to add a single include completion item. auto addIncludeCompletion = [&](StringRef path, bool isDirectory) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = path.str(); - item.kind = isDirectory ? lsp::CompletionItemKind::Folder - : lsp::CompletionItemKind::File; + item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder + : llvm::lsp::CompletionItemKind::File; if (seenResults.insert(item.label).second) completionList.items.emplace_back(item); }; @@ -961,31 +979,31 @@ public: // Sort the completion results to make sure the output is deterministic in // the face of different iteration schemes for different platforms. - llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs, - const lsp::CompletionItem &rhs) { + llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs, + const llvm::lsp::CompletionItem &rhs) { return lhs.label < rhs.label; }); } private: llvm::SourceMgr &sourceMgr; - lsp::CompletionList &completionList; + llvm::lsp::CompletionList &completionList; ods::Context &odsContext; ArrayRef<std::string> includeDirs; }; } // namespace -lsp::CompletionList -PDLDocument::getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos) { +llvm::lsp::CompletionList +PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos) { SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::CompletionList(); + return llvm::lsp::CompletionList(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::CompletionList completionList; + llvm::lsp::CompletionList completionList; LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList, tmpODSContext, sourceMgr.getIncludeDirs()); @@ -1005,7 +1023,7 @@ namespace { class LSPSignatureHelpContext : public CodeCompleteContext { public: LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::SignatureHelp &signatureHelp, + llvm::lsp::SignatureHelp &signatureHelp, ods::Context &odsContext) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), signatureHelp(signatureHelp), odsContext(odsContext) {} @@ -1014,7 +1032,7 @@ public: unsigned currentNumArgs) final { signatureHelp.activeParameter = currentNumArgs; - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; { llvm::raw_string_ostream strOS(signatureInfo.label); strOS << callable->getName()->getName() << "("; @@ -1022,7 +1040,7 @@ public: unsigned paramStart = strOS.str().size(); strOS << var->getName().getName() << ": " << var->getType(); unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()}); }; @@ -1070,7 +1088,7 @@ public: // not more than what is defined in ODS, as this will result in an error // anyways. if (odsOp && currentValue < values.size()) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; // Build the signature label. { @@ -1099,7 +1117,7 @@ public: } unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), paramDoc}); }; @@ -1114,12 +1132,12 @@ public: // If there aren't any arguments yet, we also add the generic signature. if (currentValue == 0 && (!odsOp || !values.empty())) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; signatureInfo.label = llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str(); signatureInfo.documentation = ("Generic operation " + label + " specification").str(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(signatureInfo.label).drop_front().drop_back().str(), std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1), ("All of the " + label + "s of the operation.").str()}); @@ -1129,21 +1147,22 @@ public: private: llvm::SourceMgr &sourceMgr; - lsp::SignatureHelp &signatureHelp; + llvm::lsp::SignatureHelp &signatureHelp; ods::Context &odsContext; }; } // namespace -lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos) { +llvm::lsp::SignatureHelp +PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos) { SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::SignatureHelp(); + return llvm::lsp::SignatureHelp(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::SignatureHelp signatureHelp; + llvm::lsp::SignatureHelp signatureHelp; LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp, tmpODSContext); @@ -1173,9 +1192,9 @@ static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) { return true; } -void PDLDocument::getInlayHints(const lsp::URIForFile &uri, - const lsp::Range &range, - std::vector<lsp::InlayHint> &inlayHints) { +void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector<llvm::lsp::InlayHint> &inlayHints) { if (failed(astModule)) return; SMRange rangeLoc = range.getAsSMRange(sourceMgr); @@ -1198,9 +1217,9 @@ void PDLDocument::getInlayHints(const lsp::URIForFile &uri, }); } -void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints) { // Check to see if the variable has a constraint list, if it does we don't // provide initializer hints. if (!decl->getConstraints().empty()) @@ -1215,8 +1234,8 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, return; } - lsp::InlayHint hint(lsp::InlayHintKind::Type, - lsp::Position(sourceMgr, decl->getLoc().End)); + llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type, + llvm::lsp::Position(sourceMgr, decl->getLoc().End)); { llvm::raw_string_ostream labelOS(hint.label); labelOS << ": " << decl->getType(); @@ -1225,9 +1244,9 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, inlayHints.emplace_back(std::move(hint)); } -void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, - const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints) { // Try to extract the callable of this call. const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr()); const auto *callable = @@ -1242,9 +1261,9 @@ void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, std::get<1>(it)->getName().getName()); } -void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector<lsp::InlayHint> &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::InlayHint> &inlayHints) { // Check for ODS information. ast::OperationType opType = dyn_cast<ast::OperationType>(expr->getType()); const auto *odsOp = opType ? opType.getODSOperation() : nullptr; @@ -1290,13 +1309,15 @@ void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, "results"); } -void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints, - const ast::Expr *expr, StringRef label) { +void PDLDocument::addParameterHintFor( + std::vector<llvm::lsp::InlayHint> &inlayHints, const ast::Expr *expr, + StringRef label) { if (!shouldAddHintFor(expr, label)) return; - lsp::InlayHint hint(lsp::InlayHintKind::Parameter, - lsp::Position(sourceMgr, expr->getLoc().Start)); + llvm::lsp::InlayHint hint( + llvm::lsp::InlayHintKind::Parameter, + llvm::lsp::Position(sourceMgr, expr->getLoc().Start)); hint.label = (label + ":").str(); hint.paddingRight = true; inlayHints.emplace_back(std::move(hint)); @@ -1342,22 +1363,24 @@ void PDLDocument::getPDLLViewOutput(raw_ostream &os, namespace { /// This class represents a single chunk of an PDL text file. struct PDLTextFileChunk { - PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri, + PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector<std::string> &extraDirs, - std::vector<lsp::Diagnostic> &diagnostics) + std::vector<llvm::lsp::Diagnostic> &diagnostics) : lineOffset(lineOffset), document(uri, contents, extraDirs, diagnostics) {} /// Adjust the line number of the given range to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Range &range) { + void adjustLocForChunkOffset(llvm::lsp::Range &range) { adjustLocForChunkOffset(range.start); adjustLocForChunkOffset(range.end); } /// Adjust the line number of the given position to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } + void adjustLocForChunkOffset(llvm::lsp::Position &pos) { + pos.line += lineOffset; + } /// The line offset of this chunk from the beginning of the file. uint64_t lineOffset; @@ -1374,38 +1397,41 @@ namespace { /// This class represents a text file containing one or more PDL documents. class PDLTextFile { public: - PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, + PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector<std::string> &extraDirs, - std::vector<lsp::Diagnostic> &diagnostics); + std::vector<llvm::lsp::Diagnostic> &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef<lsp::TextDocumentContentChangeEvent> changes, - std::vector<lsp::Diagnostic> &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes, + std::vector<llvm::lsp::Diagnostic> &diagnostics); //===--------------------------------------------------------------------===// // LSP Queries //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, - std::vector<lsp::Location> &locations); - void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, - std::vector<lsp::Location> &references); - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links); - std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos); - void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos); - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos); - void getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector<lsp::InlayHint> &inlayHints); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector<llvm::lsp::Location> &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position pos, + std::vector<llvm::lsp::Location> &references); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links); + std::optional<llvm::lsp::Hover> findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos); + void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols); + llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos); + void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range, + std::vector<llvm::lsp::InlayHint> &inlayHints); lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind); private: @@ -1413,14 +1439,14 @@ private: std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>; /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector<lsp::Diagnostic> &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector<llvm::lsp::Diagnostic> &diagnostics); /// Find the PDL document that contains the given position, and update the /// position to be anchored at the start of the found chunk instead of the /// beginning of the file. - ChunkIterator getChunkItFor(lsp::Position &pos); - PDLTextFileChunk &getChunkFor(lsp::Position &pos) { + ChunkIterator getChunkItFor(llvm::lsp::Position &pos); + PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) { return *getChunkItFor(pos); } @@ -1442,20 +1468,21 @@ private: }; } // namespace -PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, - int64_t version, +PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri, + StringRef fileContents, int64_t version, const std::vector<std::string> &extraDirs, - std::vector<lsp::Diagnostic> &diagnostics) + std::vector<llvm::lsp::Diagnostic> &diagnostics) : contents(fileContents.str()), extraIncludeDirs(extraDirs) { initialize(uri, version, diagnostics); } LogicalResult -PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef<lsp::TextDocumentContentChangeEvent> changes, - std::vector<lsp::Diagnostic> &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes, + std::vector<llvm::lsp::Diagnostic> &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -1464,36 +1491,37 @@ PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri, - lsp::Position defPos, - std::vector<lsp::Location> &locations) { +void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector<llvm::lsp::Location> &locations) { PDLTextFileChunk &chunk = getChunkFor(defPos); chunk.document.getLocationsOf(uri, defPos, locations); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : locations) + for (llvm::lsp::Location &loc : locations) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri, - lsp::Position pos, - std::vector<lsp::Location> &references) { +void PDLTextFile::findReferencesOf( + const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos, + std::vector<llvm::lsp::Location> &references) { PDLTextFileChunk &chunk = getChunkFor(pos); chunk.document.findReferencesOf(uri, pos, references); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : references) + for (llvm::lsp::Location &loc : references) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links) { +void PDLTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links) { chunks.front()->document.getDocumentLinks(uri, links); for (const auto &it : llvm::drop_begin(chunks)) { size_t currentNumLinks = links.size(); @@ -1506,10 +1534,12 @@ void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, } } -std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos) { +std::optional<llvm::lsp::Hover> +PDLTextFile::findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos) { PDLTextFileChunk &chunk = getChunkFor(hoverPos); - std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); + std::optional<llvm::lsp::Hover> hoverInfo = + chunk.document.findHover(uri, hoverPos); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) @@ -1518,7 +1548,7 @@ std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri, } void PDLTextFile::findDocumentSymbols( - std::vector<lsp::DocumentSymbol> &symbols) { + std::vector<llvm::lsp::DocumentSymbol> &symbols) { if (chunks.size() == 1) return chunks.front()->document.findDocumentSymbols(symbols); @@ -1526,27 +1556,27 @@ void PDLTextFile::findDocumentSymbols( // each chunk. for (unsigned i = 0, e = chunks.size(); i < e; ++i) { PDLTextFileChunk &chunk = *chunks[i]; - lsp::Position startPos(chunk.lineOffset); - lsp::Position endPos((i == e - 1) ? totalNumLines - 1 - : chunks[i + 1]->lineOffset); - lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">", - lsp::SymbolKind::Namespace, - /*range=*/lsp::Range(startPos, endPos), - /*selectionRange=*/lsp::Range(startPos)); + llvm::lsp::Position startPos(chunk.lineOffset); + llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1 + : chunks[i + 1]->lineOffset); + llvm::lsp::DocumentSymbol symbol( + "<file-split-" + Twine(i) + ">", llvm::lsp::SymbolKind::Namespace, + /*range=*/llvm::lsp::Range(startPos, endPos), + /*selectionRange=*/llvm::lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); // Fixup the locations of document symbols within this chunk. if (i != 0) { - SmallVector<lsp::DocumentSymbol *> symbolsToFix; - for (lsp::DocumentSymbol &childSymbol : symbol.children) + SmallVector<llvm::lsp::DocumentSymbol *> symbolsToFix; + for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children) symbolsToFix.push_back(&childSymbol); while (!symbolsToFix.empty()) { - lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); + llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); chunk.adjustLocForChunkOffset(symbol->range); chunk.adjustLocForChunkOffset(symbol->selectionRange); - for (lsp::DocumentSymbol &childSymbol : symbol->children) + for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children) symbolsToFix.push_back(&childSymbol); } } @@ -1556,34 +1586,37 @@ void PDLTextFile::findDocumentSymbols( } } -lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos) { +llvm::lsp::CompletionList +PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos) { PDLTextFileChunk &chunk = getChunkFor(completePos); - lsp::CompletionList completionList = + llvm::lsp::CompletionList completionList = chunk.document.getCodeCompletion(uri, completePos); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; } -lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos) { +llvm::lsp::SignatureHelp +PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos) { return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); } -void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector<lsp::InlayHint> &inlayHints) { +void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri, + llvm::lsp::Range range, + std::vector<llvm::lsp::InlayHint> &inlayHints) { auto startIt = getChunkItFor(range.start); auto endIt = getChunkItFor(range.end); // Functor used to get the chunks for a given file, and fixup any locations - auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) { + auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) { size_t currentNumHints = inlayHints.size(); chunkIt->document.getInlayHints(uri, range, inlayHints); @@ -1605,15 +1638,16 @@ void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, // Otherwise, the range is split between multiple chunks. The first chunk // has the correct range start, but covers the total document. - getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt))); + getHintsForChunk(startIt, + llvm::lsp::Range(range.start, getNumLines(startIt))); // Every chunk in between uses the full document. for (++startIt; startIt != endIt; ++startIt) - getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt))); + getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt))); // The range for the last chunk starts at the beginning of the document, up // through the end of the input range. - getHintsForChunk(startIt, lsp::Range(0, range.end)); + getHintsForChunk(startIt, llvm::lsp::Range(0, range.end)); } lsp::PDLLViewOutputResult @@ -1632,8 +1666,9 @@ PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) { return result; } -void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector<lsp::Diagnostic> &diagnostics) { +void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri, + int64_t newVersion, + std::vector<llvm::lsp::Diagnostic> &diagnostics) { version = newVersion; chunks.clear(); @@ -1653,7 +1688,7 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, // Adjust locations used in diagnostics to account for the offset from the // beginning of the file. - for (lsp::Diagnostic &diag : + for (llvm::lsp::Diagnostic &diag : llvm::drop_begin(diagnostics, currentNumDiags)) { chunk->adjustLocForChunkOffset(diag.range); @@ -1668,14 +1703,15 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, totalNumLines = lineOffset; } -PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) { +PDLTextFile::ChunkIterator +PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) { if (chunks.size() == 1) return chunks.begin(); // Search for the first chunk with a greater line offset, the previous chunk // is the one that contains `pos`. auto it = llvm::upper_bound( - chunks, pos, [](const lsp::Position &pos, const auto &chunk) { + chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) { return static_cast<uint64_t>(pos.line) < chunk->lineOffset; }); ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it); @@ -1710,9 +1746,9 @@ lsp::PDLLServer::PDLLServer(const Options &options) : impl(std::make_unique<Impl>(options)) {} lsp::PDLLServer::~PDLLServer() = default; -void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, - int64_t version, - std::vector<Diagnostic> &diagnostics) { +void lsp::PDLLServer::addDocument( + const URIForFile &uri, StringRef contents, int64_t version, + std::vector<llvm::lsp::Diagnostic> &diagnostics) { // Build the set of additional include directories. std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs; const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file()); @@ -1724,7 +1760,7 @@ void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, void lsp::PDLLServer::updateDocument( const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes, - int64_t version, std::vector<Diagnostic> &diagnostics) { + int64_t version, std::vector<llvm::lsp::Diagnostic> &diagnostics) { // Check that we actually have a document for this uri. auto it = impl->files.find(uri.file()); if (it == impl->files.end()) @@ -1746,17 +1782,17 @@ std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::PDLLServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector<Location> &locations) { +void lsp::PDLLServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector<llvm::lsp::Location> &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::PDLLServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector<Location> &references) { +void lsp::PDLLServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector<llvm::lsp::Location> &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1769,8 +1805,8 @@ void lsp::PDLLServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri, - const Position &hoverPos) { +std::optional<llvm::lsp::Hover> +lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->findHover(uri, hoverPos); @@ -1793,8 +1829,9 @@ lsp::PDLLServer::getCodeCompletion(const URIForFile &uri, return CompletionList(); } -lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, - const Position &helpPos) { +llvm::lsp::SignatureHelp +lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, + const Position &helpPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->getSignatureHelp(uri, helpPos); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h index 134431f..d82014d 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include <memory> #include <optional> #include <string> @@ -18,21 +19,22 @@ namespace mlir { namespace lsp { -struct Diagnostic; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::InlayHint; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::SignatureHelp; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; + class CompilationDatabase; struct PDLLViewOutputResult; enum class PDLLViewOutputKind; -struct CompletionList; -struct DocumentLink; -struct DocumentSymbol; -struct Hover; -struct InlayHint; -struct Location; -struct Position; -struct Range; -struct SignatureHelp; -struct TextDocumentContentChangeEvent; -class URIForFile; /// This class implements all of the PDLL related functionality necessary for a /// language server. This class allows for keeping the PDLL specific logic diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp index 0c9896e..ace4605 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Protocol.h" +#include "mlir/Support/LLVM.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h index 0706316..a2775f8 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h @@ -20,10 +20,12 @@ #ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" namespace mlir { namespace lsp { +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // PDLLViewOutputParams //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt index 80fc1ff..b21650e 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS Demangle Support TableGen + SupportLSP ) llvm_add_library(TableGenLspServerLib diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp index bb3c0a7..95a457f 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp @@ -9,14 +9,33 @@ #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include <optional> using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h index 501a9da..596688b 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class TableGenServer; /// Run the main loop of the LSP server using the given TableGen server and /// transport. llvm::LogicalResult runTableGenLSPServer(TableGenServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp index 21af78c..8014b8d 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/tblgen-lsp-server/TableGenLspServerMain.h" #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::TableGenLspServerMain(int argc, char **argv) { llvm::cl::opt<JSONStreamStyle> inputStyle{ "input-style", diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 5faeeae..3080b78 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -10,12 +10,12 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" @@ -36,45 +36,49 @@ static SMRange convertTokenLocToRange(SMLoc loc) { /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(const SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc); if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID())) return mainFileURI; - llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected<llvm::lsp::URIForFile> fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, loc.Start, uri), - lsp::Range(mgr, loc)); +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, loc.Start, uri), + llvm::lsp::Range(mgr, loc)); } -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &uri) { +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &uri) { return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri); } /// Convert the given TableGen diagnostic to the LSP form. -static std::optional<lsp::Diagnostic> +static std::optional<llvm::lsp::Diagnostic> getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, - const lsp::URIForFile &uri) { + const llvm::lsp::URIForFile &uri) { auto *sourceMgr = const_cast<SourceMgr *>(diag.getSourceMgr()); if (!sourceMgr || !diag.getLoc().isValid()) return std::nullopt; - lsp::Diagnostic lspDiag; + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "tablegen"; lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); + llvm::lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -84,17 +88,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, // Convert the severity for the diagnostic. switch (diag.getKind()) { case SourceMgr::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case SourceMgr::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case SourceMgr::DK_Note: // Notes are emitted separately from the main diagnostic, so we just treat // them as remarks given that we can't determine the diagnostic to relate // them to. case SourceMgr::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); @@ -322,54 +326,59 @@ namespace { /// This class represents a text file containing one or more TableGen documents. class TableGenTextFile { public: - TableGenTextFile(const lsp::URIForFile &uri, StringRef fileContents, + TableGenTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector<std::string> &extraIncludeDirs, - std::vector<lsp::Diagnostic> &diagnostics); + std::vector<llvm::lsp::Diagnostic> &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef<lsp::TextDocumentContentChangeEvent> changes, - std::vector<lsp::Diagnostic> &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes, + std::vector<llvm::lsp::Diagnostic> &diagnostics); //===--------------------------------------------------------------------===// // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector<lsp::Location> &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector<lsp::Location> &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector<llvm::lsp::Location> &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector<llvm::lsp::Location> &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - lsp::Hover buildHoverForRecord(const Record *record, - const SMRange &hoverRange); - lsp::Hover buildHoverForTemplateArg(const Record *record, + std::optional<llvm::lsp::Hover> + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + llvm::lsp::Hover buildHoverForRecord(const Record *record, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForTemplateArg(const Record *record, + const RecordVal *value, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, const SMRange &hoverRange); - lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, - const SMRange &hoverRange); private: /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector<lsp::Diagnostic> &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector<llvm::lsp::Diagnostic> &diagnostics); /// The full string contents of the file. std::string contents; @@ -395,9 +404,9 @@ private: } // namespace TableGenTextFile::TableGenTextFile( - const lsp::URIForFile &uri, StringRef fileContents, int64_t version, + const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector<std::string> &extraIncludeDirs, - std::vector<lsp::Diagnostic> &diagnostics) + std::vector<llvm::lsp::Diagnostic> &diagnostics) : contents(fileContents.str()), version(version) { // Build the set of include directories for this file. llvm::SmallString<32> uriDirectory(uri.file()); @@ -409,12 +418,13 @@ TableGenTextFile::TableGenTextFile( initialize(uri, version, diagnostics); } -LogicalResult -TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef<lsp::TextDocumentContentChangeEvent> changes, - std::vector<lsp::Diagnostic> &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +LogicalResult TableGenTextFile::update( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes, + std::vector<llvm::lsp::Diagnostic> &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -423,9 +433,9 @@ TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void TableGenTextFile::initialize(const lsp::URIForFile &uri, - int64_t newVersion, - std::vector<lsp::Diagnostic> &diagnostics) { +void TableGenTextFile::initialize( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector<llvm::lsp::Diagnostic> &diagnostics) { version = newVersion; sourceMgr = SourceMgr(); recordKeeper = std::make_unique<RecordKeeper>(); @@ -433,7 +443,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // Build a buffer for this file. auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } sourceMgr.setIncludeDirs(includeDirs); @@ -442,8 +453,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // This class provides a context argument for the SourceMgr diagnostic // handler. struct DiagHandlerContext { - std::vector<lsp::Diagnostic> &diagnostics; - const lsp::URIForFile &uri; + std::vector<llvm::lsp::Diagnostic> &diagnostics; + const llvm::lsp::URIForFile &uri; } handlerContext{diagnostics, uri}; // Set the diagnostic handler for the tablegen source manager. @@ -469,9 +480,9 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // TableGenTextFile: Definitions and References //===----------------------------------------------------------------------===// -void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector<lsp::Location> &locations) { +void TableGenTextFile::getLocationsOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &defPos, + std::vector<llvm::lsp::Location> &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -492,8 +503,8 @@ void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, } void TableGenTextFile::findReferencesOf( - const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector<lsp::Location> &references) { + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector<llvm::lsp::Location> &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -508,8 +519,9 @@ void TableGenTextFile::findReferencesOf( // TableGenTextFile: Document Links //===--------------------------------------------------------------------===// -void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector<lsp::DocumentLink> &links) { +void TableGenTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector<llvm::lsp::DocumentLink> &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -518,9 +530,9 @@ void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, // TableGenTextFile: Hover //===----------------------------------------------------------------------===// -std::optional<lsp::Hover> -TableGenTextFile::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional<llvm::lsp::Hover> +TableGenTextFile::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { // Check for a reference to an include. for (const lsp::SourceMgrInclude &include : parsedIncludes) if (include.range.contains(hoverPos)) @@ -546,9 +558,10 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri, return buildHoverForField(recordVal->record, value, hoverRange); } -lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +TableGenTextFile::buildHoverForRecord(const Record *record, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); @@ -590,9 +603,9 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, return hover; } -lsp::Hover TableGenTextFile::buildHoverForTemplateArg( +llvm::lsp::Hover TableGenTextFile::buildHoverForTemplateArg( const Record *record, const RecordVal *value, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); StringRef name = value->getName().rsplit(':').second; @@ -604,10 +617,9 @@ lsp::Hover TableGenTextFile::buildHoverForTemplateArg( return hover; } -lsp::Hover TableGenTextFile::buildHoverForField(const Record *record, - const RecordVal *value, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover TableGenTextFile::buildHoverForField( + const Record *record, const RecordVal *value, const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**field** `" << value->getName() << "`\n***\nType: `"; @@ -722,7 +734,7 @@ void lsp::TableGenServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional<lsp::Hover> +std::optional<llvm::lsp::Hover> lsp::TableGenServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h index bdc8510..e54b8bc 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include <memory> #include <optional> #include <string> @@ -18,13 +19,13 @@ namespace mlir { namespace lsp { -struct Diagnostic; -struct DocumentLink; -struct Hover; -struct Location; -struct Position; -struct TextDocumentContentChangeEvent; -class URIForFile; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; /// This class implements all of the TableGen related functionality necessary /// for a language server. This class allows for keeping the TableGen specific diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 36ee87b..df9700f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput( SmallVector<Value, 1>(replacements.begin(), replacements.end())}; } -LogicalResult TypeConverter::convertType(Type t, - SmallVectorImpl<Type> &results) const { - assert(t && "expected non-null type"); - +/// Internal implementation of the type conversion. +/// This is used with either a Type or a Value as the first argument. +/// - we can cache the context-free conversions until the last registered +/// context-aware conversion. +/// - we can't cache the result of type conversion happening after context-aware +/// conversions, because the type converter may return different results for the +/// same input type. +LogicalResult +TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue, + SmallVectorImpl<Type> &results) const { + assert(typeOrValue && "expected non-null type"); + Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType() + : cast<Type>(typeOrValue); { std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, std::defer_lock); @@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t, // registered first. size_t currentCount = results.size(); + // We can cache the context-free conversions until the last registered + // context-aware conversion. But only if we're processing a Value right now. + auto isCacheable = [&](int index) { + int numberOfConversionsUntilContextAware = + conversions.size() - 1 - contextAwareTypeConversionsIndex; + return index < numberOfConversionsUntilContextAware; + }; + std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, std::defer_lock); - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(t, results)) { - if (t.getContext()->isMultithreadingEnabled()) - cacheWriteLock.lock(); - if (!succeeded(*result)) { - assert(results.size() == currentCount && - "failed type conversion should not change results"); - cachedDirectConversions.try_emplace(t, nullptr); - return failure(); - } - auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); - if (newTypes.size() == 1) - cachedDirectConversions.try_emplace(t, newTypes.front()); - else - cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) { + const ConversionCallbackFn &converter = indexedConverter.value(); + std::optional<LogicalResult> result = converter(typeOrValue, results); + if (!result) { + assert(results.size() == currentCount && + "failed type conversion should not change results"); + continue; + } + if (!isCacheable(indexedConverter.index())) return success(); - } else { + if (t.getContext()->isMultithreadingEnabled()) + cacheWriteLock.lock(); + if (!succeeded(*result)) { assert(results.size() == currentCount && "failed type conversion should not change results"); + cachedDirectConversions.try_emplace(t, nullptr); + return failure(); } + auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); + if (newTypes.size() == 1) + cachedDirectConversions.try_emplace(t, newTypes.front()); + else + cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); + return success(); } return failure(); } -LogicalResult TypeConverter::convertType(Value v, +LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl<Type> &results) const { - assert(v && "expected non-null value"); - - // If this type converter does not have context-aware type conversions, call - // the type-based overload, which has caching. - if (!hasContextAwareTypeConversions) - return convertType(v.getType(), results); + return convertTypeImpl(t, results); +} - // Walk the added converters in reverse order to apply the most recently - // registered first. - for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { - if (std::optional<LogicalResult> result = converter(v, results)) { - if (!succeeded(*result)) - return failure(); - return success(); - } - } - return failure(); +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl<Type> &results) const { + return convertTypeImpl(v, results); } Type TypeConverter::convertType(Type t) const { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 8e79494..c983914 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -1,9 +1,5 @@ include(AddMLIRPython) -# Specifies that all MLIR packages are co-located under the `mlir_standalone` -# top level package (the API has been embedded in a relocatable way). -add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.") - ################################################################################ # Structural groupings. ################################################################################ @@ -27,6 +23,11 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python passmanager.py rewrite.py dialects/_ods_common.py + + # The main _mlir module has submodules: include stubs from each. + _mlir_libs/_mlir/__init__.pyi + _mlir_libs/_mlir/ir.pyi + _mlir_libs/_mlir/passmanager.pyi ) declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras @@ -42,6 +43,7 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine ADD_TO_PARENT MLIRPythonSources SOURCES execution_engine.py + _mlir_libs/_mlirExecutionEngine.pyi SOURCES_GLOB runtime/*.py ) @@ -193,6 +195,7 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/TransformOps.td SOURCES dialects/transform/__init__.py + _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform GEN_ENUM_BINDINGS_TD_FILE "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td" @@ -364,7 +367,8 @@ declare_mlir_python_sources( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" GEN_ENUM_BINDINGS SOURCES - dialects/quant.py) + dialects/quant.py + _mlir_libs/_mlir/dialects/quant.pyi) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -380,6 +384,7 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/PDLOps.td SOURCES dialects/pdl.py + _mlir_libs/_mlir/dialects/pdl.pyi DIALECT_NAME pdl) declare_mlir_dialect_python_bindings( @@ -505,11 +510,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core # Dialects MLIRCAPIFunc - GENERATE_TYPE_STUBS - "_mlir/__init__.pyi" - "_mlir/ir.pyi" - "_mlir/passmanager.pyi" - "_mlir/rewrite.pyi" ) # This extension exposes an API to register all dialects, extensions, and passes @@ -531,8 +531,6 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything MLIRCAPIConversion MLIRCAPITransforms MLIRCAPIRegisterEverything - GENERATE_TYPE_STUBS - "_mlirRegisterEverything.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind @@ -547,8 +545,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPILinalg - GENERATE_TYPE_STUBS - "_mlirDialectsLinalg.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind @@ -563,8 +559,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPIGPU - GENERATE_TYPE_STUBS - "_mlirDialectsGPU.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind @@ -579,8 +573,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPILLVM - GENERATE_TYPE_STUBS - "_mlirDialectsLLVM.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind @@ -595,8 +587,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPIQuant - GENERATE_TYPE_STUBS - "_mlirDialectsQuant.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind @@ -611,8 +601,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPINVGPU - GENERATE_TYPE_STUBS - "_mlirDialectsNVGPU.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind @@ -627,8 +615,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPIPDL - GENERATE_TYPE_STUBS - "_mlirDialectsPDL.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind @@ -643,8 +629,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPISparseTensor - GENERATE_TYPE_STUBS - "_mlirDialectsSparseTensor.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind @@ -659,8 +643,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind EMBED_CAPI_LINK_LIBS MLIRCAPIIR MLIRCAPITransformDialect - GENERATE_TYPE_STUBS - "_mlirDialectsTransform.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses @@ -674,8 +656,6 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPIAsync - GENERATE_TYPE_STUBS - "_mlirAsyncPasses.pyi" ) if(MLIR_ENABLE_EXECUTION_ENGINE) @@ -690,8 +670,6 @@ if(MLIR_ENABLE_EXECUTION_ENGINE) LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPIExecutionEngine - GENERATE_TYPE_STUBS - "_mlirExecutionEngine.pyi" ) endif() @@ -706,8 +684,6 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPIGPU - GENERATE_TYPE_STUBS - "_mlirGPUPasses.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses @@ -721,8 +697,6 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPILinalg - GENERATE_TYPE_STUBS - "_mlirLinalgPasses.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind @@ -740,8 +714,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind MLIRCAPIIR MLIRCAPISMT MLIRCAPIExportSMTLIB - GENERATE_TYPE_STUBS - "_mlirDialectsSMT.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses @@ -755,8 +727,6 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPISparseTensor - GENERATE_TYPE_STUBS - "_mlirSparseTensorPasses.pyi" ) declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter @@ -770,8 +740,6 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPITransformDialectTransforms - GENERATE_TYPE_STUBS - "_mlirTransformInterpreter.pyi" ) # TODO: Figure out how to put this in the test tree. @@ -830,8 +798,6 @@ if(MLIR_INCLUDE_TESTS) LLVMSupport EMBED_CAPI_LINK_LIBS MLIRCAPIPythonTestDialect - GENERATE_TYPE_STUBS - "_mlirPythonTestNanobind.pyi" ) endif() @@ -851,7 +817,7 @@ endif() add_mlir_python_common_capi_library(MLIRPythonCAPI INSTALL_COMPONENT MLIRPythonModules INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" - OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" + OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" RELATIVE_INSTALL_ROOT "../../../.." DECLARED_HEADERS MLIRPythonCAPI.HeaderSources @@ -880,7 +846,7 @@ endif() ################################################################################ add_mlir_python_modules(MLIRPythonModules - ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" + ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" DECLARED_SOURCES MLIRPythonSources diff --git a/mlir/python/mlir/_mlir_libs/.gitignore b/mlir/python/mlir/_mlir_libs/.gitignore deleted file mode 100644 index 8f0c82a..0000000 --- a/mlir/python/mlir/_mlir_libs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -_mlir/**/*.pyi -*.pyi diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi new file mode 100644 index 0000000..03449b7 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -0,0 +1,12 @@ + +globals: "_Globals" + +class _Globals: + dialect_search_modules: list[str] + def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... + def append_dialect_search_prefix(self, module_name: str) -> None: ... + def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... + +def register_dialect(dialect_class: type) -> type: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi new file mode 100644 index 0000000..d12c683 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi @@ -0,0 +1,63 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import Type, Context + +__all__ = [ + 'PDLType', + 'AttributeType', + 'OperationType', + 'RangeType', + 'TypeType', + 'ValueType', +] + + +class PDLType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + +class AttributeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> AttributeType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> OperationType: ... + + +class RangeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(element_type: Type) -> RangeType: ... + + @property + def element_type(self) -> Type: ... + + +class TypeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> TypeType: ... + + +class ValueType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> ValueType: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi new file mode 100644 index 0000000..3f53045 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -0,0 +1,142 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import DenseElementsAttr, Type + +__all__ = [ + "QuantizedType", + "AnyQuantizedType", + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "CalibratedQuantizedType", +] + +class QuantizedType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @staticmethod + def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @property + def expressed_type(self) -> Type: ... + + @property + def flags(self) -> int: ... + + @property + def is_signed(self) -> bool: ... + + @property + def storage_type(self) -> Type: ... + + @property + def storage_type_min(self) -> int: ... + + @property + def storage_type_max(self) -> int: ... + + @property + def storage_type_integral_width(self) -> int: ... + + def is_compatible_expressed_type(self, candidate: Type) -> bool: ... + + @property + def quantized_element_type(self) -> Type: ... + + def cast_from_storage_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_storage_type(type: Type) -> Type: ... + + def cast_from_expressed_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_expressed_type(type: Type) -> Type: ... + + def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ... + + +class AnyQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + storage_type_min: int, storage_type_max: int) -> Type: + ... + + +class UniformQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scale: float, zero_point: int, storage_type_min: int, + storage_type_max: int) -> Type: ... + + @property + def scale(self) -> float: ... + + @property + def zero_point(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +class UniformQuantizedPerAxisType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: list[float], zero_points: list[int], quantized_dimension: int, + storage_type_min: int, storage_type_max: int): + ... + + @property + def scales(self) -> list[float]: ... + + @property + def zero_points(self) -> list[int]: ... + + @property + def quantized_dimension(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + +class UniformQuantizedSubChannelType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: DenseElementsAttr, zero_points: DenseElementsAttr, + quantized_dimensions: list[int], block_sizes: list[int], + storage_type_min: int, storage_type_max: int): + ... + + @property + def quantized_dimensions(self) -> list[int]: ... + + @property + def block_sizes(self) -> list[int]: ... + + @property + def scales(self) -> DenseElementsAttr: ... + + @property + def zero_points(self) -> DenseElementsAttr: ... + +def CalibratedQuantizedType(QuantizedType): + + @classmethod + def get(cls, expressed_type: Type, min: float, max: float): ... + + @property + def min(self) -> float: ... + + @property + def max(self) -> float: ...
\ No newline at end of file diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi new file mode 100644 index 0000000..a3f1b09 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi @@ -0,0 +1,25 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import Type, Context + + +class AnyOpType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> AnyOpType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(operation_name: str, context: Context | None = None) -> OperationType: ... + + @property + def operation_name(self) -> str: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi new file mode 100644 index 0000000..dcae3dd --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -0,0 +1,2846 @@ +# Originally imported via: +# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir +# but with the following diff (in order to remove pipes from types, +# which we won't support until bumping minimum python to 3.10) +# +# --------------------- diff begins ------------------------------------ +# +# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py +# index 1f755aa..4924927 100644 +# --- a/pybind11_stubgen/printer.py +# +++ b/pybind11_stubgen/printer.py +# @@ -283,14 +283,6 @@ class Printer: +# return split[0] + "..." +# +# def print_type(self, type_: ResolvedType) -> str: +# - if ( +# - str(type_.name) == "typing.Optional" +# - and type_.parameters is not None +# - and len(type_.parameters) == 1 +# - ): +# - return f"{self.print_annotation(type_.parameters[0])} | None" +# - if str(type_.name) == "typing.Union" and type_.parameters is not None: +# - return " | ".join(self.print_annotation(p) for p in type_.parameters) +# if type_.parameters: +# param_str = ( +# "[" +# +# --------------------- diff ends ------------------------------------ +# +# Local modifications: +# * Rewrite references to 'mlir.ir' to local types. +# * Drop `typing.` everywhere (top-level import instead). +# * List -> List, dict -> Dict, Tuple -> Tuple. +# * copy-paste Buffer type from typing_extensions. +# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top. +# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`. +# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute. +# * Local edits to signatures and types that pybind11-stubgen did not auto detect (or detected incorrectly). +# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand. +# * Fill in `Any`s where possible. +# * black formatting. + +from __future__ import annotations + +import abc +import collections +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any, BinaryIO, ClassVar, Literal, TypeVar, overload + +__all__ = [ + "AffineAddExpr", + "AffineBinaryExpr", + "AffineCeilDivExpr", + "AffineConstantExpr", + "AffineDimExpr", + "AffineExpr", + "AffineExprList", + "AffineFloorDivExpr", + "AffineMap", + "AffineMapAttr", + "AffineModExpr", + "AffineMulExpr", + "AffineSymbolExpr", + "ArrayAttr", + "ArrayAttributeIterator", + "AsmState", + "AttrBuilder", + "Attribute", + "BF16Type", + "Block", + "BlockArgument", + "BlockArgumentList", + "BlockIterator", + "BlockList", + "BoolAttr", + "ComplexType", + "Context", + "DenseBoolArrayAttr", + "DenseBoolArrayIterator", + "DenseElementsAttr", + "DenseF32ArrayAttr", + "DenseF32ArrayIterator", + "DenseF64ArrayAttr", + "DenseF64ArrayIterator", + "DenseFPElementsAttr", + "DenseI16ArrayAttr", + "DenseI16ArrayIterator", + "DenseI32ArrayAttr", + "DenseI32ArrayIterator", + "DenseI64ArrayAttr", + "DenseI64ArrayIterator", + "DenseI8ArrayAttr", + "DenseI8ArrayIterator", + "DenseIntElementsAttr", + "DenseResourceElementsAttr", + "Diagnostic", + "DiagnosticHandler", + "DiagnosticInfo", + "DiagnosticSeverity", + "Dialect", + "DialectDescriptor", + "DialectRegistry", + "Dialects", + "DictAttr", + "F16Type", + "F32Type", + "F64Type", + "FlatSymbolRefAttr", + "Float4E2M1FNType", + "Float6E2M3FNType", + "Float6E3M2FNType", + "Float8E3M4Type", + "Float8E4M3B11FNUZType", + "Float8E4M3FNType", + "Float8E4M3FNUZType", + "Float8E4M3Type", + "Float8E5M2FNUZType", + "Float8E5M2Type", + "Float8E8M0FNUType", + "FloatAttr", + "FloatTF32Type", + "FloatType", + "FunctionType", + "IndexType", + "InferShapedTypeOpInterface", + "InferTypeOpInterface", + "InsertionPoint", + "IntegerAttr", + "IntegerSet", + "IntegerSetAttr", + "IntegerSetConstraint", + "IntegerSetConstraintList", + "IntegerType", + "Location", + "MemRefType", + "Module", + "NamedAttribute", + "NoneType", + "OpAttributeMap", + "OpOperand", + "OpOperandIterator", + "OpOperandList", + "OpResult", + "OpResultList", + "OpSuccessors", + "OpView", + "OpaqueAttr", + "OpaqueType", + "Operation", + "OperationIterator", + "OperationList", + "RankedTensorType", + "Region", + "RegionIterator", + "RegionSequence", + "ShapedType", + "ShapedTypeComponents", + "StridedLayoutAttr", + "StringAttr", + "SymbolRefAttr", + "SymbolTable", + "TupleType", + "Type", + "TypeAttr", + "TypeID", + "UnitAttr", + "UnrankedMemRefType", + "UnrankedTensorType", + "Value", + "VectorType", + "_GlobalDebug", + "_OperationBase", +] + +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): + pass + +class _OperationBase: + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + def __hash__(self) -> int: ... + def __str__(self) -> str: + """ + Returns the assembly form of the operation. + """ + def clone(self, ip: InsertionPoint = None) -> OpView: ... + def detach_from_parent(self) -> OpView: + """ + Detaches the operation from its parent block. + """ + + @property + def attached(self) -> bool: + """ + Reports if the operation is attached to its parent block. + """ + + def erase(self) -> None: ... + + @overload + def get_asm( + binary: Literal[True], + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + skip_regions: bool = False, + ) -> bytes: ... + @overload + def get_asm( + self, + binary: bool = False, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + skip_regions: bool = False, + ) -> str: + """ + Returns the assembly form of the operation. + + See the print() method for common keyword arguments for configuring + the output. + """ + + def move_after(self, other: _OperationBase) -> None: + """ + Puts self immediately after the other operation in its parent block. + """ + def move_before(self, other: _OperationBase) -> None: + """ + Puts self immediately before the other operation in its parent block. + """ + @overload + def print( + self, + state: AsmState, + file: Any | None = None, + binary: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. + """ + @overload + def print( + self, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + file: Any | None = None, + binary: bool = False, + skip_regions: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource strings above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_Scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + skip_regions: Whether to skip printing regions. Defaults to False. + """ + def verify(self) -> bool: + """ + Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. + """ + def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None: + """ + Write the bytecode form of the operation to a file like object. + + Args: + file: The file like object or path to write to. + desired_version: The version of bytecode to emit. + Returns: + The bytecode writer status. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def attributes(self) -> OpAttributeMap: ... + @property + def context(self) -> Context: + """ + Context that owns the Operation + """ + @property + def location(self) -> Location: + """ + Returns the source location the operation was defined or derived from. + """ + @property + def name(self) -> str: ... + @property + def operands(self) -> OpOperandList: ... + @property + def parent(self) -> _OperationBase | None: ... + @property + def regions(self) -> RegionSequence: ... + @property + def result(self) -> OpResult: + """ + Shortcut to get an op result if it has only one (throws an error otherwise). + """ + @property + def results(self) -> OpResultList: + """ + Returns the List of Operation results. + """ + +_TOperation = TypeVar("_TOperation", bound=_OperationBase) + +class AffineExpr: + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of two expressions. + """ + @staticmethod + @overload + def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of a constant and another expression. + """ + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of an expression and a constant. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing one expression by another. + """ + @staticmethod + @overload + def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing an expression by a constant. + """ + @staticmethod + def get_constant( + value: int, context: Context | None = None + ) -> AffineConstantExpr: + """ + Gets a constant affine expression with the given value. + """ + @staticmethod + def get_dim(position: int, context: Context | None = None) -> AffineDimExpr: + """ + Gets an affine expression of a dimension at the given position. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing one expression by another. + """ + @staticmethod + @overload + def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing an expression by a constant. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: + """ + Gets an affine expression containing the modulo of dividing one expression by another. + """ + @staticmethod + @overload + def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr: + """ + Gets a semi-affine expression containing the modulo of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr: + """ + Gets an affine expression containing the module of dividingan expression by a constant. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of two expressions. + """ + @staticmethod + @overload + def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of a constant and another expression. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr: + """ + Gets an affine expression containing a product of an expression and a constant. + """ + @staticmethod + def get_symbol( + position: int, context: Context | None = None + ) -> AffineSymbolExpr: + """ + Gets an affine expression of a symbol at the given position. + """ + def _CAPICreate(self) -> AffineExpr: ... + @overload + def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __add__(self, arg0: int) -> AffineAddExpr: ... + @overload + def __eq__(self, arg0: AffineExpr) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + @overload + def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ... + @overload + def __mod__(self, arg0: int) -> AffineModExpr: ... + @overload + def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ... + @overload + def __mul__(self, arg0: int) -> AffineMulExpr: ... + def __radd__(self, arg0: int) -> AffineAddExpr: ... + def __rmod__(self, arg0: int) -> AffineModExpr: ... + def __rmul__(self, arg0: int) -> AffineMulExpr: ... + def __rsub__(self, arg0: int) -> AffineAddExpr: ... + @overload + def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __sub__(self, arg0: int) -> AffineAddExpr: ... + def compose(self, arg0: AffineMap) -> AffineExpr: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: ... + +class Attribute: + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Attribute: + """ + Parses an attribute from an assembly form. Raises an MLIRError on failure. + """ + def _CAPICreate(self) -> Attribute: ... + @overload + def __eq__(self, arg0: Attribute) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Attribute) -> None: + """ + Casts the passed attribute to the generic Attribute + """ + def __str__(self) -> str: + """ + Returns the assembly form of the Attribute. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_named(self, arg0: str) -> NamedAttribute: + """ + Binds a name to the attribute + """ + def maybe_downcast(self) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Attribute + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class Type: + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Type: + """ + Parses the assembly form of a type. + + Returns a Type object or raises an MLIRError if the type cannot be parsed. + + See also: https://mlir.llvm.org/docs/LangRef/#type-system + """ + def _CAPICreate(self) -> Type: ... + @overload + def __eq__(self, arg0: Type) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Type) -> None: + """ + Casts the passed type to the generic Type + """ + def __str__(self) -> str: + """ + Returns the assembly form of the type. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def maybe_downcast(self) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Type + """ + @property + def typeid(self) -> TypeID: ... + +class Value: + def _CAPICreate(self) -> Value: ... + @overload + def __eq__(self, arg0: Value) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, value: Value) -> None: ... + def __str__(self) -> str: + """ + Returns the string form of the value. + + If the value is a block argument, this is the assembly form of its type and the + position in the argument List. If the value is an operation result, this is + equivalent to printing the operation that produced it. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @overload + def get_name(self, use_local_scope: bool = False, use_name_loc_as_prefix: bool = True) -> str: ... + @overload + def get_name(self, state: AsmState) -> str: + """ + Returns the string form of value as an operand (i.e., the ValueID). + """ + def maybe_downcast(self) -> Any: ... + def replace_all_uses_with(self, arg0: Value) -> None: + """ + Replace all uses of value with the new value, updating anything in + the IR that uses 'self' to use the other value instead. + """ + def set_type(self, type: Type) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context in which the value lives. + """ + @property + def owner(self) -> _OperationBase: ... + @property + def type(self) -> Type: ... + @property + def uses(self) -> OpOperandIterator: ... + +class AffineAddExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineBinaryExpr(AffineExpr): + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def lhs(self) -> AffineExpr: ... + @property + def rhs(self) -> AffineExpr: ... + +class AffineCeilDivExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineConstantExpr(AffineExpr): + @staticmethod + def get(value: int, context: Context | None = None) -> AffineConstantExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def value(self) -> int: ... + +class AffineDimExpr(AffineExpr): + @staticmethod + def get(position: int, context: Context | None = None) -> AffineDimExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def position(self) -> int: ... + +class AffineExprList: + def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ... + +class AffineFloorDivExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineMap: + @staticmethod + def compress_unused_symbols( + arg0: list, arg1: Context | None + ) -> list[AffineMap]: ... + @staticmethod + def get( + dim_count: int, + symbol_count: int, + exprs: list, + context: Context | None = None, + ) -> AffineMap: + """ + Gets a map with the given expressions as results. + """ + @staticmethod + def get_constant(value: int, context: Context | None = None) -> AffineMap: + """ + Gets an affine map with a single constant result + """ + @staticmethod + def get_empty(context: Context | None = None) -> AffineMap: + """ + Gets an empty affine map. + """ + @staticmethod + def get_identity(n_dims: int, context: Context | None = None) -> AffineMap: + """ + Gets an identity map with the given number of dimensions. + """ + @staticmethod + def get_minor_identity( + n_dims: int, n_results: int, context: Context | None = None + ) -> AffineMap: + """ + Gets a minor identity map with the given number of dimensions and results. + """ + @staticmethod + def get_permutation( + permutation: list[int], context: Context | None = None + ) -> AffineMap: + """ + Gets an affine map that permutes its inputs. + """ + def _CAPICreate(self) -> AffineMap: ... + @overload + def __eq__(self, arg0: AffineMap) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_major_submap(self, n_results: int) -> AffineMap: ... + def get_minor_submap(self, n_results: int) -> AffineMap: ... + def get_submap(self, result_positions: list[int]) -> AffineMap: ... + def replace( + self, + expr: AffineExpr, + replacement: AffineExpr, + n_result_dims: int, + n_result_syms: int, + ) -> AffineMap: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Affine Map + """ + @property + def is_permutation(self) -> bool: ... + @property + def is_projected_permutation(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + @property + def results(self) -> AffineMapExprList: ... + +class AffineMapAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(affine_map: AffineMap) -> AffineMapAttr: + """ + Gets an attribute wrapping an AffineMap. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class AffineModExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineMulExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineSymbolExpr(AffineExpr): + @staticmethod + def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def position(self) -> int: ... + +class ArrayAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(attributes: list, context: Context | None = None) -> ArrayAttr: + """ + Gets a uniqued Array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> ArrayAttr: ... + def __getitem__(self, arg0: int) -> Attribute: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> ArrayAttributeIterator: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class ArrayAttributeIterator: + def __iter__(self) -> ArrayAttributeIterator: ... + def __next__(self) -> Attribute: ... + +class AsmState: + @overload + def __init__(self, value: Value, use_local_scope: bool = False) -> None: ... + @overload + def __init__(self, op: _OperationBase, use_local_scope: bool = False) -> None: ... + +class AttrBuilder: + @staticmethod + def contains(arg0: str) -> bool: ... + @staticmethod + def get(arg0: str) -> Callable: ... + @staticmethod + def insert( + attribute_kind: str, attr_builder: Callable, replace: bool = False + ) -> None: + """ + Register an attribute builder for building MLIR attributes from python values. + """ + +class BF16Type(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> BF16Type: + """ + Create a bf16 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Block: + @staticmethod + def create_at_start( + parent: Region, + arg_types: list[Type], + arg_locs: Sequence | None = None, + ) -> Block: + """ + Creates and returns a new Block at the beginning of the given region (with given argument types and locations). + """ + @overload + def __eq__(self, arg0: Block) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + def __iter__(self) -> OperationIterator: + """ + Iterates over operations in the block. + """ + def __str__(self) -> str: + """ + Returns the assembly form of the block. + """ + def append(self, operation: _OperationBase) -> None: + """ + Appends an operation to this block. If the operation is currently in another block, it will be moved. + """ + def append_to(self, arg0: Region) -> None: + """ + Append this block to a region, transferring ownership if necessary + """ + def create_after(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Creates and returns a new Block after this block (with given argument types and locations). + """ + def create_before(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Creates and returns a new Block before this block (with given argument types and locations). + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def arguments(self) -> BlockArgumentList: + """ + Returns a List of block arguments. + """ + @property + def operations(self) -> OperationList: + """ + Returns a forward-optimized sequence of operations. + """ + @property + def owner(self) -> OpView: + """ + Returns the owning operation of this block. + """ + @property + def region(self) -> Region: + """ + Returns the owning region of this block. + """ + +class BlockArgument(Value): + @staticmethod + def isinstance(other_value: Value) -> bool: ... + def __init__(self, value: Value) -> None: ... + def maybe_downcast(self) -> Any: ... + def set_type(self, type: Type) -> None: ... + @property + def arg_number(self) -> int: ... + @property + def owner(self) -> Block: ... + +class BlockArgumentList: + @overload + def __getitem__(self, arg0: int) -> BlockArgument: ... + @overload + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __len__(self) -> int: ... + def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ... + @property + def types(self) -> list[Type]: ... + +class BlockIterator: + def __iter__(self) -> BlockIterator: ... + def __next__(self) -> Block: ... + +class BlockList: + def __getitem__(self, arg0: int) -> Block: ... + def __iter__(self) -> BlockIterator: ... + def __len__(self) -> int: ... + def append(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Appends a new block, with argument types as positional args. + + Returns: + The created block. + """ + +class BoolAttr(Attribute): + @staticmethod + def get(value: bool, context: Context | None = None) -> BoolAttr: + """ + Gets an uniqued bool attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __bool__(self: Attribute) -> bool: + """ + Converts the value of the bool attribute to a Python bool + """ + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> bool: + """ + Returns the value of the bool attribute + """ + +class ComplexType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(arg0: Type) -> ComplexType: + """ + Create a complex type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def element_type(self) -> Type: + """ + Returns element type. + """ + @property + def typeid(self) -> TypeID: ... + +class Context: + current: ClassVar[Context] = ... # read-only + allow_unregistered_dialects: bool + @staticmethod + def _get_live_count() -> int: ... + def _CAPICreate(self) -> object: ... + def __enter__(self) -> Context: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... + def __init__(self) -> None: ... + def _clear_live_operations(self) -> int: ... + def _get_context_again(self) -> Context: ... + def _get_live_module_count(self) -> int: ... + def _get_live_operation_count(self) -> int: ... + def _get_live_operation_objects(self) -> list[Operation]: ... + def append_dialect_registry(self, registry: DialectRegistry) -> None: ... + def attach_diagnostic_handler( + self, callback: Callable[[Diagnostic], bool] + ) -> DiagnosticHandler: + """ + Attaches a diagnostic handler that will receive callbacks + """ + def enable_multithreading(self, enable: bool) -> None: ... + def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: + """ + Gets or loads a dialect by name, returning its descriptor object + """ + def is_registered_operation(self, operation_name: str) -> bool: ... + def load_all_available_dialects(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def d(self) -> Dialects: + """ + Alias for 'dialect' + """ + @property + def dialects(self) -> Dialects: + """ + Gets a container for accessing dialects by name + """ + +class DenseBoolArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[bool], context: Context | None = None + ) -> DenseBoolArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseBoolArrayAttr: ... + def __getitem__(self, arg0: int) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseBoolArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseBoolArrayIterator: + def __iter__(self) -> DenseBoolArrayIterator: ... + def __next__(self) -> bool: ... + +class DenseElementsAttr(Attribute): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr from a Python buffer or array. + + When `type` is not provided, then some limited type inferencing is done based + on the buffer format. Support presently exists for 8/16/32/64 signed and + unsigned integers and float16/float32/float64. DenseElementsAttrs of these + types can also be converted back to a corresponding buffer. + + For conversions outside of these types, a `type=` must be explicitly provided + and the buffer contents must be bit-castable to the MLIR internal + representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + + If a single element buffer is passed (or for i1, a single byte with value 0 + or 255), then a splat will be created. + + Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + + Returns: + DenseElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def get_splat(shaped_type: Type, element_attr: Attribute) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr where all values are the same + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __len__(self) -> int: ... + def get_splat_value(self) -> Attribute: ... + @property + def is_splat(self) -> bool: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF32ArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[float], context: Context | None = None + ) -> DenseF32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseF32ArrayAttr: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF32ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF32ArrayIterator: + def __iter__(self) -> DenseF32ArrayIterator: ... + def __next__(self) -> float: ... + +class DenseF64ArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[float], context: Context | None = None + ) -> DenseF64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseF64ArrayAttr: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF64ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF64ArrayIterator: + def __iter__(self) -> DenseF64ArrayIterator: ... + def __next__(self) -> float: ... + +class DenseFPElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseFPElementsAttr: ... + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI16ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI16ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI16ArrayIterator: + def __iter__(self) -> DenseI16ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI32ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI32ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI32ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI32ArrayIterator: + def __iter__(self) -> DenseI32ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI64ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI64ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI64ArrayIterator: + def __iter__(self) -> DenseI64ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI8ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI8ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI8ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI8ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI8ArrayIterator: + def __iter__(self) -> DenseI8ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseIntElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseIntElementsAttr: ... + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseResourceElementsAttr(Attribute): + @staticmethod + def get_from_buffer( + array: Buffer, + name: str, + type: Type, + alignment: int | None = None, + is_mutable: bool = False, + context: Context | None = None, + ) -> DenseResourceElementsAttr: + """ + Gets a DenseResourceElementsAttr from a Python buffer or array. + + This function does minimal validation or massaging of the data, and it is + up to the caller to ensure that the buffer meets the characteristics + implied by the shape. + + The backing buffer and any user objects will be retained for the lifetime + of the resource blob. This is typically bounded to the context but the + resource can have a shorter lifespan depending on how it is used in + subsequent processing. + + Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + + Returns: + DenseResourceElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class Diagnostic: + @property + def location(self) -> Location: ... + @property + def message(self) -> str: ... + @property + def notes(self) -> tuple[Diagnostic]: ... + @property + def severity(self) -> DiagnosticSeverity: ... + +class DiagnosticHandler: + def __enter__(self) -> DiagnosticHandler: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def detach(self) -> None: ... + @property + def attached(self) -> bool: ... + @property + def had_error(self) -> bool: ... + +class DiagnosticInfo: + def __init__(self, arg0: Diagnostic) -> None: ... + @property + def location(self) -> Location: ... + @property + def message(self) -> str: ... + @property + def notes(self) -> list[DiagnosticInfo]: ... + @property + def severity(self) -> DiagnosticSeverity: ... + +class DiagnosticSeverity: + """ + Members: + + ERROR + + WARNING + + NOTE + + REMARK + """ + + ERROR: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.ERROR: 0> + NOTE: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.NOTE: 2> + REMARK: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.REMARK: 3> + WARNING: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.WARNING: 1> + __members__: ClassVar[ + dict[str, DiagnosticSeverity] + ] # value = {'ERROR': <DiagnosticSeverity.ERROR: 0>, 'WARNING': <DiagnosticSeverity.WARNING: 1>, 'NOTE': <DiagnosticSeverity.NOTE: 2>, 'REMARK': <DiagnosticSeverity.REMARK: 3>} + def __eq__(self, other: Any) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: int) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: Any) -> bool: ... + def __setstate__(self, state: int) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class Dialect: + def __init__(self, descriptor: DialectDescriptor) -> None: ... + @property + def descriptor(self) -> DialectDescriptor: ... + +class DialectDescriptor: + @property + def namespace(self) -> str: ... + +class DialectRegistry: + def _CAPICreate(self) -> DialectRegistry: ... + def __init__(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + +class Dialects: + def __getattr__(self, arg0: str) -> Dialect: ... + def __getitem__(self, arg0: str) -> Dialect: ... + +class DictAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: dict = {}, context: Context | None = None) -> DictAttr: + """ + Gets an uniqued Dict attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __contains__(self, arg0: str) -> bool: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class FloatType(Type): + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def width(self) -> int: + """ + Returns the width of the floating-point type. + """ + +class F16Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F16Type: + """ + Create a f16 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class F32Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F32Type: + """ + Create a f32 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class F64Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F64Type: + """ + Create a f64 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FlatSymbolRefAttr(Attribute): + @staticmethod + def get(value: str, context: Context | None = None) -> FlatSymbolRefAttr: + """ + Gets a uniqued FlatSymbolRef attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> str: + """ + Returns the value of the FlatSymbolRef attribute as a string + """ + +class Float4E2M1FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float4E2M1FNType: + """ + Create a float4_e2m1fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float6E2M3FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float6E2M3FNType: + """ + Create a float6_e2m3fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float6E3M2FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float6E3M2FNType: + """ + Create a float6_e3m2fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E3M4Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E3M4Type: + """ + Create a float8_e3m4 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3B11FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3B11FNUZType: + """ + Create a float8_e4m3b11fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3FNType: + """ + Create a float8_e4m3fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3FNUZType: + """ + Create a float8_e4m3fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3Type: + """ + Create a float8_e4m3 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E5M2FNUZType: + """ + Create a float8_e5m2fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E5M2Type: + """ + Create a float8_e5m2 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E8M0FNUType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E8M0FNUType: + """ + Create a float8_e8m0fnu type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FloatAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(type: Type, value: float, loc: Location | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a type + """ + @staticmethod + def get_f32(value: float, context: Context | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f32 type + """ + @staticmethod + def get_f64(value: float, context: Context | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f64 type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __float__(self: Attribute) -> float: + """ + Converts the value of the float attribute to a Python float + """ + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> float: + """ + Returns the value of the float attribute + """ + +class FloatTF32Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> FloatTF32Type: + """ + Create a tf32 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FunctionType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + inputs: list[Type], results: list[Type], context: Context | None = None + ) -> FunctionType: + """ + Gets a FunctionType from a List of input and result types + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def inputs(self) -> list: + """ + Returns the List of input types in the FunctionType. + """ + @property + def results(self) -> list: + """ + Returns the List of result types in the FunctionType. + """ + @property + def typeid(self) -> TypeID: ... + +class IndexType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> IndexType: + """ + Create a index type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class InferShapedTypeOpInterface: + def __init__(self, object: object, context: Context | None = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypeComponents( + self, + operands: list | None = None, + attributes: Attribute | None = None, + properties=None, + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[ShapedTypeComponents]: + """ + Given the arguments required to build an operation, attempts to infer + its return shaped type components. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ + +class InferTypeOpInterface: + def __init__(self, object: object, context: Context | None = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypes( + self, + operands: list | None = None, + attributes: Attribute | None = None, + properties=None, + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[Type]: + """ + Given the arguments required to build an operation, attempts to infer + its return types. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ + +class InsertionPoint: + current: ClassVar[InsertionPoint] = ... # read-only + @staticmethod + def at_block_begin(block: Block) -> InsertionPoint: + """ + Inserts at the beginning of the block. + """ + @staticmethod + def at_block_terminator(block: Block) -> InsertionPoint: + """ + Inserts before the block terminator. + """ + def __enter__(self) -> InsertionPoint: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... + @overload + def __init__(self, block: Block) -> None: + """ + Inserts after the last operation but still inside the block. + """ + @overload + def __init__(self, beforeOperation: _OperationBase) -> None: + """ + Inserts before a referenced operation. + """ + def insert(self, operation: _OperationBase) -> None: + """ + Inserts an operation. + """ + @property + def block(self) -> Block: + """ + Returns the block that this InsertionPoint points to. + """ + @property + def ref_operation(self) -> _OperationBase | None: + """ + The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block + """ + +class IntegerAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(type: Type, value: int) -> IntegerAttr: + """ + Gets an uniqued integer attribute associated to a type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __int__(self) -> int: + """ + Converts the value of the integer attribute to a Python int + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> int: + """ + Returns the value of the integer attribute + """ + +class IntegerSet: + @staticmethod + def get( + num_dims: int, + num_symbols: int, + exprs: list, + eq_flags: list[bool], + context: Context | None = None, + ) -> IntegerSet: ... + @staticmethod + def get_empty( + num_dims: int, num_symbols: int, context: Context | None = None + ) -> IntegerSet: ... + def _CAPICreate(self) -> IntegerSet: ... + @overload + def __eq__(self, arg0: IntegerSet) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_replaced( + self, + dim_exprs: list, + symbol_exprs: list, + num_result_dims: int, + num_result_symbols: int, + ) -> IntegerSet: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def constraints(self) -> IntegerSetConstraintList: ... + @property + def context(self) -> Context: ... + @property + def is_canonical_empty(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_equalities(self) -> int: ... + @property + def n_inequalities(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + +class IntegerSetAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(integer_set) -> IntegerSetAttr: + """ + Gets an attribute wrapping an IntegerSet. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class IntegerSetConstraint: + def __init__(self, *args, **kwargs) -> None: ... + @property + def expr(self) -> AffineExpr: ... + @property + def is_eq(self) -> bool: ... + +class IntegerSetConstraintList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: IntegerSetConstraintList) -> list[IntegerSetConstraint]: ... + @overload + def __getitem__(self, arg0: int) -> IntegerSetConstraint: ... + @overload + def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... + def __len__(self) -> int: ... + +class IntegerType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get_signed(width: int, context: Context | None = None) -> IntegerType: + """ + Create a signed integer type + """ + @staticmethod + def get_signless(width: int, context: Context | None = None) -> IntegerType: + """ + Create a signless integer type + """ + @staticmethod + def get_unsigned(width: int, context: Context | None = None) -> IntegerType: + """ + Create an unsigned integer type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def is_signed(self) -> bool: + """ + Returns whether this is a signed integer + """ + @property + def is_signless(self) -> bool: + """ + Returns whether this is a signless integer + """ + @property + def is_unsigned(self) -> bool: + """ + Returns whether this is an unsigned integer + """ + @property + def typeid(self) -> TypeID: ... + @property + def width(self) -> int: + """ + Returns the width of the integer type + """ + +class Location: + current: ClassVar[Location] = ... # read-only + __hash__: ClassVar[None] = None + @staticmethod + def callsite( + callee: Location, frames: Sequence[Location], context: Context | None = None + ) -> Location: + """ + Gets a Location representing a caller and callsite + """ + @staticmethod + def file( + filename: str, line: int, col: int, context: Context | None = None + ) -> Location: + """ + Gets a Location representing a file, line and column + """ + @staticmethod + def from_attr(attribute: Attribute, context: Context | None = None) -> Location: + """ + Gets a Location from a LocationAttr + """ + @staticmethod + def fused( + locations: Sequence[Location], + metadata: Attribute | None = None, + context: Context | None = None, + ) -> Location: + """ + Gets a Location representing a fused location with optional metadata + """ + @staticmethod + def name( + name: str, + childLoc: Location | None = None, + context: Context | None = None, + ) -> Location: + """ + Gets a Location representing a named location with optional child location + """ + @staticmethod + def unknown(context: Context | None = None) -> Location: + """ + Gets a Location representing an unknown location + """ + def _CAPICreate(self) -> Location: ... + def __enter__(self) -> Location: ... + @overload + def __eq__(self, arg0: Location) -> bool: ... + @overload + def __eq__(self, arg0: Location) -> bool: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def emit_error(self, message: str) -> None: + """ + Emits an error at this location + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def attr(self) -> Attribute: + """ + Get the underlying LocationAttr + """ + @property + def context(self) -> Context: + """ + Context that owns the Location + """ + +class MemRefType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + layout: Attribute = None, + memory_space: Attribute = None, + loc: Location | None = None, + ) -> MemRefType: + """ + Create a memref type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def affine_map(self) -> AffineMap: + """ + The layout of the MemRef type as an affine map. + """ + @property + def layout(self) -> Attribute: + """ + The layout of the MemRef type. + """ + @property + def memory_space(self) -> Attribute | None: + """ + Returns the memory space of the given MemRef type. + """ + @property + def typeid(self) -> TypeID: ... + def get_strides_and_offset(self) -> tuple[list[int], int]: + """ + The strides and offset of the MemRef type. + """ + +class Module: + @staticmethod + def create(loc: Location | None = None) -> Module: + """ + Creates an empty module + """ + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Module: + """ + Parses a module's assembly format from a string. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + + See also: https://mlir.llvm.org/docs/LangRef/ + """ + @staticmethod + def parseFile(path: str, context: Context | None = None) -> Module: + """ + Parses a module's assembly format from file. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + + See also: https://mlir.llvm.org/docs/LangRef/ + """ + def _CAPICreate(self) -> Any: ... + def __str__(self) -> str: + """ + Gets the assembly form of the operation with default options. + + If more advanced control over the assembly formatting or I/O options is needed, + use the dedicated print or get_asm method, which supports keyword arguments to + customize behavior. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def body(self) -> Block: + """ + Return the block for this module + """ + @property + def context(self) -> Context: + """ + Context that created the Module + """ + @property + def operation(self) -> Operation: + """ + Accesses the module as an operation + """ + +class MLIRError(Exception): + def __init__( + self, message: str, error_diagnostics: list[DiagnosticInfo] + ) -> None: ... + +class NamedAttribute: + @property + def attr(self) -> Attribute: + """ + The underlying generic attribute of the NamedAttribute binding + """ + @property + def name(self) -> str: + """ + The name of the NamedAttribute binding + """ + +class NoneType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> NoneType: + """ + Create a none type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class OpAttributeMap: + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... + +class OpOperand: + @property + def operand_number(self) -> int: ... + @property + def owner(self) -> _OperationBase: ... + +class OpOperandIterator: + def __iter__(self) -> OpOperandIterator: ... + def __next__(self) -> OpOperand: ... + +class OpOperandList: + def __add__(self, arg0: OpOperandList) -> list[Value]: ... + @overload + def __getitem__(self, arg0: int) -> Value: ... + @overload + def __getitem__(self, arg0: slice) -> OpOperandList: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: int, arg1: Value) -> None: ... + +class OpResult(Value): + @staticmethod + def isinstance(other_value: Value) -> bool: ... + def __init__(self, value: Value) -> None: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + @property + def owner(self) -> _OperationBase: ... + @property + def result_number(self) -> int: ... + +class OpResultList: + def __add__(self, arg0: OpResultList) -> list[OpResult]: ... + @overload + def __getitem__(self, arg0: int) -> OpResult: ... + @overload + def __getitem__(self, arg0: slice) -> OpResultList: ... + def __len__(self) -> int: ... + @property + def owner(self) -> _OperationBase: ... + @property + def types(self) -> list[Type]: ... + +class OpSuccessors: + def __add__(self, arg0: OpSuccessors) -> list[Block]: ... + @overload + def __getitem__(self, arg0: int) -> Block: ... + @overload + def __getitem__(self, arg0: slice) -> OpSuccessors: ... + def __setitem__(self, arg0: int, arg1: Block) -> None: ... + def __len__(self) -> int: ... + +class OpView(_OperationBase): + _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... + _ODS_REGIONS: ClassVar[tuple] = ... + _ODS_RESULT_SEGMENTS: ClassVar[None] = ... + def __init__(self, operation: _OperationBase) -> None: ... + @classmethod + def build_generic( + cls: type[_TOperation], + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: int | None = None, + loc: Location | None = None, + ip: InsertionPoint | None = None, + ) -> _TOperation: + """ + Builds a specific, generated OpView based on class level attributes. + """ + @classmethod + def parse( + cls: type[_TOperation], + source: str | bytes, + *, + source_name: str = "", + context: Context | None = None, + ) -> _TOperation: + """ + Parses a specific, generated OpView based on class level attributes + """ + def __init__(self, operation: _OperationBase) -> None: ... + @property + def operation(self) -> _OperationBase: ... + @property + def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ + +class OpaqueAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + dialect_namespace: str, + buffer: Buffer, + type: Type, + context: Context | None = None, + ) -> OpaqueAttr: + """ + Gets an Opaque attribute. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def data(self) -> bytes: + """ + Returns the data for the Opaqued attributes as `bytes` + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque attribute as a string + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class OpaqueType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + dialect_namespace: str, buffer: str, context: Context | None = None + ) -> OpaqueType: + """ + Create an unregistered (opaque) dialect type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def data(self) -> str: + """ + Returns the data for the Opaque type as a string. + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque type as a string. + """ + @property + def typeid(self) -> TypeID: ... + +class Operation(_OperationBase): + def _CAPICreate(self) -> object: ... + @staticmethod + def create( + name: str, + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: int = 0, + loc: Location | None = None, + ip: InsertionPoint | None = None, + infer_type: bool = False, + ) -> Operation: + """ + Creates a new operation. + + Args: + name: Operation name (e.g. "dialect.operation"). + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + loc: A Location object (defaults to resolve from context manager). + ip: An InsertionPoint (defaults to resolve from context manager or set to + False to disable insertion, even with an insertion point set in the + context manager). + infer_type: Whether to infer result types. + Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." + """ + @staticmethod + def parse( + source: str | bytes, *, source_name: str = "", context: Context | None = None + ) -> Operation: + """ + Parses an operation. Supports both text assembly format and binary bytecode format. + """ + def _CAPICreate(self) -> object: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def operation(self) -> Operation: ... + @property + def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ + +class OperationIterator: + def __iter__(self) -> OperationIterator: ... + def __next__(self) -> OpView: ... + +class OperationList: + def __getitem__(self, arg0: int) -> OpView: ... + def __iter__(self) -> OperationIterator: ... + def __len__(self) -> int: ... + +class RankedTensorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + encoding: Attribute | None = None, + loc: Location | None = None, + ) -> RankedTensorType: + """ + Create a ranked tensor type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def encoding(self) -> Attribute | None: ... + @property + def typeid(self) -> TypeID: ... + +class Region: + __hash__: ClassVar[None] = None + @overload + def __eq__(self, arg0: Region) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __iter__(self) -> BlockIterator: + """ + Iterates over blocks in the region. + """ + @property + def blocks(self) -> BlockList: + """ + Returns a forward-optimized sequence of blocks. + """ + @property + def owner(self) -> OpView: + """ + Returns the operation owning this region. + """ + +class RegionIterator: + def __iter__(self) -> RegionIterator: ... + def __next__(self) -> Region: ... + +class RegionSequence: + @overload + def __getitem__(self, arg0: int) -> Region: ... + @overload + def __getitem__(self, arg0: slice) -> Sequence[Region]: ... + def __iter__(self) -> RegionIterator: ... + def __len__(self) -> int: ... + +class ShapedType(Type): + @staticmethod + def get_dynamic_size() -> int: + """ + Returns the value used to indicate dynamic dimensions in shaped types. + """ + @staticmethod + def get_dynamic_stride_or_offset() -> int: + """ + Returns the value used to indicate dynamic strides or offsets in shaped types. + """ + @staticmethod + def is_dynamic_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a dynamic dimension. + """ + @staticmethod + def is_static_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a static dimension. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def get_dim_size(self, dim: int) -> int: + """ + Returns the dim-th dimension of the given ranked shaped type. + """ + def is_dynamic_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is dynamic. + """ + def is_static_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is static. + """ + def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. + """ + def is_static_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given shaped type stride or offset value is statically-sized. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type is ranked. + """ + @property + def has_static_shape(self) -> bool: + """ + Returns whether the given shaped type has a static shape. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type. + """ + @property + def shape(self) -> list[int]: + """ + Returns the shape of the ranked shaped type as a List of integers. + """ + @property + def static_typeid(self) -> TypeID: ... + @property + def typeid(self) -> TypeID: ... + +class ShapedTypeComponents: + @staticmethod + @overload + def get(element_type: Type) -> ShapedTypeComponents: + """ + Create an shaped type components object with only the element type. + """ + @staticmethod + @overload + def get(shape: list, element_type: Type) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object. + """ + @staticmethod + @overload + def get( + shape: list, element_type: Type, attribute: Attribute + ) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object with attribute. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type components. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type component is ranked. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned. + """ + @property + def shape(self) -> list[int]: + """ + Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank. + """ + +class StridedLayoutAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + offset: int, strides: list[int], context: Context | None = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute. + """ + @staticmethod + def get_fully_dynamic( + rank: int, context: Context | None = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute with dynamic offset and strides of a given rank. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def offset(self) -> int: + """ + Returns the value of the float point attribute + """ + @property + def strides(self) -> list[int]: + """ + Returns the value of the float point attribute + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class StringAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: str | bytes, context: Context | None = None) -> StringAttr: + """ + Gets a uniqued string attribute + """ + @staticmethod + def get_typed(type: Type, value: str) -> StringAttr: + """ + Gets a uniqued string attribute associated to a type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> str: + """ + Returns the value of the string attribute + """ + @property + def value_bytes(self) -> bytes: + """ + Returns the value of the string attribute as `bytes` + """ + +class SymbolRefAttr(Attribute): + @staticmethod + def get(symbols: list[str], context: Context | None = None) -> Attribute: + """ + Gets a uniqued SymbolRef attribute from a List of symbol names + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> list[str]: + """ + Returns the value of the SymbolRef attribute as a List[str] + """ + +class SymbolTable: + @staticmethod + def get_symbol_name(symbol: _OperationBase) -> Attribute: ... + @staticmethod + def get_visibility(symbol: _OperationBase) -> Attribute: ... + @staticmethod + def replace_all_symbol_uses( + old_symbol: str, new_symbol: str, from_op: _OperationBase + ) -> None: ... + @staticmethod + def set_symbol_name(symbol: _OperationBase, name: str) -> None: ... + @staticmethod + def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... + @staticmethod + def walk_symbol_tables( + from_op: _OperationBase, + all_sym_uses_visible: bool, + callback: Callable[[_OperationBase, bool], None], + ) -> None: ... + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + def __getitem__(self, arg0: str) -> OpView: ... + def __init__(self, arg0: _OperationBase) -> None: ... + def erase(self, operation: _OperationBase) -> None: ... + def insert(self, operation: _OperationBase) -> Attribute: ... + +class TupleType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get_tuple(elements: list[Type], context: Context | None = None) -> TupleType: + """ + Create a Tuple type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def get_type(self, pos: int) -> Type: + """ + Returns the pos-th type in the Tuple type. + """ + @property + def num_types(self) -> int: + """ + Returns the number of types contained in a Tuple. + """ + @property + def typeid(self) -> TypeID: ... + +class TypeAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: Type, context: Context | None = None) -> TypeAttr: + """ + Gets a uniqued Type attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> Type: ... + +class TypeID: + def _CAPICreate(self) -> TypeID: ... + @overload + def __eq__(self, arg0: TypeID) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + +class UnitAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> UnitAttr: + """ + Create a Unit attribute. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class UnrankedMemRefType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + element_type: Type, memory_space: Attribute, loc: Location | None = None + ) -> UnrankedMemRefType: + """ + Create a unranked memref type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def memory_space(self) -> Attribute | None: + """ + Returns the memory space of the given Unranked MemRef type. + """ + @property + def typeid(self) -> TypeID: ... + +class UnrankedTensorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(element_type: Type, loc: Location | None = None) -> UnrankedTensorType: + """ + Create a unranked tensor type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class VectorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + *, + scalable: list | None = None, + scalable_dims: list[int] | None = None, + loc: Location | None = None, + ) -> VectorType: + """ + Create a vector type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def scalable(self) -> bool: ... + @property + def scalable_dims(self) -> list[bool]: ... + @property + def typeid(self) -> TypeID: ... + +class _GlobalDebug: + flag: ClassVar[bool] = False diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi new file mode 100644 index 0000000..1010dad --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -0,0 +1,36 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlir.passmanager +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + + +from . import ir as _ir + +__all__ = [ + "PassManager", +] + +class PassManager: + def __init__(self, context: _ir.Context | None = None) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def enable_ir_printing( + self, + print_before_all: bool = False, + print_after_all: bool = True, + print_module_scope: bool = False, + print_after_change: bool = False, + print_after_failure: bool = False, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + print_generic_op_form: bool = False, + tree_printing_dir_path: str | None = None, + ) -> None: ... + def enable_verifier(self, enable: bool) -> None: ... + @staticmethod + def parse(pipeline: str, context: _ir.Context | None = None) -> PassManager: ... + def run(self, module: _ir._OperationBase) -> None: ... + @property + def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi new file mode 100644 index 0000000..4b82c78 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -0,0 +1,24 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + +from collections.abc import Sequence + +from ._mlir import ir as _ir + +__all__ = [ + "ExecutionEngine", +] + +class ExecutionEngine: + def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def dump_to_object_file(self, file_name: str) -> None: ... + def raw_lookup(self, func_name: str) -> int: ... + def raw_register_runtime(self, name: str, callback: object) -> None: ... + def init() -> None: ... + @property + def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6f37266..7ddc70a 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -2,9 +2,18 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from __future__ import annotations + +from collections.abc import Iterable +from contextlib import contextmanager + from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs._mlir import ( + register_type_caster, + register_value_caster, + globals, +) from ._mlir_libs import ( get_dialect_registry, append_load_on_create_dialect, @@ -12,6 +21,30 @@ from ._mlir_libs import ( ) +@contextmanager +def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]: + """Enables automatic traceback-based locations for MLIR operations. + + Operations created within this context will have their location + automatically set based on the Python call stack. + + Args: + max_depth: Maximum number of frames to include in the location. + If None, the default limit is used. + """ + old_enabled = globals.loc_tracebacks_enabled() + old_limit = globals.loc_tracebacks_frame_limit() + try: + globals.set_loc_tracebacks_frame_limit(max_depth) + if not old_enabled: + globals.set_loc_tracebacks_enabled(True) + yield + finally: + if not old_enabled: + globals.set_loc_tracebacks_enabled(False) + globals.set_loc_tracebacks_frame_limit(old_limit) + + # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind, replace=False): def decorator_builder(func): diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index abe0925..1a0075e 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,7 +1,6 @@ -nanobind>=2.9, <3.0 +nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 -ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" -typing_extensions>=4.12.2 +ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
\ No newline at end of file diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir index a722acb..d362bb6 100644 --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -6,7 +6,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index, // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) { // CHECK: omp.wsloop { - // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) collapse(2) { // CHECK: memref.alloca_scope scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 07d3351..2d33888 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> return %0 : vector<2x1x2xf32> } + +// ----- + +//===----------------------------------------------------------------------===// +// vector.to_elements +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[V0]], %[[V1]] +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// NOTE: We unroll multi-dimensional to_elements ops with pattern +// `UnrollToElements` and then convert the 1-D to_elements ops to llvm. + +// CHECK-LABEL: func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32> +// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32> +// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]] +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index ed664a7..d6e36fa 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc { // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index - // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32 - // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32 - // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64 - // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32 - // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64 - // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32 - // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64 - // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32> - %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> - - // CHECK: %[[C8:.*]] = arith.constant 8 : index - %c8 = arith.constant 8 : index - // CHECK: %[[C16:.*]] = arith.constant 16 : index - %c16 = arith.constant 16 : index - // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32 - // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32> - // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32 - // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32> - // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32 - // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32> - // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32 - // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32> - %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 0f67dc2..0b150e9 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -1,239 +1,73 @@ // RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s gpu.module @test { -// CHECK-LABEL: @load_gather_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 - %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) { - // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} - // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32> - // CHECK: scf.yield %[[VAR9]] : vector<2xf32> - // CHECK: } else { - // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> - // CHECK: scf.yield %[[CST_1]] : vector<2xf32> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @load_gather_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) { - // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} - // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32> - // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32> - // CHECK: scf.yield %[[VAR9]] : f32 - // CHECK: } else { - // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> - // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32> - // CHECK: scf.yield %[[VAR8]] : f32 - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @load_gather_memref_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex> -gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) { +// CHECK-LABEL: @load_gather_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex> - -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) { - // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} - // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> - // CHECK: scf.yield %[[VAR8]] : vector<8xf16> - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> - // CHECK: scf.yield %[[CST_0]] : vector<8xf16> - %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @store_scatter_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32> - %2 = arith.constant dense<2.9>: vector<2xf32> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 - %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> - // CHECK: scf.if %[[VAR4]] { - // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} - // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> - : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @store_scatter_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16> - %2 = arith.constant dense<2.9>: vector<2xf16> - // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} - // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> - : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) { + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16> + // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16> + // CHECK: scf.yield %[[VAR8]] : f16 + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16> + // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16> + // CHECK: scf.yield %[[VAR7]] : f16 + // CHECK: } + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> gpu.return } } // ----- gpu.module @test { -// CHECK-LABEL: @store_scatter_memref_src_value_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> -gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { +// CHECK-LABEL: @store_scatter_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> - // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> + // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> %2 = arith.constant dense<2.9>: vector<1xf32> // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 - %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64 // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} - // CHECK-SAME: : f32, !llvm.ptr<1> - xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> - : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> + // CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1> + // CHECK: } + xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> + : vector<1xf32>, i64, vector<1xindex>, vector<1xi1> gpu.return } } // ----- gpu.module @test { -// CHECK-LABEL: @prefetch_ui64_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: ui64 -gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64 - %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> - // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - gpu.return -} -} -// ----- - -gpu.module @test { -// CHECK-LABEL: @prefetch_memref_src_constant_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> -gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> +// CHECK-LABEL: @prefetch_i64_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex> +gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> - // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> - // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> + // CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64 + // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR4]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>) + xegpu.prefetch %src[%offset] <{offset_align_byte=4, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : i64, vector<1xindex> gpu.return } } @@ -250,12 +84,10 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>) - xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>> + xegpu.prefetch %src[%offset] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : memref<256xf32>, vector<1xindex> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir index b28a8c2..2a2b99f 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -9,9 +9,9 @@ gpu.module @materializecast { gpu.func @materialize_memref(%src: memref<128xf32>) kernel { // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : memref<128xf32>, index, i1 -> f32 gpu.return } } @@ -23,9 +23,9 @@ gpu.module @materializecast { gpu.func @materialize_ui64(%src: ui64) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : ui64, index, i1 -> vector<1xf32> gpu.return } } @@ -37,9 +37,9 @@ gpu.module @materializecast { gpu.func @materialize_ui32(%src: ui32) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32 - %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %offset = arith.constant 0 : index + %mask = arith.constant 1 : i1 + %val = xegpu.load %src[%offset], %mask : ui32, index, i1 -> vector<1xf32> gpu.return } } @@ -52,24 +52,12 @@ gpu.module @materializecast { // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64 + // CHECK: %[[CST1:.*]] = arith.constant dense<true> : vector<1xi1> + // CHECK: %[[VAR3:.*]] = vector.extract %[[CST1]][0] : i1 from vector<1xi1> %offset = arith.constant dense<0> : vector<1xindex> - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %mask = arith.constant dense<1> : vector<1xi1> + %val = xegpu.load %src[%offset], %mask : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1xf32> gpu.return } } -// ----- -gpu.module @materializecast { - // CHECK-LABEL: gpu.func @materialize_single_elem_vector - // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> - gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel { - // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1> - // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %mask = arith.constant dense<1>: vector<1xi1> - %offset = arith.constant dense<0> : vector<1xindex> - %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> - : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32> - gpu.return - } -} diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir deleted file mode 100644 index 6e59414..0000000 --- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s - -gpu.module @update_offset { - // CHECK-LABEL: gpu.func @update_offset - // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> - gpu.func @update_offset(%src: memref<128xf32>) kernel { - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> - %offset = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 - // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 - %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 - %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - , vector<1xindex> - gpu.return - } -} diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index bdbb12b..8f60a07 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -242,3 +242,22 @@ llvm.func @prefetch(%ptr: !llvm.ptr<1>) { llvm.return } +// ----- +// CHECK-LABEL: llvm.func @llvm.load +llvm.func @llvm.load(%a: !llvm.ptr<1>) -> i32 { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 + // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 + %val = llvm.load %a {cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>} : !llvm.ptr<1> -> i32 + llvm.return %val : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @llvm.store +llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 + // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32 + llvm.store %val, %a {cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>} : i32, !llvm.ptr<1> + llvm.return +} diff --git a/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir new file mode 100644 index 0000000..937c855 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s + +// Test cases specifically for sibling fusion. Note that sibling fusion test +// cases also exist in loop-fusion*.mlir. + +// CHECK-LABEL: func @disjoint_stores +func.func @disjoint_stores(%0: memref<8xf32>) { + %alloc_1 = memref.alloc() : memref<16xf32> + // The affine stores below are to different parts of the memrefs. Sibling + // fusion helps improve reuse and is valid. + affine.for %arg2 = 0 to 8 { + %2 = affine.load %0[%arg2] : memref<8xf32> + affine.store %2, %alloc_1[%arg2] : memref<16xf32> + } + affine.for %arg2 = 0 to 8 { + %2 = affine.load %0[%arg2] : memref<8xf32> + %3 = arith.negf %2 : f32 + affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32> + } + // CHECK: affine.for + // CHECK-NOT: affine.for + return +} diff --git a/mlir/test/Dialect/LLVMIR/mmra.mlir b/mlir/test/Dialect/LLVMIR/mmra.mlir new file mode 100644 index 0000000..95da966 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/mmra.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -split-input-file --verify-roundtrip --mlir-print-local-scope | FileCheck %s + +// CHECK-LABEL: llvm.func @native +// CHECK: llvm.load +// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar"> +// CHECK: llvm.fence +// CHECK-SAME: llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #llvm.mmra_tag<"foo":"bar">] +// CHECK: llvm.store +// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar"> + +#mmra_tag = #llvm.mmra_tag<"foo":"bar"> + +llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) { + %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32 + llvm.fence syncscope("workgroup-one-as") release + {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]} + llvm.store %0, %y {llvm.mmra = #llvm.mmra_tag<"foo":"bar">} : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @foreign_op +// CHECK: rocdl.load.to.lds +// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"fake":"example"> +llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) { + rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1> + llvm.return +} diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 986c384..763f41c 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -159,6 +159,29 @@ func.func @no_loops(%lb : index, %ub : index, %step : index) { // ----- +func.func @collapse_size(%lb : index, %ub : index, %step : index) { + omp.wsloop { + // expected-error@+1 {{collapse value is larger than the number of loops}} + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) collapse(4) { + omp.yield + } + } +} + +// ----- + +func.func @tiles_length(%lb : index, %ub : index, %step : index) { + omp.wsloop { + // expected-error@+1 {{op too few canonical loops for tile dimensions}} + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) tiles(2, 4) { + omp.yield + } + } +} + + +// ----- + func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) { // expected-error @below {{expected '{'}} omp.wsloop nowait inclusive { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 3c2e0a3..60b1f61 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -376,6 +376,60 @@ func.func @omp_loop_nest_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, return } +// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse +func.func @omp_loop_nest_pretty_multiple_collapse(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref<?xi32> + memref.store %iv2, %data1[%2] : memref<?xi32> + omp.yield + } + } + + return +} + +// CHECK-LABEL: omp_loop_nest_pretty_multiple_tiles +func.func @omp_loop_nest_pretty_multiple_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) tiles(5, 10) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) tiles(5, 10) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref<?xi32> + memref.store %iv2, %data1[%2] : memref<?xi32> + omp.yield + } + } + + return +} + +// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse_tiles +func.func @omp_loop_nest_pretty_multiple_collapse_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32, + %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () { + + omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) tiles(5, 10) + omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) tiles(5, 10) { + %1 = "test.payload"(%iv1) : (i32) -> (index) + %2 = "test.payload"(%iv2) : (i32) -> (index) + memref.store %iv1, %data1[%1] : memref<?xi32> + memref.store %iv2, %data1[%2] : memref<?xi32> + omp.yield + } + } + + return +} + // CHECK-LABEL: omp_wsloop func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32>, %linear_var : i32, %chunk_var : i32) -> () { diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 5e8bfd0..fe697c8 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> return %1 : vector<2x2xf32> } + +// ----- + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] +// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg new file mode 100644 index 0000000..3e9e8f8 --- /dev/null +++ b/mlir/test/Dialect/Vector/lit.local.cfg @@ -0,0 +1,2 @@ +# Skip the directory with input TD sequences. +config.excludes = ["td"] diff --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir new file mode 100644 index 0000000..40a90a3 --- /dev/null +++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir @@ -0,0 +1,11 @@ +module attributes {transform.with_named_sequence} { + transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.unroll_to_elements + } : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir new file mode 100644 index 0000000..9ec0d76 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s +// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \ +// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s + +// CHECK-LABEL: func.func @to_elements_1d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32> +// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32> +// CHECK: return %[[RES]]#0, %[[RES]]#1 +func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) { + %0:2 = vector.to_elements %arg0 : vector<2xf32> + return %0#0, %0#1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func.func @to_elements_2d( +// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32> +// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32> +// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32> +// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1 +func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) { + %0:4 = vector.to_elements %arg0 : vector<2x2xf32> + return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 8750582..bb76392 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre // CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size // CHECK-PROP-NOT: vector.broadcast // CHECK-PROP: vector.step : vector<64xindex> + +// ----- + +func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) { + %laneid = gpu.lane_id + %c0 = arith.constant 0 : index + + gpu.warp_execute_on_lane_0(%laneid)[32] { + %seq = vector.step : vector<32xindex> + scf.if %pred { + vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex> + } + gpu.yield + } + return +} + +// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute( +// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1 +// CHECK-PROP: scf.if %[[ARG1]] { +// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) { +// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>): +// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex> + +// ----- + +func.func @warp_scf_if_distribute(%pred : i1) { + %laneid = gpu.lane_id + %c0 = arith.constant 0 : index + + %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> { + %seq1 = vector.step : vector<32xindex> + %seq2 = arith.constant dense<2> : vector<32xindex> + %0 = scf.if %pred -> (vector<32xf32>) { + %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>) + scf.yield %1 : vector<32xf32> + } else { + %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>) + scf.yield %2 : vector<32xf32> + } + gpu.yield %0 : vector<32xf32> + } + "some_use"(%0) : (vector<1xf32>) -> () + + return +} + +// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute( +// CHECK-PROP-SAME: %[[ARG0:.+]]: i1 +// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex> +// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id +// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex> +// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) { +// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) { +// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>): +// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32> +// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32> +// CHECK-PROP: } else { +// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) { +// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32> +// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> () +// CHECK-PROP: return +// CHECK-PROP: } diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index a39aa90..60acea0 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -339,6 +339,63 @@ gpu.module @test { } // ----- +// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}}, +// CHECK-SAME: %[[PREDICATE:.*]]: i1) { +// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16> +// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1> +// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) { +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16> +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[DEFAULT]] : vector<8xf16> +// CHECK-NEXT: } +// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +gpu.module @test { + gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) { + %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> + %loaded = scf.if %pred -> (vector<16x8xf16>) { + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + scf.yield %3 : vector<16x8xf16> + } else { + %3 = arith.constant { + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> + } dense<12.> : vector<16x8xf16> + scf.yield %3 : vector<16x8xf16> + } { layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> } + xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + gpu.return + } +} + +// ----- +// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) { +// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1> +// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1 +// CHECK: scf.if %[[PREDICATE]] { +// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16> +// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1> +// CHECK-NEXT: } +gpu.module @test { + gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) { + %pred = llvm.mlir.poison : i1 + %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1> + %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex> + scf.if %pred { + %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> { + layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> + } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> + xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> + } + gpu.return + } +} + +// ----- // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) { // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1> // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> diff --git a/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll new file mode 100644 index 0000000..5e1ed37 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll @@ -0,0 +1,22 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-DAG: #[[$MMRA0:.+]] = #llvm.mmra_tag<"foo":"bar"> +; CHECK-DAG: #[[$MMRA1:.+]] = #llvm.mmra_tag<"amdgpu-synchronize-as":"local"> + +; CHECK-LABEL: llvm.func @native +define void @native(ptr %x, ptr %y) { + ; CHECK: llvm.load + ; CHECK-SAME: llvm.mmra = #[[$MMRA0]] + %v = load i32, ptr %x, align 4, !mmra !0 + ; CHECK: llvm.fence + ; CHECK-SAME: llvm.mmra = [#[[$MMRA1]], #[[$MMRA0]]] + fence syncscope("workgroup-one-as") release, !mmra !2 + ; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}} + store i32 %v, ptr %y, align 4, !mmra !3 + ret void +} + +!0 = !{!"foo", !"bar"} +!1 = !{!"amdgpu-synchronize-as", !"local"} +!2 = !{!1, !0} +!3 = !{} diff --git a/mlir/test/Target/LLVMIR/mmra.mlir b/mlir/test/Target/LLVMIR/mmra.mlir new file mode 100644 index 0000000..5864e0e --- /dev/null +++ b/mlir/test/Target/LLVMIR/mmra.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// CHECK-LABEL: define void @native +// CHECK: load +// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]] +// CHECK: fence +// CHECK-SAME: !mmra ![[MMRA1:[0-9]+]] +// CHECK: store {{.*}}, align 4{{$}} + +#mmra_tag = #llvm.mmra_tag<"foo":"bar"> + +llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) { + %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32 + llvm.fence syncscope("workgroup-one-as") release + {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]} + llvm.store %0, %y {llvm.mmra = []} : i32, !llvm.ptr + llvm.return +} + +// Actual MMRA metadata +// CHECK-DAG: ![[MMRA0]] = !{!"foo", !"bar"} +// CHECK-DAG: ![[MMRA_PART0:[0-9]+]] = !{!"amdgpu-synchronize-as", !"local"} +// CHECK-DAG: ![[MMRA1]] = !{![[MMRA_PART0]], ![[MMRA0]]} + +// ----- + +// CHECK-LABEL: define void @foreign_op +// CHECK: call void @llvm.amdgcn.load.to.lds +// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]] +llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) { + rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1> + llvm.return +} + +// CHECK: ![[MMRA0]] = !{!"fake", !"example"} diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir index b42e387..d84641f 100644 --- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir @@ -9,7 +9,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo %loop_lb = llvm.mlir.constant(0 : i32) : i32 %loop_step = llvm.mlir.constant(1 : index) : i32 omp.wsloop { - omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) { + omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) collapse(2) { %1 = llvm.add %arg1, %arg2 : i32 %2 = llvm.mul %arg2, %loop_ub overflow<nsw> : i32 %3 = llvm.add %arg1, %2 :i32 diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 3f4dcd5..27210bc 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -698,7 +698,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) { // CHECK-LABEL: @simd_simple_multiple llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 // The form of the emitted IR is controlled by OpenMPIRBuilder and // tested there. Just check that the right metadata is added and collapsed @@ -736,7 +736,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64 // CHECK-LABEL: @simd_simple_multiple_simdlen llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd simdlen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 // The form of the emitted IR is controlled by OpenMPIRBuilder and // tested there. Just check that the right metadata is added. @@ -760,7 +760,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l // CHECK-LABEL: @simd_simple_multiple_safelen llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd safelen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -779,7 +779,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l // CHECK-LABEL: @simd_simple_multiple_simdlen_safelen llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { omp.simd simdlen(1) safelen(2) { - omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) { %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32 %4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32 @@ -1177,7 +1177,7 @@ llvm.func @collapse_wsloop( // CHECK: store i32 %[[TOTAL_SUB_1]], ptr // CHECK: call void @__kmpc_for_static_init_4u omp.wsloop { - omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) { + omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) { %31 = llvm.load %20 : !llvm.ptr -> i32 %32 = llvm.add %31, %arg0 : i32 %33 = llvm.add %32, %arg1 : i32 @@ -1239,7 +1239,7 @@ llvm.func @collapse_wsloop_dynamic( // CHECK: store i32 %[[TOTAL]], ptr // CHECK: call void @__kmpc_dispatch_init_4u omp.wsloop schedule(dynamic) { - omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) { + omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) { %31 = llvm.load %20 : !llvm.ptr -> i32 %32 = llvm.add %31, %arg0 : i32 %33 = llvm.add %32, %arg1 : i32 diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir index a3dd0b6..112d923 100644 --- a/mlir/test/Target/LLVMIR/xevm.mlir +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -19,3 +19,35 @@ module { // CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} // CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} +// ----- +module { + // CHECK-LABEL: define i32 @load(ptr addrspace(1) + // CHECK-SAME: %[[ARG0:.*]]) { + llvm.func @load(%arg0: !llvm.ptr<1>) -> i32 { + // CHECK: load i32, ptr addrspace(1) %[[ARG0]], align 4, + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + %0 = llvm.load %arg0 {xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} : !llvm.ptr<1> -> i32 + llvm.return %0 : i32 + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + +// ----- +module { + // CHECK-LABEL: define void @store(ptr addrspace(1) + // CHECK-SAME: %[[ARG0:.*]], i32 %[[ARG1:.*]]) { + llvm.func @store(%arg0: !llvm.ptr<1>, %arg1: i32) { + // CHECK: store i32 %[[ARG1]], ptr addrspace(1) %[[ARG0]], align 4, + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.store %arg1, %arg0 {xevm.DecorationCacheControl = [[6443 : i32, 0 : i32, 2 : i32, 0 : i32], [6443 : i32, 1 : i32, 2 : i32, 0 : i32]]} : i32, !llvm.ptr<1> + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6443, i32 0, i32 2, i32 0} +// CHECK: ![[DECO3]] = !{i32 6443, i32 1, i32 2, i32 0} + diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index c05e15f..f2adca6 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -35,7 +35,6 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h index b414b47..4201ade 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.h +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -33,7 +33,6 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 231400e..5564264 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,7 +23,6 @@ include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" @@ -2322,10 +2321,10 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op", } //===----------------------------------------------------------------------===// -// Test CopyOpInterface +// Copy Operation Test //===----------------------------------------------------------------------===// -def CopyOp : TEST_Op<"copy", [CopyOpInterface]> { +def CopyOp : TEST_Op<"copy", []> { let description = [{ Represents a copy operation. }]; diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index bb1598e..d6596cd 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements } }; +struct TestUnrollVectorToElements + : public PassWrapper<TestUnrollVectorToElements, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements) + + StringRef getArgument() const final { + return "test-unroll-vector-to-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for to_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<func::FuncDialect, vector::VectorDialect>(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToElementsLoweringPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns, OperationPass<func::FuncOp>> { @@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() { PassRegistration<TestUnrollVectorFromElements>(); + PassRegistration<TestUnrollVectorToElements>(); + PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); PassRegistration<TestVectorEmulateMaskedLoadStore>(); diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index f213f50..87b41f9 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -543,3 +543,12 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>; // REDUCE_EXC-NOT: NS::AOp declarations // REDUCE_EXC-LABEL: NS::BOp declarations + +// CHECK-LABEL: _TypeInferredPropOp declarations +def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [ + AllTypesMatch<["value", "result"]> + ]> { + let arguments = (ins Property<"unsigned">:$prop, AnyType:$value); + let results = (outs AnyType:$result); + let hasCustomAssemblyFormat = 1; +} diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 5a648fe..28902b0 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -48,6 +48,8 @@ def non_configurable_patterns(): vector.ApplyLowerGatherPatternsOp() # CHECK: transform.apply_patterns.vector.unroll_from_elements vector.ApplyUnrollFromElementsPatternsOp() + # CHECK: transform.apply_patterns.vector.unroll_to_elements + vector.ApplyUnrollToElementsPatternsOp() # CHECK: transform.apply_patterns.vector.lower_scan vector.ApplyLowerScanPatternsOp() # CHECK: transform.apply_patterns.vector.lower_shape_cast diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index c94f96e..50c4210 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -64,12 +64,12 @@ def testCustomPass(): """ ) - def custom_pass_1(op): + def custom_pass_1(op, pass_): print("hello from pass 1!!!", file=sys.stderr) class CustomPass2: - def __call__(self, m): - apply_patterns_and_fold_greedily(m, frozen) + def __call__(self, op, pass_): + apply_patterns_and_fold_greedily(op, frozen) custom_pass_2 = CustomPass2() @@ -86,3 +86,17 @@ def testCustomPass(): # CHECK: llvm.mul pm.add("convert-arith-to-llvm") pm.run(module) + + # test signal_pass_failure + def custom_pass_that_fails(op, pass_): + print("hello from pass that fails") + pass_.signal_pass_failure() + + pm = PassManager("any") + pm.add(custom_pass_that_fails, "CustomPassThatFails") + # CHECK: hello from pass that fails + # CHECK: caught exception: Failure while executing pass pipeline + try: + pm.run(module) + except Exception as e: + print(f"caught exception: {e}") diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index 10d602f..712237b 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -10,8 +10,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "llvm/Support/LSP/Protocol.h" using namespace mlir; @@ -37,8 +37,8 @@ int main(int argc, char **argv) { // Returns the registry, except in testing mode when the URI contains // "-disable-lsp-registration". Testing for/example of registering dialects // based on URI. - auto registryFn = [®istry, - &empty](const lsp::URIForFile &uri) -> DialectRegistry & { + auto registryFn = [®istry, &empty]( + const llvm::lsp::URIForFile &uri) -> DialectRegistry & { (void)empty; #ifdef MLIR_INCLUDE_TESTS if (uri.uri().contains("-disable-lsp-registration")) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 8ea4eb7..4fdde76 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3849,9 +3849,9 @@ void OpEmitter::genTypeInterfaceMethods() { const InferredResultType &infer = op.getInferredResultType(i); if (!infer.isArg()) continue; - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + Operator::OperandAttrOrProp arg = + op.getArgToOperandAttrOrProp(infer.getIndex()); + if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) { maxAccessedIndex = std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); } @@ -3877,17 +3877,16 @@ void OpEmitter::genTypeInterfaceMethods() { if (infer.isArg()) { // If this is an operand, just index into operand list to access the // type. - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + Operator::OperandAttrOrProp arg = + op.getArgToOperandAttrOrProp(infer.getIndex()); + if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) { typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + "].getType()") .str(); // If this is an attribute, index into the attribute dictionary. - } else { - auto *attr = - cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex())); + } else if (auto *attr = dyn_cast<NamedAttribute *>( + op.getArg(arg.operandOrAttributeIndex()))) { body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx << " = "; if (op.getDialect().usePropertiesForAttributes()) { @@ -3907,6 +3906,9 @@ void OpEmitter::genTypeInterfaceMethods() { typeStr = ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()") .str(); + } else { + llvm::PrintFatalError(&op.getDef(), + "Properties cannot be used for type inference"); } } else if (std::optional<StringRef> builder = op.getResult(infer.getResultIndex()) diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp index 9ea6560..d7b442f 100644 --- a/mlir/unittests/Bytecode/BytecodeTest.cpp +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -72,6 +72,8 @@ TEST(Bytecode, MultiModuleWithResource) { ASSERT_TRUE(module); // Write the module to bytecode. + // Ensure that reserveExtraSpace is called with the size needed to write the + // bytecode buffer. MockOstream ostream; EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) { ostream.buffer = std::make_unique<std::byte[]>(space); @@ -128,31 +130,28 @@ TEST(Bytecode, AlignmentFailure) { ASSERT_TRUE(module); // Write the module to bytecode. - MockOstream ostream; - EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) { - ostream.buffer = std::make_unique<std::byte[]>(space); - ostream.size = space; - }); + std::string serializedBytecode; + llvm::raw_string_ostream ostream(serializedBytecode); ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream))); // Create copy of buffer which is not aligned to requested resource alignment. - std::string buffer((char *)ostream.buffer.get(), - (char *)ostream.buffer.get() + ostream.size); + std::string buffer(serializedBytecode); size_t bufferSize = buffer.size(); - // Increment into the buffer until we get to a power of 2 alignment that is - // not 32 bit aligned. + // Increment into the buffer until we get to an address that is 2 byte aligned + // but not 32 byte aligned. size_t pad = 0; while (true) { - if (llvm::isAddrAligned(Align(2), &buffer[pad]) && - !llvm::isAddrAligned(Align(32), &buffer[pad])) + if (llvm::isAddrAligned(Align(2), buffer.data() + pad) && + !llvm::isAddrAligned(Align(32), buffer.data() + pad)) break; pad++; - buffer.reserve(bufferSize + pad); + // Pad the beginning of the buffer to push the start point to an unaligned + // value. + buffer.insert(0, 1, ' '); } - buffer.insert(0, pad, ' '); StringRef alignedBuffer(buffer.data() + pad, bufferSize); // Attach a diagnostic handler to get the error message. diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index c5f0d7e..89332bc 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -18,7 +18,6 @@ add_subdirectory(Support) add_subdirectory(Rewrite) add_subdirectory(TableGen) add_subdirectory(Target) -add_subdirectory(Tools) add_subdirectory(Transforms) if(MLIR_ENABLE_EXECUTION_ENGINE) diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt deleted file mode 100644 index a97588d..0000000 --- a/mlir/unittests/Tools/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(lsp-server-support) diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt deleted file mode 100644 index c539c9b..0000000 --- a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_unittest(MLIRLspServerSupportTests - Protocol.cpp - Transport.cpp -) -mlir_target_link_libraries(MLIRLspServerSupportTests - PRIVATE - MLIRLspServerSupportLib) diff --git a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp b/mlir/unittests/Tools/lsp-server-support/Protocol.cpp deleted file mode 100644 index 04d7b2f..0000000 --- a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp +++ /dev/null @@ -1,51 +0,0 @@ -//===- Protocol.cpp - LSP JSON protocol unit tests ------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Protocol.h" - -#include "gtest/gtest.h" - -using namespace mlir; -using namespace mlir::lsp; -using namespace testing; - -namespace { - -TEST(ProtocolTest, DiagnosticTagPresent) { - Diagnostic diagnostic; - diagnostic.tags.push_back(DiagnosticTag::Unnecessary); - - llvm::json::Value json = toJSON(diagnostic); - const llvm::json::Object *o = json.getAsObject(); - const llvm::json::Array *v = o->get("tags")->getAsArray(); - EXPECT_EQ(*v, llvm::json::Array{1}); - - Diagnostic parsed; - llvm::json::Path::Root root = llvm::json::Path::Root(); - bool success = fromJSON(json, parsed, llvm::json::Path(root)); - EXPECT_TRUE(success); - ASSERT_EQ(parsed.tags.size(), (size_t)1); - EXPECT_EQ(parsed.tags.at(0), DiagnosticTag::Unnecessary); -} - -TEST(ProtocolTest, DiagnosticTagNotPresent) { - Diagnostic diagnostic; - - llvm::json::Value json = toJSON(diagnostic); - const llvm::json::Object *o = json.getAsObject(); - const llvm::json::Value *v = o->get("tags"); - EXPECT_EQ(v, nullptr); - - Diagnostic parsed; - llvm::json::Path::Root root = llvm::json::Path::Root(); - bool success = fromJSON(json, parsed, llvm::json::Path(root)); - EXPECT_TRUE(success); - EXPECT_TRUE(parsed.tags.empty()); -} - -} // namespace diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp deleted file mode 100644 index 92581bd..0000000 --- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp +++ /dev/null @@ -1,205 +0,0 @@ -//===- Transport.cpp - LSP JSON transport unit tests ----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/Support/FileSystem.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using namespace mlir; -using namespace mlir::lsp; -using namespace testing; - -namespace { - -TEST(TransportTest, SendReply) { - std::string out; - llvm::raw_string_ostream os(out); - JSONTransport transport(nullptr, os); - MessageHandler handler(transport); - - transport.reply(1989, nullptr); - EXPECT_THAT(out, HasSubstr("\"id\":1989")); - EXPECT_THAT(out, HasSubstr("\"result\":null")); -} - -class TransportInputTest : public Test { - llvm::SmallVector<char> inputPath; - std::FILE *in = nullptr; - std::string output = ""; - llvm::raw_string_ostream os; - std::optional<JSONTransport> transport = std::nullopt; - std::optional<MessageHandler> messageHandler = std::nullopt; - -protected: - TransportInputTest() : os(output) {} - - void SetUp() override { - std::error_code ec = - llvm::sys::fs::createTemporaryFile("lsp-unittest", "json", inputPath); - ASSERT_FALSE(ec) << "Could not create temporary file: " << ec.message(); - - in = std::fopen(inputPath.data(), "r"); - ASSERT_TRUE(in) << "Could not open temporary file: " - << std::strerror(errno); - transport.emplace(in, os, JSONStreamStyle::Delimited); - messageHandler.emplace(*transport); - } - - void TearDown() override { - EXPECT_EQ(std::fclose(in), 0) - << "Could not close temporary file FD: " << std::strerror(errno); - std::error_code ec = - llvm::sys::fs::remove(inputPath, /*IgnoreNonExisting=*/false); - EXPECT_FALSE(ec) << "Could not remove temporary file '" << inputPath.data() - << "': " << ec.message(); - } - - void writeInput(StringRef buffer) { - std::error_code ec; - llvm::raw_fd_ostream os(inputPath.data(), ec); - ASSERT_FALSE(ec) << "Could not write to '" << inputPath.data() - << "': " << ec.message(); - os << buffer; - os.close(); - } - - StringRef getOutput() const { return output; } - MessageHandler &getMessageHandler() { return *messageHandler; } - - void runTransport() { - bool gotEOF = false; - llvm::Error err = llvm::handleErrors( - transport->run(*messageHandler), [&](const llvm::ECError &ecErr) { - gotEOF = ecErr.convertToErrorCode() == std::errc::io_error; - }); - llvm::consumeError(std::move(err)); - EXPECT_TRUE(gotEOF); - } -}; - -TEST_F(TransportInputTest, RequestWithInvalidParams) { - struct Handler { - void onMethod(const TextDocumentItem ¶ms, - mlir::lsp::Callback<TextDocumentIdentifier> callback) {} - } handler; - getMessageHandler().method("invalid-params-request", &handler, - &Handler::onMethod); - - writeInput("{\"jsonrpc\":\"2.0\",\"id\":92," - "\"method\":\"invalid-params-request\",\"params\":{}}\n"); - runTransport(); - EXPECT_THAT(getOutput(), HasSubstr("error")); - EXPECT_THAT(getOutput(), HasSubstr("missing value at (root).uri")); -} - -TEST_F(TransportInputTest, NotificationWithInvalidParams) { - // JSON parsing errors are only reported via error logging. As a result, this - // test can't make any expectations -- but it prints the output anyway, by way - // of demonstration. - Logger::setLogLevel(Logger::Level::Error); - - struct Handler { - void onNotification(const TextDocumentItem ¶ms) {} - } handler; - getMessageHandler().notification("invalid-params-notification", &handler, - &Handler::onNotification); - - writeInput("{\"jsonrpc\":\"2.0\",\"method\":\"invalid-params-notification\"," - "\"params\":{}}\n"); - runTransport(); -} - -TEST_F(TransportInputTest, MethodNotFound) { - writeInput("{\"jsonrpc\":\"2.0\",\"id\":29,\"method\":\"ack\"}\n"); - runTransport(); - EXPECT_THAT(getOutput(), HasSubstr("\"id\":29")); - EXPECT_THAT(getOutput(), HasSubstr("\"error\"")); - EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\"")); -} - -TEST_F(TransportInputTest, OutgoingNotification) { - auto notifyFn = getMessageHandler().outgoingNotification<CompletionList>( - "outgoing-notification"); - notifyFn(CompletionList{}); - EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\"")); -} - -TEST_F(TransportInputTest, ResponseHandlerNotFound) { - // Unhandled responses are only reported via error logging. As a result, this - // test can't make any expectations -- but it prints the output anyway, by way - // of demonstration. - Logger::setLogLevel(Logger::Level::Error); - writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n"); - runTransport(); -} - -TEST_F(TransportInputTest, OutgoingRequest) { - // Make some outgoing requests. - int responseCallbackInvoked = 0; - auto callFn = - getMessageHandler().outgoingRequest<CompletionList, CompletionContext>( - "outgoing-request", - [&responseCallbackInvoked](const llvm::json::Value &id, - llvm::Expected<CompletionContext> result) { - // Make expectations on the expected response. - EXPECT_EQ(id, 83); - ASSERT_TRUE((bool)result); - EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked); - responseCallbackInvoked += 1; - }); - callFn({}, 82); - callFn({}, 83); - callFn({}, 84); - EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\"")); - EXPECT_EQ(responseCallbackInvoked, 0); - - // One of the requests receives a response. The message handler handles this - // response by invoking the callback from above. Subsequent responses with the - // same ID are ignored. - writeInput( - "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n" - "// -----\n" - "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n"); - runTransport(); - EXPECT_EQ(responseCallbackInvoked, 1); -} - -TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) { - // Make an outgoing request that expects a failure response. - bool responseCallbackInvoked = false; - auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>( - "outgoing-request-json-parse-failure", - [&responseCallbackInvoked](const llvm::json::Value &id, - llvm::Expected<Position> result) { - llvm::Error err = result.takeError(); - EXPECT_EQ(id, 109); - ASSERT_TRUE((bool)err); - EXPECT_THAT(debugString(err), - HasSubstr("failed to decode " - "reply:outgoing-request-json-parse-failure(109) " - "response: missing value at (root).character")); - llvm::consumeError(std::move(err)); - responseCallbackInvoked += 1; - }); - callFn({}, 109); - EXPECT_EQ(responseCallbackInvoked, 0); - - // The request receives multiple responses, but only the first one triggers - // the response callback. The first response has erroneous JSON that causes a - // parse failure. - writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n" - "// -----\n" - "{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3," - "\"character\":2}}\n"); - runTransport(); - EXPECT_EQ(responseCallbackInvoked, 1); -} -} // namespace |