aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/cmake/modules/AddMLIRPython.cmake126
-rw-r--r--mlir/docs/BytecodeFormat.md4
-rw-r--r--mlir/docs/DialectConversion.md10
-rw-r--r--mlir/docs/Dialects/IRDL.md123
-rw-r--r--mlir/examples/standalone/python/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h7
-rw-r--r--mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h11
-rw-r--r--mlir/include/mlir/Conversion/Passes.td1
-rw-r--r--mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h4
-rw-r--r--mlir/include/mlir/Dialect/Affine/Analysis/Utils.h14
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h1
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td2
-rw-r--r--mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td41
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td1
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/Linalg.h1
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td6
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRef.h1
-rw-r--r--mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td3
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h2
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td33
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td15
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td12
-rw-r--r--mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td11
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h6
-rw-r--r--mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h6
-rw-r--r--mlir/include/mlir/Interfaces/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Interfaces/CopyOpInterface.h21
-rw-r--r--mlir/include/mlir/Interfaces/CopyOpInterface.td38
-rw-r--r--mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h20
-rw-r--r--mlir/include/mlir/Interfaces/VectorInterfaces.td6
-rw-r--r--mlir/include/mlir/TableGen/Operator.h23
-rw-r--r--mlir/include/mlir/Tools/lsp-server-support/Logging.h65
-rw-r--r--mlir/include/mlir/Tools/lsp-server-support/Protocol.h1257
-rw-r--r--mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h12
-rw-r--r--mlir/include/mlir/Tools/lsp-server-support/Transport.h283
-rw-r--r--mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h6
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h17
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp37
-rw-r--r--mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp22
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp10
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp7
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp9
-rw-r--r--mlir/lib/Bindings/Python/Pass.cpp11
-rw-r--r--mlir/lib/Bytecode/Reader/BytecodeReader.cpp38
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp24
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp6
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp1
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp170
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp292
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/Utils.cpp106
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp2
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp66
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp4
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp6
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp53
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp265
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp49
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp35
-rw-r--r--mlir/lib/Interfaces/CMakeLists.txt2
-rw-r--r--mlir/lib/Interfaces/CopyOpInterface.cpp18
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp25
-rw-r--r--mlir/lib/TableGen/Operator.cpp18
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp39
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp44
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp45
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp4
-rw-r--r--mlir/lib/Tools/lsp-server-support/CMakeLists.txt8
-rw-r--r--mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp5
-rw-r--r--mlir/lib/Tools/lsp-server-support/Logging.cpp41
-rw-r--r--mlir/lib/Tools/lsp-server-support/Protocol.cpp1043
-rw-r--r--mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp4
-rw-r--r--mlir/lib/Tools/lsp-server-support/Transport.cpp369
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt3
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp31
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/LSPServer.h6
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp133
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/MLIRServer.h22
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp8
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/Protocol.cpp7
-rw-r--r--mlir/lib/Tools/mlir-lsp-server/Protocol.h6
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt3
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp29
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h6
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp10
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp561
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h26
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp1
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h4
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt1
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp25
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/LSPServer.h6
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp8
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp162
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h15
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp84
-rw-r--r--mlir/python/CMakeLists.txt58
-rw-r--r--mlir/python/mlir/_mlir_libs/.gitignore2
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi12
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi63
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi142
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi25
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/ir.pyi2846
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi36
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi24
-rw-r--r--mlir/python/mlir/ir.py35
-rw-r--r--mlir/python/requirements.txt5
-rw-r--r--mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir2
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir42
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir32
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir244
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir38
-rw-r--r--mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir25
-rw-r--r--mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir19
-rw-r--r--mlir/test/Dialect/Affine/loop-fusion-sibling.mlir23
-rw-r--r--mlir/test/Dialect/LLVMIR/mmra.mlir29
-rw-r--r--mlir/test/Dialect/OpenMP/invalid.mlir23
-rw-r--r--mlir/test/Dialect/OpenMP/ops.mlir54
-rw-r--r--mlir/test/Dialect/Vector/linearize.mlir23
-rw-r--r--mlir/test/Dialect/Vector/lit.local.cfg2
-rw-r--r--mlir/test/Dialect/Vector/td/unroll-elements.mlir11
-rw-r--r--mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir26
-rw-r--r--mlir/test/Dialect/Vector/vector-warp-distribute.mlir69
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute.mlir57
-rw-r--r--mlir/test/Target/LLVMIR/Import/metadata-mmra.ll22
-rw-r--r--mlir/test/Target/LLVMIR/mmra.mlir35
-rw-r--r--mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/openmp-llvm.mlir12
-rw-r--r--mlir/test/Target/LLVMIR/xevm.mlir32
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td5
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp24
-rw-r--r--mlir/test/mlir-tblgen/op-decl-and-defs.td9
-rw-r--r--mlir/test/python/dialects/transform_vector_ext.py2
-rw-r--r--mlir/test/python/python_pass.py20
-rw-r--r--mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp6
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp20
-rw-r--r--mlir/unittests/Bytecode/BytecodeTest.cpp25
-rw-r--r--mlir/unittests/CMakeLists.txt1
-rw-r--r--mlir/unittests/Tools/CMakeLists.txt1
-rw-r--r--mlir/unittests/Tools/lsp-server-support/CMakeLists.txt7
-rw-r--r--mlir/unittests/Tools/lsp-server-support/Protocol.cpp51
-rw-r--r--mlir/unittests/Tools/lsp-server-support/Transport.cpp205
150 files changed, 5848 insertions, 4860 deletions
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 85c8027..2b88355 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -99,62 +99,6 @@ function(declare_mlir_python_sources name)
endif()
endfunction()
-# Function: generate_type_stubs
-# Turns on automatic type stub generation (via nanobind's stubgen) for extension modules.
-# Arguments:
-# MODULE_NAME: The name of the extension module as specified in declare_mlir_python_extension.
-# DEPENDS_TARGET: The dso target corresponding to the extension module
-# (e.g., something like StandalonePythonModules.extension._standaloneDialectsNanobind.dso)
-# MLIR_DEPENDS_TARGET: The dso target corresponding to the main/core extension module
-# (e.g., something like StandalonePythonModules.extension._mlir.dso)
-# OUTPUT_DIR: The root output directory to emit the type stubs into.
-# Outputs:
-# NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on.
-function(generate_type_stubs MODULE_NAME DEPENDS_TARGET MLIR_DEPENDS_TARGET OUTPUT_DIR)
- cmake_parse_arguments(ARG
- ""
- ""
- "OUTPUTS"
- ${ARGN})
- if(EXISTS ${nanobind_DIR}/../src/stubgen.py)
- set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py")
- elseif(EXISTS ${nanobind_DIR}/../stubgen.py)
- set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py")
- else()
- message(FATAL_ERROR "generate_type_stubs(): could not locate 'stubgen.py'!")
- endif()
- file(REAL_PATH "${NB_STUBGEN}" NB_STUBGEN)
-
- set(_module "${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs.${MODULE_NAME}")
- file(REAL_PATH "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/.." _import_path)
-
- set(NB_STUBGEN_CMD
- "${Python_EXECUTABLE}"
- "${NB_STUBGEN}"
- --module
- "${_module}"
- -i
- "${_import_path}"
- --recursive
- --include-private
- --output-dir
- "${OUTPUT_DIR}")
-
- list(TRANSFORM ARG_OUTPUTS PREPEND "${OUTPUT_DIR}/" OUTPUT_VARIABLE _generated_type_stubs)
- add_custom_command(
- OUTPUT ${_generated_type_stubs}
- COMMAND ${NB_STUBGEN_CMD}
- WORKING_DIRECTORY "${CMAKE_CURRENT_FUNCTION_LIST_DIR}"
- DEPENDS
- "${MLIR_DEPENDS_TARGET}.extension._mlir.dso"
- "${MLIR_DEPENDS_TARGET}.sources.MLIRPythonSources.Core.Python"
- "${DEPENDS_TARGET}"
- )
- set(_name "MLIRPythonModuleStubs_${_module}")
- add_custom_target("${_name}" ALL DEPENDS ${_generated_type_stubs})
- set(NB_STUBGEN_CUSTOM_TARGET "${_name}" PARENT_SCOPE)
-endfunction()
-
# Function: declare_mlir_python_extension
# Declares a buildable python extension from C++ source files. The built
# module is considered a python source file and included as everything else.
@@ -171,12 +115,11 @@ endfunction()
# on. These will be collected for all extensions and put into an
# aggregate dylib that is linked against.
# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind.
-# GENERATE_TYPE_STUBS: List of generated type stubs expected from stubgen relative to _mlir_libs.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
""
"ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY"
- "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS;GENERATE_TYPE_STUBS"
+ "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
if(NOT ARG_ROOT_DIR)
@@ -192,13 +135,12 @@ function(declare_mlir_python_extension name)
set_target_properties(${name} PROPERTIES
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
- EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY;mlir_python_GENERATE_TYPE_STUBS"
+ EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY"
mlir_python_SOURCES_TYPE extension
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}"
- mlir_python_GENERATE_TYPE_STUBS "${ARG_GENERATE_TYPE_STUBS}"
)
# Set the interface source and link_libs properties of the target
@@ -301,32 +243,6 @@ function(add_mlir_python_modules name)
)
add_dependencies(${modules_target} ${_extension_target})
mlir_python_setup_extension_rpath(${_extension_target})
- get_target_property(_generate_type_stubs ${sources_target} mlir_python_GENERATE_TYPE_STUBS)
- if(_generate_type_stubs)
- generate_type_stubs(
- ${_module_name}
- ${_extension_target}
- ${name}
- "${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs"
- OUTPUTS "${_generate_type_stubs}"
- )
- add_dependencies("${modules_target}" "${NB_STUBGEN_CUSTOM_TARGET}")
- set(_stubgen_target "${MLIR_PYTHON_PACKAGE_PREFIX}.${_module_name}_type_stub_gen")
- declare_mlir_python_sources(
- ${_stubgen_target}
- ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs"
- ADD_TO_PARENT "${sources_target}"
- SOURCES "${_generate_type_stubs}"
- )
- set(_pure_sources_target "${modules_target}.sources.${sources_target}_type_stub_gen")
- add_mlir_python_sources_target(${_pure_sources_target}
- INSTALL_COMPONENT ${modules_target}
- INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
- OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
- SOURCES_TARGETS ${_stubgen_target}
- )
- add_dependencies(${modules_target} ${_pure_sources_target})
- endif()
else()
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")
return()
@@ -762,28 +678,26 @@ function(add_mlir_python_extension libname extname)
# the super project handle compile options as it wishes.
get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
target_compile_options(${NB_LIBRARY_TARGET_NAME}
- PRIVATE
- -Wall -Wextra -Wpedantic
- -Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- ${eh_rtti_enable})
+ PRIVATE
+ -Wall -Wextra -Wpedantic
+ -Wno-c++98-compat-extra-semi
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ ${eh_rtti_enable})
target_compile_options(${libname}
- PRIVATE
- -Wall -Wextra -Wpedantic
- -Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- ${eh_rtti_enable})
+ PRIVATE
+ -Wall -Wextra -Wpedantic
+ -Wno-c++98-compat-extra-semi
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ ${eh_rtti_enable})
endif()
if(APPLE)
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index 9846df8..f50ddeb 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -125,7 +125,9 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit
indicates if the section has alignment requirements, a length (which allows for
skipping over the section), and an optional alignment. When an alignment is
present, a variable number of padding bytes (0xCB) may appear before the section
-data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section.
+data. The alignment of a section must be a power of 2.
+The input bytecode buffer must satisfy the same alignment requirements as
+those of every section.
## MLIR Encoding
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 7070351..5ae3515 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -285,9 +285,13 @@ conversions. A context-unaware conversion function converts a `Type` into a
`Type`. A context-aware conversion function converts a `Value` into a type. The
latter allows users to customize type conversion rules based on the IR.
-Note: When there is at least one context-aware type conversion function, the
-result of type conversions can no longer be cached, which can increase
-compilation time. Use this feature with caution!
+Note: context-aware type conversion functions impact the ability of the
+framework to cache the conversion result. In the absence of a context-aware
+conversion, all context-free type conversions can be cached. Otherwise only the
+context-free conversions added after a context-aware type conversion can be
+cached (conversions are applied in reverse order).
+As such it is advised to add context-aware conversions as early as possible in
+the sequence of `addConversion` calls (so that they apply last).
A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
diff --git a/mlir/docs/Dialects/IRDL.md b/mlir/docs/Dialects/IRDL.md
new file mode 100644
index 0000000..c09457a
--- /dev/null
+++ b/mlir/docs/Dialects/IRDL.md
@@ -0,0 +1,123 @@
+# 'irdl' Dialect
+
+[TOC]
+
+## Basics
+
+The IRDL (*Intermediate Representation Definition Language*) dialect allows
+defining MLIR dialects as MLIR programs. Nested operations are used to
+represent dialect structure: dialects contain operations, types and
+attributes, themselves containing type parameters, operands, results, etc.
+Each of those concepts are mapped to MLIR operations in the IRDL dialect, as
+shown in the example dialect below:
+
+```mlir
+irdl.dialect @cmath {
+ irdl.type @complex {
+ %0 = irdl.is f32
+ %1 = irdl.is f64
+ %2 = irdl.any_of(%0, %1)
+ irdl.parameters(%2)
+ }
+
+ irdl.operation @mul {
+ %0 = irdl.is f32
+ %1 = irdl.is f64
+ %2 = irdl.any_of(%0, %1)
+ %3 = irdl.parametric @cmath::@complex<%2>
+ irdl.operands(%3, %3)
+ irdl.results(%3)
+ }
+}
+```
+
+This program defines a `cmath` dialect that defines a `complex` type, and
+a `mul` operation. Both express constraints over their parameters using
+SSA constraint operations. Informally, one can see those SSA values as
+constraint variables that evaluate to a single type at constraint
+evaluation. For example, the result of the `irdl.any_of` stored in `%2`
+in the `mul` operation will collapse into either `f32` or `f64` for the
+entirety of this instance of `mul` constraint evaluation. As such,
+both operands and the result of `mul` must be of equal type (and not just
+satisfy the same constraint). For more information, see
+[constraints and combinators](#constraints-and-combinators).
+
+In order to simplify the dialect, IRDL variables are handles over
+`mlir::Attribute`. In order to support manipulating `mlir::Type`,
+IRDL wraps all types in an `mlir::TypeAttr` attribute.
+
+## Principles
+
+The core principles of IRDL are the following, in no particular order:
+
+- **Portability.** IRDL dialects should be self-contained, such that dialects
+ can be easily distributed with minimal assumptions on which compiler
+ infrastructure (or which commit of MLIR) is used.
+- **Introspection.** The IRDL dialect definition mechanism should strive
+ towards offering as much introspection abilities as possible. Dialects
+ should be as easy to manipulate, generate, and analyze as possible.
+- **Runtime declaration support**. The specification of IRDL dialects should
+ offer the ability to have them be loaded at runtime, via dynamic registration
+ or JIT compilation. Compatibility with dynamic workflows should not hinder
+ the ability to compile IRDL dialects into ahead-of-time declarations.
+- **Reliability.** Concepts in IRDL should be consistent and predictable, with
+ as much focus on high-level simplicity as possible. Consequently, IRDL
+ definitions that verify should work out of the box, and those that do not
+ verify should provide clear and understandable errors in all circumstances.
+
+While IRDL simplifies IR definition, it remains an IR itself and thus does not
+require to be comfortably user-writeable.
+
+## Constraints and combinators
+
+Attribute, type and operation verifiers are expressed in terms of constraint
+variables. Constraint variables are defined as the results of constraint
+operations (like `irdl.is` or constraint combinators).
+
+Constraint variables act as variables: as such, matching against the same
+constraint variable multiple times can only succeed if the matching type or
+attribute is the same as the one that previously matched. In the following
+example:
+
+```mlir
+irdl.type @foo {
+ %ty = irdl.any_type
+ irdl.parameters(param1: %ty, param2: %ty)
+}
+```
+
+only types with two equal parameters will successfully match (`foo<i32, i32>`
+would match while `foo<i32, i64>` would fail, even though both i32 and i64
+individually satisfy the `irdl.any_type` constraint). This constraint variable
+mechanism allows to easily express a requirement on type or attribute equality.
+
+To declare more complex verifiers, IRDL provides constraint-combinator
+operations such as `irdl.any_of`, `irdl.all_of` or `irdl.parametric`. These
+combinators can be used to combine constraint variables into new constraint
+variables. Like all uses of constraint variables, their constraint variable
+operands enforce equality of matched types of attributes as explained in the
+previous paragraph.
+
+## Motivating use cases
+
+To illustrate the rationale behind IRDL, the following list describes examples
+of intended use cases for IRDL, in no particular order:
+
+- **Fuzzer generation.** With declarative verifier definitions, it is possible
+ to compile IRDL dialects into compiler fuzzers that generate only programs
+ passing verifiers.
+- **Portable dialects between compiler infrastructures.** Some compiler
+ infrastructures are independent from MLIR but are otherwise IR-compatible.
+ Portable IRDL dialects allow to share the dialect definitions between MLIR
+ and other compiler infrastructures without needing to maintain multiple
+ potentially out-of-sync definitions.
+- **Dialect simplification.** Because IRDL definitions can easily be
+ mechanically modified, it is possible to simplify the definition of dialects
+ based on which operations are actually used, leading to smaller compilers.
+- **SMT analysis.** Because IRDL dialect definitions are declarative, their
+ definition can be lowered to alternative representations like SMT, allowing
+ analysis of the behavior of transforms taking verifiers into account.
+
+## Operations
+
+[include "Dialects/IRDLOps.md"]
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index cb10518..a0eca9c 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -39,7 +39,6 @@ declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
EMBED_CAPI_LINK_LIBS
StandaloneCAPI
PYTHON_BINDINGS_LIBRARY nanobind
- GENERATE_TYPE_STUBS
)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db8..c7c405e1 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -229,6 +229,13 @@ private:
/// considered an external callable.
Operation *analysisScope;
+ /// Whether the analysis scope has a symbol table. This is used to avoid
+ /// resolving callables outside the analysis scope.
+ /// It is updated when recursing into a region in case where the top-level
+ /// operation does not have a symbol table, but one is encountered in a nested
+ /// region.
+ bool hasSymbolTable = false;
+
/// A symbol table used for O(1) symbol lookups during simplification.
SymbolTableCollection symbolTable;
};
diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
index 291b809..220da0a 100644
--- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
+++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
@@ -13,6 +13,7 @@
#include <memory>
namespace mlir {
+class Pass;
class LLVMTypeConverter;
class ConversionTarget;
class RewritePatternSet;
@@ -42,16 +43,6 @@ void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter,
/// Configure target to convert from the GPU dialect to ROCDL.
void configureGpuToROCDLConversionLegality(ConversionTarget &target);
-/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The
-/// index bitwidth used for the lowering of the device side index computations
-/// is configurable.
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-createLowerGpuOpsToROCDLOpsPass(
- const std::string &chipset = "gfx900",
- unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
- bool useBarePtrCallConv = false,
- gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
-
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 44dc1bc..1a37d05 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -624,7 +624,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let summary = "Generate ROCDL operations for gpu operations";
- let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
let dependentDialects = [
"ROCDL::ROCDLDialect",
"amdgpu::AMDGPUDialect",
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
index 7ffdbd4..f591407 100644
--- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -11,6 +11,7 @@
#include <memory>
namespace mlir {
+class ConversionTarget;
class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
@@ -19,7 +20,8 @@ class Pass;
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
-void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+void populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
+ RewritePatternSet &patterns);
void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 0dd8de4..df4145d 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -153,9 +153,17 @@ public:
MemRefDependenceGraph(Block &block) : block(block) {}
- // Initializes the dependence graph based on operations in `block'.
- // Returns true on success, false otherwise.
- bool init();
+ // Initializes the data dependence graph by iterating over the operations of
+ // the MDG's `block`. A `Node` is created for every top-level op except for
+ // side-effect-free operations with zero results and no regions. Assigns each
+ // node in the graph a node id based on the order in block. Fails if certain
+ // kinds of operations, for which `Node` creation isn't supported, are
+ // encountered (unknown region holding ops). If `fullAffineDependences` is
+ // set, affine memory dependence analysis is performed before concluding that
+ // conflicting affine memory accesses lead to a dependence check; otherwise, a
+ // pair of conflicting affine memory accesses (where one of them is a store
+ // and they are to the same memref) always leads to an edge (conservatively).
+ bool init(bool fullAffineDependences = true);
// Returns the graph node for 'id'.
const Node *getNode(unsigned id) const;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 1ef5370..e735651 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,7 +12,6 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 271b420..6724d4c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -18,7 +18,6 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/SubsetOpInterface.td"
-include "mlir/Interfaces/CopyOpInterface.td"
class Bufferization_Op<string mnemonic, list<Trait> traits = []>
: Op<Bufferization_Dialect, mnemonic, traits>;
@@ -171,7 +170,6 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
//===----------------------------------------------------------------------===//
def Bufferization_CloneOp : Bufferization_Op<"clone", [
- CopyOpInterface,
MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<AllocationOpInterface, ["buildDealloc", "buildClone"]>
]> {
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
index 9dcadec..1c08f0f 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect(IRDL irdl)
-add_mlir_doc(IRDLOps IRDL Dialects/ -gen-dialect-doc -dialect=irdl)
+add_mlir_doc(IRDLOps IRDLOps Dialects/ -gen-op-doc -dialect=irdl)
# Add IRDL interfaces
set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index ac99b8a..a8c9ef7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1233,6 +1233,47 @@ def LLVM_TBAATagArrayAttr
}
//===----------------------------------------------------------------------===//
+// MMRATagAttr
+//===----------------------------------------------------------------------===//
+
+def LLVM_MMRATagAttr : LLVM_Attr<"MMRATag", "mmra_tag"> {
+ let parameters = (ins
+ StringRefParameter<>:$prefix,
+ StringRefParameter<>:$suffix
+ );
+
+ let summary = "MLIR wrapper around a prefix:suffix MMRA tag";
+
+ let description = [{
+ Defines a single memory model relaxation annotation (MMRA) entry
+ with prefix `$prefix` and suffix `$suffix`. This corresponds directly
+ to a LLVM `!{prefix, suffix}` metadata tuple, which is often written
+ `prefix:shuffix` as shorthand.
+
+ Example:
+ ```mlir
+ #mmra_tag = #llvm.mmmra_tag<"amdgpu-synchronize-as":"local">
+ #mmra_tag1 = #llvm.mmra_tag<"foo":"bar">
+ ```
+
+ Either one MMRA tag or an array of them may be added to any LLVM
+ operation that operates on memory.
+
+ ```mlir
+ %v = llvm.load %ptr {llvm.mmra = #mmra_tag} : !llvm.ptr -> i8
+ llvm.store %v, %ptr2 {llvm.mmra [#mmra_tag, #mmra_tag1]} : i8, !llvm.ptr
+ ```
+
+ See the following link for more details:
+ https://llvm.org/docs/MemoryModelRelaxationAnnotations.html
+ }];
+
+ let assemblyFormat = "`<` $prefix `` `:` `` $suffix `>`";
+
+ let genMnemonicAlias = 1;
+}
+
+//===----------------------------------------------------------------------===//
// ConstantRangeAttr
//===----------------------------------------------------------------------===//
def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index ab0462f..d2d7131 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -36,6 +36,7 @@ def LLVM_Dialect : Dialect {
static StringRef getIdentAttrName() { return "llvm.ident"; }
static StringRef getModuleFlags() { return "llvm.module.flags"; }
static StringRef getCommandlineAttrName() { return "llvm.commandline"; }
+ static StringRef getMmraAttrName() { return "llvm.mmra"; }
/// Names of llvm parameter attributes.
static StringRef getAlignAttrName() { return "llvm.align"; }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index eb4e381..9de6d8f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -22,7 +22,6 @@
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4..8f3232f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
- TransformHandleTypeInterface:$split_linalg_op,
- TransformHandleTypeInterface:$combining_linalg_op,
+ TransformHandleTypeInterface:$split_op,
+ TransformHandleTypeInterface:$combining_op,
TransformHandleTypeInterface:$forall_op);
let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
+ Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index ac383ab4..bdec699 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -16,7 +16,6 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6b7a97..513a9a1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -13,7 +13,6 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
@@ -530,7 +529,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
// CopyOp
//===----------------------------------------------------------------------===//
-def CopyOp : MemRef_Op<"copy", [CopyOpInterface, SameOperandsElementType,
+def CopyOp : MemRef_Op<"copy", [SameOperandsElementType,
SameOperandsShape]> {
let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index faf820d..6a92b13 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -40,7 +40,7 @@ struct DeviceTypeClauseOps {
/// Clauses that correspond to operations other than omp.target, but might have
/// to be evaluated outside of a parent target region.
using HostEvaluatedOperands =
- detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps,
+ detail::Clauses<CollapseClauseOps, LoopRelatedClauseOps, NumTeamsClauseOps,
NumThreadsClauseOps, ThreadLimitClauseOps>;
// TODO: Add `indirect` clause.
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57f..5f40abe 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -210,6 +210,23 @@ class OpenMP_BindClauseSkip<
def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
//===----------------------------------------------------------------------===//
+// V5.2: [4.4.3] `collapse` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_CollapseClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ ConfinedAttr<DefaultValuedOptionalAttr<I64Attr, "1">, [IntMinValue<1>]>
+ :$collapse_num_loops
+ );
+}
+
+def OpenMP_CollapseClause : OpenMP_CollapseClauseSkip<>;
+
+//===----------------------------------------------------------------------===//
// V5.2: [5.7.2] `copyprivate` clause
//===----------------------------------------------------------------------===//
@@ -1386,6 +1403,22 @@ class OpenMP_ThreadLimitClauseSkip<
def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>;
//===----------------------------------------------------------------------===//
+// V5.2: [9.1.1] `sizes` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_TileSizesClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ OptionalAttr<DenseI64ArrayAttr>:$tile_sizes
+ );
+}
+
+def OpenMP_TileSizesClause : OpenMP_TileSizesClauseSkip<>;
+
+//===----------------------------------------------------------------------===//
// V5.2: [12.1] `untied` clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 2548a8a..830b36f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -614,13 +614,18 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
RecursiveMemoryEffects, SameVariadicOperandSize
], clauses = [
- OpenMP_LoopRelatedClause
+ OpenMP_CollapseClause,
+ OpenMP_LoopRelatedClause,
+ OpenMP_TileSizesClause
], singleRegion = true> {
let summary = "rectangular loop nest";
let description = [{
- This operation represents a collapsed rectangular loop nest. For each
- rectangular loop of the nest represented by an instance of this operation,
- lower and upper bounds, as well as a step variable, must be defined.
+ This operation represents a rectangular loop nest which may be collapsed
+ and/or tiled. For each rectangular loop of the nest represented by an
+ instance of this operation, lower and upper bounds, as well as a step
+ variable, must be defined. The collapse clause specifies how many loops
+ that should be collapsed (1 if no collapse is done) after any tiling is
+ performed. The tiling sizes is represented by the tile sizes clause.
The lower and upper bounds specify a half-open range: the range includes the
lower bound but does not include the upper bound. If the `loop_inclusive`
@@ -633,7 +638,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
`loop_steps` arguments.
```mlir
- omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+ omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) collapse(2) tiles(5,5) {
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
%sum = arith.addf %a, %b : f32
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 77e26cc..65ba7e0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -972,10 +972,6 @@ def Vector_ScalableInsertOp :
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
- /// Wrapper for getResult, which replaced getRes.
- [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() {
- return getResult();
- }
}];
}
@@ -1027,10 +1023,6 @@ def Vector_ScalableExtractOp :
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- /// Wrapper for getResult, which replaced getRes.
- [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() {
- return getResult();
- }
}];
}
@@ -1085,10 +1077,6 @@ def Vector_InsertStridedSliceOp :
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
});
}
- /// Wrapper for getResult, which replaced getRes.
- [[deprecated("Use getResult instead!")]] ::mlir::Value getRes() {
- return getResult();
- }
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 07a4117..72a69a0 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_to_elements",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector to_elements operations should be unrolled
+ along the outermost dimension.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_scan",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f9611..f56124c 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -313,6 +313,12 @@ void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
/// Populate the pattern set with the following patterns:
///
+/// [UnrollToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
+///
/// [ContractionOpToMatmulOpLowering]
/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
///
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ace2699..97163c4 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -255,6 +255,12 @@ using UnrollVectorOpFn =
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
UnrollVectorOpFn unrollFn);
+/// Generic utility for unrolling values of type vector<NxAxBx...>
+/// to N values of type vector<AxBx...> using vector.extract. If the input
+/// is rank-1 or has leading scalable dimension, failure is returned.
+FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
+ RewriterBase &);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index 20cc267..2add220 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_interface(CallInterfaces)
add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
-add_mlir_interface(CopyOpInterface)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
add_mlir_interface(FunctionInterfaces)
diff --git a/mlir/include/mlir/Interfaces/CopyOpInterface.h b/mlir/include/mlir/Interfaces/CopyOpInterface.h
deleted file mode 100644
index 2f38eb3..0000000
--- a/mlir/include/mlir/Interfaces/CopyOpInterface.h
+++ /dev/null
@@ -1,21 +0,0 @@
-//===- CopyOpInterface.h - copy operations interface ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements the operation interface for copy-like operations.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_COPYOPINTERFACE_H_
-#define MLIR_INTERFACES_COPYOPINTERFACE_H_
-
-#include "mlir/IR/OpDefinition.h"
-
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/CopyOpInterface.h.inc"
-
-#endif // MLIR_INTERFACES_COPYOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/CopyOpInterface.td b/mlir/include/mlir/Interfaces/CopyOpInterface.td
deleted file mode 100644
index f6c5a6f..0000000
--- a/mlir/include/mlir/Interfaces/CopyOpInterface.td
+++ /dev/null
@@ -1,38 +0,0 @@
-//===- CopyOpInterface.td - Copy operation interface -------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines the interface for copy-like operations.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_INTERFACES_COPYOPINTERFACE
-#define MLIR_INTERFACES_COPYOPINTERFACE
-
-include "mlir/IR/OpBase.td"
-
-def CopyOpInterface : OpInterface<"CopyOpInterface"> {
- let description = [{
- A copy-like operation is one that copies from source value to target value.
- }];
- let cppNamespace = "::mlir";
-
- let methods = [
- InterfaceMethod<
- /*desc=*/"Returns the source value for this copy operation",
- /*retTy=*/"::mlir::Value",
- /*methodName=*/"getSource"
- >,
- InterfaceMethod<
- /*desc=*/"Returns the target value for this copy operation",
- /*retTy=*/"::mlir::Value",
- /*methodName=*/"getTarget"
- >
- ];
-}
-
-#endif // MLIR_INTERFACES_COPYOPINTERFACE
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index d168735..58852239 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -217,7 +217,7 @@ public:
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
computeConstantBound(presburger::BoundType type, const Variable &var,
- StopConditionFn stopCondition = nullptr,
+ const StopConditionFn &stopCondition = nullptr,
bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
@@ -282,18 +282,18 @@ public:
///
/// Slice are non-overlapping if the above constraint is not satisfied for
/// at least one dimension.
- static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
- HyperrectangularSlice slice1,
- HyperrectangularSlice slice2);
+ static FailureOr<bool>
+ areOverlappingSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1,
+ const HyperrectangularSlice &slice2);
/// Return "true" if the given slices are guaranteed to be equivalent.
/// Return "false" if the given slices are guaranteed to be non-equivalent.
/// Return "failure" if unknown.
///
/// Slices are equivalent if their offsets, sizes and strices are equal.
- static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
- HyperrectangularSlice slice1,
- HyperrectangularSlice slice2);
+ static FailureOr<bool>
+ areEquivalentSlices(MLIRContext *ctx, const HyperrectangularSlice &slice1,
+ const HyperrectangularSlice &slice2);
/// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound.
@@ -326,7 +326,8 @@ protected:
/// An index-typed value or the dimension of a shaped-type value.
using ValueDim = std::pair<Value, int64_t>;
- ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition,
+ ValueBoundsConstraintSet(MLIRContext *ctx,
+ const StopConditionFn &stopCondition,
bool addConservativeSemiAffineBounds = false);
/// Return "true" if, based on the current state of the constraint system,
@@ -401,7 +402,8 @@ protected:
/// Insert the given affine map and its bound operands as a new column in the
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
- int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+ int64_t insert(AffineMap map, const ValueDimList &operands,
+ bool isSymbol = true);
int64_t insert(const Variable &var, bool isSymbol = true);
/// Project out the given column in the constraint set.
diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td
index c72ca58..6838c16 100644
--- a/mlir/include/mlir/Interfaces/VectorInterfaces.td
+++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td
@@ -187,12 +187,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
return inBounds;
}
- /// Wrapper for getBase, which replaced getSource.
- [[deprecated("Use getBase instead!")]]
- ::mlir::Value getSource() {
- return $_op.getBase();
- }
-
/// Return the number of leading shaped dimensions (of the "source" operand)
/// that do not participate in the permutation map.
unsigned getLeadingShapedRank() {
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 9e57037..f0514d8 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -323,21 +323,22 @@ public:
/// Requires: all result types are known.
const InferredResultType &getInferredResultType(int index) const;
- /// Pair consisting kind of argument and index into operands or attributes.
- struct OperandOrAttribute {
- enum class Kind { Operand, Attribute };
- OperandOrAttribute(Kind kind, int index) {
- packed = (index << 1) | (kind == Kind::Attribute);
+ /// Pair consisting kind of argument and index into operands, attributes, or
+ /// properties.
+ struct OperandAttrOrProp {
+ enum class Kind { Operand = 0x0, Attribute = 0x1, Property = 0x2 };
+ OperandAttrOrProp(Kind kind, int index) {
+ packed = (index << 2) | static_cast<int>(kind);
}
- int operandOrAttributeIndex() const { return (packed >> 1); }
- Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
+ int operandOrAttributeIndex() const { return (packed >> 2); }
+ Kind kind() const { return static_cast<Kind>(packed & 0x3); }
private:
int packed;
};
- /// Returns the OperandOrAttribute corresponding to the index.
- OperandOrAttribute getArgToOperandOrAttribute(int index) const;
+ /// Returns the OperandAttrOrProp corresponding to the index.
+ OperandAttrOrProp getArgToOperandAttrOrProp(int index) const;
/// Returns the builders of this operation.
ArrayRef<Builder> getBuilders() const { return builders; }
@@ -405,8 +406,8 @@ private:
/// The argument with the same type as the result.
SmallVector<InferredResultType> resultTypeMapping;
- /// Map from argument to attribute or operand number.
- SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
+ /// Map from argument to attribute, property, or operand number.
+ SmallVector<OperandAttrOrProp, 4> attrPropOrOperandMapping;
/// The builders of this operator.
SmallVector<Builder> builders;
diff --git a/mlir/include/mlir/Tools/lsp-server-support/Logging.h b/mlir/include/mlir/Tools/lsp-server-support/Logging.h
deleted file mode 100644
index 9b090d0..0000000
--- a/mlir/include/mlir/Tools/lsp-server-support/Logging.h
+++ /dev/null
@@ -1,65 +0,0 @@
-//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
-#define MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Support/FormatVariadic.h"
-#include <memory>
-#include <mutex>
-
-namespace mlir {
-namespace lsp {
-
-/// This class represents the main interface for logging, and allows for
-/// filtering logging based on different levels of severity or significance.
-class Logger {
-public:
- /// The level of significance for a log message.
- enum class Level { Debug, Info, Error };
-
- /// Set the severity level of the logger.
- static void setLogLevel(Level logLevel);
-
- /// Initiate a log message at various severity levels. These should be called
- /// after a call to `initialize`.
- template <typename... Ts>
- static void debug(const char *fmt, Ts &&...vals) {
- log(Level::Debug, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
- }
- template <typename... Ts>
- static void info(const char *fmt, Ts &&...vals) {
- log(Level::Info, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
- }
- template <typename... Ts>
- static void error(const char *fmt, Ts &&...vals) {
- log(Level::Error, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
- }
-
-private:
- Logger() = default;
-
- /// Return the main logger instance.
- static Logger &get();
-
- /// Start a log message with the given severity level.
- static void log(Level logLevel, const char *fmt,
- const llvm::formatv_object_base &message);
-
- /// The minimum logging level. Messages with lower level are ignored.
- Level logLevel = Level::Error;
-
- /// A mutex used to guard logging.
- std::mutex mutex;
-};
-} // namespace lsp
-} // namespace mlir
-
-#endif // MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
diff --git a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h b/mlir/include/mlir/Tools/lsp-server-support/Protocol.h
deleted file mode 100644
index cc06dbf..0000000
--- a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h
+++ /dev/null
@@ -1,1257 +0,0 @@
-//===--- Protocol.h - Language Server Protocol Implementation ---*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains structs based on the LSP specification at
-// https://microsoft.github.io/language-server-protocol/specification
-//
-// This is not meant to be a complete implementation, new interfaces are added
-// when they're needed.
-//
-// Each struct has a toJSON and fromJSON function, that converts between
-// the struct and a JSON representation. (See JSON.h)
-//
-// Some structs also have operator<< serialization. This is for debugging and
-// tests, and is not generally machine-readable.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H
-#define MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H
-
-#include "mlir/Support/LLVM.h"
-#include "llvm/Support/JSON.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include <bitset>
-#include <optional>
-#include <string>
-#include <utility>
-#include <vector>
-
-namespace mlir {
-namespace lsp {
-
-enum class ErrorCode {
- // Defined by JSON RPC.
- ParseError = -32700,
- InvalidRequest = -32600,
- MethodNotFound = -32601,
- InvalidParams = -32602,
- InternalError = -32603,
-
- ServerNotInitialized = -32002,
- UnknownErrorCode = -32001,
-
- // Defined by the protocol.
- RequestCancelled = -32800,
- ContentModified = -32801,
- RequestFailed = -32803,
-};
-
-/// Defines how the host (editor) should sync document changes to the language
-/// server.
-enum class TextDocumentSyncKind {
- /// Documents should not be synced at all.
- None = 0,
-
- /// Documents are synced by always sending the full content of the document.
- Full = 1,
-
- /// Documents are synced by sending the full content on open. After that only
- /// incremental updates to the document are sent.
- Incremental = 2,
-};
-
-//===----------------------------------------------------------------------===//
-// LSPError
-//===----------------------------------------------------------------------===//
-
-/// This class models an LSP error as an llvm::Error.
-class LSPError : public llvm::ErrorInfo<LSPError> {
-public:
- std::string message;
- ErrorCode code;
- static char ID;
-
- LSPError(std::string message, ErrorCode code)
- : message(std::move(message)), code(code) {}
-
- void log(raw_ostream &os) const override {
- os << int(code) << ": " << message;
- }
- std::error_code convertToErrorCode() const override {
- return llvm::inconvertibleErrorCode();
- }
-};
-
-//===----------------------------------------------------------------------===//
-// URIForFile
-//===----------------------------------------------------------------------===//
-
-/// URI in "file" scheme for a file.
-class URIForFile {
-public:
- URIForFile() = default;
-
- /// Try to build a URIForFile from the given URI string.
- static llvm::Expected<URIForFile> fromURI(StringRef uri);
-
- /// Try to build a URIForFile from the given absolute file path and optional
- /// scheme.
- static llvm::Expected<URIForFile> fromFile(StringRef absoluteFilepath,
- StringRef scheme = "file");
-
- /// Returns the absolute path to the file.
- StringRef file() const { return filePath; }
-
- /// Returns the original uri of the file.
- StringRef uri() const { return uriStr; }
-
- /// Return the scheme of the uri.
- StringRef scheme() const;
-
- explicit operator bool() const { return !filePath.empty(); }
-
- friend bool operator==(const URIForFile &lhs, const URIForFile &rhs) {
- return lhs.filePath == rhs.filePath;
- }
- friend bool operator!=(const URIForFile &lhs, const URIForFile &rhs) {
- return !(lhs == rhs);
- }
- friend bool operator<(const URIForFile &lhs, const URIForFile &rhs) {
- return lhs.filePath < rhs.filePath;
- }
-
- /// Register a supported URI scheme. The protocol supports `file` by default,
- /// so this is only necessary for any additional schemes that a server wants
- /// to support.
- static void registerSupportedScheme(StringRef scheme);
-
-private:
- explicit URIForFile(std::string &&filePath, std::string &&uriStr)
- : filePath(std::move(filePath)), uriStr(uriStr) {}
-
- std::string filePath;
- std::string uriStr;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const URIForFile &value);
-bool fromJSON(const llvm::json::Value &value, URIForFile &result,
- llvm::json::Path path);
-raw_ostream &operator<<(raw_ostream &os, const URIForFile &value);
-
-//===----------------------------------------------------------------------===//
-// ClientCapabilities
-//===----------------------------------------------------------------------===//
-
-struct ClientCapabilities {
- /// Client supports hierarchical document symbols.
- /// textDocument.documentSymbol.hierarchicalDocumentSymbolSupport
- bool hierarchicalDocumentSymbol = false;
-
- /// Client supports CodeAction return value for textDocument/codeAction.
- /// textDocument.codeAction.codeActionLiteralSupport.
- bool codeActionStructure = false;
-
- /// Client supports server-initiated progress via the
- /// window/workDoneProgress/create method.
- ///
- /// window.workDoneProgress
- bool workDoneProgress = false;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, ClientCapabilities &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// ClientInfo
-//===----------------------------------------------------------------------===//
-
-struct ClientInfo {
- /// The name of the client as defined by the client.
- std::string name;
-
- /// The client's version as defined by the client.
- std::optional<std::string> version;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, ClientInfo &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// InitializeParams
-//===----------------------------------------------------------------------===//
-
-enum class TraceLevel {
- Off = 0,
- Messages = 1,
- Verbose = 2,
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, TraceLevel &result,
- llvm::json::Path path);
-
-struct InitializeParams {
- /// The capabilities provided by the client (editor or tool).
- ClientCapabilities capabilities;
-
- /// Information about the client.
- std::optional<ClientInfo> clientInfo;
-
- /// The initial trace setting. If omitted trace is disabled ('off').
- std::optional<TraceLevel> trace;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, InitializeParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// InitializedParams
-//===----------------------------------------------------------------------===//
-
-struct NoParams {};
-inline bool fromJSON(const llvm::json::Value &, NoParams &, llvm::json::Path) {
- return true;
-}
-using InitializedParams = NoParams;
-
-//===----------------------------------------------------------------------===//
-// TextDocumentItem
-//===----------------------------------------------------------------------===//
-
-struct TextDocumentItem {
- /// The text document's URI.
- URIForFile uri;
-
- /// The text document's language identifier.
- std::string languageId;
-
- /// The content of the opened text document.
- std::string text;
-
- /// The version number of this document.
- int64_t version;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, TextDocumentItem &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// TextDocumentIdentifier
-//===----------------------------------------------------------------------===//
-
-struct TextDocumentIdentifier {
- /// The text document's URI.
- URIForFile uri;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const TextDocumentIdentifier &value);
-bool fromJSON(const llvm::json::Value &value, TextDocumentIdentifier &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// VersionedTextDocumentIdentifier
-//===----------------------------------------------------------------------===//
-
-struct VersionedTextDocumentIdentifier {
- /// The text document's URI.
- URIForFile uri;
- /// The version number of this document.
- int64_t version;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const VersionedTextDocumentIdentifier &value);
-bool fromJSON(const llvm::json::Value &value,
- VersionedTextDocumentIdentifier &result, llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// Position
-//===----------------------------------------------------------------------===//
-
-struct Position {
- Position(int line = 0, int character = 0)
- : line(line), character(character) {}
-
- /// Construct a position from the given source location.
- Position(llvm::SourceMgr &mgr, SMLoc loc) {
- std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
- line = lineAndCol.first - 1;
- character = lineAndCol.second - 1;
- }
-
- /// Line position in a document (zero-based).
- int line = 0;
-
- /// Character offset on a line in a document (zero-based).
- int character = 0;
-
- friend bool operator==(const Position &lhs, const Position &rhs) {
- return std::tie(lhs.line, lhs.character) ==
- std::tie(rhs.line, rhs.character);
- }
- friend bool operator!=(const Position &lhs, const Position &rhs) {
- return !(lhs == rhs);
- }
- friend bool operator<(const Position &lhs, const Position &rhs) {
- return std::tie(lhs.line, lhs.character) <
- std::tie(rhs.line, rhs.character);
- }
- friend bool operator<=(const Position &lhs, const Position &rhs) {
- return std::tie(lhs.line, lhs.character) <=
- std::tie(rhs.line, rhs.character);
- }
-
- /// Convert this position into a source location in the main file of the given
- /// source manager.
- SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const {
- return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), line + 1,
- character + 1);
- }
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, Position &result,
- llvm::json::Path path);
-llvm::json::Value toJSON(const Position &value);
-raw_ostream &operator<<(raw_ostream &os, const Position &value);
-
-//===----------------------------------------------------------------------===//
-// Range
-//===----------------------------------------------------------------------===//
-
-struct Range {
- Range() = default;
- Range(Position start, Position end) : start(start), end(end) {}
- Range(Position loc) : Range(loc, loc) {}
-
- /// Construct a range from the given source range.
- Range(llvm::SourceMgr &mgr, SMRange range)
- : Range(Position(mgr, range.Start), Position(mgr, range.End)) {}
-
- /// The range's start position.
- Position start;
-
- /// The range's end position.
- Position end;
-
- friend bool operator==(const Range &lhs, const Range &rhs) {
- return std::tie(lhs.start, lhs.end) == std::tie(rhs.start, rhs.end);
- }
- friend bool operator!=(const Range &lhs, const Range &rhs) {
- return !(lhs == rhs);
- }
- friend bool operator<(const Range &lhs, const Range &rhs) {
- return std::tie(lhs.start, lhs.end) < std::tie(rhs.start, rhs.end);
- }
-
- bool contains(Position pos) const { return start <= pos && pos < end; }
- bool contains(Range range) const {
- return start <= range.start && range.end <= end;
- }
-
- /// Convert this range into a source range in the main file of the given
- /// source manager.
- SMRange getAsSMRange(llvm::SourceMgr &mgr) const {
- SMLoc startLoc = start.getAsSMLoc(mgr);
- SMLoc endLoc = end.getAsSMLoc(mgr);
- // Check that the start and end locations are valid.
- if (!startLoc.isValid() || !endLoc.isValid() ||
- startLoc.getPointer() > endLoc.getPointer())
- return SMRange();
- return SMRange(startLoc, endLoc);
- }
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, Range &result,
- llvm::json::Path path);
-llvm::json::Value toJSON(const Range &value);
-raw_ostream &operator<<(raw_ostream &os, const Range &value);
-
-//===----------------------------------------------------------------------===//
-// Location
-//===----------------------------------------------------------------------===//
-
-struct Location {
- Location() = default;
- Location(const URIForFile &uri, Range range) : uri(uri), range(range) {}
-
- /// Construct a Location from the given source range.
- Location(const URIForFile &uri, llvm::SourceMgr &mgr, SMRange range)
- : Location(uri, Range(mgr, range)) {}
-
- /// The text document's URI.
- URIForFile uri;
- Range range;
-
- friend bool operator==(const Location &lhs, const Location &rhs) {
- return lhs.uri == rhs.uri && lhs.range == rhs.range;
- }
-
- friend bool operator!=(const Location &lhs, const Location &rhs) {
- return !(lhs == rhs);
- }
-
- friend bool operator<(const Location &lhs, const Location &rhs) {
- return std::tie(lhs.uri, lhs.range) < std::tie(rhs.uri, rhs.range);
- }
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, Location &result,
- llvm::json::Path path);
-llvm::json::Value toJSON(const Location &value);
-raw_ostream &operator<<(raw_ostream &os, const Location &value);
-
-//===----------------------------------------------------------------------===//
-// TextDocumentPositionParams
-//===----------------------------------------------------------------------===//
-
-struct TextDocumentPositionParams {
- /// The text document.
- TextDocumentIdentifier textDocument;
-
- /// The position inside the text document.
- Position position;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value,
- TextDocumentPositionParams &result, llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// ReferenceParams
-//===----------------------------------------------------------------------===//
-
-struct ReferenceContext {
- /// Include the declaration of the current symbol.
- bool includeDeclaration = false;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, ReferenceContext &result,
- llvm::json::Path path);
-
-struct ReferenceParams : public TextDocumentPositionParams {
- ReferenceContext context;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, ReferenceParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// DidOpenTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-struct DidOpenTextDocumentParams {
- /// The document that was opened.
- TextDocumentItem textDocument;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, DidOpenTextDocumentParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// DidCloseTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-struct DidCloseTextDocumentParams {
- /// The document that was closed.
- TextDocumentIdentifier textDocument;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value,
- DidCloseTextDocumentParams &result, llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// DidChangeTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-struct TextDocumentContentChangeEvent {
- /// Try to apply this change to the given contents string.
- LogicalResult applyTo(std::string &contents) const;
- /// Try to apply a set of changes to the given contents string.
- static LogicalResult applyTo(ArrayRef<TextDocumentContentChangeEvent> changes,
- std::string &contents);
-
- /// The range of the document that changed.
- std::optional<Range> range;
-
- /// The length of the range that got replaced.
- std::optional<int> rangeLength;
-
- /// The new text of the range/document.
- std::string text;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value,
- TextDocumentContentChangeEvent &result, llvm::json::Path path);
-
-struct DidChangeTextDocumentParams {
- /// The document that changed.
- VersionedTextDocumentIdentifier textDocument;
-
- /// The actual content changes.
- std::vector<TextDocumentContentChangeEvent> contentChanges;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value,
- DidChangeTextDocumentParams &result, llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// MarkupContent
-//===----------------------------------------------------------------------===//
-
-/// Describes the content type that a client supports in various result literals
-/// like `Hover`.
-enum class MarkupKind {
- PlainText,
- Markdown,
-};
-raw_ostream &operator<<(raw_ostream &os, MarkupKind kind);
-
-struct MarkupContent {
- MarkupKind kind = MarkupKind::PlainText;
- std::string value;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const MarkupContent &mc);
-
-//===----------------------------------------------------------------------===//
-// Hover
-//===----------------------------------------------------------------------===//
-
-struct Hover {
- /// Construct a default hover with the given range that uses Markdown content.
- Hover(Range range) : contents{MarkupKind::Markdown, ""}, range(range) {}
-
- /// The hover's content.
- MarkupContent contents;
-
- /// An optional range is a range inside a text document that is used to
- /// visualize a hover, e.g. by changing the background color.
- std::optional<Range> range;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const Hover &hover);
-
-//===----------------------------------------------------------------------===//
-// SymbolKind
-//===----------------------------------------------------------------------===//
-
-enum class SymbolKind {
- File = 1,
- Module = 2,
- Namespace = 3,
- Package = 4,
- Class = 5,
- Method = 6,
- Property = 7,
- Field = 8,
- Constructor = 9,
- Enum = 10,
- Interface = 11,
- Function = 12,
- Variable = 13,
- Constant = 14,
- String = 15,
- Number = 16,
- Boolean = 17,
- Array = 18,
- Object = 19,
- Key = 20,
- Null = 21,
- EnumMember = 22,
- Struct = 23,
- Event = 24,
- Operator = 25,
- TypeParameter = 26
-};
-
-//===----------------------------------------------------------------------===//
-// DocumentSymbol
-//===----------------------------------------------------------------------===//
-
-/// Represents programming constructs like variables, classes, interfaces etc.
-/// that appear in a document. Document symbols can be hierarchical and they
-/// have two ranges: one that encloses its definition and one that points to its
-/// most interesting range, e.g. the range of an identifier.
-struct DocumentSymbol {
- DocumentSymbol() = default;
- DocumentSymbol(DocumentSymbol &&) = default;
- DocumentSymbol(const Twine &name, SymbolKind kind, Range range,
- Range selectionRange)
- : name(name.str()), kind(kind), range(range),
- selectionRange(selectionRange) {}
-
- /// The name of this symbol.
- std::string name;
-
- /// More detail for this symbol, e.g the signature of a function.
- std::string detail;
-
- /// The kind of this symbol.
- SymbolKind kind;
-
- /// The range enclosing this symbol not including leading/trailing whitespace
- /// but everything else like comments. This information is typically used to
- /// determine if the clients cursor is inside the symbol to reveal in the
- /// symbol in the UI.
- Range range;
-
- /// The range that should be selected and revealed when this symbol is being
- /// picked, e.g the name of a function. Must be contained by the `range`.
- Range selectionRange;
-
- /// Children of this symbol, e.g. properties of a class.
- std::vector<DocumentSymbol> children;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const DocumentSymbol &symbol);
-
-//===----------------------------------------------------------------------===//
-// DocumentSymbolParams
-//===----------------------------------------------------------------------===//
-
-struct DocumentSymbolParams {
- // The text document to find symbols in.
- TextDocumentIdentifier textDocument;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, DocumentSymbolParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// DiagnosticRelatedInformation
-//===----------------------------------------------------------------------===//
-
-/// Represents a related message and source code location for a diagnostic.
-/// This should be used to point to code locations that cause or related to a
-/// diagnostics, e.g. when duplicating a symbol in a scope.
-struct DiagnosticRelatedInformation {
- DiagnosticRelatedInformation() = default;
- DiagnosticRelatedInformation(Location location, std::string message)
- : location(std::move(location)), message(std::move(message)) {}
-
- /// The location of this related diagnostic information.
- Location location;
- /// The message of this related diagnostic information.
- std::string message;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value,
- DiagnosticRelatedInformation &result, llvm::json::Path path);
-llvm::json::Value toJSON(const DiagnosticRelatedInformation &info);
-
-//===----------------------------------------------------------------------===//
-// Diagnostic
-//===----------------------------------------------------------------------===//
-
-enum class DiagnosticSeverity {
- /// It is up to the client to interpret diagnostics as error, warning, info or
- /// hint.
- Undetermined = 0,
- Error = 1,
- Warning = 2,
- Information = 3,
- Hint = 4
-};
-
-enum class DiagnosticTag {
- Unnecessary = 1,
- Deprecated = 2,
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(DiagnosticTag tag);
-bool fromJSON(const llvm::json::Value &value, DiagnosticTag &result,
- llvm::json::Path path);
-
-struct Diagnostic {
- /// The source range where the message applies.
- Range range;
-
- /// The diagnostic's severity. Can be omitted. If omitted it is up to the
- /// client to interpret diagnostics as error, warning, info or hint.
- DiagnosticSeverity severity = DiagnosticSeverity::Undetermined;
-
- /// A human-readable string describing the source of this diagnostic, e.g.
- /// 'typescript' or 'super lint'.
- std::string source;
-
- /// The diagnostic's message.
- std::string message;
-
- /// An array of related diagnostic information, e.g. when symbol-names within
- /// a scope collide all definitions can be marked via this property.
- std::optional<std::vector<DiagnosticRelatedInformation>> relatedInformation;
-
- /// Additional metadata about the diagnostic.
- std::vector<DiagnosticTag> tags;
-
- /// The diagnostic's category. Can be omitted.
- /// An LSP extension that's used to send the name of the category over to the
- /// client. The category typically describes the compilation stage during
- /// which the issue was produced, e.g. "Semantic Issue" or "Parse Issue".
- std::optional<std::string> category;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const Diagnostic &diag);
-bool fromJSON(const llvm::json::Value &value, Diagnostic &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// PublishDiagnosticsParams
-//===----------------------------------------------------------------------===//
-
-struct PublishDiagnosticsParams {
- PublishDiagnosticsParams(URIForFile uri, int64_t version)
- : uri(std::move(uri)), version(version) {}
-
- /// The URI for which diagnostic information is reported.
- URIForFile uri;
- /// The list of reported diagnostics.
- std::vector<Diagnostic> diagnostics;
- /// The version number of the document the diagnostics are published for.
- int64_t version;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const PublishDiagnosticsParams &params);
-
-//===----------------------------------------------------------------------===//
-// TextEdit
-//===----------------------------------------------------------------------===//
-
-struct TextEdit {
- /// The range of the text document to be manipulated. To insert
- /// text into a document create a range where start === end.
- Range range;
-
- /// The string to be inserted. For delete operations use an
- /// empty string.
- std::string newText;
-};
-
-inline bool operator==(const TextEdit &lhs, const TextEdit &rhs) {
- return std::tie(lhs.newText, lhs.range) == std::tie(rhs.newText, rhs.range);
-}
-
-bool fromJSON(const llvm::json::Value &value, TextEdit &result,
- llvm::json::Path path);
-llvm::json::Value toJSON(const TextEdit &value);
-raw_ostream &operator<<(raw_ostream &os, const TextEdit &value);
-
-//===----------------------------------------------------------------------===//
-// CompletionItemKind
-//===----------------------------------------------------------------------===//
-
-/// The kind of a completion entry.
-enum class CompletionItemKind {
- Missing = 0,
- Text = 1,
- Method = 2,
- Function = 3,
- Constructor = 4,
- Field = 5,
- Variable = 6,
- Class = 7,
- Interface = 8,
- Module = 9,
- Property = 10,
- Unit = 11,
- Value = 12,
- Enum = 13,
- Keyword = 14,
- Snippet = 15,
- Color = 16,
- File = 17,
- Reference = 18,
- Folder = 19,
- EnumMember = 20,
- Constant = 21,
- Struct = 22,
- Event = 23,
- Operator = 24,
- TypeParameter = 25,
-};
-bool fromJSON(const llvm::json::Value &value, CompletionItemKind &result,
- llvm::json::Path path);
-
-constexpr auto kCompletionItemKindMin =
- static_cast<size_t>(CompletionItemKind::Text);
-constexpr auto kCompletionItemKindMax =
- static_cast<size_t>(CompletionItemKind::TypeParameter);
-using CompletionItemKindBitset = std::bitset<kCompletionItemKindMax + 1>;
-bool fromJSON(const llvm::json::Value &value, CompletionItemKindBitset &result,
- llvm::json::Path path);
-
-CompletionItemKind
-adjustKindToCapability(CompletionItemKind kind,
- CompletionItemKindBitset &supportedCompletionItemKinds);
-
-//===----------------------------------------------------------------------===//
-// CompletionItem
-//===----------------------------------------------------------------------===//
-
-/// Defines whether the insert text in a completion item should be interpreted
-/// as plain text or a snippet.
-enum class InsertTextFormat {
- Missing = 0,
- /// The primary text to be inserted is treated as a plain string.
- PlainText = 1,
- /// The primary text to be inserted is treated as a snippet.
- ///
- /// A snippet can define tab stops and placeholders with `$1`, `$2`
- /// and `${3:foo}`. `$0` defines the final tab stop, it defaults to the end
- /// of the snippet. Placeholders with equal identifiers are linked, that is
- /// typing in one will update others too.
- ///
- /// See also:
- /// https//github.com/Microsoft/vscode/blob/master/src/vs/editor/contrib/snippet/common/snippet.md
- Snippet = 2,
-};
-
-struct CompletionItem {
- CompletionItem() = default;
- CompletionItem(const Twine &label, CompletionItemKind kind,
- StringRef sortText = "")
- : label(label.str()), kind(kind), sortText(sortText.str()),
- insertTextFormat(InsertTextFormat::PlainText) {}
-
- /// The label of this completion item. By default also the text that is
- /// inserted when selecting this completion.
- std::string label;
-
- /// The kind of this completion item. Based of the kind an icon is chosen by
- /// the editor.
- CompletionItemKind kind = CompletionItemKind::Missing;
-
- /// A human-readable string with additional information about this item, like
- /// type or symbol information.
- std::string detail;
-
- /// A human-readable string that represents a doc-comment.
- std::optional<MarkupContent> documentation;
-
- /// A string that should be used when comparing this item with other items.
- /// When `falsy` the label is used.
- std::string sortText;
-
- /// A string that should be used when filtering a set of completion items.
- /// When `falsy` the label is used.
- std::string filterText;
-
- /// A string that should be inserted to a document when selecting this
- /// completion. When `falsy` the label is used.
- std::string insertText;
-
- /// The format of the insert text. The format applies to both the `insertText`
- /// property and the `newText` property of a provided `textEdit`.
- InsertTextFormat insertTextFormat = InsertTextFormat::Missing;
-
- /// An edit which is applied to a document when selecting this completion.
- /// When an edit is provided `insertText` is ignored.
- ///
- /// Note: The range of the edit must be a single line range and it must
- /// contain the position at which completion has been requested.
- std::optional<TextEdit> textEdit;
-
- /// An optional array of additional text edits that are applied when selecting
- /// this completion. Edits must not overlap with the main edit nor with
- /// themselves.
- std::vector<TextEdit> additionalTextEdits;
-
- /// Indicates if this item is deprecated.
- bool deprecated = false;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const CompletionItem &value);
-raw_ostream &operator<<(raw_ostream &os, const CompletionItem &value);
-bool operator<(const CompletionItem &lhs, const CompletionItem &rhs);
-
-//===----------------------------------------------------------------------===//
-// CompletionList
-//===----------------------------------------------------------------------===//
-
-/// Represents a collection of completion items to be presented in the editor.
-struct CompletionList {
- /// The list is not complete. Further typing should result in recomputing the
- /// list.
- bool isIncomplete = false;
-
- /// The completion items.
- std::vector<CompletionItem> items;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const CompletionList &value);
-
-//===----------------------------------------------------------------------===//
-// CompletionContext
-//===----------------------------------------------------------------------===//
-
-enum class CompletionTriggerKind {
- /// Completion was triggered by typing an identifier (24x7 code
- /// complete), manual invocation (e.g Ctrl+Space) or via API.
- Invoked = 1,
-
- /// Completion was triggered by a trigger character specified by
- /// the `triggerCharacters` properties of the `CompletionRegistrationOptions`.
- TriggerCharacter = 2,
-
- /// Completion was re-triggered as the current completion list is incomplete.
- TriggerTriggerForIncompleteCompletions = 3
-};
-
-struct CompletionContext {
- /// How the completion was triggered.
- CompletionTriggerKind triggerKind = CompletionTriggerKind::Invoked;
-
- /// The trigger character (a single character) that has trigger code complete.
- /// Is undefined if `triggerKind !== CompletionTriggerKind.TriggerCharacter`
- std::string triggerCharacter;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, CompletionContext &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// CompletionParams
-//===----------------------------------------------------------------------===//
-
-struct CompletionParams : TextDocumentPositionParams {
- CompletionContext context;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, CompletionParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// ParameterInformation
-//===----------------------------------------------------------------------===//
-
-/// A single parameter of a particular signature.
-struct ParameterInformation {
- /// The label of this parameter. Ignored when labelOffsets is set.
- std::string labelString;
-
- /// Inclusive start and exclusive end offsets withing the containing signature
- /// label.
- std::optional<std::pair<unsigned, unsigned>> labelOffsets;
-
- /// The documentation of this parameter. Optional.
- std::string documentation;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const ParameterInformation &value);
-
-//===----------------------------------------------------------------------===//
-// SignatureInformation
-//===----------------------------------------------------------------------===//
-
-/// Represents the signature of something callable.
-struct SignatureInformation {
- /// The label of this signature. Mandatory.
- std::string label;
-
- /// The documentation of this signature. Optional.
- std::string documentation;
-
- /// The parameters of this signature.
- std::vector<ParameterInformation> parameters;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const SignatureInformation &value);
-raw_ostream &operator<<(raw_ostream &os, const SignatureInformation &value);
-
-//===----------------------------------------------------------------------===//
-// SignatureHelp
-//===----------------------------------------------------------------------===//
-
-/// Represents the signature of a callable.
-struct SignatureHelp {
- /// The resulting signatures.
- std::vector<SignatureInformation> signatures;
-
- /// The active signature.
- int activeSignature = 0;
-
- /// The active parameter of the active signature.
- int activeParameter = 0;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const SignatureHelp &value);
-
-//===----------------------------------------------------------------------===//
-// DocumentLinkParams
-//===----------------------------------------------------------------------===//
-
-/// Parameters for the document link request.
-struct DocumentLinkParams {
- /// The document to provide document links for.
- TextDocumentIdentifier textDocument;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, DocumentLinkParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// DocumentLink
-//===----------------------------------------------------------------------===//
-
-/// A range in a text document that links to an internal or external resource,
-/// like another text document or a web site.
-struct DocumentLink {
- DocumentLink() = default;
- DocumentLink(Range range, URIForFile target)
- : range(range), target(std::move(target)) {}
-
- /// The range this link applies to.
- Range range;
-
- /// The uri this link points to. If missing a resolve request is sent later.
- URIForFile target;
-
- // TODO: The following optional fields defined by the language server protocol
- // are unsupported:
- //
- // data?: any - A data entry field that is preserved on a document link
- // between a DocumentLinkRequest and a
- // DocumentLinkResolveRequest.
-
- friend bool operator==(const DocumentLink &lhs, const DocumentLink &rhs) {
- return lhs.range == rhs.range && lhs.target == rhs.target;
- }
-
- friend bool operator!=(const DocumentLink &lhs, const DocumentLink &rhs) {
- return !(lhs == rhs);
- }
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const DocumentLink &value);
-
-//===----------------------------------------------------------------------===//
-// InlayHintsParams
-//===----------------------------------------------------------------------===//
-
-/// A parameter literal used in inlay hint requests.
-struct InlayHintsParams {
- /// The text document.
- TextDocumentIdentifier textDocument;
-
- /// The visible document range for which inlay hints should be computed.
- Range range;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, InlayHintsParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// InlayHintKind
-//===----------------------------------------------------------------------===//
-
-/// Inlay hint kinds.
-enum class InlayHintKind {
- /// An inlay hint that for a type annotation.
- ///
- /// An example of a type hint is a hint in this position:
- /// auto var ^ = expr;
- /// which shows the deduced type of the variable.
- Type = 1,
-
- /// An inlay hint that is for a parameter.
- ///
- /// An example of a parameter hint is a hint in this position:
- /// func(^arg);
- /// which shows the name of the corresponding parameter.
- Parameter = 2,
-};
-
-//===----------------------------------------------------------------------===//
-// InlayHint
-//===----------------------------------------------------------------------===//
-
-/// Inlay hint information.
-struct InlayHint {
- InlayHint(InlayHintKind kind, Position pos) : position(pos), kind(kind) {}
-
- /// The position of this hint.
- Position position;
-
- /// The label of this hint. A human readable string or an array of
- /// InlayHintLabelPart label parts.
- ///
- /// *Note* that neither the string nor the label part can be empty.
- std::string label;
-
- /// The kind of this hint. Can be omitted in which case the client should fall
- /// back to a reasonable default.
- InlayHintKind kind;
-
- /// Render padding before the hint.
- ///
- /// Note: Padding should use the editor's background color, not the
- /// background color of the hint itself. That means padding can be used
- /// to visually align/separate an inlay hint.
- bool paddingLeft = false;
-
- /// Render padding after the hint.
- ///
- /// Note: Padding should use the editor's background color, not the
- /// background color of the hint itself. That means padding can be used
- /// to visually align/separate an inlay hint.
- bool paddingRight = false;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const InlayHint &);
-bool operator==(const InlayHint &lhs, const InlayHint &rhs);
-bool operator<(const InlayHint &lhs, const InlayHint &rhs);
-llvm::raw_ostream &operator<<(llvm::raw_ostream &os, InlayHintKind value);
-
-//===----------------------------------------------------------------------===//
-// CodeActionContext
-//===----------------------------------------------------------------------===//
-
-struct CodeActionContext {
- /// An array of diagnostics known on the client side overlapping the range
- /// provided to the `textDocument/codeAction` request. They are provided so
- /// that the server knows which errors are currently presented to the user for
- /// the given range. There is no guarantee that these accurately reflect the
- /// error state of the resource. The primary parameter to compute code actions
- /// is the provided range.
- std::vector<Diagnostic> diagnostics;
-
- /// Requested kind of actions to return.
- ///
- /// Actions not of this kind are filtered out by the client before being
- /// shown. So servers can omit computing them.
- std::vector<std::string> only;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, CodeActionContext &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// CodeActionParams
-//===----------------------------------------------------------------------===//
-
-struct CodeActionParams {
- /// The document in which the command was invoked.
- TextDocumentIdentifier textDocument;
-
- /// The range for which the command was invoked.
- Range range;
-
- /// Context carrying additional information.
- CodeActionContext context;
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, CodeActionParams &result,
- llvm::json::Path path);
-
-//===----------------------------------------------------------------------===//
-// WorkspaceEdit
-//===----------------------------------------------------------------------===//
-
-struct WorkspaceEdit {
- /// Holds changes to existing resources.
- std::map<std::string, std::vector<TextEdit>> changes;
-
- /// Note: "documentChanges" is not currently used because currently there is
- /// no support for versioned edits.
-};
-
-/// Add support for JSON serialization.
-bool fromJSON(const llvm::json::Value &value, WorkspaceEdit &result,
- llvm::json::Path path);
-llvm::json::Value toJSON(const WorkspaceEdit &value);
-
-//===----------------------------------------------------------------------===//
-// CodeAction
-//===----------------------------------------------------------------------===//
-
-/// A code action represents a change that can be performed in code, e.g. to fix
-/// a problem or to refactor code.
-///
-/// A CodeAction must set either `edit` and/or a `command`. If both are
-/// supplied, the `edit` is applied first, then the `command` is executed.
-struct CodeAction {
- /// A short, human-readable, title for this code action.
- std::string title;
-
- /// The kind of the code action.
- /// Used to filter code actions.
- std::optional<std::string> kind;
- const static llvm::StringLiteral kQuickFix;
- const static llvm::StringLiteral kRefactor;
- const static llvm::StringLiteral kInfo;
-
- /// The diagnostics that this code action resolves.
- std::optional<std::vector<Diagnostic>> diagnostics;
-
- /// Marks this as a preferred action. Preferred actions are used by the
- /// `auto fix` command and can be targeted by keybindings.
- /// A quick fix should be marked preferred if it properly addresses the
- /// underlying error. A refactoring should be marked preferred if it is the
- /// most reasonable choice of actions to take.
- bool isPreferred = false;
-
- /// The workspace edit this code action performs.
- std::optional<WorkspaceEdit> edit;
-};
-
-/// Add support for JSON serialization.
-llvm::json::Value toJSON(const CodeAction &);
-
-} // namespace lsp
-} // namespace mlir
-
-namespace llvm {
-template <>
-struct format_provider<mlir::lsp::Position> {
- static void format(const mlir::lsp::Position &pos, raw_ostream &os,
- StringRef style) {
- assert(style.empty() && "style modifiers for this type are not supported");
- os << pos;
- }
-};
-} // namespace llvm
-
-#endif
diff --git a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h
index 9ed8326..920ce83 100644
--- a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h
+++ b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h
@@ -14,7 +14,8 @@
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H
-#include "mlir/Tools/lsp-server-support/Protocol.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/SourceMgr.h"
#include <optional>
@@ -45,17 +46,18 @@ bool contains(SMRange range, SMLoc loc);
/// This class represents a single include within a root file.
struct SourceMgrInclude {
- SourceMgrInclude(const lsp::URIForFile &uri, const lsp::Range &range)
+ SourceMgrInclude(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Range &range)
: uri(uri), range(range) {}
/// Build a hover for the current include file.
- Hover buildHover() const;
+ llvm::lsp::Hover buildHover() const;
/// The URI of the file that is included.
- lsp::URIForFile uri;
+ llvm::lsp::URIForFile uri;
/// The range of the include directive.
- lsp::Range range;
+ llvm::lsp::Range range;
};
/// Given a source manager, gather all of the processed include files. These are
diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h
deleted file mode 100644
index 0010a47..0000000
--- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h
+++ /dev/null
@@ -1,283 +0,0 @@
-//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The language server protocol is usually implemented by writing messages as
-// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON
-// transport interface that handles this communication.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
-#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
-
-#include "mlir/Support/DebugStringHelper.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-#include "llvm/ADT/FunctionExtras.h"
-#include "llvm/ADT/StringMap.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/FormatAdapters.h"
-#include "llvm/Support/JSON.h"
-#include "llvm/Support/raw_ostream.h"
-#include <atomic>
-
-namespace mlir {
-namespace lsp {
-class MessageHandler;
-
-//===----------------------------------------------------------------------===//
-// JSONTransport
-//===----------------------------------------------------------------------===//
-
-/// The encoding style of the JSON-RPC messages (both input and output).
-enum JSONStreamStyle {
- /// Encoding per the LSP specification, with mandatory Content-Length header.
- Standard,
- /// Messages are delimited by a '// -----' line. Comment lines start with //.
- Delimited
-};
-
-/// An abstract class used by the JSONTransport to read JSON message.
-class JSONTransportInput {
-public:
- explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard)
- : style(style) {}
- virtual ~JSONTransportInput() = default;
-
- virtual bool hasError() const = 0;
- virtual bool isEndOfInput() const = 0;
-
- /// Read in a message from the input stream.
- LogicalResult readMessage(std::string &json) {
- return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json)
- : readStandardMessage(json);
- }
- virtual LogicalResult readDelimitedMessage(std::string &json) = 0;
- virtual LogicalResult readStandardMessage(std::string &json) = 0;
-
-private:
- /// The JSON stream style to use.
- JSONStreamStyle style;
-};
-
-/// Concrete implementation of the JSONTransportInput that reads from a file.
-class JSONTransportInputOverFile : public JSONTransportInput {
-public:
- explicit JSONTransportInputOverFile(
- std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard)
- : JSONTransportInput(style), in(in) {}
-
- bool hasError() const final { return ferror(in); }
- bool isEndOfInput() const final { return feof(in); }
-
- LogicalResult readDelimitedMessage(std::string &json) final;
- LogicalResult readStandardMessage(std::string &json) final;
-
-private:
- std::FILE *in;
-};
-
-/// A transport class that performs the JSON-RPC communication with the LSP
-/// client.
-class JSONTransport {
-public:
- JSONTransport(std::unique_ptr<JSONTransportInput> in, raw_ostream &out,
- bool prettyOutput = false)
- : in(std::move(in)), out(out), prettyOutput(prettyOutput) {}
-
- JSONTransport(std::FILE *in, raw_ostream &out,
- JSONStreamStyle style = JSONStreamStyle::Standard,
- bool prettyOutput = false)
- : in(std::make_unique<JSONTransportInputOverFile>(in, style)), out(out),
- prettyOutput(prettyOutput) {}
-
- /// The following methods are used to send a message to the LSP client.
- void notify(StringRef method, llvm::json::Value params);
- void call(StringRef method, llvm::json::Value params, llvm::json::Value id);
- void reply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
-
- /// Start executing the JSON-RPC transport.
- llvm::Error run(MessageHandler &handler);
-
-private:
- /// Dispatches the given incoming json message to the message handler.
- bool handleMessage(llvm::json::Value msg, MessageHandler &handler);
- /// Writes the given message to the output stream.
- void sendMessage(llvm::json::Value msg);
-
-private:
- /// The input to read a message from.
- std::unique_ptr<JSONTransportInput> in;
- SmallVector<char, 0> outputBuffer;
- /// The output file stream.
- raw_ostream &out;
- /// If the output JSON should be formatted for easier readability.
- bool prettyOutput;
-};
-
-//===----------------------------------------------------------------------===//
-// MessageHandler
-//===----------------------------------------------------------------------===//
-
-/// A Callback<T> is a void function that accepts Expected<T>. This is
-/// accepted by functions that logically return T.
-template <typename T>
-using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
-
-/// An OutgoingNotification<T> is a function used for outgoing notifications
-/// send to the client.
-template <typename T>
-using OutgoingNotification = llvm::unique_function<void(const T &)>;
-
-/// An OutgoingRequest<T> is a function used for outgoing requests to send to
-/// the client.
-template <typename T>
-using OutgoingRequest =
- llvm::unique_function<void(const T &, llvm::json::Value id)>;
-
-/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
-/// client receives a response in turn. It is passed the original request's ID,
-/// as well as the response result.
-template <typename T>
-using OutgoingRequestCallback =
- std::function<void(llvm::json::Value, llvm::Expected<T>)>;
-
-/// A handler used to process the incoming transport messages.
-class MessageHandler {
-public:
- MessageHandler(JSONTransport &transport) : transport(transport) {}
-
- bool onNotify(StringRef method, llvm::json::Value value);
- bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id);
- bool onReply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
-
- template <typename T>
- static llvm::Expected<T> parse(const llvm::json::Value &raw,
- StringRef payloadName, StringRef payloadKind) {
- T result;
- llvm::json::Path::Root root;
- if (fromJSON(raw, result, root))
- return std::move(result);
-
- // Dump the relevant parts of the broken message.
- std::string context;
- llvm::raw_string_ostream os(context);
- root.printErrorContext(raw, os);
-
- // Report the error (e.g. to the client).
- return llvm::make_error<LSPError>(
- llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind,
- fmt_consume(root.getError())),
- ErrorCode::InvalidParams);
- }
-
- template <typename Param, typename Result, typename ThisT>
- void method(llvm::StringLiteral method, ThisT *thisPtr,
- void (ThisT::*handler)(const Param &, Callback<Result>)) {
- methodHandlers[method] = [method, handler,
- thisPtr](llvm::json::Value rawParams,
- Callback<llvm::json::Value> reply) {
- llvm::Expected<Param> param = parse<Param>(rawParams, method, "request");
- if (!param)
- return reply(param.takeError());
- (thisPtr->*handler)(*param, std::move(reply));
- };
- }
-
- template <typename Param, typename ThisT>
- void notification(llvm::StringLiteral method, ThisT *thisPtr,
- void (ThisT::*handler)(const Param &)) {
- notificationHandlers[method] = [method, handler,
- thisPtr](llvm::json::Value rawParams) {
- llvm::Expected<Param> param =
- parse<Param>(rawParams, method, "notification");
- if (!param) {
- return llvm::consumeError(
- llvm::handleErrors(param.takeError(), [](const LSPError &lspError) {
- Logger::error("JSON parsing error: {0}",
- lspError.message.c_str());
- }));
- }
- (thisPtr->*handler)(*param);
- };
- }
-
- /// Create an OutgoingNotification object used for the given method.
- template <typename T>
- OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) {
- return [&, method](const T &params) {
- std::lock_guard<std::mutex> transportLock(transportOutputMutex);
- Logger::info("--> {0}", method);
- transport.notify(method, llvm::json::Value(params));
- };
- }
-
- /// Create an OutgoingRequest function that, when called, sends a request with
- /// the given method via the transport. Should the outgoing request be
- /// met with a response, the result JSON is parsed and the response callback
- /// is invoked.
- template <typename Param, typename Result>
- OutgoingRequest<Param>
- outgoingRequest(llvm::StringLiteral method,
- OutgoingRequestCallback<Result> callback) {
- return [&, method, callback](const Param &param, llvm::json::Value id) {
- auto callbackWrapper = [method, callback = std::move(callback)](
- llvm::json::Value id,
- llvm::Expected<llvm::json::Value> value) {
- if (!value)
- return callback(std::move(id), value.takeError());
-
- std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
- llvm::Expected<Result> result =
- parse<Result>(*value, responseName, "response");
- if (!result)
- return callback(std::move(id), result.takeError());
-
- return callback(std::move(id), *result);
- };
-
- {
- std::lock_guard<std::mutex> lock(responseHandlersMutex);
- responseHandlers.insert(
- {debugString(id), std::make_pair(method.str(), callbackWrapper)});
- }
-
- std::lock_guard<std::mutex> transportLock(transportOutputMutex);
- Logger::info("--> {0}({1})", method, id);
- transport.call(method, llvm::json::Value(param), id);
- };
- }
-
-private:
- template <typename HandlerT>
- using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
-
- HandlerMap<void(llvm::json::Value)> notificationHandlers;
- HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
- methodHandlers;
-
- /// A pair of (1) the original request's method name, and (2) the callback
- /// function to be invoked for responses.
- using ResponseHandlerTy =
- std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
- /// A mapping from request/response ID to response handler.
- llvm::StringMap<ResponseHandlerTy> responseHandlers;
- /// Mutex to guard insertion into the response handler map.
- std::mutex responseHandlersMutex;
-
- JSONTransport &transport;
-
- /// Mutex to guard sending output messages to the transport.
- std::mutex transportOutputMutex;
-};
-
-} // namespace lsp
-} // namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h
index 4811ecb..0d9ba2a 100644
--- a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h
+++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h
@@ -16,14 +16,16 @@
namespace llvm {
template <typename Fn>
class function_ref;
+namespace lsp {
+class URIForFile;
+} // namespace lsp
} // namespace llvm
namespace mlir {
class DialectRegistry;
namespace lsp {
-class URIForFile;
using DialectRegistryFn =
- llvm::function_ref<DialectRegistry &(const URIForFile &uri)>;
+ llvm::function_ref<DialectRegistry &(const llvm::lsp::URIForFile &uri)>;
} // namespace lsp
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a..a096f82 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -433,7 +433,7 @@ private:
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
- hasContextAwareTypeConversions = true;
+ contextAwareTypeConversionsIndex = conversions.size();
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -555,6 +555,10 @@ private:
cachedMultiConversions.clear();
}
+ /// Internal implementation of the type conversion.
+ LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
+ SmallVectorImpl<Type> &results) const;
+
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
@@ -575,10 +579,13 @@ private:
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
- /// Type conversion caching is deactivated when there are context-aware
- /// conversions because the type converter may return different results for
- /// the same input type.
- bool hasContextAwareTypeConversions = false;
+ /// We store here the index in the `conversions` vector of the last added
+ /// context-aware conversion, if any. This is useful because we can't cache
+ /// the result of type conversion happening after context-aware conversions,
+ /// because the type converter may return different results for the same input
+ /// type. This is why it is recommened to add context-aware conversions first,
+ /// any context-free conversions after will benefit from caching.
+ int contextAwareTypeConversionsIndex = -1;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 9424eff..131c49c 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -22,6 +22,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
+ hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG() << "[init] Processing symbol table op: "
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
return failure();
}
// Recurse on nested operations.
- for (Region &region : op->getRegions()) {
- LDBG() << "[init] Recursing into region of op: "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- for (Operation &nestedOp : region.getOps()) {
- LDBG() << "[init] Recursing into nested op: "
- << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
- if (failed(initializeRecursively(&nestedOp)))
- return failure();
+ if (op->getNumRegions()) {
+ // If we haven't seen a symbol table yet, check if the current operation
+ // has one. If so, update the flag to allow for resolving callables in
+ // nested regions.
+ bool savedHasSymbolTable = hasSymbolTable;
+ auto restoreHasSymbolTable =
+ llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
+ if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
+ hasSymbolTable = true;
+
+ for (Region &region : op->getRegions()) {
+ LDBG() << "[init] Recursing into region of op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ for (Operation &nestedOp : region.getOps()) {
+ LDBG() << "[init] Recursing into nested op: "
+ << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
+ if (failed(initializeRecursively(&nestedOp)))
+ return failure();
+ }
}
}
LDBG() << "[init] Finished initializeRecursively for op: "
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG() << "visitCallOperation: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
- Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+
+ Operation *callableOp = nullptr;
+ if (hasSymbolTable)
+ callableOp = call.resolveCallableInTable(&symbolTable);
+ else
+ LDBG()
+ << "No symbol table present in analysis scope, can't resolve callable";
// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index d05374f..b51465b 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
AbstractDenseLattice *after) {
// Allow for customizing the behavior of calls to external symbols, including
// when the analysis is explicitly marked as non-interprocedural.
- auto callable =
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
- if (!getSolverConfig().isInterprocedural() ||
- (callable && !callable.getCallableRegion())) {
+ auto isExternalCallable = [&]() {
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ return callable && !callable.getCallableRegion();
+ };
+ if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, before, after);
}
@@ -290,6 +292,12 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
+ // If the solver is not interprocedural, let the hook handle it as an external
+ // callee.
+ if (!getSolverConfig().isInterprocedural())
+ return visitCallControlFlowTransfer(
+ call, CallControlFlowAction::ExternalCallee, after, before);
+
// Find the callee.
Operation *callee = call.resolveCallableInTable(&symbolTable);
@@ -297,12 +305,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
// No region means the callee is only declared in this module.
// If that is the case or if the solver is not interprocedural,
// let the hook handle it.
- if (!getSolverConfig().isInterprocedural() ||
- (callable && (!callable.getCallableRegion() ||
- callable.getCallableRegion()->empty()))) {
+ if (callable &&
+ (!callable.getCallableRegion() || callable.getCallableRegion()->empty()))
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
- }
if (!callable)
return setToExitState(before);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 13a3e14..0d2e2ed 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
ArrayRef<AbstractSparseLattice *> resultLattices) {
// If the call operation is to an external function, attempt to infer the
// results from the call arguments.
- auto callable =
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
- if (!getSolverConfig().isInterprocedural() ||
- (callable && !callable.getCallableRegion())) {
+ auto isExternalCallable = [&]() {
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ return callable && !callable.getCallableRegion();
+ };
+ if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
visitExternalCallImpl(call, operandLattices, resultLattices);
return success();
}
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 7e1b405..9352ab0 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -9,6 +9,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/iterator.h"
@@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
isRunning = true;
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
+ bool isInterprocedural = config.isInterprocedural();
+ auto restoreInterprocedural = llvm::make_scope_exit(
+ [&]() { config.setInterprocedural(isInterprocedural); });
+ if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
+ config.setInterprocedural(false);
+
// Initialize equivalent lattice anchors.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
analysis.initializeEquivalentLatticeAnchor(top);
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index d7282b3..a14f09f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -52,9 +52,14 @@ NB_MODULE(_mlir, m) {
[](PyGlobals &self, bool enabled) {
self.getTracebackLoc().setLocTracebacksEnabled(enabled);
})
+ .def("loc_tracebacks_frame_limit",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebackFramesLimit();
+ })
.def("set_loc_tracebacks_frame_limit",
- [](PyGlobals &self, int n) {
- self.getTracebackLoc().setLocTracebackFramesLimit(n);
+ [](PyGlobals &self, std::optional<int> n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(
+ n.value_or(PyGlobals::TracebackLoc::kMaxFrames));
})
.def("register_traceback_file_inclusion",
[](PyGlobals &self, const std::string &filename) {
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 6ee85e8..47ef5d8 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -57,6 +57,13 @@ private:
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
+ // Mapping of MlirExternalPass
+ //----------------------------------------------------------------------------
+ nb::class_<MlirExternalPass>(m, "ExternalPass")
+ .def("signal_pass_failure",
+ [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
+
+ //----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
nb::class_<PyPassManager>(m, "PassManager")
@@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
- callbacks.run = [](MlirOperation op, MlirExternalPass,
+ callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
- nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
+ nb::handle(static_cast<PyObject *>(userData))(op, pass);
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index d29053a..1659437 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -22,8 +22,6 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
-#include "llvm/Support/Format.h"
-#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
@@ -296,12 +294,38 @@ public:
if (failed(parseVarInt(alignment)))
return failure();
- // Check that the requested alignment is less than or equal to the
- // alignment of the root buffer. If it is not, we cannot safely guarantee
- // that the specified alignment is globally correct.
+ // Check that the requested alignment must not exceed the alignment of
+ // the root buffer itself. Otherwise we cannot guarantee that pointers
+ // derived from this buffer will actually satisfy the requested alignment
+ // globally.
//
- // E.g. if the buffer is 8k aligned and the section is 16k aligned,
- // we could end up at an offset of 24k, which is not globally 16k aligned.
+ // Consider a bytecode buffer that is guaranteed to be 8k aligned, but not
+ // 16k aligned (e.g. absolute address 40960. If a section inside this
+ // buffer declares a 16k alignment requirement, two problems can arise:
+ //
+ // (a) If we "align forward" the current pointer to the next
+ // 16k boundary, the amount of padding we skip depends on the
+ // buffer's starting address. For example:
+ //
+ // buffer_start = 40960
+ // next 16k boundary = 49152
+ // bytes skipped = 49152 - 40960 = 8192
+ //
+ // This leaves behind variable padding that could be misinterpreted
+ // as part of the next section.
+ //
+ // (b) If we align relative to the buffer start, we may
+ // obtain addresses that are multiples of "buffer_start +
+ // section_alignment" rather than truly globally aligned
+ // addresses. For example:
+ //
+ // buffer_start = 40960 (5×8k, 8k aligned but not 16k)
+ // offset = 16384 (first multiple of 16k)
+ // section_ptr = 40960 + 16384 = 57344
+ //
+ // 57344 is 8k aligned but not 16k aligned.
+ // Any consumer expecting true 16k alignment would see this as a
+ // violation.
if (failed(alignmentValidator(alignment)))
return emitError("failed to align section ID: ", unsigned(sectionID));
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 807d1f5..bbfa3d1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -36,7 +36,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -287,19 +286,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// code.
struct LowerGpuOpsToROCDLOpsPass final
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
- LowerGpuOpsToROCDLOpsPass() = default;
- LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
- bool useBarePtrCallConv,
- gpu::amd::Runtime runtime) {
- if (this->chipset.getNumOccurrences() == 0)
- this->chipset = chipset;
- if (this->indexBitwidth.getNumOccurrences() == 0)
- this->indexBitwidth = indexBitwidth;
- if (this->useBarePtrCallConv.getNumOccurrences() == 0)
- this->useBarePtrCallConv = useBarePtrCallConv;
- if (this->runtime.getNumOccurrences() == 0)
- this->runtime = runtime;
- }
+ using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
Base::getDependentDialects(registry);
@@ -499,12 +486,3 @@ void mlir::populateGpuToROCDLConversionPatterns(
populateMathToROCDLConversionPatterns(converter, patterns);
}
-
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
- unsigned indexBitwidth,
- bool useBarePtrCallConv,
- gpu::amd::Runtime runtime) {
- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
- chipset, indexBitwidth, useBarePtrCallConv, runtime);
-}
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index e5496e5..aa47e39 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -405,7 +405,8 @@ std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
return std::make_unique<OMPIImplTraits>(moduleOp);
if (!strAttr || strAttr.getValue() != "MPICH")
moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
- << strAttr.getValue() << "), defaulting to MPICH";
+ << (strAttr ? strAttr.getValue() : "<NULL>")
+ << "), defaulting to MPICH";
return std::make_unique<MPICHImplTraits>(moduleOp);
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index c4a9fc2..460595b 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -492,8 +492,10 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Create loop nest and populate region with contents of scf.parallel.
auto loopOp = omp::LoopNestOp::create(
- rewriter, parallelOp.getLoc(), parallelOp.getLowerBound(),
- parallelOp.getUpperBound(), parallelOp.getStep());
+ rewriter, parallelOp.getLoc(), parallelOp.getLowerBound().size(),
+ parallelOp.getLowerBound(), parallelOp.getUpperBound(),
+ parallelOp.getStep(), /*loop_inclusive=*/false,
+ /*tile_sizes=*/nullptr);
rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
loopOp.getRegion().begin());
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6..0b44ca7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
populateVectorFromElementsLoweringPatterns(patterns);
+ populateVectorToElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a7f2dc2..9ead1d8 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
matchAndRewrite(xegpu::CreateNdDescOp op,
xegpu::CreateNdDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ if (mixedOffsets.size() != 0)
+ return rewriter.notifyMatchFailure(op, "Offsets not supported.");
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets can be either 2D or not provided (0 is used).
- if (mixedOffsets.size() == 2) {
- offsetW = createOffset(mixedOffsets, 1);
- offsetH = createOffset(mixedOffsets, 0);
- } else if (mixedOffsets.size() == 0) {
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- } else {
- return rewriter.notifyMatchFailure(op,
- "Expected 2D offsets or no offsets.");
- }
+ // Offsets are not supported (0 is used).
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
}
};
-class UpdateNdOffsetToXeVMPattern
- : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::UpdateNdOffsetOp op,
- xegpu::UpdateNdOffsetOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto mixedOffsets = op.getMixedOffsets();
- // Only 2D offsets are supported for now.
- if (mixedOffsets.size() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
- auto payload = adaptor.getTensorDesc();
- // Utility for updating payload offset values from op fold result.
- auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
- Value offset =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
- offset = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offset);
- Value oldOffset =
- vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
- Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- return vector::InsertOp::create(rewriter, loc, newOffset, payload,
- payloadPos);
- };
- // Update offsets in the payload.
- payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
- payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
- rewriter.replaceOp(op, payload);
- return success();
- }
-};
-
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto mixedOffsets = op.getMixedOffsets();
+ int64_t opOffsetsSize = mixedOffsets.size();
+ if (opOffsetsSize != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
@@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets provided in two ways:
- // 1. Offsets are extracted from the tensor descriptor.
- // 2. (Mixed) offsets which are provided by the op.
- Value offsetW;
- Value offsetH;
- auto mixedOffsets = op.getMixedOffsets();
- int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 0 && opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op,
- "Expected 2D offsets or no offsets.");
- if (opOffsetsSize) {
- // If mixed offsets are provided by the op convert them to i32.
- offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
- } else {
- // If offsets are not available, we need to extract them from the tensor
- // descriptor.
- offsetW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
- offsetH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
- }
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
return newAddr;
}
-class CreateDescToXeVMPattern
- : public OpConversionPattern<xegpu::CreateDescOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto eTy = op.getTensorDescType().getElementType();
- auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0)
- return rewriter.notifyMatchFailure(
- op, "Expected element type bit width to be multiple of 8.");
- auto loc = op.getLoc();
- // Offsets are provided as scalar i64 by type converter.
- auto offsets = adaptor.getOffsets();
- // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
- // But type converter will convert them to integer types.
- Value addr = adaptor.getSource();
- // ui32 or i32 are passed as i32 so they need to be casted to i64.
- if (addr.getType() != rewriter.getI64Type())
- addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
- auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
- rewriter.replaceOp(op, laneAddr);
- return success();
- }
-};
-
-class UpdateOffsetToXeVMPattern
- : public OpConversionPattern<xegpu::UpdateOffsetOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::UpdateOffsetOp op,
- xegpu::UpdateOffsetOp::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto eTy = op.getTensorDescType().getElementType();
- auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0)
- return rewriter.notifyMatchFailure(
- op, "Expected element type bit width to be multiple of 8.");
- auto loc = op.getLoc();
- // Scatter descriptor is provided as scalar i64 by type converter.
- // Offsets are provided as scalar i64 by type converter.
- Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
- adaptor.getOffsets(), eBw / 8);
- rewriter.replaceOp(op, newOffset);
- return success();
- }
-};
-
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Value offset = adaptor.getOffsets();
+ if (!offset)
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
@@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
}
- Value offsets = adaptor.getOffsets();
Value mask = adaptor.getMask();
- if (offsets) {
- if (dyn_cast<VectorType>(offsets.getType())) {
- // Offset needs be scalar. Single element vector is converted to scalar
- // by type converter.
- return rewriter.notifyMatchFailure(op,
- "Expected offsets to be a scalar.");
- } else {
- // If offsets are provided, we add them to the base pointer.
- // Offsets are in number of elements, we need to multiply by
- // element byte size.
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
- }
+ if (dyn_cast<VectorType>(offset.getType())) {
+ // Offset needs be scalar. Single element vector is converted to scalar
+ // by type converter.
+ return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
+ } else {
+ // If offset is provided, we add them to the base pointer.
+ // Offset is in number of elements, we need to multiply by
+ // element byte size.
+ basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
//===----------------------------------------------------------------------===//
void mlir::populateXeGPUToXeVMConversionPatterns(
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+ patterns.add<CreateNdDescToXeVMPattern,
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
typeConverter, patterns.getContext());
- patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
- AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+ patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 4dfcb2b..0f90acf 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
return os.str();
}
-template <bool isLoad, typename OpType>
-int32_t getL1CacheControl(OpType op) {
+static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
- if constexpr (isLoad) {
- switch (*op.getCacheControl()) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2UC_L3C:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1UC_L2C_L3C:
- control = 1;
- break;
- case LoadCacheControl::L1C_L2UC_L3UC:
- case LoadCacheControl::L1C_L2UC_L3C:
- case LoadCacheControl::L1C_L2C_L3UC:
- case LoadCacheControl::L1C_L2C_L3C:
- control = 2;
- break;
- case LoadCacheControl::L1S_L2UC_L3UC:
- case LoadCacheControl::L1S_L2UC_L3C:
- case LoadCacheControl::L1S_L2C_L3UC:
- case LoadCacheControl::L1S_L2C_L3C:
- control = 3;
- break;
- case LoadCacheControl::INVALIDATE_READ:
- control = 4;
- break;
- }
- } else {
- switch (*op.getCacheControl()) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2UC_L3WB:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3WB:
- control = 1;
- break;
- case StoreCacheControl::L1WT_L2UC_L3UC:
- case StoreCacheControl::L1WT_L2UC_L3WB:
- case StoreCacheControl::L1WT_L2WB_L3UC:
- case StoreCacheControl::L1WT_L2WB_L3WB:
- control = 2;
- break;
- case StoreCacheControl::L1S_L2UC_L3UC:
- case StoreCacheControl::L1S_L2UC_L3WB:
- case StoreCacheControl::L1S_L2WB_L3UC:
- case StoreCacheControl::L1S_L2WB_L3WB:
- control = 3;
- break;
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3WB:
- control = 4;
- break;
- }
+ switch (cc) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ control = 1;
+ break;
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1C_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 3;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
}
return control;
}
-template <bool isLoad, typename OpType>
-int32_t getL3CacheControl(OpType op) {
+static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
- if constexpr (isLoad) {
- switch (*op.getCacheControl()) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1C_L2UC_L3UC:
- case LoadCacheControl::L1C_L2C_L3UC:
- case LoadCacheControl::L1S_L2UC_L3UC:
- case LoadCacheControl::L1S_L2C_L3UC:
- control = 1;
- break;
- case LoadCacheControl::L1UC_L2UC_L3C:
- case LoadCacheControl::L1UC_L2C_L3C:
- case LoadCacheControl::L1C_L2UC_L3C:
- case LoadCacheControl::L1C_L2C_L3C:
- case LoadCacheControl::L1S_L2UC_L3C:
- case LoadCacheControl::L1S_L2C_L3C:
- control = 2;
- break;
- case LoadCacheControl::INVALIDATE_READ:
- control = 4;
- break;
- }
- } else {
- switch (*op.getCacheControl()) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1WT_L2UC_L3UC:
- case StoreCacheControl::L1WT_L2WB_L3UC:
- case StoreCacheControl::L1S_L2UC_L3UC:
- case StoreCacheControl::L1S_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- control = 1;
- break;
- case StoreCacheControl::L1UC_L2UC_L3WB:
- case StoreCacheControl::L1UC_L2WB_L3WB:
- case StoreCacheControl::L1WT_L2UC_L3WB:
- case StoreCacheControl::L1WT_L2WB_L3WB:
- case StoreCacheControl::L1S_L2UC_L3WB:
- case StoreCacheControl::L1S_L2WB_L3WB:
- case StoreCacheControl::L1WB_L2UC_L3WB:
- control = 2;
- break;
- }
+ switch (cc) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 2;
+ break;
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ control = 3;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 4;
+ break;
}
return control;
}
-template <bool isLoad, typename OpType>
+static int32_t getL3CacheControl(LoadCacheControl cc) {
+ int32_t control = 0;
+ switch (cc) {
+ case LoadCacheControl::L1UC_L2UC_L3UC:
+ case LoadCacheControl::L1UC_L2C_L3UC:
+ case LoadCacheControl::L1C_L2UC_L3UC:
+ case LoadCacheControl::L1C_L2C_L3UC:
+ case LoadCacheControl::L1S_L2UC_L3UC:
+ case LoadCacheControl::L1S_L2C_L3UC:
+ control = 1;
+ break;
+ case LoadCacheControl::L1UC_L2UC_L3C:
+ case LoadCacheControl::L1UC_L2C_L3C:
+ case LoadCacheControl::L1C_L2UC_L3C:
+ case LoadCacheControl::L1C_L2C_L3C:
+ case LoadCacheControl::L1S_L2UC_L3C:
+ case LoadCacheControl::L1S_L2C_L3C:
+ control = 2;
+ break;
+ case LoadCacheControl::INVALIDATE_READ:
+ control = 4;
+ break;
+ }
+ return control;
+}
+
+static int32_t getL3CacheControl(StoreCacheControl cc) {
+ int32_t control = 0;
+ switch (cc) {
+ case StoreCacheControl::L1UC_L2UC_L3UC:
+ case StoreCacheControl::L1UC_L2WB_L3UC:
+ case StoreCacheControl::L1WT_L2UC_L3UC:
+ case StoreCacheControl::L1WT_L2WB_L3UC:
+ case StoreCacheControl::L1S_L2UC_L3UC:
+ case StoreCacheControl::L1S_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ control = 1;
+ break;
+ case StoreCacheControl::L1UC_L2UC_L3WB:
+ case StoreCacheControl::L1UC_L2WB_L3WB:
+ case StoreCacheControl::L1WT_L2UC_L3WB:
+ case StoreCacheControl::L1WT_L2WB_L3WB:
+ case StoreCacheControl::L1S_L2UC_L3WB:
+ case StoreCacheControl::L1S_L2WB_L3WB:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
+ control = 2;
+ break;
+ }
+ return control;
+}
+
+static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
+ return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
+ return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
+ return op.getCacheControl();
+}
+
+static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
+ return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
+ if (op->hasAttr("cache_control")) {
+ auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
+ if (!attr)
+ return std::nullopt;
+ return std::optional<LoadCacheControl>(attr.getValue());
+ }
+ return std::nullopt;
+}
+
+static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
+ if (op->hasAttr("cache_control")) {
+ auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
+ if (!attr)
+ return std::nullopt;
+ return std::optional<StoreCacheControl>(attr.getValue());
+ }
+ return std::nullopt;
+}
+
+template <typename OpType>
+int32_t getL1CacheControl(OpType op) {
+ return getL1CacheControl(*getCacheControl(op));
+}
+
+template <typename OpType>
+int32_t getL3CacheControl(OpType op) {
+ return getL3CacheControl(*getCacheControl(op));
+}
+
+template <typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
- if (!op.getCacheControl())
+ if (!getCacheControl(op))
return {};
constexpr int32_t decorationCacheControlArity{4};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
+ constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
+ std::is_same_v<OpType, BlockPrefetch2dOp> ||
+ std::is_same_v<OpType, LLVM::LoadOp> ||
+ std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
- controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+ controlKey, 0, getL1CacheControl<OpType>(op), 0};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
- controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+ controlKey, 1, getL3CacheControl<OpType>(op), 0};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
@@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata<true>(rewriter, op))
+ getCacheControlMetadata(rewriter, op))
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
rewriter.eraseOp(op);
return success();
@@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
+ getCacheControlMetadata(rewriter, op)) {
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
}
if constexpr (isLoad)
@@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};
+template <typename OpType>
+class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!op->hasAttr("cache_control"))
+ return failure();
+ std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op);
+ op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ op->removeAttr("cache_control");
+ return success();
+ }
+};
//===----------------------------------------------------------------------===//
// Pass Definition
@@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass
void runOnOperation() override {
ConversionTarget target(getContext());
- target.addLegalDialect<LLVM::LLVMDialect>();
- target.addIllegalDialect<XeVMDialect>();
RewritePatternSet patterns(&getContext());
- populateXeVMToLLVMConversionPatterns(patterns);
+ populateXeVMToLLVMConversionPatterns(target, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
- populateXeVMToLLVMConversionPatterns(patterns);
+ populateXeVMToLLVMConversionPatterns(target, patterns);
}
};
} // namespace
@@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
// Pattern Population
//===----------------------------------------------------------------------===//
-void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
+void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
+ RewritePatternSet &patterns) {
+ target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
+ [](Operation *op) { return !op->hasAttr("cache_control"); });
+ target.addIllegalDialect<XeVMDialect>();
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
- MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
- patterns.getContext());
+ MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
+ LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
}
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 99ea20b..f38493b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/SetVector.h"
@@ -241,7 +242,98 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
return &node;
}
-bool MemRefDependenceGraph::init() {
+/// Returns the memref being read/written by a memref/affine load/store op.
+static Value getMemRef(Operation *memOp) {
+ if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
+ return memrefLoad.getMemRef();
+ if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
+ return affineLoad.getMemRef();
+ if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
+ return memrefStore.getMemRef();
+ if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
+ return affineStore.getMemRef();
+ llvm_unreachable("unexpected op");
+}
+
+/// Returns true if there may be a dependence on `memref` from srcNode's
+/// memory ops to dstNode's memory ops, while using the affine memory
+/// dependence analysis checks. The method assumes that there is at least one
+/// memory op in srcNode's loads and stores on `memref`, and similarly for
+/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
+/// same block and so the dependences are tested at the depth of that block.
+static bool mayDependence(const Node &srcNode, const Node &dstNode,
+ Value memref) {
+ assert(srcNode.op->getBlock() == dstNode.op->getBlock());
+ if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
+ return true;
+
+ // Conservatively handle dependences involving non-affine load/stores. Return
+ // true if there exists a conflicting read/write access involving such.
+
+ // Check whether there is a dependence from a source read/write op to a
+ // destination read/write one; all expected to be memref/affine load/store.
+ auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps,
+ ArrayRef<Operation *> dstMemOps) {
+ return llvm::any_of(srcMemOps, [&](Operation *srcOp) {
+ Value srcMemref = getMemRef(srcOp);
+ if (srcMemref != memref)
+ return false;
+ return llvm::find_if(dstMemOps, [&](Operation *dstOp) {
+ return srcMemref == getMemRef(dstOp);
+ }) != dstMemOps.end();
+ });
+ };
+
+ SmallVector<Operation *> dstOps;
+ // Between non-affine src stores and dst load/store.
+ llvm::append_range(dstOps, llvm::concat<Operation *const>(
+ dstNode.loads, dstNode.stores,
+ dstNode.memrefLoads, dstNode.memrefStores));
+ if (hasNonAffineDep(srcNode.memrefStores, dstOps))
+ return true;
+ // Between non-affine loads and dst stores.
+ dstOps.clear();
+ llvm::append_range(dstOps, llvm::concat<Operation *const>(
+ dstNode.stores, dstNode.memrefStores));
+ if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
+ return true;
+ // Between affine stores and memref load/stores.
+ dstOps.clear();
+ llvm::append_range(dstOps, llvm::concat<Operation *const>(
+ dstNode.memrefLoads, dstNode.memrefStores));
+ if (hasNonAffineDep(srcNode.stores, dstOps))
+ return true;
+ // Between affine loads and memref stores.
+ dstOps.clear();
+ llvm::append_range(dstOps, dstNode.memrefStores);
+ if (hasNonAffineDep(srcNode.loads, dstOps))
+ return true;
+
+ // Affine load/store pairs. We don't need to check for locally allocated
+ // memrefs since the dependence analysis here is between mem ops from
+ // srcNode's for op to dstNode's for op at the depth at which those
+ // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
+ // `d` is the number of common surrounding loops.
+ for (auto *srcMemOp :
+ llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
+ MemRefAccess srcAcc(srcMemOp);
+ if (srcAcc.memref != memref)
+ continue;
+ for (auto *destMemOp :
+ llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
+ MemRefAccess destAcc(destMemOp);
+ if (destAcc.memref != memref)
+ continue;
+ // Check for a top-level dependence between srcNode and destNode's ops.
+ if (!noDependence(checkMemrefAccessDependence(
+ srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
+ return true;
+ }
+ }
+ return false;
+}
+
+bool MemRefDependenceGraph::init(bool fullAffineDependences) {
LDBG() << "--- Initializing MDG ---";
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
@@ -344,8 +436,12 @@ bool MemRefDependenceGraph::init() {
Node *dstNode = getNode(dstId);
bool dstHasStoreOrFree =
dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
- if (srcHasStoreOrFree || dstHasStoreOrFree)
- addEdge(srcId, dstId, srcMemRef);
+ if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
+ // Check precise affine deps if asked for; otherwise, conservative.
+ if (!fullAffineDependences ||
+ mayDependence(*srcNode, *dstNode, srcMemRef))
+ addEdge(srcId, dstId, srcMemRef);
+ }
}
}
}
@@ -562,13 +658,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
}
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
- SmallPtrSet<Operation *, 2> srcDepInsts;
+ llvm::SmallPtrSet<Operation *, 2> srcDepInsts;
for (auto &outEdge : outEdges.lookup(srcId))
if (outEdge.id != dstId)
srcDepInsts.insert(getNode(outEdge.id)->op);
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
- SmallPtrSet<Operation *, 2> dstDepInsts;
+ llvm::SmallPtrSet<Operation *, 2> dstDepInsts;
for (auto &inEdge : inEdges.lookup(dstId))
if (inEdge.id != srcId)
dstDepInsts.insert(getNode(inEdge.id)->op);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f44..f3db8f7c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
}
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp target,
+ transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
+
+ auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
+ if (!partialReductionOp) {
+ return emitSilenceableFailure(
+ target->getLoc(),
+ "Operation should implement PartialReductionOpInterface");
+ }
SmallVector<OpFoldResult> numThreads =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
- llvm::enumerate(target.getIteratorTypesArray())) {
+ llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}
options.setReductionDims(reductionDims);
- FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
- rewriter, cast<TilingInterface>(target.getOperation()), options);
+ FailureOr<scf::SCFTilingResult> result =
+ scf::tileUsingSCF(rewriter, partialReductionOp, options);
if (failed(result)) {
auto diag = emitSilenceableError() << "could not tile reduction";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
index b7da20c..9015cbb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp
@@ -208,7 +208,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
// Does it require broadcast?
if (!broadcastedDims.empty()) {
- assert(broadcastedDims.size() && "should have non size broadcast");
+ assert(!broadcastedDims.empty() && "should have non size broadcast");
Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape,
inputRTType.getElementType());
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6e43f28..3d70e28 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -56,6 +56,11 @@ makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
}
+static DenseI64ArrayAttr
+makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef<int64_t> intArray) {
+ return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
+}
+
namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -2956,10 +2961,10 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto &iv : ivs)
iv.type = loopVarType;
+ auto *ctx = parser.getBuilder().getContext();
// Parse "inclusive" flag.
if (succeeded(parser.parseOptionalKeyword("inclusive")))
- result.addAttribute("loop_inclusive",
- UnitAttr::get(parser.getBuilder().getContext()));
+ result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
// Parse step values.
SmallVector<OpAsmParser::UnresolvedOperand> steps;
@@ -2967,6 +2972,35 @@ ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
return failure();
+ // Parse collapse
+ int64_t value = 0;
+ if (!parser.parseOptionalKeyword("collapse") &&
+ (parser.parseLParen() || parser.parseInteger(value) ||
+ parser.parseRParen()))
+ return failure();
+ if (value > 1)
+ result.addAttribute(
+ "collapse_num_loops",
+ IntegerAttr::get(parser.getBuilder().getI64Type(), value));
+
+ // Parse tiles
+ SmallVector<int64_t> tiles;
+ auto parseTiles = [&]() -> ParseResult {
+ int64_t tile;
+ if (parser.parseInteger(tile))
+ return failure();
+ tiles.push_back(tile);
+ return success();
+ };
+
+ if (!parser.parseOptionalKeyword("tiles") &&
+ (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
+ parser.parseRParen()))
+ return failure();
+
+ if (tiles.size() > 0)
+ result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
+
// Parse the body.
Region *region = result.addRegion();
if (parser.parseRegion(*region, ivs))
@@ -2990,14 +3024,23 @@ void LoopNestOp::print(OpAsmPrinter &p) {
if (getLoopInclusive())
p << "inclusive ";
p << "step (" << getLoopSteps() << ") ";
+ if (int64_t numCollapse = getCollapseNumLoops())
+ if (numCollapse > 1)
+ p << "collapse(" << numCollapse << ") ";
+
+ if (const auto tiles = getTileSizes())
+ p << "tiles(" << tiles.value() << ") ";
+
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
void LoopNestOp::build(OpBuilder &builder, OperationState &state,
const LoopNestOperands &clauses) {
- LoopNestOp::build(builder, state, clauses.loopLowerBounds,
- clauses.loopUpperBounds, clauses.loopSteps,
- clauses.loopInclusive);
+ MLIRContext *ctx = builder.getContext();
+ LoopNestOp::build(builder, state, clauses.collapseNumLoops,
+ clauses.loopLowerBounds, clauses.loopUpperBounds,
+ clauses.loopSteps, clauses.loopInclusive,
+ makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
}
LogicalResult LoopNestOp::verify() {
@@ -3013,6 +3056,17 @@ LogicalResult LoopNestOp::verify() {
<< "range argument type does not match corresponding IV type";
}
+ uint64_t numIVs = getIVs().size();
+
+ if (const auto &numCollapse = getCollapseNumLoops())
+ if (numCollapse > numIVs)
+ return emitOpError()
+ << "collapse value is larger than the number of loops";
+
+ if (const auto &tiles = getTileSizes())
+ if (tiles.value().size() > numIVs)
+ return emitOpError() << "too few canonical loops for tile dimensions";
+
if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
return emitOpError() << "expects parent op to be a loop wrapper";
@@ -3142,7 +3196,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
SmallString<64> Name("canonloop");
- for (std::string s : reverse(components)) {
+ for (const std::string &s : reverse(components)) {
Name += '_';
Name += s;
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
index 1c3971d2..1940e6d 100644
--- a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
@@ -88,10 +88,10 @@ class NormalizedQuantTypesConverter : public TypeConverter {
llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
auto scales = llvm::to_vector(llvm::map_range(
subChannelType.getScales().getValues<APFloat>(),
- [](APFloat scale) { return scale.convertToDouble(); }));
+ [](const APFloat &scale) { return scale.convertToDouble(); }));
auto zeroPoints = llvm::to_vector(llvm::map_range(
subChannelType.getZeroPoints().getValues<APInt>(),
- [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
+ [](const APInt &zeroPoint) { return zeroPoint.getSExtValue(); }));
auto perAxisType = UniformQuantizedPerAxisType::get(
subChannelType.getFlags(), subChannelType.getStorageType(),
subChannelType.getExpressedType(), scales, zeroPoints,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index f993398..5511998 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -723,7 +723,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
- if (intTy && intTy.getWidth() == 1) {
+ assert(intTy);
+
+ if (intTy.getWidth() == 1) {
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 87bed81..b4c87a3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -512,20 +512,20 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
// Tosa utilities.
//===----------------------------------------------------------------------===//
-std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
+static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
if (lhs % rhs != 0)
return std::nullopt;
return lhs / rhs;
}
-Type getStorageElementTypeOrSelf(Type type) {
+static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
srcType = quantType.getStorageType();
return srcType;
}
-Type getStorageElementTypeOrSelf(Value value) {
+static Type getStorageElementTypeOrSelf(Value value) {
return getStorageElementTypeOrSelf(value.getType());
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 706076c..aba6178 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1639,8 +1639,8 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
return a.size() < b.size();
})->size();
- for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
- payloads[argIdx].resize(numIterations);
+ for (auto &payload : payloads)
+ payload.resize(numIterations);
}
// As we will be "zipping" over them, check all payloads have the same size.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index fe066dc..6bb390a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
vector::populateVectorFromElementsLoweringPatterns(patterns);
}
+void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorToElementsLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b7..d74007f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToElements.cpp
LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000..a53a183
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,53 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ToElementsOp op,
+ PatternRewriter &rewriter) const override {
+
+ TypedValue<VectorType> source = op.getSource();
+ FailureOr<SmallVector<Value>> result =
+ vector::unrollVectorValue(source, rewriter);
+ if (failed(result)) {
+ return failure();
+ }
+ SmallVector<Value> vectors = *result;
+
+ SmallVector<Value> results;
+ for (const Value &vector : vectors) {
+ auto subElements =
+ vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
+ llvm::append_range(results, subElements.getResults());
+ }
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c84eb2c..995a259 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
return targetType;
}
+/// Given a warpOp that contains ops with regions, the corresponding op's
+/// "inner" region and the distributionMapFn, get all values used by the op's
+/// region that are defined within the warpOp, but outside the inner region.
+/// Return the set of values, their types and their distributed types.
+std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
+ SmallVector<Type>>
+getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
+ DistributionMapFn distributionMapFn) {
+ llvm::SmallSetVector<Value, 32> escapingValues;
+ SmallVector<Type> escapingValueTypes;
+ SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
+ if (innerRegion.empty())
+ return {std::move(escapingValues), std::move(escapingValueTypes),
+ std::move(escapingValueDistTypes)};
+ mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
+ Operation *parent = operand->get().getParentRegion()->getParentOp();
+ if (warpOp->isAncestor(parent)) {
+ if (!escapingValues.insert(operand->get()))
+ return;
+ Type distType = operand->get().getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(operand->get());
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ escapingValueTypes.push_back(operand->get().getType());
+ escapingValueDistTypes.push_back(distType);
+ }
+ });
+ return {std::move(escapingValues), std::move(escapingValueTypes),
+ std::move(escapingValueDistTypes)};
+}
+
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
/// will not be distributed (it should be less than the warp size).
@@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern {
}
};
+/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
+/// the scf.if is the last operation in the region so that it doesn't
+/// change the order of execution. This creates a new scf.if after the
+/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
+/// the "inner" WarpExecuteOnLane0Op. Example:
+/// ```
+/// gpu.warp_execute_on_lane_0(%laneid)[32] {
+/// %payload = ... : vector<32xindex>
+/// scf.if %pred {
+/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
+/// vector<32xindex>
+/// }
+/// gpu.yield
+/// }
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
+/// %payload = ... : vector<32xindex>
+/// gpu.yield %payload : vector<32xindex>
+/// }
+/// scf.if %pred {
+/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
+/// ^bb0(%arg1: vector<32xindex>):
+/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
+/// }
+/// }
+/// ```
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+ WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
+ // Only pick up `IfOp` if it is the last op in the region.
+ Operation *lastNode = warpOpYield->getPrevNode();
+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+ if (!ifOp)
+ return failure();
+
+ // The current `WarpOp` can yield two types of values:
+ // 1. Not results of `IfOp`:
+ // Preserve them in the new `WarpOp`.
+ // Collect their yield index to remap the usages.
+ // 2. Results of `IfOp`:
+ // They are not part of the new `WarpOp` results.
+ // Map current warp's yield operand index to `IfOp` result idx.
+ SmallVector<Value> nonIfYieldValues;
+ SmallVector<unsigned> nonIfYieldIndices;
+ llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+ llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+ if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+ nonIfYieldValues.push_back(yieldOperand.get());
+ nonIfYieldIndices.push_back(yieldOperandIdx);
+ continue;
+ }
+ OpResult ifResult = cast<OpResult>(yieldOperand.get());
+ const unsigned ifResultIdx = ifResult.getResultNumber();
+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
+ // If this `ifOp` result is vector type and it is yielded by the
+ // `WarpOp`, we keep track the distributed type for this result.
+ if (!isa<VectorType>(ifResult.getType()))
+ continue;
+ VectorType distType =
+ cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+ ifResultDistTypes[ifResultIdx] = distType;
+ }
+
+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+ // them
+ auto [escapingValuesThen, escapingValueInputTypesThen,
+ escapingValueDistTypesThen] =
+ getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
+ distributionMapFn);
+ auto [escapingValuesElse, escapingValueInputTypesElse,
+ escapingValueDistTypesElse] =
+ getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
+ distributionMapFn);
+ if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+ llvm::is_contained(escapingValueDistTypesElse, Type{}))
+ return failure();
+
+ // The new `WarpOp` groups yields values in following order:
+ // 1. Branch condition
+ // 2. Escaping values then branch
+ // 3. Escaping values else branch
+ // 4. All non-`ifOp` yielded values.
+ SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
+ newWarpOpYieldValues.append(escapingValuesThen.begin(),
+ escapingValuesThen.end());
+ newWarpOpYieldValues.append(escapingValuesElse.begin(),
+ escapingValuesElse.end());
+ SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
+ newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
+ escapingValueDistTypesThen.end());
+ newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+ escapingValueDistTypesElse.end());
+
+ llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+ for (auto [idx, val] :
+ llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(val);
+ newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+ }
+ // Create the new `WarpOp` with the updated yield values and types.
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // `ifOp` returns the result of the inner warp op.
+ SmallVector<Type> newIfOpDistResTypes;
+ for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+ Type distType = cast<Value>(res).getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(cast<Value>(res));
+ // Fallback to affine map if the dist result was not previously recorded
+ distType = ifResultDistTypes.count(i)
+ ? ifResultDistTypes[i]
+ : getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ newIfOpDistResTypes.push_back(distType);
+ }
+ // Create a new `IfOp` outside the new `WarpOp` region.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newIfOp = scf::IfOp::create(
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
+ static_cast<bool>(ifOp.thenBlock()),
+ static_cast<bool>(ifOp.elseBlock()));
+ auto encloseRegionInWarpOp =
+ [&](Block *oldIfBranch, Block *newIfBranch,
+ llvm::SmallSetVector<Value, 32> &escapingValues,
+ SmallVector<Type> &escapingValueInputTypes,
+ size_t warpResRangeStart) {
+ OpBuilder::InsertionGuard g(rewriter);
+ if (!newIfBranch)
+ return;
+ rewriter.setInsertionPointToStart(newIfBranch);
+ llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
+ SmallVector<Value> innerWarpInputVals;
+ SmallVector<Type> innerWarpInputTypes;
+ for (size_t i = 0; i < escapingValues.size();
+ ++i, ++warpResRangeStart) {
+ innerWarpInputVals.push_back(
+ newWarpOp.getResult(warpResRangeStart));
+ escapeValToBlockArgIndex[escapingValues[i]] =
+ innerWarpInputTypes.size();
+ innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
+ }
+ auto innerWarp = WarpExecuteOnLane0Op::create(
+ rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
+ newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
+ innerWarpInputVals, innerWarpInputTypes);
+
+ innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
+ innerWarp.getWarpRegion().addArguments(
+ innerWarpInputTypes,
+ SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
+
+ SmallVector<Value> yieldOperands;
+ for (Value operand : oldIfBranch->getTerminator()->getOperands())
+ yieldOperands.push_back(operand);
+ rewriter.eraseOp(oldIfBranch->getTerminator());
+
+ rewriter.setInsertionPointToEnd(innerWarp.getBody());
+ gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
+ rewriter.setInsertionPointAfter(innerWarp);
+ scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
+
+ // Update any users of escaping values that were forwarded to the
+ // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
+ innerWarp.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto it = escapeValToBlockArgIndex.find(operand.get());
+ if (it == escapeValToBlockArgIndex.end())
+ continue;
+ operand.set(innerWarp.getBodyRegion().getArgument(it->second));
+ }
+ });
+ mlir::vector::moveScalarUniformCode(innerWarp);
+ };
+ encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
+ &newIfOp.getThenRegion().front(), escapingValuesThen,
+ escapingValueInputTypesThen, 1);
+ if (!ifOp.getElseRegion().empty())
+ encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
+ &newIfOp.getElseRegion().front(),
+ escapingValuesElse, escapingValueInputTypesElse,
+ 1 + escapingValuesThen.size());
+ // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
+ // result.
+ for (auto [origIdx, newIdx] : ifResultMapping)
+ rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ newIfOp.getResult(newIdx), newIfOp);
+ // Similarly, update any users of the `WarpOp` results that were not
+ // results of the `IfOp`.
+ for (auto [origIdx, newIdx] : origToNewYieldIdx)
+ rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+ newWarpOp.getResult(newIdx));
+ // Remove the original `WarpOp` and `IfOp`, they should not have any uses
+ // at this point.
+ rewriter.eraseOp(ifOp);
+ rewriter.eraseOp(warpOp);
+ return success();
+ }
+
+private:
+ DistributionMapFn distributionMapFn;
+};
+
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
@@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
return failure();
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
// Those Values need to be returned by the new warp op.
- llvm::SmallSetVector<Value, 32> escapingValues;
- SmallVector<Type> escapingValueInputTypes;
- SmallVector<Type> escapingValueDistTypes;
- mlir::visitUsedValuesDefinedAbove(
- forOp.getBodyRegion(), [&](OpOperand *operand) {
- Operation *parent = operand->get().getParentRegion()->getParentOp();
- if (warpOp->isAncestor(parent)) {
- if (!escapingValues.insert(operand->get()))
- return;
- Type distType = operand->get().getType();
- if (auto vecType = dyn_cast<VectorType>(distType)) {
- AffineMap map = distributionMapFn(operand->get());
- distType = getDistributedType(vecType, map, warpOp.getWarpSize());
- }
- escapingValueInputTypes.push_back(operand->get().getType());
- escapingValueDistTypes.push_back(distType);
- }
- });
-
+ auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
+ getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
+ distributionMapFn);
if (llvm::is_contained(escapingValueDistTypes, Type{}))
return failure();
// `WarpOp` can yield two types of values:
@@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
+ patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
+ benefit);
}
void mlir::vector::populateDistributeReduction(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7dde631..12acf4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -798,6 +798,51 @@ struct LinearizeVectorFromElements final
}
};
+/// This pattern linearizes the operand in `vector.to_elements` operations
+/// by converting the source type to a 1-D vector while preserving all element
+/// values. The transformation creates a linearized `vector.shape_cast`
+/// followed by a `vector.to_elements`.
+///
+/// Example:
+///
+/// %0:4 = vector.to_elements %v : vector<2x2xf32>
+///
+/// is converted to:
+///
+/// %vector_cast = vector.shape_cast %v : vector<2x2xf32> to vector<4xf32>
+/// %0:4 = vector.to_elements %vector_cast : vector<4xf32>
+///
+struct LinearizeVectorToElements final
+ : public OpConversionPattern<vector::ToElementsOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorToElements(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType vecType = toElementsOp.getSource().getType();
+ if (vecType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(
+ toElementsOp, "the rank is already less than or equal to 1");
+
+ assert(vecType.getNumScalableDims() == 0 &&
+ "to_elements does not support scalable vectors");
+ auto vec1DType =
+ VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+ Value shapeCast = vector::ShapeCastOp::create(
+ rewriter, toElementsOp.getLoc(), vec1DType, toElementsOp.getSource());
+ auto newToElementsOp =
+ vector::ToElementsOp::create(rewriter, toElementsOp.getLoc(),
+ toElementsOp.getResultTypes(), shapeCast);
+ rewriter.replaceOp(toElementsOp, newToElementsOp);
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -890,8 +935,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
- LinearizeVectorStore, LinearizeVectorFromElements>(
- typeConverter, patterns.getContext());
+ LinearizeVectorStore, LinearizeVectorFromElements,
+ LinearizeVectorToElements>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 841e138..39dc7a4f 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
return success();
}
+/// Takes a 2+ dimensional vector as an input
+/// returns n vector values produced by n vector.extract operations.
+/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
+///
+/// %v : vector<nxaxb...>
+///
+/// will produce the following IR changes
+///
+/// %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
+/// %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
+/// ...
+/// %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
+///
+/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
+FailureOr<SmallVector<Value>>
+vector::unrollVectorValue(TypedValue<VectorType> vector,
+ RewriterBase &rewriter) {
+ SmallVector<Value> subvectors;
+ VectorType ty = cast<VectorType>(vector.getType());
+ Location loc = vector.getLoc();
+ if (ty.getRank() < 2)
+ return rewriter.notifyMatchFailure(loc, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (ty.getScalableDims().front())
+ return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
+
+ for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
+ subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
+ }
+
+ return subvectors;
+}
+
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn) {
assert(op->getNumResults() == 1 && "expected single result");
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index af923d9..fdc1984 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -2,7 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
CallInterfaces.cpp
CastInterfaces.cpp
ControlFlowInterfaces.cpp
- CopyOpInterface.cpp
DataLayoutInterfaces.cpp
DerivedAttributeOpInterface.cpp
DestinationStyleOpInterface.cpp
@@ -43,7 +42,6 @@ endfunction(add_mlir_interface_library)
add_mlir_interface_library(CallInterfaces)
add_mlir_interface_library(CastInterfaces)
add_mlir_interface_library(ControlFlowInterfaces)
-add_mlir_interface_library(CopyOpInterface)
add_mlir_interface_library(DataLayoutInterfaces)
add_mlir_interface_library(DerivedAttributeOpInterface)
add_mlir_interface_library(DestinationStyleOpInterface)
diff --git a/mlir/lib/Interfaces/CopyOpInterface.cpp b/mlir/lib/Interfaces/CopyOpInterface.cpp
deleted file mode 100644
index 8e6132c..0000000
--- a/mlir/lib/Interfaces/CopyOpInterface.cpp
+++ /dev/null
@@ -1,18 +0,0 @@
-//===- CopyOpInterface.cpp - Copy operations interface in MLIR ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/CopyOpInterface.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// CopyOp Interface
-//===----------------------------------------------------------------------===//
-
-/// Include the definitions of the copy operation interface.
-#include "mlir/Interfaces/CopyOpInterface.cpp.inc"
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index caa9091..d2bafb7 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include <utility>
+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -151,7 +153,7 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
[](Value v) { return Variable(v); })) {}
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
- MLIRContext *ctx, StopConditionFn stopCondition,
+ MLIRContext *ctx, const StopConditionFn &stopCondition,
bool addConservativeSemiAffineBounds)
: builder(ctx), stopCondition(stopCondition),
addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {
@@ -302,7 +304,8 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
return pos;
}
-int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
+int64_t ValueBoundsConstraintSet::insert(AffineMap map,
+ const ValueDimList &operands,
bool isSymbol) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(isSymbol);
@@ -629,7 +632,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, const Variable &var,
- StopConditionFn stopCondition, bool closedUB) {
+ const StopConditionFn &stopCondition, bool closedUB) {
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos = 0;
@@ -666,7 +669,7 @@ void ValueBoundsConstraintSet::populateConstraints(Value value,
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
ValueDimList operands) {
- int64_t pos = insert(map, operands, /*isSymbol=*/false);
+ int64_t pos = insert(map, std::move(operands), /*isSymbol=*/false);
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
processWorklist();
@@ -826,10 +829,9 @@ FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
return strongCompare(var1, ComparisonOperator::EQ, var2);
}
-FailureOr<bool>
-ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
- HyperrectangularSlice slice1,
- HyperrectangularSlice slice2) {
+FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
+ MLIRContext *ctx, const HyperrectangularSlice &slice1,
+ const HyperrectangularSlice &slice2) {
assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
"expected slices of same rank");
assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
@@ -891,10 +893,9 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
return true;
}
-FailureOr<bool>
-ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
- HyperrectangularSlice slice1,
- HyperrectangularSlice slice2) {
+FailureOr<bool> ValueBoundsConstraintSet::areEquivalentSlices(
+ MLIRContext *ctx, const HyperrectangularSlice &slice1,
+ const HyperrectangularSlice &slice2) {
assert(slice1.getMixedOffsets().size() == slice2.getMixedOffsets().size() &&
"expected slices of same rank");
assert(slice1.getMixedSizes().size() == slice2.getMixedSizes().size() &&
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index da86b00..926ffd0 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -385,7 +385,8 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
- NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
+ NamedTypeConstraint *operand =
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
@@ -663,15 +664,17 @@ void Operator::populateOpStructure() {
argDef = argDef->getValueAsDef("constraint");
if (argDef->isSubClassOf(typeConstraintClass)) {
- attrOrOperandMapping.push_back(
- {OperandOrAttribute::Kind::Operand, operandIndex});
+ attrPropOrOperandMapping.push_back(
+ {OperandAttrOrProp::Kind::Operand, operandIndex});
arguments.emplace_back(&operands[operandIndex++]);
} else if (argDef->isSubClassOf(attrClass)) {
- attrOrOperandMapping.push_back(
- {OperandOrAttribute::Kind::Attribute, attrIndex});
+ attrPropOrOperandMapping.push_back(
+ {OperandAttrOrProp::Kind::Attribute, attrIndex});
arguments.emplace_back(&attributes[attrIndex++]);
} else {
assert(argDef->isSubClassOf(propertyClass));
+ attrPropOrOperandMapping.push_back(
+ {OperandAttrOrProp::Kind::Property, propIndex});
arguments.emplace_back(&properties[propIndex++]);
}
}
@@ -867,9 +870,8 @@ auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
return VariableDecorator(cast<DefInit>(init)->getDef());
}
-auto Operator::getArgToOperandOrAttribute(int index) const
- -> OperandOrAttribute {
- return attrOrOperandMapping[index];
+auto Operator::getArgToOperandAttrOrProp(int index) const -> OperandAttrOrProp {
+ return attrPropOrOperandMapping[index];
}
std::string Operator::getGetterName(StringRef name) const {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 8d5d7f9..44732d5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Support/LLVM.h"
@@ -21,6 +22,7 @@
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
using namespace mlir;
using namespace mlir::LLVM;
@@ -88,6 +90,7 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_alias_scope,
llvm::LLVMContext::MD_dereferenceable,
llvm::LLVMContext::MD_dereferenceable_or_null,
+ llvm::LLVMContext::MD_mmra,
context.getMDKindID(vecTypeHintMDName),
context.getMDKindID(workGroupSizeHintMDName),
context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -212,6 +215,39 @@ static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
return success();
}
+/// Convert the given MMRA metadata (either an MMRA tag or an array of them)
+/// into corresponding MLIR attributes and set them on the given operation as a
+/// discardable `llvm.mmra` attribute.
+static LogicalResult setMmraAttr(llvm::MDNode *node, Operation *op,
+ LLVM::ModuleImport &moduleImport) {
+ if (!node)
+ return success();
+
+ // We don't use the LLVM wrappers here becasue we care about the order
+ // of the metadata for deterministic roundtripping.
+ MLIRContext *ctx = op->getContext();
+ auto toAttribute = [&](llvm::MDNode *tag) -> Attribute {
+ return LLVM::MMRATagAttr::get(
+ ctx, cast<llvm::MDString>(tag->getOperand(0))->getString(),
+ cast<llvm::MDString>(tag->getOperand(1))->getString());
+ };
+ Attribute mlirMmra;
+ if (llvm::MMRAMetadata::isTagMD(node)) {
+ mlirMmra = toAttribute(node);
+ } else {
+ SmallVector<Attribute> tags;
+ for (const llvm::MDOperand &operand : node->operands()) {
+ auto *tagNode = dyn_cast<llvm::MDNode>(operand.get());
+ if (!tagNode || !llvm::MMRAMetadata::isTagMD(tagNode))
+ return failure();
+ tags.push_back(toAttribute(tagNode));
+ }
+ mlirMmra = ArrayAttr::get(ctx, tags);
+ }
+ op->setAttr(LLVMDialect::getMmraAttrName(), mlirMmra);
+ return success();
+}
+
/// Converts the given loop metadata node to an MLIR loop annotation attribute
/// and attaches it to the imported operation if the translation succeeds.
/// Returns failure otherwise.
@@ -432,7 +468,8 @@ public:
return setDereferenceableAttr(
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
moduleImport);
-
+ if (kind == llvm::LLVMContext::MD_mmra)
+ return setMmraAttr(node, op, moduleImport);
llvm::LLVMContext &context = node->getContext();
if (kind == context.getMDKindID(vecTypeHintMDName))
return setVecTypeHintAttr(builder, node, op, moduleImport);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index fd8463a..eaf1d20 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -24,6 +24,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/MatrixBuilder.h"
+#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
+#include "llvm/Support/LogicalResult.h"
using namespace mlir;
using namespace mlir::LLVM;
@@ -723,6 +725,40 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
}
+static LogicalResult
+amendOperationImpl(Operation &op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ StringRef name = attribute.getName();
+ if (name == LLVMDialect::getMmraAttrName()) {
+ SmallVector<llvm::MMRAMetadata::TagT> tags;
+ if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
+ tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
+ } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
+ for (Attribute attr : manyTags) {
+ auto tag = dyn_cast<MMRATagAttr>(attr);
+ if (!tag)
+ return op.emitOpError(
+ "MMRA annotations array contains value that isn't an MMRA tag");
+ tags.emplace_back(tag.getPrefix(), tag.getSuffix());
+ }
+ } else {
+ return op.emitOpError(
+ "llvm.mmra is something other than an MMRA tag or an array of them");
+ }
+ llvm::MDTuple *mmraMd =
+ llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags);
+ if (!mmraMd) {
+ // Empty list, canonicalizes to nothing
+ return success();
+ }
+ for (llvm::Instruction *inst : instructions)
+ inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
+ return success();
+ }
+ return success();
+}
+
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the LLVM dialect to LLVM IR.
@@ -738,6 +774,14 @@ public:
LLVM::ModuleTranslation &moduleTranslation) const final {
return convertOperationImpl(*op, builder, moduleTranslation);
}
+
+ /// Handle some metadata that is represented as a discardable attribute.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ return amendOperationImpl(*op, instructions, attribute, moduleTranslation);
+ }
};
} // namespace
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4e26e65..5e194dc 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -33,6 +33,7 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -3041,16 +3042,46 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
loopInfos.push_back(*loopResult);
}
- // Collapse loops. Store the insertion point because LoopInfos may get
- // invalidated.
llvm::OpenMPIRBuilder::InsertPointTy afterIP =
loopInfos.front()->getAfterIP();
- // Update the stack frame created for this loop to point to the resulting loop
- // after applying transformations.
+ // Do tiling.
+ if (const auto &tiles = loopOp.getTileSizes()) {
+ llvm::Type *ivType = loopInfos.front()->getIndVarType();
+ SmallVector<llvm::Value *> tileSizes;
+
+ for (auto tile : tiles.value()) {
+ llvm::Value *tileVal = llvm::ConstantInt::get(ivType, tile);
+ tileSizes.push_back(tileVal);
+ }
+
+ std::vector<llvm::CanonicalLoopInfo *> newLoops =
+ ompBuilder->tileLoops(ompLoc.DL, loopInfos, tileSizes);
+
+ // Update afterIP to get the correct insertion point after
+ // tiling.
+ llvm::BasicBlock *afterBB = newLoops.front()->getAfter();
+ llvm::BasicBlock *afterAfterBB = afterBB->getSingleSuccessor();
+ afterIP = {afterAfterBB, afterAfterBB->begin()};
+
+ // Update the loop infos.
+ loopInfos.clear();
+ for (const auto &newLoop : newLoops)
+ loopInfos.push_back(newLoop);
+ } // Tiling done.
+
+ // Do collapse.
+ const auto &numCollapse = loopOp.getCollapseNumLoops();
+ SmallVector<llvm::CanonicalLoopInfo *> collapseLoopInfos(
+ loopInfos.begin(), loopInfos.begin() + (numCollapse));
+
+ auto newTopLoopInfo =
+ ompBuilder->collapseLoops(ompLoc.DL, collapseLoopInfos, {});
+
+ assert(newTopLoopInfo && "New top loop information is missing");
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
[&](OpenMPLoopInfoStackFrame &frame) {
- frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+ frame.loopInfo = newTopLoopInfo;
return WalkResult::interrupt();
});
@@ -6304,7 +6335,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
- ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
+ auto VFS = llvm::vfs::getRealFileSystem();
+ ompBuilder->loadOffloadInfoMetadata(*VFS,
+ filepathAttr.getValue());
return success();
}
return failure();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 73b166d..7e9318a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -55,10 +55,6 @@ public:
return handleDecorationCacheControl(instructions.front(),
cacheControlsArray.getValue());
}
- auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
- if (!func)
- return failure();
-
return success();
}
diff --git a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt
index 48a9601..2fe29f1 100644
--- a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt
+++ b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt
@@ -1,13 +1,13 @@
add_mlir_library(MLIRLspServerSupportLib
CompilationDatabase.cpp
- Logging.cpp
- Protocol.cpp
SourceMgrUtils.cpp
- Transport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/lsp-server-support
+ LINK_COMPONENTS
+ SupportLSP
+
LINK_LIBS PUBLIC
MLIRSupport
- )
+)
diff --git a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp
index 9ae0674..67b8ef6 100644
--- a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp
+++ b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp
@@ -8,14 +8,15 @@
#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
#include "mlir/Support/FileUtilities.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/YAMLTraits.h"
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::Logger;
//===----------------------------------------------------------------------===//
// YamlFileInfo
diff --git a/mlir/lib/Tools/lsp-server-support/Logging.cpp b/mlir/lib/Tools/lsp-server-support/Logging.cpp
deleted file mode 100644
index 373e216..0000000
--- a/mlir/lib/Tools/lsp-server-support/Logging.cpp
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- Logging.cpp --------------------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "llvm/Support/Chrono.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-using namespace mlir::lsp;
-
-void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; }
-
-Logger &Logger::get() {
- static Logger logger;
- return logger;
-}
-
-void Logger::log(Level logLevel, const char *fmt,
- const llvm::formatv_object_base &message) {
- Logger &logger = get();
-
- // Ignore messages with log levels below the current setting in the logger.
- if (logLevel < logger.logLevel)
- return;
-
- // An indicator character for each log level.
- const char *logLevelIndicators = "DIE";
-
- // Format the message and print to errs.
- llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now();
- std::lock_guard<std::mutex> logGuard(logger.mutex);
- llvm::errs() << llvm::formatv(
- "{0}[{1:%H:%M:%S.%L}] {2}\n",
- logLevelIndicators[static_cast<unsigned>(logLevel)], timestamp, message);
- llvm::errs().flush();
-}
diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp
deleted file mode 100644
index 9828704..0000000
--- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp
+++ /dev/null
@@ -1,1043 +0,0 @@
-//===--- Protocol.cpp - Language Server Protocol Implementation -----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains the serialization code for the LSP structs.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/JSON.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/Path.h"
-#include "llvm/Support/raw_ostream.h"
-
-using namespace mlir;
-using namespace mlir::lsp;
-
-// Helper that doesn't treat `null` and absent fields as failures.
-template <typename T>
-static bool mapOptOrNull(const llvm::json::Value &params,
- llvm::StringLiteral prop, T &out,
- llvm::json::Path path) {
- const llvm::json::Object *o = params.getAsObject();
- assert(o);
-
- // Field is missing or null.
- auto *v = o->get(prop);
- if (!v || v->getAsNull())
- return true;
- return fromJSON(*v, out, path.field(prop));
-}
-
-//===----------------------------------------------------------------------===//
-// LSPError
-//===----------------------------------------------------------------------===//
-
-char LSPError::ID;
-
-//===----------------------------------------------------------------------===//
-// URIForFile
-//===----------------------------------------------------------------------===//
-
-static bool isWindowsPath(StringRef path) {
- return path.size() > 1 && llvm::isAlpha(path[0]) && path[1] == ':';
-}
-
-static bool isNetworkPath(StringRef path) {
- return path.size() > 2 && path[0] == path[1] &&
- llvm::sys::path::is_separator(path[0]);
-}
-
-static bool shouldEscapeInURI(unsigned char c) {
- // Unreserved characters.
- if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
- (c >= '0' && c <= '9'))
- return false;
-
- switch (c) {
- case '-':
- case '_':
- case '.':
- case '~':
- // '/' is only reserved when parsing.
- case '/':
- // ':' is only reserved for relative URI paths, which we doesn't produce.
- case ':':
- return false;
- }
- return true;
-}
-
-/// Encodes a string according to percent-encoding.
-/// - Unreserved characters are not escaped.
-/// - Reserved characters always escaped with exceptions like '/'.
-/// - All other characters are escaped.
-static void percentEncode(StringRef content, std::string &out) {
- for (unsigned char c : content) {
- if (shouldEscapeInURI(c)) {
- out.push_back('%');
- out.push_back(llvm::hexdigit(c / 16));
- out.push_back(llvm::hexdigit(c % 16));
- } else {
- out.push_back(c);
- }
- }
-}
-
-/// Decodes a string according to percent-encoding.
-static std::string percentDecode(StringRef content) {
- std::string result;
- for (auto i = content.begin(), e = content.end(); i != e; ++i) {
- if (*i != '%') {
- result += *i;
- continue;
- }
- if (*i == '%' && i + 2 < content.end() && llvm::isHexDigit(*(i + 1)) &&
- llvm::isHexDigit(*(i + 2))) {
- result.push_back(llvm::hexFromNibbles(*(i + 1), *(i + 2)));
- i += 2;
- } else {
- result.push_back(*i);
- }
- }
- return result;
-}
-
-/// Return the set containing the supported URI schemes.
-static StringSet<> &getSupportedSchemes() {
- static StringSet<> schemes({"file", "test"});
- return schemes;
-}
-
-/// Returns true if the given scheme is structurally valid, i.e. it does not
-/// contain any invalid scheme characters. This does not check that the scheme
-/// is actually supported.
-static bool isStructurallyValidScheme(StringRef scheme) {
- if (scheme.empty())
- return false;
- if (!llvm::isAlpha(scheme[0]))
- return false;
- return llvm::all_of(llvm::drop_begin(scheme), [](char c) {
- return llvm::isAlnum(c) || c == '+' || c == '.' || c == '-';
- });
-}
-
-static llvm::Expected<std::string> uriFromAbsolutePath(StringRef absolutePath,
- StringRef scheme) {
- std::string body;
- StringRef authority;
- StringRef root = llvm::sys::path::root_name(absolutePath);
- if (isNetworkPath(root)) {
- // Windows UNC paths e.g. \\server\share => file://server/share
- authority = root.drop_front(2);
- absolutePath.consume_front(root);
- } else if (isWindowsPath(root)) {
- // Windows paths e.g. X:\path => file:///X:/path
- body = "/";
- }
- body += llvm::sys::path::convert_to_slash(absolutePath);
-
- std::string uri = scheme.str() + ":";
- if (authority.empty() && body.empty())
- return uri;
-
- // If authority if empty, we only print body if it starts with "/"; otherwise,
- // the URI is invalid.
- if (!authority.empty() || StringRef(body).starts_with("/")) {
- uri.append("//");
- percentEncode(authority, uri);
- }
- percentEncode(body, uri);
- return uri;
-}
-
-static llvm::Expected<std::string> getAbsolutePath(StringRef authority,
- StringRef body) {
- if (!body.starts_with("/"))
- return llvm::createStringError(
- llvm::inconvertibleErrorCode(),
- "File scheme: expect body to be an absolute path starting "
- "with '/': " +
- body);
- SmallString<128> path;
- if (!authority.empty()) {
- // Windows UNC paths e.g. file://server/share => \\server\share
- ("//" + authority).toVector(path);
- } else if (isWindowsPath(body.substr(1))) {
- // Windows paths e.g. file:///X:/path => X:\path
- body.consume_front("/");
- }
- path.append(body);
- llvm::sys::path::native(path);
- return std::string(path);
-}
-
-static llvm::Expected<std::string> parseFilePathFromURI(StringRef origUri) {
- StringRef uri = origUri;
-
- // Decode the scheme of the URI.
- size_t pos = uri.find(':');
- if (pos == StringRef::npos)
- return llvm::createStringError(llvm::inconvertibleErrorCode(),
- "Scheme must be provided in URI: " +
- origUri);
- StringRef schemeStr = uri.substr(0, pos);
- std::string uriScheme = percentDecode(schemeStr);
- if (!isStructurallyValidScheme(uriScheme))
- return llvm::createStringError(llvm::inconvertibleErrorCode(),
- "Invalid scheme: " + schemeStr +
- " (decoded: " + uriScheme + ")");
- uri = uri.substr(pos + 1);
-
- // Decode the authority of the URI.
- std::string uriAuthority;
- if (uri.consume_front("//")) {
- pos = uri.find('/');
- uriAuthority = percentDecode(uri.substr(0, pos));
- uri = uri.substr(pos);
- }
-
- // Decode the body of the URI.
- std::string uriBody = percentDecode(uri);
-
- // Compute the absolute path for this uri.
- if (!getSupportedSchemes().contains(uriScheme)) {
- return llvm::createStringError(llvm::inconvertibleErrorCode(),
- "unsupported URI scheme `" + uriScheme +
- "' for workspace files");
- }
- return getAbsolutePath(uriAuthority, uriBody);
-}
-
-llvm::Expected<URIForFile> URIForFile::fromURI(StringRef uri) {
- llvm::Expected<std::string> filePath = parseFilePathFromURI(uri);
- if (!filePath)
- return filePath.takeError();
- return URIForFile(std::move(*filePath), uri.str());
-}
-
-llvm::Expected<URIForFile> URIForFile::fromFile(StringRef absoluteFilepath,
- StringRef scheme) {
- llvm::Expected<std::string> uri =
- uriFromAbsolutePath(absoluteFilepath, scheme);
- if (!uri)
- return uri.takeError();
- return fromURI(*uri);
-}
-
-StringRef URIForFile::scheme() const { return uri().split(':').first; }
-
-void URIForFile::registerSupportedScheme(StringRef scheme) {
- getSupportedSchemes().insert(scheme);
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result,
- llvm::json::Path path) {
- if (std::optional<StringRef> str = value.getAsString()) {
- llvm::Expected<URIForFile> expectedURI = URIForFile::fromURI(*str);
- if (!expectedURI) {
- path.report("unresolvable URI");
- consumeError(expectedURI.takeError());
- return false;
- }
- result = std::move(*expectedURI);
- return true;
- }
- return false;
-}
-
-llvm::json::Value mlir::lsp::toJSON(const URIForFile &value) {
- return value.uri();
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) {
- return os << value.uri();
-}
-
-//===----------------------------------------------------------------------===//
-// ClientCapabilities
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- ClientCapabilities &result, llvm::json::Path path) {
- const llvm::json::Object *o = value.getAsObject();
- if (!o) {
- path.report("expected object");
- return false;
- }
- if (const llvm::json::Object *textDocument = o->getObject("textDocument")) {
- if (const llvm::json::Object *documentSymbol =
- textDocument->getObject("documentSymbol")) {
- if (std::optional<bool> hierarchicalSupport =
- documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport"))
- result.hierarchicalDocumentSymbol = *hierarchicalSupport;
- }
- if (auto *codeAction = textDocument->getObject("codeAction")) {
- if (codeAction->getObject("codeActionLiteralSupport"))
- result.codeActionStructure = true;
- }
- }
- if (auto *window = o->getObject("window")) {
- if (std::optional<bool> workDoneProgressSupport =
- window->getBoolean("workDoneProgress"))
- result.workDoneProgress = *workDoneProgressSupport;
- }
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// ClientInfo
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, ClientInfo &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- if (!o || !o.map("name", result.name))
- return false;
-
- // Don't fail if we can't parse version.
- o.map("version", result.version);
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// InitializeParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, TraceLevel &result,
- llvm::json::Path path) {
- if (std::optional<StringRef> str = value.getAsString()) {
- if (*str == "off") {
- result = TraceLevel::Off;
- return true;
- }
- if (*str == "messages") {
- result = TraceLevel::Messages;
- return true;
- }
- if (*str == "verbose") {
- result = TraceLevel::Verbose;
- return true;
- }
- }
- return false;
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- InitializeParams &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- if (!o)
- return false;
- // We deliberately don't fail if we can't parse individual fields.
- o.map("capabilities", result.capabilities);
- o.map("trace", result.trace);
- mapOptOrNull(value, "clientInfo", result.clientInfo, path);
-
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// TextDocumentItem
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- TextDocumentItem &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("uri", result.uri) &&
- o.map("languageId", result.languageId) && o.map("text", result.text) &&
- o.map("version", result.version);
-}
-
-//===----------------------------------------------------------------------===//
-// TextDocumentIdentifier
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const TextDocumentIdentifier &value) {
- return llvm::json::Object{{"uri", value.uri}};
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- TextDocumentIdentifier &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("uri", result.uri);
-}
-
-//===----------------------------------------------------------------------===//
-// VersionedTextDocumentIdentifier
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value
-mlir::lsp::toJSON(const VersionedTextDocumentIdentifier &value) {
- return llvm::json::Object{
- {"uri", value.uri},
- {"version", value.version},
- };
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- VersionedTextDocumentIdentifier &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("uri", result.uri) && o.map("version", result.version);
-}
-
-//===----------------------------------------------------------------------===//
-// Position
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, Position &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("line", result.line) &&
- o.map("character", result.character);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const Position &value) {
- return llvm::json::Object{
- {"line", value.line},
- {"character", value.character},
- };
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Position &value) {
- return os << value.line << ':' << value.character;
-}
-
-//===----------------------------------------------------------------------===//
-// Range
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, Range &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("start", result.start) && o.map("end", result.end);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const Range &value) {
- return llvm::json::Object{
- {"start", value.start},
- {"end", value.end},
- };
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Range &value) {
- return os << value.start << '-' << value.end;
-}
-
-//===----------------------------------------------------------------------===//
-// Location
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, Location &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("uri", result.uri) && o.map("range", result.range);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const Location &value) {
- return llvm::json::Object{
- {"uri", value.uri},
- {"range", value.range},
- };
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Location &value) {
- return os << value.range << '@' << value.uri;
-}
-
-//===----------------------------------------------------------------------===//
-// TextDocumentPositionParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- TextDocumentPositionParams &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument) &&
- o.map("position", result.position);
-}
-
-//===----------------------------------------------------------------------===//
-// ReferenceParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- ReferenceContext &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.mapOptional("includeDeclaration", result.includeDeclaration);
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- ReferenceParams &result, llvm::json::Path path) {
- TextDocumentPositionParams &base = result;
- llvm::json::ObjectMapper o(value, path);
- return fromJSON(value, base, path) && o &&
- o.mapOptional("context", result.context);
-}
-
-//===----------------------------------------------------------------------===//
-// DidOpenTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DidOpenTextDocumentParams &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument);
-}
-
-//===----------------------------------------------------------------------===//
-// DidCloseTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DidCloseTextDocumentParams &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument);
-}
-
-//===----------------------------------------------------------------------===//
-// DidChangeTextDocumentParams
-//===----------------------------------------------------------------------===//
-
-LogicalResult
-TextDocumentContentChangeEvent::applyTo(std::string &contents) const {
- // If there is no range, the full document changed.
- if (!range) {
- contents = text;
- return success();
- }
-
- // Try to map the replacement range to the content.
- llvm::SourceMgr tmpScrMgr;
- tmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(contents),
- SMLoc());
- SMRange rangeLoc = range->getAsSMRange(tmpScrMgr);
- if (!rangeLoc.isValid())
- return failure();
-
- contents.replace(rangeLoc.Start.getPointer() - contents.data(),
- rangeLoc.End.getPointer() - rangeLoc.Start.getPointer(),
- text);
- return success();
-}
-
-LogicalResult TextDocumentContentChangeEvent::applyTo(
- ArrayRef<TextDocumentContentChangeEvent> changes, std::string &contents) {
- for (const auto &change : changes)
- if (failed(change.applyTo(contents)))
- return failure();
- return success();
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- TextDocumentContentChangeEvent &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("range", result.range) &&
- o.map("rangeLength", result.rangeLength) && o.map("text", result.text);
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DidChangeTextDocumentParams &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument) &&
- o.map("contentChanges", result.contentChanges);
-}
-
-//===----------------------------------------------------------------------===//
-// MarkupContent
-//===----------------------------------------------------------------------===//
-
-static llvm::StringRef toTextKind(MarkupKind kind) {
- switch (kind) {
- case MarkupKind::PlainText:
- return "plaintext";
- case MarkupKind::Markdown:
- return "markdown";
- }
- llvm_unreachable("Invalid MarkupKind");
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, MarkupKind kind) {
- return os << toTextKind(kind);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const MarkupContent &mc) {
- if (mc.value.empty())
- return nullptr;
-
- return llvm::json::Object{
- {"kind", toTextKind(mc.kind)},
- {"value", mc.value},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// Hover
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const Hover &hover) {
- llvm::json::Object result{{"contents", toJSON(hover.contents)}};
- if (hover.range)
- result["range"] = toJSON(*hover.range);
- return std::move(result);
-}
-
-//===----------------------------------------------------------------------===//
-// DocumentSymbol
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) {
- llvm::json::Object result{{"name", symbol.name},
- {"kind", static_cast<int>(symbol.kind)},
- {"range", symbol.range},
- {"selectionRange", symbol.selectionRange}};
-
- if (!symbol.detail.empty())
- result["detail"] = symbol.detail;
- if (!symbol.children.empty())
- result["children"] = symbol.children;
- return std::move(result);
-}
-
-//===----------------------------------------------------------------------===//
-// DocumentSymbolParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DocumentSymbolParams &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument);
-}
-
-//===----------------------------------------------------------------------===//
-// DiagnosticRelatedInformation
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DiagnosticRelatedInformation &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("location", result.location) &&
- o.map("message", result.message);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const DiagnosticRelatedInformation &info) {
- return llvm::json::Object{
- {"location", info.location},
- {"message", info.message},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// Diagnostic
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(DiagnosticTag tag) {
- return static_cast<int>(tag);
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, DiagnosticTag &result,
- llvm::json::Path path) {
- if (std::optional<int64_t> i = value.getAsInteger()) {
- result = (DiagnosticTag)*i;
- return true;
- }
-
- return false;
-}
-
-llvm::json::Value mlir::lsp::toJSON(const Diagnostic &diag) {
- llvm::json::Object result{
- {"range", diag.range},
- {"severity", (int)diag.severity},
- {"message", diag.message},
- };
- if (diag.category)
- result["category"] = *diag.category;
- if (!diag.source.empty())
- result["source"] = diag.source;
- if (diag.relatedInformation)
- result["relatedInformation"] = *diag.relatedInformation;
- if (!diag.tags.empty())
- result["tags"] = diag.tags;
- return std::move(result);
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, Diagnostic &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- if (!o)
- return false;
- int severity = 0;
- if (!mapOptOrNull(value, "severity", severity, path))
- return false;
- result.severity = (DiagnosticSeverity)severity;
-
- return o.map("range", result.range) && o.map("message", result.message) &&
- mapOptOrNull(value, "category", result.category, path) &&
- mapOptOrNull(value, "source", result.source, path) &&
- mapOptOrNull(value, "relatedInformation", result.relatedInformation,
- path) &&
- mapOptOrNull(value, "tags", result.tags, path);
-}
-
-//===----------------------------------------------------------------------===//
-// PublishDiagnosticsParams
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams &params) {
- return llvm::json::Object{
- {"uri", params.uri},
- {"diagnostics", params.diagnostics},
- {"version", params.version},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// TextEdit
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("range", result.range) && o.map("newText", result.newText);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) {
- return llvm::json::Object{
- {"range", value.range},
- {"newText", value.newText},
- };
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) {
- os << value.range << " => \"";
- llvm::printEscapedString(value.newText, os);
- return os << '"';
-}
-
-//===----------------------------------------------------------------------===//
-// CompletionItemKind
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CompletionItemKind &result, llvm::json::Path path) {
- if (std::optional<int64_t> intValue = value.getAsInteger()) {
- if (*intValue < static_cast<int>(CompletionItemKind::Text) ||
- *intValue > static_cast<int>(CompletionItemKind::TypeParameter))
- return false;
- result = static_cast<CompletionItemKind>(*intValue);
- return true;
- }
- return false;
-}
-
-CompletionItemKind mlir::lsp::adjustKindToCapability(
- CompletionItemKind kind,
- CompletionItemKindBitset &supportedCompletionItemKinds) {
- size_t kindVal = static_cast<size_t>(kind);
- if (kindVal >= kCompletionItemKindMin &&
- kindVal <= supportedCompletionItemKinds.size() &&
- supportedCompletionItemKinds[kindVal])
- return kind;
-
- // Provide some fall backs for common kinds that are close enough.
- switch (kind) {
- case CompletionItemKind::Folder:
- return CompletionItemKind::File;
- case CompletionItemKind::EnumMember:
- return CompletionItemKind::Enum;
- case CompletionItemKind::Struct:
- return CompletionItemKind::Class;
- default:
- return CompletionItemKind::Text;
- }
-}
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CompletionItemKindBitset &result,
- llvm::json::Path path) {
- if (const llvm::json::Array *arrayValue = value.getAsArray()) {
- for (size_t i = 0, e = arrayValue->size(); i < e; ++i) {
- CompletionItemKind kindOut;
- if (fromJSON((*arrayValue)[i], kindOut, path.index(i)))
- result.set(size_t(kindOut));
- }
- return true;
- }
- return false;
-}
-
-//===----------------------------------------------------------------------===//
-// CompletionItem
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) {
- assert(!value.label.empty() && "completion item label is required");
- llvm::json::Object result{{"label", value.label}};
- if (value.kind != CompletionItemKind::Missing)
- result["kind"] = static_cast<int>(value.kind);
- if (!value.detail.empty())
- result["detail"] = value.detail;
- if (value.documentation)
- result["documentation"] = value.documentation;
- if (!value.sortText.empty())
- result["sortText"] = value.sortText;
- if (!value.filterText.empty())
- result["filterText"] = value.filterText;
- if (!value.insertText.empty())
- result["insertText"] = value.insertText;
- if (value.insertTextFormat != InsertTextFormat::Missing)
- result["insertTextFormat"] = static_cast<int>(value.insertTextFormat);
- if (value.textEdit)
- result["textEdit"] = *value.textEdit;
- if (!value.additionalTextEdits.empty()) {
- result["additionalTextEdits"] =
- llvm::json::Array(value.additionalTextEdits);
- }
- if (value.deprecated)
- result["deprecated"] = value.deprecated;
- return std::move(result);
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
- const CompletionItem &value) {
- return os << value.label << " - " << toJSON(value);
-}
-
-bool mlir::lsp::operator<(const CompletionItem &lhs,
- const CompletionItem &rhs) {
- return (lhs.sortText.empty() ? lhs.label : lhs.sortText) <
- (rhs.sortText.empty() ? rhs.label : rhs.sortText);
-}
-
-//===----------------------------------------------------------------------===//
-// CompletionList
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) {
- return llvm::json::Object{
- {"isIncomplete", value.isIncomplete},
- {"items", llvm::json::Array(value.items)},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// CompletionContext
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CompletionContext &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- int triggerKind;
- if (!o || !o.map("triggerKind", triggerKind) ||
- !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path))
- return false;
- result.triggerKind = static_cast<CompletionTriggerKind>(triggerKind);
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// CompletionParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CompletionParams &result, llvm::json::Path path) {
- if (!fromJSON(value, static_cast<TextDocumentPositionParams &>(result), path))
- return false;
- if (const llvm::json::Value *context = value.getAsObject()->get("context"))
- return fromJSON(*context, result.context, path.field("context"));
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// ParameterInformation
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) {
- assert((value.labelOffsets || !value.labelString.empty()) &&
- "parameter information label is required");
- llvm::json::Object result;
- if (value.labelOffsets)
- result["label"] = llvm::json::Array(
- {value.labelOffsets->first, value.labelOffsets->second});
- else
- result["label"] = value.labelString;
- if (!value.documentation.empty())
- result["documentation"] = value.documentation;
- return std::move(result);
-}
-
-//===----------------------------------------------------------------------===//
-// SignatureInformation
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) {
- assert(!value.label.empty() && "signature information label is required");
- llvm::json::Object result{
- {"label", value.label},
- {"parameters", llvm::json::Array(value.parameters)},
- };
- if (!value.documentation.empty())
- result["documentation"] = value.documentation;
- return std::move(result);
-}
-
-raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
- const SignatureInformation &value) {
- return os << value.label << " - " << toJSON(value);
-}
-
-//===----------------------------------------------------------------------===//
-// SignatureHelp
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) {
- assert(value.activeSignature >= 0 &&
- "Unexpected negative value for number of active signatures.");
- assert(value.activeParameter >= 0 &&
- "Unexpected negative value for active parameter index");
- return llvm::json::Object{
- {"activeSignature", value.activeSignature},
- {"activeParameter", value.activeParameter},
- {"signatures", llvm::json::Array(value.signatures)},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// DocumentLinkParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- DocumentLinkParams &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument);
-}
-
-//===----------------------------------------------------------------------===//
-// DocumentLink
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const DocumentLink &value) {
- return llvm::json::Object{
- {"range", value.range},
- {"target", value.target},
- };
-}
-
-//===----------------------------------------------------------------------===//
-// InlayHintsParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- InlayHintsParams &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument) &&
- o.map("range", result.range);
-}
-
-//===----------------------------------------------------------------------===//
-// InlayHint
-//===----------------------------------------------------------------------===//
-
-llvm::json::Value mlir::lsp::toJSON(const InlayHint &value) {
- return llvm::json::Object{{"position", value.position},
- {"kind", (int)value.kind},
- {"label", value.label},
- {"paddingLeft", value.paddingLeft},
- {"paddingRight", value.paddingRight}};
-}
-bool mlir::lsp::operator==(const InlayHint &lhs, const InlayHint &rhs) {
- return std::tie(lhs.position, lhs.kind, lhs.label) ==
- std::tie(rhs.position, rhs.kind, rhs.label);
-}
-bool mlir::lsp::operator<(const InlayHint &lhs, const InlayHint &rhs) {
- return std::tie(lhs.position, lhs.kind, lhs.label) <
- std::tie(rhs.position, rhs.kind, rhs.label);
-}
-
-llvm::raw_ostream &mlir::lsp::operator<<(llvm::raw_ostream &os,
- InlayHintKind value) {
- switch (value) {
- case InlayHintKind::Parameter:
- return os << "parameter";
- case InlayHintKind::Type:
- return os << "type";
- }
- llvm_unreachable("Unknown InlayHintKind");
-}
-
-//===----------------------------------------------------------------------===//
-// CodeActionContext
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CodeActionContext &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- if (!o || !o.map("diagnostics", result.diagnostics))
- return false;
- o.map("only", result.only);
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// CodeActionParams
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
- CodeActionParams &result, llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("textDocument", result.textDocument) &&
- o.map("range", result.range) && o.map("context", result.context);
-}
-
-//===----------------------------------------------------------------------===//
-// WorkspaceEdit
-//===----------------------------------------------------------------------===//
-
-bool mlir::lsp::fromJSON(const llvm::json::Value &value, WorkspaceEdit &result,
- llvm::json::Path path) {
- llvm::json::ObjectMapper o(value, path);
- return o && o.map("changes", result.changes);
-}
-
-llvm::json::Value mlir::lsp::toJSON(const WorkspaceEdit &value) {
- llvm::json::Object fileChanges;
- for (auto &change : value.changes)
- fileChanges[change.first] = llvm::json::Array(change.second);
- return llvm::json::Object{{"changes", std::move(fileChanges)}};
-}
-
-//===----------------------------------------------------------------------===//
-// CodeAction
-//===----------------------------------------------------------------------===//
-
-const llvm::StringLiteral CodeAction::kQuickFix = "quickfix";
-const llvm::StringLiteral CodeAction::kRefactor = "refactor";
-const llvm::StringLiteral CodeAction::kInfo = "info";
-
-llvm::json::Value mlir::lsp::toJSON(const CodeAction &value) {
- llvm::json::Object codeAction{{"title", value.title}};
- if (value.kind)
- codeAction["kind"] = *value.kind;
- if (value.diagnostics)
- codeAction["diagnostics"] = llvm::json::Array(*value.diagnostics);
- if (value.isPreferred)
- codeAction["isPreferred"] = true;
- if (value.edit)
- codeAction["edit"] = *value.edit;
- return std::move(codeAction);
-}
diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp
index f1a3623..5cd1c85 100644
--- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp
+++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp
@@ -14,6 +14,10 @@
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::Hover;
+using llvm::lsp::Range;
+using llvm::lsp::URIForFile;
+
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp
deleted file mode 100644
index 5a098b2..0000000
--- a/mlir/lib/Tools/lsp-server-support/Transport.cpp
+++ /dev/null
@@ -1,369 +0,0 @@
-//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Tools/lsp-server-support/Transport.h"
-#include "mlir/Support/ToolUtilities.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-#include "llvm/ADT/SmallString.h"
-#include "llvm/Support/Error.h"
-#include <optional>
-#include <system_error>
-#include <utility>
-
-using namespace mlir;
-using namespace mlir::lsp;
-
-//===----------------------------------------------------------------------===//
-// Reply
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Function object to reply to an LSP call.
-/// Each instance must be called exactly once, otherwise:
-/// - if there was no reply, an error reply is sent
-/// - if there were multiple replies, only the first is sent
-class Reply {
-public:
- Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
- std::mutex &transportOutputMutex);
- Reply(Reply &&other);
- Reply &operator=(Reply &&) = delete;
- Reply(const Reply &) = delete;
- Reply &operator=(const Reply &) = delete;
-
- void operator()(llvm::Expected<llvm::json::Value> reply);
-
-private:
- std::string method;
- std::atomic<bool> replied = {false};
- llvm::json::Value id;
- JSONTransport *transport;
- std::mutex &transportOutputMutex;
-};
-} // namespace
-
-Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
- JSONTransport &transport, std::mutex &transportOutputMutex)
- : method(method), id(id), transport(&transport),
- transportOutputMutex(transportOutputMutex) {}
-
-Reply::Reply(Reply &&other)
- : method(other.method), replied(other.replied.load()),
- id(std::move(other.id)), transport(other.transport),
- transportOutputMutex(other.transportOutputMutex) {
- other.transport = nullptr;
-}
-
-void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
- if (replied.exchange(true)) {
- Logger::error("Replied twice to message {0}({1})", method, id);
- assert(false && "must reply to each call only once!");
- return;
- }
- assert(transport && "expected valid transport to reply to");
-
- std::lock_guard<std::mutex> transportLock(transportOutputMutex);
- if (reply) {
- Logger::info("--> reply:{0}({1})", method, id);
- transport->reply(std::move(id), std::move(reply));
- } else {
- llvm::Error error = reply.takeError();
- Logger::info("--> reply:{0}({1}): {2}", method, id, error);
- transport->reply(std::move(id), std::move(error));
- }
-}
-
-//===----------------------------------------------------------------------===//
-// MessageHandler
-//===----------------------------------------------------------------------===//
-
-bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
- Logger::info("--> {0}", method);
-
- if (method == "exit")
- return false;
- if (method == "$cancel") {
- // TODO: Add support for cancelling requests.
- } else {
- auto it = notificationHandlers.find(method);
- if (it != notificationHandlers.end())
- it->second(std::move(value));
- }
- return true;
-}
-
-bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
- llvm::json::Value id) {
- Logger::info("--> {0}({1})", method, id);
-
- Reply reply(id, method, transport, transportOutputMutex);
-
- auto it = methodHandlers.find(method);
- if (it != methodHandlers.end()) {
- it->second(std::move(params), std::move(reply));
- } else {
- reply(llvm::make_error<LSPError>("method not found: " + method.str(),
- ErrorCode::MethodNotFound));
- }
- return true;
-}
-
-bool MessageHandler::onReply(llvm::json::Value id,
- llvm::Expected<llvm::json::Value> result) {
- // Find the response handler in the mapping. If it exists, move it out of the
- // mapping and erase it.
- ResponseHandlerTy responseHandler;
- {
- std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
- auto it = responseHandlers.find(debugString(id));
- if (it != responseHandlers.end()) {
- responseHandler = std::move(it->second);
- responseHandlers.erase(it);
- }
- }
-
- // If we found a response handler, invoke it. Otherwise, log an error.
- if (responseHandler.second) {
- Logger::info("--> reply:{0}({1})", responseHandler.first, id);
- responseHandler.second(std::move(id), std::move(result));
- } else {
- Logger::error(
- "received a reply with ID {0}, but there was no such outgoing request",
- id);
- if (!result)
- llvm::consumeError(result.takeError());
- }
- return true;
-}
-
-//===----------------------------------------------------------------------===//
-// JSONTransport
-//===----------------------------------------------------------------------===//
-
-/// Encode the given error as a JSON object.
-static llvm::json::Object encodeError(llvm::Error error) {
- std::string message;
- ErrorCode code = ErrorCode::UnknownErrorCode;
- auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
- message = lspError.message;
- code = lspError.code;
- return llvm::Error::success();
- };
- if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
- message = llvm::toString(std::move(unhandled));
-
- return llvm::json::Object{
- {"message", std::move(message)},
- {"code", int64_t(code)},
- };
-}
-
-/// Decode the given JSON object into an error.
-llvm::Error decodeError(const llvm::json::Object &o) {
- StringRef msg = o.getString("message").value_or("Unspecified error");
- if (std::optional<int64_t> code = o.getInteger("code"))
- return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
- return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
- msg.str());
-}
-
-void JSONTransport::notify(StringRef method, llvm::json::Value params) {
- sendMessage(llvm::json::Object{
- {"jsonrpc", "2.0"},
- {"method", method},
- {"params", std::move(params)},
- });
-}
-void JSONTransport::call(StringRef method, llvm::json::Value params,
- llvm::json::Value id) {
- sendMessage(llvm::json::Object{
- {"jsonrpc", "2.0"},
- {"id", std::move(id)},
- {"method", method},
- {"params", std::move(params)},
- });
-}
-void JSONTransport::reply(llvm::json::Value id,
- llvm::Expected<llvm::json::Value> result) {
- if (result) {
- return sendMessage(llvm::json::Object{
- {"jsonrpc", "2.0"},
- {"id", std::move(id)},
- {"result", std::move(*result)},
- });
- }
-
- sendMessage(llvm::json::Object{
- {"jsonrpc", "2.0"},
- {"id", std::move(id)},
- {"error", encodeError(result.takeError())},
- });
-}
-
-llvm::Error JSONTransport::run(MessageHandler &handler) {
- std::string json;
- while (!in->isEndOfInput()) {
- if (in->hasError()) {
- return llvm::errorCodeToError(
- std::error_code(errno, std::system_category()));
- }
-
- if (succeeded(in->readMessage(json))) {
- if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
- if (!handleMessage(std::move(*doc), handler))
- return llvm::Error::success();
- } else {
- Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
- }
- }
- }
- return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
-}
-
-void JSONTransport::sendMessage(llvm::json::Value msg) {
- outputBuffer.clear();
- llvm::raw_svector_ostream os(outputBuffer);
- os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
- out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
- << outputBuffer;
- out.flush();
- Logger::debug(">>> {0}\n", outputBuffer);
-}
-
-bool JSONTransport::handleMessage(llvm::json::Value msg,
- MessageHandler &handler) {
- // Message must be an object with "jsonrpc":"2.0".
- llvm::json::Object *object = msg.getAsObject();
- if (!object ||
- object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
- return false;
-
- // `id` may be any JSON value. If absent, this is a notification.
- std::optional<llvm::json::Value> id;
- if (llvm::json::Value *i = object->get("id"))
- id = std::move(*i);
- std::optional<StringRef> method = object->getString("method");
-
- // This is a response.
- if (!method) {
- if (!id)
- return false;
- if (auto *err = object->getObject("error"))
- return handler.onReply(std::move(*id), decodeError(*err));
- // result should be given, use null if not.
- llvm::json::Value result = nullptr;
- if (llvm::json::Value *r = object->get("result"))
- result = std::move(*r);
- return handler.onReply(std::move(*id), std::move(result));
- }
-
- // Params should be given, use null if not.
- llvm::json::Value params = nullptr;
- if (llvm::json::Value *p = object->get("params"))
- params = std::move(*p);
-
- if (id)
- return handler.onCall(*method, std::move(params), std::move(*id));
- return handler.onNotify(*method, std::move(params));
-}
-
-/// Tries to read a line up to and including \n.
-/// If failing, feof(), ferror(), or shutdownRequested() will be set.
-LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
- // Big enough to hold any reasonable header line. May not fit content lines
- // in delimited mode, but performance doesn't matter for that mode.
- static constexpr int bufSize = 128;
- size_t size = 0;
- out.clear();
- for (;;) {
- out.resize_for_overwrite(size + bufSize);
- if (!std::fgets(&out[size], bufSize, in))
- return failure();
-
- clearerr(in);
-
- // If the line contained null bytes, anything after it (including \n) will
- // be ignored. Fortunately this is not a legal header or JSON.
- size_t read = std::strlen(&out[size]);
- if (read > 0 && out[size + read - 1] == '\n') {
- out.resize(size + read);
- return success();
- }
- size += read;
- }
-}
-
-// Returns std::nullopt when:
-// - ferror(), feof(), or shutdownRequested() are set.
-// - Content-Length is missing or empty (protocol error)
-LogicalResult
-JSONTransportInputOverFile::readStandardMessage(std::string &json) {
- // A Language Server Protocol message starts with a set of HTTP headers,
- // delimited by \r\n, and terminated by an empty line (\r\n).
- unsigned long long contentLength = 0;
- llvm::SmallString<128> line;
- while (true) {
- if (feof(in) || hasError() || failed(readLine(in, line)))
- return failure();
-
- // Content-Length is a mandatory header, and the only one we handle.
- StringRef lineRef = line;
- if (lineRef.consume_front("Content-Length: ")) {
- llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
- } else if (!lineRef.trim().empty()) {
- // It's another header, ignore it.
- continue;
- } else {
- // An empty line indicates the end of headers. Go ahead and read the JSON.
- break;
- }
- }
-
- // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
- if (contentLength == 0 || contentLength > 1 << 30)
- return failure();
-
- json.resize(contentLength);
- for (size_t pos = 0, read; pos < contentLength; pos += read) {
- read = std::fread(&json[pos], 1, contentLength - pos, in);
- if (read == 0)
- return failure();
-
- // If we're done, the error was transient. If we're not done, either it was
- // transient or we'll see it again on retry.
- clearerr(in);
- pos += read;
- }
- return success();
-}
-
-/// For lit tests we support a simplified syntax:
-/// - messages are delimited by '// -----' on a line by itself
-/// - lines starting with // are ignored.
-/// This is a testing path, so favor simplicity over performance here.
-/// When returning failure: feof(), ferror(), or shutdownRequested() will be
-/// set.
-LogicalResult
-JSONTransportInputOverFile::readDelimitedMessage(std::string &json) {
- json.clear();
- llvm::SmallString<128> line;
- while (succeeded(readLine(in, line))) {
- StringRef lineRef = line.str().trim();
- if (lineRef.starts_with("//")) {
- // Found a delimiter for the message.
- if (lineRef == kDefaultSplitMarker)
- break;
- continue;
- }
-
- json += line;
- }
-
- return failure(ferror(in));
-}
diff --git a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt
index d04d5156f..e2acba5 100644
--- a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt
@@ -7,6 +7,9 @@ add_mlir_library(MLIRLspServerLib
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server
+ LINK_COMPONENTS
+ SupportLSP
+
LINK_LIBS PUBLIC
MLIRBytecodeWriter
MLIRFunctionInterfaces
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
index 9b937db..1bbbcde 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp
@@ -9,8 +9,8 @@
#include "LSPServer.h"
#include "MLIRServer.h"
#include "Protocol.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Transport.h"
#include <optional>
#define DEBUG_TYPE "mlir-lsp-server"
@@ -18,6 +18,33 @@
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::Callback;
+using llvm::lsp::CodeAction;
+using llvm::lsp::CodeActionParams;
+using llvm::lsp::CompletionList;
+using llvm::lsp::CompletionParams;
+using llvm::lsp::DidChangeTextDocumentParams;
+using llvm::lsp::DidCloseTextDocumentParams;
+using llvm::lsp::DidOpenTextDocumentParams;
+using llvm::lsp::DocumentSymbol;
+using llvm::lsp::DocumentSymbolParams;
+using llvm::lsp::Hover;
+using llvm::lsp::InitializedParams;
+using llvm::lsp::InitializeParams;
+using llvm::lsp::JSONTransport;
+using llvm::lsp::Location;
+using llvm::lsp::Logger;
+using llvm::lsp::MessageHandler;
+using llvm::lsp::MLIRConvertBytecodeParams;
+using llvm::lsp::MLIRConvertBytecodeResult;
+using llvm::lsp::NoParams;
+using llvm::lsp::OutgoingNotification;
+using llvm::lsp::PublishDiagnosticsParams;
+using llvm::lsp::ReferenceParams;
+using llvm::lsp::TextDocumentPositionParams;
+using llvm::lsp::TextDocumentSyncKind;
+using llvm::lsp::URIForFile;
+
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h
index 2c50c6b..d652899 100644
--- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h
+++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h
@@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
+namespace lsp {
+class JSONTransport;
+} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
-class JSONTransport;
class MLIRServer;
/// Run the main loop of the LSP server using the given MLIR server and
/// transport.
llvm::LogicalResult runMlirLSPServer(MLIRServer &server,
- JSONTransport &transport);
+ llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 6198752..47b4328 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -16,10 +16,10 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/ToolUtilities.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Base64.h"
+#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/SourceMgr.h"
#include <optional>
@@ -39,9 +39,9 @@ static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
llvm::Expected<lsp::URIForFile> sourceURI =
lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
if (!sourceURI) {
- lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
- loc.getFilename(),
- llvm::toString(sourceURI.takeError()));
+ llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
+ loc.getFilename(),
+ llvm::toString(sourceURI.takeError()));
return std::nullopt;
}
@@ -217,22 +217,22 @@ static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
// Convert the severity for the diagnostic.
switch (diag.getSeverity()) {
- case DiagnosticSeverity::Note:
+ case mlir::DiagnosticSeverity::Note:
llvm_unreachable("expected notes to be handled separately");
- case DiagnosticSeverity::Warning:
- lspDiag.severity = lsp::DiagnosticSeverity::Warning;
+ case mlir::DiagnosticSeverity::Warning:
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
break;
- case DiagnosticSeverity::Error:
- lspDiag.severity = lsp::DiagnosticSeverity::Error;
+ case mlir::DiagnosticSeverity::Error:
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
break;
- case DiagnosticSeverity::Remark:
- lspDiag.severity = lsp::DiagnosticSeverity::Information;
+ case mlir::DiagnosticSeverity::Remark:
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
break;
}
lspDiag.message = diag.str();
// Attach any notes to the main diagnostic as related information.
- std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
+ std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
for (Diagnostic &note : diag.getNotes()) {
lsp::Location noteLoc;
if (std::optional<lsp::Location> loc =
@@ -317,7 +317,7 @@ struct MLIRDocument {
void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
lsp::Position &pos, StringRef severity,
StringRef message,
- std::vector<lsp::TextEdit> &edits);
+ std::vector<llvm::lsp::TextEdit> &edits);
//===--------------------------------------------------------------------===//
// Bytecode
@@ -355,7 +355,8 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
// Try to parsed the given IR string.
auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
if (!memBuffer) {
- lsp::Logger::error("Failed to create memory buffer for file", uri.file());
+ llvm::lsp::Logger::error("Failed to create memory buffer for file",
+ uri.file());
return;
}
@@ -695,8 +696,8 @@ void MLIRDocument::findDocumentSymbols(
if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
symbols.emplace_back(symbol.getName(),
isa<FunctionOpInterface>(op)
- ? lsp::SymbolKind::Function
- : lsp::SymbolKind::Class,
+ ? llvm::lsp::SymbolKind::Function
+ : llvm::lsp::SymbolKind::Class,
lsp::Range(sourceMgr, def->scopeLoc),
lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;
@@ -704,9 +705,9 @@ void MLIRDocument::findDocumentSymbols(
} else if (op->hasTrait<OpTrait::SymbolTable>()) {
// Otherwise, if this is a symbol table push an anonymous document symbol.
symbols.emplace_back("<" + op->getName().getStringRef() + ">",
- lsp::SymbolKind::Namespace,
- lsp::Range(sourceMgr, def->scopeLoc),
- lsp::Range(sourceMgr, def->loc));
+ llvm::lsp::SymbolKind::Namespace,
+ llvm::lsp::Range(sourceMgr, def->scopeLoc),
+ llvm::lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;
}
}
@@ -734,9 +735,9 @@ public:
/// Signal code completion for a dialect name, with an optional prefix.
void completeDialectName(StringRef prefix) final {
for (StringRef dialect : ctx->getAvailableDialects()) {
- lsp::CompletionItem item(prefix + dialect,
- lsp::CompletionItemKind::Module,
- /*sortText=*/"3");
+ llvm::lsp::CompletionItem item(prefix + dialect,
+ llvm::lsp::CompletionItemKind::Module,
+ /*sortText=*/"3");
item.detail = "dialect";
completionList.items.emplace_back(item);
}
@@ -753,9 +754,9 @@ public:
if (&op.getDialect() != dialect)
continue;
- lsp::CompletionItem item(
+ llvm::lsp::CompletionItem item(
op.getStringRef().drop_front(dialectName.size() + 1),
- lsp::CompletionItemKind::Field,
+ llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
item.detail = "operation";
completionList.items.emplace_back(item);
@@ -768,7 +769,8 @@ public:
// Check if we need to insert the `%` or not.
bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
- lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
+ llvm::lsp::CompletionItem item(name,
+ llvm::lsp::CompletionItemKind::Variable);
if (stripPrefix)
item.insertText = name.drop_front(1).str();
item.detail = std::move(typeData);
@@ -781,7 +783,7 @@ public:
// Check if we need to insert the `^` or not.
bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
- lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
+ llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
if (stripPrefix)
item.insertText = name.drop_front(1).str();
completionList.items.emplace_back(item);
@@ -790,8 +792,9 @@ public:
/// Signal a completion for the given expected token.
void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
for (StringRef token : tokens) {
- lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
- /*sortText=*/"0");
+ llvm::lsp::CompletionItem item(token,
+ llvm::lsp::CompletionItemKind::Keyword,
+ /*sortText=*/"0");
item.detail = optional ? "optional" : "";
completionList.items.emplace_back(item);
}
@@ -802,7 +805,7 @@ public:
appendSimpleCompletions({"affine_set", "affine_map", "dense",
"dense_resource", "false", "loc", "sparse", "true",
"unit"},
- lsp::CompletionItemKind::Field,
+ llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
completeDialectName("#");
@@ -820,13 +823,14 @@ public:
appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
"bf16", "f16", "f32", "f64", "f80", "f128",
"index", "none"},
- lsp::CompletionItemKind::Field,
+ llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
// Handle the builtin integer types.
for (StringRef type : {"i", "si", "ui"}) {
- lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
- /*sortText=*/"1");
+ llvm::lsp::CompletionItem item(type + "<N>",
+ llvm::lsp::CompletionItemKind::Field,
+ /*sortText=*/"1");
item.insertText = type.str();
completionList.items.emplace_back(item);
}
@@ -846,9 +850,9 @@ public:
void completeAliases(const llvm::StringMap<T> &aliases,
StringRef prefix = "") {
for (const auto &alias : aliases) {
- lsp::CompletionItem item(prefix + alias.getKey(),
- lsp::CompletionItemKind::Field,
- /*sortText=*/"2");
+ llvm::lsp::CompletionItem item(prefix + alias.getKey(),
+ llvm::lsp::CompletionItemKind::Field,
+ /*sortText=*/"2");
llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
completionList.items.emplace_back(item);
}
@@ -856,7 +860,7 @@ public:
/// Add a set of simple completions that all have the same kind.
void appendSimpleCompletions(ArrayRef<StringRef> completions,
- lsp::CompletionItemKind kind,
+ llvm::lsp::CompletionItemKind kind,
StringRef sortText = "") {
for (StringRef completion : completions)
completionList.items.emplace_back(completion, kind, sortText);
@@ -897,7 +901,7 @@ MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
void MLIRDocument::getCodeActionForDiagnostic(
const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
- StringRef message, std::vector<lsp::TextEdit> &edits) {
+ StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
// Ignore diagnostics that print the current operation. These are always
// enabled for the language server, but not generally during normal
// parsing/verification.
@@ -913,7 +917,7 @@ void MLIRDocument::getCodeActionForDiagnostic(
// Add a text edit for adding an expected-* diagnostic check for this
// diagnostic.
- lsp::TextEdit edit;
+ llvm::lsp::TextEdit edit;
edit.range = lsp::Range(lsp::Position(pos.line, 0));
// Use the indent of the current line for the expected-* diagnostic.
@@ -937,13 +941,14 @@ MLIRDocument::convertToBytecode() {
// conceptually be relaxed.
if (!llvm::hasSingleElement(parsedIR)) {
if (parsedIR.empty()) {
- return llvm::make_error<lsp::LSPError>(
+ return llvm::make_error<llvm::lsp::LSPError>(
"expected a single and valid top-level operation, please ensure "
"there are no errors",
- lsp::ErrorCode::RequestFailed);
+ llvm::lsp::ErrorCode::RequestFailed);
}
- return llvm::make_error<lsp::LSPError>(
- "expected a single top-level operation", lsp::ErrorCode::RequestFailed);
+ return llvm::make_error<llvm::lsp::LSPError>(
+ "expected a single top-level operation",
+ llvm::lsp::ErrorCode::RequestFailed);
}
lsp::MLIRConvertBytecodeResult result;
@@ -1134,7 +1139,7 @@ void MLIRTextFile::findDocumentSymbols(
lsp::Position endPos((i == e - 1) ? totalNumLines - 1
: chunks[i + 1]->lineOffset);
lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
- lsp::SymbolKind::Namespace,
+ llvm::lsp::SymbolKind::Namespace,
/*range=*/lsp::Range(startPos, endPos),
/*selectionRange=*/lsp::Range(startPos));
chunk.document.findDocumentSymbols(symbol.children);
@@ -1167,10 +1172,10 @@ lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
uri, completePos, context.getDialectRegistry());
// Adjust any completion locations.
- for (lsp::CompletionItem &item : completionList.items) {
+ for (llvm::lsp::CompletionItem &item : completionList.items) {
if (item.textEdit)
chunk.adjustLocForChunkOffset(item.textEdit->range);
- for (lsp::TextEdit &edit : item.additionalTextEdits)
+ for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
chunk.adjustLocForChunkOffset(edit.range);
}
return completionList;
@@ -1194,10 +1199,10 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
StringRef severity;
switch (diag.severity) {
- case lsp::DiagnosticSeverity::Error:
+ case llvm::lsp::DiagnosticSeverity::Error:
severity = "error";
break;
- case lsp::DiagnosticSeverity::Warning:
+ case llvm::lsp::DiagnosticSeverity::Warning:
severity = "warning";
break;
default:
@@ -1205,7 +1210,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
}
// Get edits for the diagnostic.
- std::vector<lsp::TextEdit> edits;
+ std::vector<llvm::lsp::TextEdit> edits;
chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
diag.message, edits);
@@ -1221,7 +1226,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
}
}
// Fixup the locations for any edits.
- for (lsp::TextEdit &edit : edits)
+ for (llvm::lsp::TextEdit &edit : edits)
chunk.adjustLocForChunkOffset(edit.range);
action.edit.emplace();
@@ -1236,9 +1241,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult>
MLIRTextFile::convertToBytecode() {
// Bail out if there is more than one chunk, bytecode wants a single module.
if (chunks.size() != 1) {
- return llvm::make_error<lsp::LSPError>(
+ return llvm::make_error<llvm::lsp::LSPError>(
"unexpected split file, please remove all `// -----`",
- lsp::ErrorCode::RequestFailed);
+ llvm::lsp::ErrorCode::RequestFailed);
}
return chunks.front()->document.convertToBytecode();
}
@@ -1283,7 +1288,7 @@ lsp::MLIRServer::~MLIRServer() = default;
void lsp::MLIRServer::addOrUpdateDocument(
const URIForFile &uri, StringRef contents, int64_t version,
- std::vector<Diagnostic> &diagnostics) {
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
uri, contents, version, impl->registry_fn, diagnostics);
}
@@ -1298,17 +1303,17 @@ std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
return version;
}
-void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
- const Position &defPos,
- std::vector<Location> &locations) {
+void lsp::MLIRServer::getLocationsOf(
+ const URIForFile &uri, const Position &defPos,
+ std::vector<llvm::lsp::Location> &locations) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->getLocationsOf(uri, defPos, locations);
}
-void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
- const Position &pos,
- std::vector<Location> &references) {
+void lsp::MLIRServer::findReferencesOf(
+ const URIForFile &uri, const Position &pos,
+ std::vector<llvm::lsp::Location> &references) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->findReferencesOf(uri, pos, references);
@@ -1367,17 +1372,17 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
// Try to parse the given source file.
Block parsedBlock;
if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
- return llvm::make_error<lsp::LSPError>(
+ return llvm::make_error<llvm::lsp::LSPError>(
"failed to parse bytecode source file: " + errorMsg,
- lsp::ErrorCode::RequestFailed);
+ llvm::lsp::ErrorCode::RequestFailed);
}
// TODO: We currently expect a single top-level operation, but this could
// conceptually be relaxed.
if (!llvm::hasSingleElement(parsedBlock)) {
- return llvm::make_error<lsp::LSPError>(
+ return llvm::make_error<llvm::lsp::LSPError>(
"expected bytecode to contain a single top-level operation",
- lsp::ErrorCode::RequestFailed);
+ llvm::lsp::ErrorCode::RequestFailed);
}
// Print the module to a buffer.
@@ -1401,9 +1406,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult>
lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
auto fileIt = impl->files.find(uri.file());
if (fileIt == impl->files.end()) {
- return llvm::make_error<lsp::LSPError>(
+ return llvm::make_error<llvm::lsp::LSPError>(
"language server does not contain an entry for this source file",
- lsp::ErrorCode::RequestFailed);
+ llvm::lsp::ErrorCode::RequestFailed);
}
return fileIt->second->convertToBytecode();
}
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
index 85e69e6..31a01fe 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
@@ -9,6 +9,7 @@
#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_
+#include "Protocol.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"
#include "llvm/Support/Error.h"
@@ -19,16 +20,17 @@ namespace mlir {
class DialectRegistry;
namespace lsp {
-struct CodeAction;
-struct CodeActionContext;
-struct CompletionList;
-struct Diagnostic;
-struct DocumentSymbol;
-struct Hover;
-struct Location;
-struct MLIRConvertBytecodeResult;
-struct Position;
-struct Range;
+using llvm::lsp::CodeAction;
+using llvm::lsp::CodeActionContext;
+using llvm::lsp::CompletionList;
+using llvm::lsp::Diagnostic;
+using llvm::lsp::DocumentSymbol;
+using llvm::lsp::Hover;
+using llvm::lsp::Location;
+using llvm::lsp::MLIRConvertBytecodeResult;
+using llvm::lsp::Position;
+using llvm::lsp::Range;
+using llvm::lsp::URIForFile;
/// This class implements all of the MLIR related functionality necessary for a
/// language server. This class allows for keeping the MLIR specific logic
diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
index f1dc326..d4589b2 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
@@ -9,14 +9,18 @@
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
#include "LSPServer.h"
#include "MLIRServer.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::JSONStreamStyle;
+using llvm::lsp::JSONTransport;
+using llvm::lsp::Logger;
+
LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
DialectRegistryFn registry_fn) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
index a56e9a1..28aded3 100644
--- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp
@@ -13,14 +13,11 @@
#include "Protocol.h"
#include "llvm/Support/JSON.h"
-using namespace mlir;
-using namespace mlir::lsp;
-
//===----------------------------------------------------------------------===//
// MLIRConvertBytecodeParams
//===----------------------------------------------------------------------===//
-bool mlir::lsp::fromJSON(const llvm::json::Value &value,
+bool llvm::lsp::fromJSON(const llvm::json::Value &value,
MLIRConvertBytecodeParams &result,
llvm::json::Path path) {
llvm::json::ObjectMapper o(value, path);
@@ -31,6 +28,6 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value,
// MLIRConvertBytecodeResult
//===----------------------------------------------------------------------===//
-llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) {
+llvm::json::Value llvm::lsp::toJSON(const MLIRConvertBytecodeResult &value) {
return llvm::json::Object{{"output", value.output}};
}
diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/Protocol.h
index d910780..ed0db4e 100644
--- a/mlir/lib/Tools/mlir-lsp-server/Protocol.h
+++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.h
@@ -20,9 +20,9 @@
#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_
-#include "mlir/Tools/lsp-server-support/Protocol.h"
+#include "llvm/Support/LSP/Protocol.h"
-namespace mlir {
+namespace llvm {
namespace lsp {
//===----------------------------------------------------------------------===//
// MLIRConvertBytecodeParams
@@ -54,6 +54,6 @@ struct MLIRConvertBytecodeResult {
llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value);
} // namespace lsp
-} // namespace mlir
+} // namespace llvm
#endif
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt
index bf25b7e..b41603f 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt
@@ -7,6 +7,9 @@ llvm_add_library(MLIRPdllLspServerLib
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server
+ LINK_COMPONENTS
+ SupportLSP
+
LINK_LIBS PUBLIC
MLIRPDLLCodeGen
MLIRPDLLParser
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
index 82542a1..7b23adc 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
@@ -10,8 +10,9 @@
#include "PDLLServer.h"
#include "Protocol.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Protocol.h"
+#include "llvm/Support/LSP/Transport.h"
#include <optional>
#define DEBUG_TYPE "pdll-lsp-server"
@@ -19,6 +20,30 @@
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::Callback;
+using llvm::lsp::CompletionList;
+using llvm::lsp::CompletionParams;
+using llvm::lsp::DidChangeTextDocumentParams;
+using llvm::lsp::DidCloseTextDocumentParams;
+using llvm::lsp::DidOpenTextDocumentParams;
+using llvm::lsp::DocumentLinkParams;
+using llvm::lsp::DocumentSymbol;
+using llvm::lsp::DocumentSymbolParams;
+using llvm::lsp::Hover;
+using llvm::lsp::InitializedParams;
+using llvm::lsp::InitializeParams;
+using llvm::lsp::InlayHintsParams;
+using llvm::lsp::JSONTransport;
+using llvm::lsp::Location;
+using llvm::lsp::Logger;
+using llvm::lsp::MessageHandler;
+using llvm::lsp::NoParams;
+using llvm::lsp::OutgoingNotification;
+using llvm::lsp::PublishDiagnosticsParams;
+using llvm::lsp::ReferenceParams;
+using llvm::lsp::TextDocumentPositionParams;
+using llvm::lsp::TextDocumentSyncKind;
+
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h
index 78c4c31..42c0a5d 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h
@@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
+namespace lsp {
+class JSONTransport;
+} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
-class JSONTransport;
class PDLLServer;
/// Run the main loop of the LSP server using the given PDLL server and
/// transport.
llvm::LogicalResult runPdllLSPServer(PDLLServer &server,
- JSONTransport &transport);
+ llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp
index 287a131..5dea130 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp
@@ -9,14 +9,17 @@
#include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h"
#include "LSPServer.h"
#include "PDLLServer.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::JSONStreamStyle;
+using llvm::lsp::Logger;
+
LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",
@@ -72,7 +75,8 @@ LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) {
// Configure the transport used for communication.
llvm::sys::ChangeStdinToBinary();
- JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint);
+ llvm::lsp::JSONTransport transport(stdin, llvm::outs(), inputStyle,
+ prettyPrint);
// Configure the servers and start the main language server.
PDLLServer::Options options(compilationDatabases, extraIncludeDirs);
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 84f529a..60b9567 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -23,13 +23,13 @@
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Path.h"
#include <optional>
@@ -38,17 +38,19 @@ using namespace mlir::pdll;
/// Returns a language server uri for the given source location. `mainFileURI`
/// corresponds to the uri for the main file of the source manager.
-static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
- const lsp::URIForFile &mainFileURI) {
+static llvm::lsp::URIForFile
+getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
+ const llvm::lsp::URIForFile &mainFileURI) {
int bufferId = mgr.FindBufferContainingLoc(loc.Start);
if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
return mainFileURI;
- llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
- mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
+ llvm::Expected<llvm::lsp::URIForFile> fileForLoc =
+ llvm::lsp::URIForFile::fromFile(
+ mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
if (fileForLoc)
return *fileForLoc;
- lsp::Logger::error("Failed to create URI for include file: {0}",
- llvm::toString(fileForLoc.takeError()));
+ llvm::lsp::Logger::error("Failed to create URI for include file: {0}",
+ llvm::toString(fileForLoc.takeError()));
return mainFileURI;
}
@@ -59,16 +61,18 @@ static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
}
/// Returns a language server location from the given source range.
-static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
- const lsp::URIForFile &uri) {
- return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range));
+static llvm::lsp::Location
+getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
+ const llvm::lsp::URIForFile &uri) {
+ return llvm::lsp::Location(getURIFromLoc(mgr, range, uri),
+ llvm::lsp::Range(mgr, range));
}
/// Convert the given MLIR diagnostic to the LSP form.
-static std::optional<lsp::Diagnostic>
+static std::optional<llvm::lsp::Diagnostic>
getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
- const lsp::URIForFile &uri) {
- lsp::Diagnostic lspDiag;
+ const llvm::lsp::URIForFile &uri) {
+ llvm::lsp::Diagnostic lspDiag;
lspDiag.source = "pdll";
// FIXME: Right now all of the diagnostics are treated as parser issues, but
@@ -76,7 +80,8 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
lspDiag.category = "Parse Error";
// Try to grab a file location for this diagnostic.
- lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
+ llvm::lsp::Location loc =
+ getLocationFromLoc(sourceMgr, diag.getLocation(), uri);
lspDiag.range = loc.range;
// Skip diagnostics that weren't emitted within the main file.
@@ -88,19 +93,19 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
case ast::Diagnostic::Severity::DK_Note:
llvm_unreachable("expected notes to be handled separately");
case ast::Diagnostic::Severity::DK_Warning:
- lspDiag.severity = lsp::DiagnosticSeverity::Warning;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
break;
case ast::Diagnostic::Severity::DK_Error:
- lspDiag.severity = lsp::DiagnosticSeverity::Error;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
break;
case ast::Diagnostic::Severity::DK_Remark:
- lspDiag.severity = lsp::DiagnosticSeverity::Information;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
break;
}
lspDiag.message = diag.getMessage().str();
// Attach any notes to the main diagnostic as related information.
- std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
+ std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
for (const ast::Diagnostic &note : diag.getNotes()) {
relatedDiags.emplace_back(
getLocationFromLoc(sourceMgr, note.getLocation(), uri),
@@ -259,9 +264,9 @@ namespace {
/// This class represents all of the information pertaining to a specific PDL
/// document.
struct PDLDocument {
- PDLDocument(const lsp::URIForFile &uri, StringRef contents,
+ PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
const std::vector<std::string> &extraDirs,
- std::vector<lsp::Diagnostic> &diagnostics);
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
PDLDocument(const PDLDocument &) = delete;
PDLDocument &operator=(const PDLDocument &) = delete;
@@ -269,76 +274,83 @@ struct PDLDocument {
// Definitions and References
//===--------------------------------------------------------------------===//
- void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
- std::vector<lsp::Location> &locations);
- void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
- std::vector<lsp::Location> &references);
+ void getLocationsOf(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &defPos,
+ std::vector<llvm::lsp::Location> &locations);
+ void findReferencesOf(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &pos,
+ std::vector<llvm::lsp::Location> &references);
//===--------------------------------------------------------------------===//
// Document Links
//===--------------------------------------------------------------------===//
- void getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links);
+ void getDocumentLinks(const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links);
//===--------------------------------------------------------------------===//
// Hover
//===--------------------------------------------------------------------===//
- std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
- const lsp::Position &hoverPos);
- std::optional<lsp::Hover> findHover(const ast::Decl *decl,
- const SMRange &hoverRange);
- lsp::Hover buildHoverForOpName(const ods::Operation *op,
- const SMRange &hoverRange);
- lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
- const SMRange &hoverRange);
- lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
- const SMRange &hoverRange);
- lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
+ std::optional<llvm::lsp::Hover>
+ findHover(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &hoverPos);
+ std::optional<llvm::lsp::Hover> findHover(const ast::Decl *decl,
+ const SMRange &hoverRange);
+ llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op,
+ const SMRange &hoverRange);
+ llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
const SMRange &hoverRange);
+ llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
+ const SMRange &hoverRange);
+ llvm::lsp::Hover
+ buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
+ const SMRange &hoverRange);
template <typename T>
- lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
- const T *decl,
- const SMRange &hoverRange);
+ llvm::lsp::Hover
+ buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl,
+ const SMRange &hoverRange);
//===--------------------------------------------------------------------===//
// Document Symbols
//===--------------------------------------------------------------------===//
- void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
+ void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
- lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
- const lsp::Position &completePos);
+ llvm::lsp::CompletionList
+ getCodeCompletion(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &completePos);
//===--------------------------------------------------------------------===//
// Signature Help
//===--------------------------------------------------------------------===//
- lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
- const lsp::Position &helpPos);
+ llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &helpPos);
//===--------------------------------------------------------------------===//
// Inlay Hints
//===--------------------------------------------------------------------===//
- void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
- std::vector<lsp::InlayHint> &inlayHints);
+ void getInlayHints(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Range &range,
+ std::vector<llvm::lsp::InlayHint> &inlayHints);
void getInlayHintsFor(const ast::VariableDecl *decl,
- const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints);
- void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints);
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints);
+ void getInlayHintsFor(const ast::CallExpr *expr,
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints);
void getInlayHintsFor(const ast::OperationExpr *expr,
- const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints);
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints);
/// Add a parameter hint for the given expression using `label`.
- void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
+ void addParameterHintFor(std::vector<llvm::lsp::InlayHint> &inlayHints,
const ast::Expr *expr, StringRef label);
//===--------------------------------------------------------------------===//
@@ -372,13 +384,14 @@ struct PDLDocument {
};
} // namespace
-PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
+PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
const std::vector<std::string> &extraDirs,
- std::vector<lsp::Diagnostic> &diagnostics)
+ std::vector<llvm::lsp::Diagnostic> &diagnostics)
: astContext(odsContext) {
auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
if (!memBuffer) {
- lsp::Logger::error("Failed to create memory buffer for file", uri.file());
+ llvm::lsp::Logger::error("Failed to create memory buffer for file",
+ uri.file());
return;
}
@@ -412,9 +425,9 @@ PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
// PDLDocument: Definitions and References
//===----------------------------------------------------------------------===//
-void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
- const lsp::Position &defPos,
- std::vector<lsp::Location> &locations) {
+void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &defPos,
+ std::vector<llvm::lsp::Location> &locations) {
SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
const PDLIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@@ -423,9 +436,9 @@ void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri));
}
-void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
- const lsp::Position &pos,
- std::vector<lsp::Location> &references) {
+void PDLDocument::findReferencesOf(
+ const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos,
+ std::vector<llvm::lsp::Location> &references) {
SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
const PDLIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@@ -440,8 +453,9 @@ void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
// PDLDocument: Document Links
//===--------------------------------------------------------------------===//
-void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links) {
+void PDLDocument::getDocumentLinks(
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links) {
for (const lsp::SourceMgrInclude &include : parsedIncludes)
links.emplace_back(include.range, include.uri);
}
@@ -450,9 +464,9 @@ void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
// PDLDocument: Hover
//===----------------------------------------------------------------------===//
-std::optional<lsp::Hover>
-PDLDocument::findHover(const lsp::URIForFile &uri,
- const lsp::Position &hoverPos) {
+std::optional<llvm::lsp::Hover>
+PDLDocument::findHover(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &hoverPos) {
SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
// Check for a reference to an include.
@@ -474,8 +488,8 @@ PDLDocument::findHover(const lsp::URIForFile &uri,
return findHover(decl, hoverRange);
}
-std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
- const SMRange &hoverRange) {
+std::optional<llvm::lsp::Hover>
+PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) {
// Add hover for variables.
if (const auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
return buildHoverForVariable(varDecl, hoverRange);
@@ -499,9 +513,9 @@ std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
return std::nullopt;
}
-lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
- const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
+ const SMRange &hoverRange) {
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
@@ -511,9 +525,10 @@ lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
return hover;
}
-lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
- const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+llvm::lsp::Hover
+PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
+ const SMRange &hoverRange) {
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
@@ -522,9 +537,9 @@ lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
return hover;
}
-lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
- const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
+ const SMRange &hoverRange) {
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**Pattern**";
@@ -545,10 +560,10 @@ lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
return hover;
}
-lsp::Hover
+llvm::lsp::Hover
PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**Constraint**: `";
@@ -573,9 +588,9 @@ PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
}
template <typename T>
-lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
+llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
StringRef typeName, const T *decl, const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
@@ -617,7 +632,7 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
//===----------------------------------------------------------------------===//
void PDLDocument::findDocumentSymbols(
- std::vector<lsp::DocumentSymbol> &symbols) {
+ std::vector<llvm::lsp::DocumentSymbol> &symbols) {
if (failed(astModule))
return;
@@ -631,25 +646,28 @@ void PDLDocument::findDocumentSymbols(
SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
- symbols.emplace_back(
- name ? name->getName() : "<pattern>", lsp::SymbolKind::Class,
- lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
+ symbols.emplace_back(name ? name->getName() : "<pattern>",
+ llvm::lsp::SymbolKind::Class,
+ llvm::lsp::Range(sourceMgr, bodyLoc),
+ llvm::lsp::Range(sourceMgr, nameLoc));
} else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(decl)) {
// TODO: Add source information for the code block body.
SMRange nameLoc = cDecl->getName().getLoc();
SMRange bodyLoc = nameLoc;
- symbols.emplace_back(
- cDecl->getName().getName(), lsp::SymbolKind::Function,
- lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
+ symbols.emplace_back(cDecl->getName().getName(),
+ llvm::lsp::SymbolKind::Function,
+ llvm::lsp::Range(sourceMgr, bodyLoc),
+ llvm::lsp::Range(sourceMgr, nameLoc));
} else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(decl)) {
// TODO: Add source information for the code block body.
SMRange nameLoc = cDecl->getName().getLoc();
SMRange bodyLoc = nameLoc;
- symbols.emplace_back(
- cDecl->getName().getName(), lsp::SymbolKind::Function,
- lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc));
+ symbols.emplace_back(cDecl->getName().getName(),
+ llvm::lsp::SymbolKind::Function,
+ llvm::lsp::Range(sourceMgr, bodyLoc),
+ llvm::lsp::Range(sourceMgr, nameLoc));
}
}
}
@@ -662,7 +680,7 @@ namespace {
class LSPCodeCompleteContext : public CodeCompleteContext {
public:
LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
- lsp::CompletionList &completionList,
+ llvm::lsp::CompletionList &completionList,
ods::Context &odsContext,
ArrayRef<std::string> includeDirs)
: CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
@@ -674,13 +692,13 @@ public:
ArrayRef<StringRef> elementNames = tupleType.getElementNames();
for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
// Push back a completion item that uses the result index.
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = llvm::formatv("{0} (field #{0})", i).str();
item.insertText = Twine(i).str();
item.filterText = item.sortText = item.insertText;
- item.kind = lsp::CompletionItemKind::Field;
+ item.kind = llvm::lsp::CompletionItemKind::Field;
item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]);
- item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
completionList.items.emplace_back(item);
// If the element has a name, push back a completion item with that name.
@@ -705,11 +723,11 @@ public:
const ods::TypeConstraint &constraint = result.getConstraint();
// Push back a completion item that uses the result index.
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = llvm::formatv("{0} (field #{0})", it.index()).str();
item.insertText = Twine(it.index()).str();
item.filterText = item.sortText = item.insertText;
- item.kind = lsp::CompletionItemKind::Field;
+ item.kind = llvm::lsp::CompletionItemKind::Field;
switch (result.getVariableLengthKind()) {
case ods::VariableLengthKind::Single:
item.detail = llvm::formatv("{0}: Value", it.index()).str();
@@ -721,12 +739,12 @@ public:
item.detail = llvm::formatv("{0}: ValueRange", it.index()).str();
break;
}
- item.documentation = lsp::MarkupContent{
- lsp::MarkupKind::Markdown,
+ item.documentation = llvm::lsp::MarkupContent{
+ llvm::lsp::MarkupKind::Markdown,
llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
constraint.getCppClass())
.str()};
- item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
completionList.items.emplace_back(item);
// If the result has a name, push back a completion item with the result
@@ -750,16 +768,16 @@ public:
for (const ods::Attribute &attr : odsOp->getAttributes()) {
const ods::AttributeConstraint &constraint = attr.getConstraint();
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = attr.getName().str();
- item.kind = lsp::CompletionItemKind::Field;
+ item.kind = llvm::lsp::CompletionItemKind::Field;
item.detail = attr.isOptional() ? "optional" : "";
- item.documentation = lsp::MarkupContent{
- lsp::MarkupKind::Markdown,
+ item.documentation = llvm::lsp::MarkupContent{
+ llvm::lsp::MarkupKind::Markdown,
llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(),
constraint.getCppClass())
.str()};
- item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
completionList.items.emplace_back(item);
}
}
@@ -769,18 +787,18 @@ public:
const ast::DeclScope *scope) final {
auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
StringRef snippetText = "") {
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = constraint.str();
- item.kind = lsp::CompletionItemKind::Class;
+ item.kind = llvm::lsp::CompletionItemKind::Class;
item.detail = (constraint + " constraint").str();
- item.documentation = lsp::MarkupContent{
- lsp::MarkupKind::Markdown,
+ item.documentation = llvm::lsp::MarkupContent{
+ llvm::lsp::MarkupKind::Markdown,
("A single entity core constraint of type `" + mlirType + "`").str()};
item.sortText = "0";
item.insertText = snippetText.str();
item.insertTextFormat = snippetText.empty()
- ? lsp::InsertTextFormat::PlainText
- : lsp::InsertTextFormat::Snippet;
+ ? llvm::lsp::InsertTextFormat::PlainText
+ : llvm::lsp::InsertTextFormat::Snippet;
completionList.items.emplace_back(item);
};
@@ -812,9 +830,9 @@ public:
while (scope) {
for (const ast::Decl *decl : scope->getDecls()) {
if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = cst->getName().getName().str();
- item.kind = lsp::CompletionItemKind::Interface;
+ item.kind = llvm::lsp::CompletionItemKind::Interface;
item.sortText = "2_" + item.label;
// Skip constraints that are not single-arg. We currently only
@@ -841,8 +859,8 @@ public:
// Format the documentation for the constraint.
if (std::optional<std::string> doc =
getDocumentationFor(sourceMgr, cst)) {
- item.documentation =
- lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)};
+ item.documentation = llvm::lsp::MarkupContent{
+ llvm::lsp::MarkupKind::Markdown, std::move(*doc)};
}
completionList.items.emplace_back(item);
@@ -856,10 +874,10 @@ public:
void codeCompleteDialectName() final {
// Code complete known dialects.
for (const ods::Dialect &dialect : odsContext.getDialects()) {
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = dialect.getName().str();
- item.kind = lsp::CompletionItemKind::Class;
- item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ item.kind = llvm::lsp::CompletionItemKind::Class;
+ item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
completionList.items.emplace_back(item);
}
}
@@ -872,10 +890,10 @@ public:
for (const auto &it : dialect->getOperations()) {
const ods::Operation &op = *it.second;
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = op.getName().drop_front(dialectName.size() + 1).str();
- item.kind = lsp::CompletionItemKind::Field;
- item.insertTextFormat = lsp::InsertTextFormat::PlainText;
+ item.kind = llvm::lsp::CompletionItemKind::Field;
+ item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText;
completionList.items.emplace_back(item);
}
}
@@ -883,16 +901,16 @@ public:
void codeCompletePatternMetadata() final {
auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
StringRef snippetText = "") {
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = constraint.str();
- item.kind = lsp::CompletionItemKind::Class;
+ item.kind = llvm::lsp::CompletionItemKind::Class;
item.detail = "pattern metadata";
item.documentation =
- lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()};
+ llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()};
item.insertText = snippetText.str();
item.insertTextFormat = snippetText.empty()
- ? lsp::InsertTextFormat::PlainText
- : lsp::InsertTextFormat::Snippet;
+ ? llvm::lsp::InsertTextFormat::PlainText
+ : llvm::lsp::InsertTextFormat::Snippet;
completionList.items.emplace_back(item);
};
@@ -913,10 +931,10 @@ public:
// Functor used to add a single include completion item.
auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
- lsp::CompletionItem item;
+ llvm::lsp::CompletionItem item;
item.label = path.str();
- item.kind = isDirectory ? lsp::CompletionItemKind::Folder
- : lsp::CompletionItemKind::File;
+ item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder
+ : llvm::lsp::CompletionItemKind::File;
if (seenResults.insert(item.label).second)
completionList.items.emplace_back(item);
};
@@ -961,31 +979,31 @@ public:
// Sort the completion results to make sure the output is deterministic in
// the face of different iteration schemes for different platforms.
- llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs,
- const lsp::CompletionItem &rhs) {
+ llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs,
+ const llvm::lsp::CompletionItem &rhs) {
return lhs.label < rhs.label;
});
}
private:
llvm::SourceMgr &sourceMgr;
- lsp::CompletionList &completionList;
+ llvm::lsp::CompletionList &completionList;
ods::Context &odsContext;
ArrayRef<std::string> includeDirs;
};
} // namespace
-lsp::CompletionList
-PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
- const lsp::Position &completePos) {
+llvm::lsp::CompletionList
+PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &completePos) {
SMLoc posLoc = completePos.getAsSMLoc(sourceMgr);
if (!posLoc.isValid())
- return lsp::CompletionList();
+ return llvm::lsp::CompletionList();
// To perform code completion, we run another parse of the module with the
// code completion context provided.
ods::Context tmpODSContext;
- lsp::CompletionList completionList;
+ llvm::lsp::CompletionList completionList;
LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
tmpODSContext,
sourceMgr.getIncludeDirs());
@@ -1005,7 +1023,7 @@ namespace {
class LSPSignatureHelpContext : public CodeCompleteContext {
public:
LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
- lsp::SignatureHelp &signatureHelp,
+ llvm::lsp::SignatureHelp &signatureHelp,
ods::Context &odsContext)
: CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
signatureHelp(signatureHelp), odsContext(odsContext) {}
@@ -1014,7 +1032,7 @@ public:
unsigned currentNumArgs) final {
signatureHelp.activeParameter = currentNumArgs;
- lsp::SignatureInformation signatureInfo;
+ llvm::lsp::SignatureInformation signatureInfo;
{
llvm::raw_string_ostream strOS(signatureInfo.label);
strOS << callable->getName()->getName() << "(";
@@ -1022,7 +1040,7 @@ public:
unsigned paramStart = strOS.str().size();
strOS << var->getName().getName() << ": " << var->getType();
unsigned paramEnd = strOS.str().size();
- signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
+ signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()});
};
@@ -1070,7 +1088,7 @@ public:
// not more than what is defined in ODS, as this will result in an error
// anyways.
if (odsOp && currentValue < values.size()) {
- lsp::SignatureInformation signatureInfo;
+ llvm::lsp::SignatureInformation signatureInfo;
// Build the signature label.
{
@@ -1099,7 +1117,7 @@ public:
}
unsigned paramEnd = strOS.str().size();
- signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
+ signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
StringRef(strOS.str()).slice(paramStart, paramEnd).str(),
std::make_pair(paramStart, paramEnd), paramDoc});
};
@@ -1114,12 +1132,12 @@ public:
// If there aren't any arguments yet, we also add the generic signature.
if (currentValue == 0 && (!odsOp || !values.empty())) {
- lsp::SignatureInformation signatureInfo;
+ llvm::lsp::SignatureInformation signatureInfo;
signatureInfo.label =
llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str();
signatureInfo.documentation =
("Generic operation " + label + " specification").str();
- signatureInfo.parameters.emplace_back(lsp::ParameterInformation{
+ signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{
StringRef(signatureInfo.label).drop_front().drop_back().str(),
std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
("All of the " + label + "s of the operation.").str()});
@@ -1129,21 +1147,22 @@ public:
private:
llvm::SourceMgr &sourceMgr;
- lsp::SignatureHelp &signatureHelp;
+ llvm::lsp::SignatureHelp &signatureHelp;
ods::Context &odsContext;
};
} // namespace
-lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
- const lsp::Position &helpPos) {
+llvm::lsp::SignatureHelp
+PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &helpPos) {
SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr);
if (!posLoc.isValid())
- return lsp::SignatureHelp();
+ return llvm::lsp::SignatureHelp();
// To perform code completion, we run another parse of the module with the
// code completion context provided.
ods::Context tmpODSContext;
- lsp::SignatureHelp signatureHelp;
+ llvm::lsp::SignatureHelp signatureHelp;
LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
tmpODSContext);
@@ -1173,9 +1192,9 @@ static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
return true;
}
-void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
- const lsp::Range &range,
- std::vector<lsp::InlayHint> &inlayHints) {
+void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Range &range,
+ std::vector<llvm::lsp::InlayHint> &inlayHints) {
if (failed(astModule))
return;
SMRange rangeLoc = range.getAsSMRange(sourceMgr);
@@ -1198,9 +1217,9 @@ void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
});
}
-void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
- const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints) {
+void PDLDocument::getInlayHintsFor(
+ const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints) {
// Check to see if the variable has a constraint list, if it does we don't
// provide initializer hints.
if (!decl->getConstraints().empty())
@@ -1215,8 +1234,8 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
return;
}
- lsp::InlayHint hint(lsp::InlayHintKind::Type,
- lsp::Position(sourceMgr, decl->getLoc().End));
+ llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type,
+ llvm::lsp::Position(sourceMgr, decl->getLoc().End));
{
llvm::raw_string_ostream labelOS(hint.label);
labelOS << ": " << decl->getType();
@@ -1225,9 +1244,9 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
inlayHints.emplace_back(std::move(hint));
}
-void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
- const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints) {
+void PDLDocument::getInlayHintsFor(
+ const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints) {
// Try to extract the callable of this call.
const auto *callableRef = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
const auto *callable =
@@ -1242,9 +1261,9 @@ void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
std::get<1>(it)->getName().getName());
}
-void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
- const lsp::URIForFile &uri,
- std::vector<lsp::InlayHint> &inlayHints) {
+void PDLDocument::getInlayHintsFor(
+ const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::InlayHint> &inlayHints) {
// Check for ODS information.
ast::OperationType opType = dyn_cast<ast::OperationType>(expr->getType());
const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
@@ -1290,13 +1309,15 @@ void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
"results");
}
-void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
- const ast::Expr *expr, StringRef label) {
+void PDLDocument::addParameterHintFor(
+ std::vector<llvm::lsp::InlayHint> &inlayHints, const ast::Expr *expr,
+ StringRef label) {
if (!shouldAddHintFor(expr, label))
return;
- lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
- lsp::Position(sourceMgr, expr->getLoc().Start));
+ llvm::lsp::InlayHint hint(
+ llvm::lsp::InlayHintKind::Parameter,
+ llvm::lsp::Position(sourceMgr, expr->getLoc().Start));
hint.label = (label + ":").str();
hint.paddingRight = true;
inlayHints.emplace_back(std::move(hint));
@@ -1342,22 +1363,24 @@ void PDLDocument::getPDLLViewOutput(raw_ostream &os,
namespace {
/// This class represents a single chunk of an PDL text file.
struct PDLTextFileChunk {
- PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
+ PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri,
StringRef contents,
const std::vector<std::string> &extraDirs,
- std::vector<lsp::Diagnostic> &diagnostics)
+ std::vector<llvm::lsp::Diagnostic> &diagnostics)
: lineOffset(lineOffset),
document(uri, contents, extraDirs, diagnostics) {}
/// Adjust the line number of the given range to anchor at the beginning of
/// the file, instead of the beginning of this chunk.
- void adjustLocForChunkOffset(lsp::Range &range) {
+ void adjustLocForChunkOffset(llvm::lsp::Range &range) {
adjustLocForChunkOffset(range.start);
adjustLocForChunkOffset(range.end);
}
/// Adjust the line number of the given position to anchor at the beginning of
/// the file, instead of the beginning of this chunk.
- void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
+ void adjustLocForChunkOffset(llvm::lsp::Position &pos) {
+ pos.line += lineOffset;
+ }
/// The line offset of this chunk from the beginning of the file.
uint64_t lineOffset;
@@ -1374,38 +1397,41 @@ namespace {
/// This class represents a text file containing one or more PDL documents.
class PDLTextFile {
public:
- PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
+ PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents,
int64_t version, const std::vector<std::string> &extraDirs,
- std::vector<lsp::Diagnostic> &diagnostics);
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// Return the current version of this text file.
int64_t getVersion() const { return version; }
/// Update the file to the new version using the provided set of content
/// changes. Returns failure if the update was unsuccessful.
- LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
- ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
- std::vector<lsp::Diagnostic> &diagnostics);
+ LogicalResult
+ update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
//===--------------------------------------------------------------------===//
// LSP Queries
//===--------------------------------------------------------------------===//
- void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
- std::vector<lsp::Location> &locations);
- void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
- std::vector<lsp::Location> &references);
- void getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links);
- std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
- lsp::Position hoverPos);
- void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
- lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
- lsp::Position completePos);
- lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
- lsp::Position helpPos);
- void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
- std::vector<lsp::InlayHint> &inlayHints);
+ void getLocationsOf(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position defPos,
+ std::vector<llvm::lsp::Location> &locations);
+ void findReferencesOf(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position pos,
+ std::vector<llvm::lsp::Location> &references);
+ void getDocumentLinks(const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links);
+ std::optional<llvm::lsp::Hover> findHover(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position hoverPos);
+ void findDocumentSymbols(std::vector<llvm::lsp::DocumentSymbol> &symbols);
+ llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position completePos);
+ llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position helpPos);
+ void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range,
+ std::vector<llvm::lsp::InlayHint> &inlayHints);
lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
private:
@@ -1413,14 +1439,14 @@ private:
std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
/// Initialize the text file from the given file contents.
- void initialize(const lsp::URIForFile &uri, int64_t newVersion,
- std::vector<lsp::Diagnostic> &diagnostics);
+ void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// Find the PDL document that contains the given position, and update the
/// position to be anchored at the start of the found chunk instead of the
/// beginning of the file.
- ChunkIterator getChunkItFor(lsp::Position &pos);
- PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
+ ChunkIterator getChunkItFor(llvm::lsp::Position &pos);
+ PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) {
return *getChunkItFor(pos);
}
@@ -1442,20 +1468,21 @@ private:
};
} // namespace
-PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
- int64_t version,
+PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri,
+ StringRef fileContents, int64_t version,
const std::vector<std::string> &extraDirs,
- std::vector<lsp::Diagnostic> &diagnostics)
+ std::vector<llvm::lsp::Diagnostic> &diagnostics)
: contents(fileContents.str()), extraIncludeDirs(extraDirs) {
initialize(uri, version, diagnostics);
}
LogicalResult
-PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
- ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
- std::vector<lsp::Diagnostic> &diagnostics) {
- if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
- lsp::Logger::error("Failed to update contents of {0}", uri.file());
+PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
+ if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
+ contents))) {
+ llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file());
return failure();
}
@@ -1464,36 +1491,37 @@ PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
return success();
}
-void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
- lsp::Position defPos,
- std::vector<lsp::Location> &locations) {
+void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position defPos,
+ std::vector<llvm::lsp::Location> &locations) {
PDLTextFileChunk &chunk = getChunkFor(defPos);
chunk.document.getLocationsOf(uri, defPos, locations);
// Adjust any locations within this file for the offset of this chunk.
if (chunk.lineOffset == 0)
return;
- for (lsp::Location &loc : locations)
+ for (llvm::lsp::Location &loc : locations)
if (loc.uri == uri)
chunk.adjustLocForChunkOffset(loc.range);
}
-void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
- lsp::Position pos,
- std::vector<lsp::Location> &references) {
+void PDLTextFile::findReferencesOf(
+ const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos,
+ std::vector<llvm::lsp::Location> &references) {
PDLTextFileChunk &chunk = getChunkFor(pos);
chunk.document.findReferencesOf(uri, pos, references);
// Adjust any locations within this file for the offset of this chunk.
if (chunk.lineOffset == 0)
return;
- for (lsp::Location &loc : references)
+ for (llvm::lsp::Location &loc : references)
if (loc.uri == uri)
chunk.adjustLocForChunkOffset(loc.range);
}
-void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links) {
+void PDLTextFile::getDocumentLinks(
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links) {
chunks.front()->document.getDocumentLinks(uri, links);
for (const auto &it : llvm::drop_begin(chunks)) {
size_t currentNumLinks = links.size();
@@ -1506,10 +1534,12 @@ void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
}
}
-std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
- lsp::Position hoverPos) {
+std::optional<llvm::lsp::Hover>
+PDLTextFile::findHover(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position hoverPos) {
PDLTextFileChunk &chunk = getChunkFor(hoverPos);
- std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
+ std::optional<llvm::lsp::Hover> hoverInfo =
+ chunk.document.findHover(uri, hoverPos);
// Adjust any locations within this file for the offset of this chunk.
if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
@@ -1518,7 +1548,7 @@ std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
}
void PDLTextFile::findDocumentSymbols(
- std::vector<lsp::DocumentSymbol> &symbols) {
+ std::vector<llvm::lsp::DocumentSymbol> &symbols) {
if (chunks.size() == 1)
return chunks.front()->document.findDocumentSymbols(symbols);
@@ -1526,27 +1556,27 @@ void PDLTextFile::findDocumentSymbols(
// each chunk.
for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
PDLTextFileChunk &chunk = *chunks[i];
- lsp::Position startPos(chunk.lineOffset);
- lsp::Position endPos((i == e - 1) ? totalNumLines - 1
- : chunks[i + 1]->lineOffset);
- lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
- lsp::SymbolKind::Namespace,
- /*range=*/lsp::Range(startPos, endPos),
- /*selectionRange=*/lsp::Range(startPos));
+ llvm::lsp::Position startPos(chunk.lineOffset);
+ llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1
+ : chunks[i + 1]->lineOffset);
+ llvm::lsp::DocumentSymbol symbol(
+ "<file-split-" + Twine(i) + ">", llvm::lsp::SymbolKind::Namespace,
+ /*range=*/llvm::lsp::Range(startPos, endPos),
+ /*selectionRange=*/llvm::lsp::Range(startPos));
chunk.document.findDocumentSymbols(symbol.children);
// Fixup the locations of document symbols within this chunk.
if (i != 0) {
- SmallVector<lsp::DocumentSymbol *> symbolsToFix;
- for (lsp::DocumentSymbol &childSymbol : symbol.children)
+ SmallVector<llvm::lsp::DocumentSymbol *> symbolsToFix;
+ for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children)
symbolsToFix.push_back(&childSymbol);
while (!symbolsToFix.empty()) {
- lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
+ llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
chunk.adjustLocForChunkOffset(symbol->range);
chunk.adjustLocForChunkOffset(symbol->selectionRange);
- for (lsp::DocumentSymbol &childSymbol : symbol->children)
+ for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children)
symbolsToFix.push_back(&childSymbol);
}
}
@@ -1556,34 +1586,37 @@ void PDLTextFile::findDocumentSymbols(
}
}
-lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
- lsp::Position completePos) {
+llvm::lsp::CompletionList
+PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position completePos) {
PDLTextFileChunk &chunk = getChunkFor(completePos);
- lsp::CompletionList completionList =
+ llvm::lsp::CompletionList completionList =
chunk.document.getCodeCompletion(uri, completePos);
// Adjust any completion locations.
- for (lsp::CompletionItem &item : completionList.items) {
+ for (llvm::lsp::CompletionItem &item : completionList.items) {
if (item.textEdit)
chunk.adjustLocForChunkOffset(item.textEdit->range);
- for (lsp::TextEdit &edit : item.additionalTextEdits)
+ for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
chunk.adjustLocForChunkOffset(edit.range);
}
return completionList;
}
-lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
- lsp::Position helpPos) {
+llvm::lsp::SignatureHelp
+PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Position helpPos) {
return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos);
}
-void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
- std::vector<lsp::InlayHint> &inlayHints) {
+void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri,
+ llvm::lsp::Range range,
+ std::vector<llvm::lsp::InlayHint> &inlayHints) {
auto startIt = getChunkItFor(range.start);
auto endIt = getChunkItFor(range.end);
// Functor used to get the chunks for a given file, and fixup any locations
- auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
+ auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) {
size_t currentNumHints = inlayHints.size();
chunkIt->document.getInlayHints(uri, range, inlayHints);
@@ -1605,15 +1638,16 @@ void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
// Otherwise, the range is split between multiple chunks. The first chunk
// has the correct range start, but covers the total document.
- getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
+ getHintsForChunk(startIt,
+ llvm::lsp::Range(range.start, getNumLines(startIt)));
// Every chunk in between uses the full document.
for (++startIt; startIt != endIt; ++startIt)
- getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
+ getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt)));
// The range for the last chunk starts at the beginning of the document, up
// through the end of the input range.
- getHintsForChunk(startIt, lsp::Range(0, range.end));
+ getHintsForChunk(startIt, llvm::lsp::Range(0, range.end));
}
lsp::PDLLViewOutputResult
@@ -1632,8 +1666,9 @@ PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
return result;
}
-void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
- std::vector<lsp::Diagnostic> &diagnostics) {
+void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri,
+ int64_t newVersion,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
version = newVersion;
chunks.clear();
@@ -1653,7 +1688,7 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
// Adjust locations used in diagnostics to account for the offset from the
// beginning of the file.
- for (lsp::Diagnostic &diag :
+ for (llvm::lsp::Diagnostic &diag :
llvm::drop_begin(diagnostics, currentNumDiags)) {
chunk->adjustLocForChunkOffset(diag.range);
@@ -1668,14 +1703,15 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
totalNumLines = lineOffset;
}
-PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
+PDLTextFile::ChunkIterator
+PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) {
if (chunks.size() == 1)
return chunks.begin();
// Search for the first chunk with a greater line offset, the previous chunk
// is the one that contains `pos`.
auto it = llvm::upper_bound(
- chunks, pos, [](const lsp::Position &pos, const auto &chunk) {
+ chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) {
return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
});
ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
@@ -1710,9 +1746,9 @@ lsp::PDLLServer::PDLLServer(const Options &options)
: impl(std::make_unique<Impl>(options)) {}
lsp::PDLLServer::~PDLLServer() = default;
-void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
- int64_t version,
- std::vector<Diagnostic> &diagnostics) {
+void lsp::PDLLServer::addDocument(
+ const URIForFile &uri, StringRef contents, int64_t version,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
// Build the set of additional include directories.
std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file());
@@ -1724,7 +1760,7 @@ void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
void lsp::PDLLServer::updateDocument(
const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
- int64_t version, std::vector<Diagnostic> &diagnostics) {
+ int64_t version, std::vector<llvm::lsp::Diagnostic> &diagnostics) {
// Check that we actually have a document for this uri.
auto it = impl->files.find(uri.file());
if (it == impl->files.end())
@@ -1746,17 +1782,17 @@ std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
return version;
}
-void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
- const Position &defPos,
- std::vector<Location> &locations) {
+void lsp::PDLLServer::getLocationsOf(
+ const URIForFile &uri, const Position &defPos,
+ std::vector<llvm::lsp::Location> &locations) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->getLocationsOf(uri, defPos, locations);
}
-void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
- const Position &pos,
- std::vector<Location> &references) {
+void lsp::PDLLServer::findReferencesOf(
+ const URIForFile &uri, const Position &pos,
+ std::vector<llvm::lsp::Location> &references) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->findReferencesOf(uri, pos, references);
@@ -1769,8 +1805,8 @@ void lsp::PDLLServer::getDocumentLinks(
return fileIt->second->getDocumentLinks(uri, documentLinks);
}
-std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
- const Position &hoverPos) {
+std::optional<llvm::lsp::Hover>
+lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
return fileIt->second->findHover(uri, hoverPos);
@@ -1793,8 +1829,9 @@ lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
return CompletionList();
}
-lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
- const Position &helpPos) {
+llvm::lsp::SignatureHelp
+lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
+ const Position &helpPos) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
return fileIt->second->getSignatureHelp(uri, helpPos);
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
index 134431f..d82014d 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LSP/Protocol.h"
#include <memory>
#include <optional>
#include <string>
@@ -18,21 +19,22 @@
namespace mlir {
namespace lsp {
-struct Diagnostic;
+using llvm::lsp::CompletionList;
+using llvm::lsp::Diagnostic;
+using llvm::lsp::DocumentLink;
+using llvm::lsp::DocumentSymbol;
+using llvm::lsp::Hover;
+using llvm::lsp::InlayHint;
+using llvm::lsp::Location;
+using llvm::lsp::Position;
+using llvm::lsp::Range;
+using llvm::lsp::SignatureHelp;
+using llvm::lsp::TextDocumentContentChangeEvent;
+using llvm::lsp::URIForFile;
+
class CompilationDatabase;
struct PDLLViewOutputResult;
enum class PDLLViewOutputKind;
-struct CompletionList;
-struct DocumentLink;
-struct DocumentSymbol;
-struct Hover;
-struct InlayHint;
-struct Location;
-struct Position;
-struct Range;
-struct SignatureHelp;
-struct TextDocumentContentChangeEvent;
-class URIForFile;
/// This class implements all of the PDLL related functionality necessary for a
/// language server. This class allows for keeping the PDLL specific logic
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp
index 0c9896e..ace4605 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Protocol.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/JSON.h"
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h
index 0706316..a2775f8 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h
@@ -20,10 +20,12 @@
#ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
#define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
-#include "mlir/Tools/lsp-server-support/Protocol.h"
+#include "llvm/Support/LSP/Protocol.h"
namespace mlir {
namespace lsp {
+using llvm::lsp::URIForFile;
+
//===----------------------------------------------------------------------===//
// PDLLViewOutputParams
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt
index 80fc1ff..b21650e 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt
+++ b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS
Demangle
Support
TableGen
+ SupportLSP
)
llvm_add_library(TableGenLspServerLib
diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
index bb3c0a7..95a457f 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp
@@ -9,14 +9,33 @@
#include "LSPServer.h"
#include "TableGenServer.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Protocol.h"
+#include "llvm/Support/LSP/Transport.h"
#include <optional>
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::Callback;
+using llvm::lsp::DidChangeTextDocumentParams;
+using llvm::lsp::DidCloseTextDocumentParams;
+using llvm::lsp::DidOpenTextDocumentParams;
+using llvm::lsp::DocumentLinkParams;
+using llvm::lsp::Hover;
+using llvm::lsp::InitializedParams;
+using llvm::lsp::InitializeParams;
+using llvm::lsp::JSONTransport;
+using llvm::lsp::Location;
+using llvm::lsp::Logger;
+using llvm::lsp::MessageHandler;
+using llvm::lsp::NoParams;
+using llvm::lsp::OutgoingNotification;
+using llvm::lsp::PublishDiagnosticsParams;
+using llvm::lsp::ReferenceParams;
+using llvm::lsp::TextDocumentPositionParams;
+using llvm::lsp::TextDocumentSyncKind;
+
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h
index 501a9da..596688b 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h
+++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h
@@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
+namespace lsp {
+class JSONTransport;
+} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
-class JSONTransport;
class TableGenServer;
/// Run the main loop of the LSP server using the given TableGen server and
/// transport.
llvm::LogicalResult runTableGenLSPServer(TableGenServer &server,
- JSONTransport &transport);
+ llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp
index 21af78c..8014b8d 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp
@@ -9,14 +9,18 @@
#include "mlir/Tools/tblgen-lsp-server/TableGenLspServerMain.h"
#include "LSPServer.h"
#include "TableGenServer.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
+using llvm::lsp::JSONStreamStyle;
+using llvm::lsp::JSONTransport;
+using llvm::lsp::Logger;
+
LogicalResult mlir::TableGenLspServerMain(int argc, char **argv) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 5faeeae..3080b78 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -10,12 +10,12 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/LSP/Logging.h"
+#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
@@ -36,45 +36,49 @@ static SMRange convertTokenLocToRange(SMLoc loc) {
/// Returns a language server uri for the given source location. `mainFileURI`
/// corresponds to the uri for the main file of the source manager.
-static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
- const lsp::URIForFile &mainFileURI) {
+static llvm::lsp::URIForFile
+getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
+ const llvm::lsp::URIForFile &mainFileURI) {
int bufferId = mgr.FindBufferContainingLoc(loc);
if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
return mainFileURI;
- llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
- mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
+ llvm::Expected<llvm::lsp::URIForFile> fileForLoc =
+ llvm::lsp::URIForFile::fromFile(
+ mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
if (fileForLoc)
return *fileForLoc;
- lsp::Logger::error("Failed to create URI for include file: {0}",
- llvm::toString(fileForLoc.takeError()));
+ llvm::lsp::Logger::error("Failed to create URI for include file: {0}",
+ llvm::toString(fileForLoc.takeError()));
return mainFileURI;
}
/// Returns a language server location from the given source range.
-static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc,
- const lsp::URIForFile &uri) {
- return lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
- lsp::Range(mgr, loc));
+static llvm::lsp::Location
+getLocationFromLoc(SourceMgr &mgr, SMRange loc,
+ const llvm::lsp::URIForFile &uri) {
+ return llvm::lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
+ llvm::lsp::Range(mgr, loc));
}
-static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
- const lsp::URIForFile &uri) {
+static llvm::lsp::Location
+getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
+ const llvm::lsp::URIForFile &uri) {
return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri);
}
/// Convert the given TableGen diagnostic to the LSP form.
-static std::optional<lsp::Diagnostic>
+static std::optional<llvm::lsp::Diagnostic>
getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
- const lsp::URIForFile &uri) {
+ const llvm::lsp::URIForFile &uri) {
auto *sourceMgr = const_cast<SourceMgr *>(diag.getSourceMgr());
if (!sourceMgr || !diag.getLoc().isValid())
return std::nullopt;
- lsp::Diagnostic lspDiag;
+ llvm::lsp::Diagnostic lspDiag;
lspDiag.source = "tablegen";
lspDiag.category = "Parse Error";
// Try to grab a file location for this diagnostic.
- lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri);
+ llvm::lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri);
lspDiag.range = loc.range;
// Skip diagnostics that weren't emitted within the main file.
@@ -84,17 +88,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
// Convert the severity for the diagnostic.
switch (diag.getKind()) {
case SourceMgr::DK_Warning:
- lspDiag.severity = lsp::DiagnosticSeverity::Warning;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
break;
case SourceMgr::DK_Error:
- lspDiag.severity = lsp::DiagnosticSeverity::Error;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
break;
case SourceMgr::DK_Note:
// Notes are emitted separately from the main diagnostic, so we just treat
// them as remarks given that we can't determine the diagnostic to relate
// them to.
case SourceMgr::DK_Remark:
- lspDiag.severity = lsp::DiagnosticSeverity::Information;
+ lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
break;
}
lspDiag.message = diag.getMessage().str();
@@ -322,54 +326,59 @@ namespace {
/// This class represents a text file containing one or more TableGen documents.
class TableGenTextFile {
public:
- TableGenTextFile(const lsp::URIForFile &uri, StringRef fileContents,
+ TableGenTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents,
int64_t version,
const std::vector<std::string> &extraIncludeDirs,
- std::vector<lsp::Diagnostic> &diagnostics);
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// Return the current version of this text file.
int64_t getVersion() const { return version; }
/// Update the file to the new version using the provided set of content
/// changes. Returns failure if the update was unsuccessful.
- LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
- ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
- std::vector<lsp::Diagnostic> &diagnostics);
+ LogicalResult
+ update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
//===--------------------------------------------------------------------===//
// Definitions and References
//===--------------------------------------------------------------------===//
- void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
- std::vector<lsp::Location> &locations);
- void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
- std::vector<lsp::Location> &references);
+ void getLocationsOf(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &defPos,
+ std::vector<llvm::lsp::Location> &locations);
+ void findReferencesOf(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &pos,
+ std::vector<llvm::lsp::Location> &references);
//===--------------------------------------------------------------------===//
// Document Links
//===--------------------------------------------------------------------===//
- void getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links);
+ void getDocumentLinks(const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links);
//===--------------------------------------------------------------------===//
// Hover
//===--------------------------------------------------------------------===//
- std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
- const lsp::Position &hoverPos);
- lsp::Hover buildHoverForRecord(const Record *record,
- const SMRange &hoverRange);
- lsp::Hover buildHoverForTemplateArg(const Record *record,
+ std::optional<llvm::lsp::Hover>
+ findHover(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &hoverPos);
+ llvm::lsp::Hover buildHoverForRecord(const Record *record,
+ const SMRange &hoverRange);
+ llvm::lsp::Hover buildHoverForTemplateArg(const Record *record,
+ const RecordVal *value,
+ const SMRange &hoverRange);
+ llvm::lsp::Hover buildHoverForField(const Record *record,
const RecordVal *value,
const SMRange &hoverRange);
- lsp::Hover buildHoverForField(const Record *record, const RecordVal *value,
- const SMRange &hoverRange);
private:
/// Initialize the text file from the given file contents.
- void initialize(const lsp::URIForFile &uri, int64_t newVersion,
- std::vector<lsp::Diagnostic> &diagnostics);
+ void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// The full string contents of the file.
std::string contents;
@@ -395,9 +404,9 @@ private:
} // namespace
TableGenTextFile::TableGenTextFile(
- const lsp::URIForFile &uri, StringRef fileContents, int64_t version,
+ const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version,
const std::vector<std::string> &extraIncludeDirs,
- std::vector<lsp::Diagnostic> &diagnostics)
+ std::vector<llvm::lsp::Diagnostic> &diagnostics)
: contents(fileContents.str()), version(version) {
// Build the set of include directories for this file.
llvm::SmallString<32> uriDirectory(uri.file());
@@ -409,12 +418,13 @@ TableGenTextFile::TableGenTextFile(
initialize(uri, version, diagnostics);
}
-LogicalResult
-TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
- ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
- std::vector<lsp::Diagnostic> &diagnostics) {
- if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
- lsp::Logger::error("Failed to update contents of {0}", uri.file());
+LogicalResult TableGenTextFile::update(
+ const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
+ if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
+ contents))) {
+ llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file());
return failure();
}
@@ -423,9 +433,9 @@ TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
return success();
}
-void TableGenTextFile::initialize(const lsp::URIForFile &uri,
- int64_t newVersion,
- std::vector<lsp::Diagnostic> &diagnostics) {
+void TableGenTextFile::initialize(
+ const llvm::lsp::URIForFile &uri, int64_t newVersion,
+ std::vector<llvm::lsp::Diagnostic> &diagnostics) {
version = newVersion;
sourceMgr = SourceMgr();
recordKeeper = std::make_unique<RecordKeeper>();
@@ -433,7 +443,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// Build a buffer for this file.
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file());
if (!memBuffer) {
- lsp::Logger::error("Failed to create memory buffer for file", uri.file());
+ llvm::lsp::Logger::error("Failed to create memory buffer for file",
+ uri.file());
return;
}
sourceMgr.setIncludeDirs(includeDirs);
@@ -442,8 +453,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// This class provides a context argument for the SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
- std::vector<lsp::Diagnostic> &diagnostics;
- const lsp::URIForFile &uri;
+ std::vector<llvm::lsp::Diagnostic> &diagnostics;
+ const llvm::lsp::URIForFile &uri;
} handlerContext{diagnostics, uri};
// Set the diagnostic handler for the tablegen source manager.
@@ -469,9 +480,9 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// TableGenTextFile: Definitions and References
//===----------------------------------------------------------------------===//
-void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri,
- const lsp::Position &defPos,
- std::vector<lsp::Location> &locations) {
+void TableGenTextFile::getLocationsOf(
+ const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &defPos,
+ std::vector<llvm::lsp::Location> &locations) {
SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
const TableGenIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@@ -492,8 +503,8 @@ void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri,
}
void TableGenTextFile::findReferencesOf(
- const lsp::URIForFile &uri, const lsp::Position &pos,
- std::vector<lsp::Location> &references) {
+ const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos,
+ std::vector<llvm::lsp::Location> &references) {
SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
const TableGenIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@@ -508,8 +519,9 @@ void TableGenTextFile::findReferencesOf(
// TableGenTextFile: Document Links
//===--------------------------------------------------------------------===//
-void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri,
- std::vector<lsp::DocumentLink> &links) {
+void TableGenTextFile::getDocumentLinks(
+ const llvm::lsp::URIForFile &uri,
+ std::vector<llvm::lsp::DocumentLink> &links) {
for (const lsp::SourceMgrInclude &include : parsedIncludes)
links.emplace_back(include.range, include.uri);
}
@@ -518,9 +530,9 @@ void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri,
// TableGenTextFile: Hover
//===----------------------------------------------------------------------===//
-std::optional<lsp::Hover>
-TableGenTextFile::findHover(const lsp::URIForFile &uri,
- const lsp::Position &hoverPos) {
+std::optional<llvm::lsp::Hover>
+TableGenTextFile::findHover(const llvm::lsp::URIForFile &uri,
+ const llvm::lsp::Position &hoverPos) {
// Check for a reference to an include.
for (const lsp::SourceMgrInclude &include : parsedIncludes)
if (include.range.contains(hoverPos))
@@ -546,9 +558,10 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri,
return buildHoverForField(recordVal->record, value, hoverRange);
}
-lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
- const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+llvm::lsp::Hover
+TableGenTextFile::buildHoverForRecord(const Record *record,
+ const SMRange &hoverRange) {
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
@@ -590,9 +603,9 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
return hover;
}
-lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
+llvm::lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
const Record *record, const RecordVal *value, const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
StringRef name = value->getName().rsplit(':').second;
@@ -604,10 +617,9 @@ lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
return hover;
}
-lsp::Hover TableGenTextFile::buildHoverForField(const Record *record,
- const RecordVal *value,
- const SMRange &hoverRange) {
- lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
+llvm::lsp::Hover TableGenTextFile::buildHoverForField(
+ const Record *record, const RecordVal *value, const SMRange &hoverRange) {
+ llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**field** `" << value->getName() << "`\n***\nType: `";
@@ -722,7 +734,7 @@ void lsp::TableGenServer::getDocumentLinks(
return fileIt->second->getDocumentLinks(uri, documentLinks);
}
-std::optional<lsp::Hover>
+std::optional<llvm::lsp::Hover>
lsp::TableGenServer::findHover(const URIForFile &uri,
const Position &hoverPos) {
auto fileIt = impl->files.find(uri.file());
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h
index bdc8510..e54b8bc 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/LSP/Protocol.h"
#include <memory>
#include <optional>
#include <string>
@@ -18,13 +19,13 @@
namespace mlir {
namespace lsp {
-struct Diagnostic;
-struct DocumentLink;
-struct Hover;
-struct Location;
-struct Position;
-struct TextDocumentContentChangeEvent;
-class URIForFile;
+using llvm::lsp::Diagnostic;
+using llvm::lsp::DocumentLink;
+using llvm::lsp::Hover;
+using llvm::lsp::Location;
+using llvm::lsp::Position;
+using llvm::lsp::TextDocumentContentChangeEvent;
+using llvm::lsp::URIForFile;
/// This class implements all of the TableGen related functionality necessary
/// for a language server. This class allows for keeping the TableGen specific
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b..df9700f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3406,10 +3406,19 @@ void TypeConverter::SignatureConversion::remapInput(
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
}
-LogicalResult TypeConverter::convertType(Type t,
- SmallVectorImpl<Type> &results) const {
- assert(t && "expected non-null type");
-
+/// Internal implementation of the type conversion.
+/// This is used with either a Type or a Value as the first argument.
+/// - we can cache the context-free conversions until the last registered
+/// context-aware conversion.
+/// - we can't cache the result of type conversion happening after context-aware
+/// conversions, because the type converter may return different results for the
+/// same input type.
+LogicalResult
+TypeConverter::convertTypeImpl(PointerUnion<Type, Value> typeOrValue,
+ SmallVectorImpl<Type> &results) const {
+ assert(typeOrValue && "expected non-null type");
+ Type t = (isa<Value>(typeOrValue)) ? cast<Value>(typeOrValue).getType()
+ : cast<Type>(typeOrValue);
{
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
std::defer_lock);
@@ -3431,52 +3440,53 @@ LogicalResult TypeConverter::convertType(Type t,
// registered first.
size_t currentCount = results.size();
+ // We can cache the context-free conversions until the last registered
+ // context-aware conversion. But only if we're processing a Value right now.
+ auto isCacheable = [&](int index) {
+ int numberOfConversionsUntilContextAware =
+ conversions.size() - 1 - contextAwareTypeConversionsIndex;
+ return index < numberOfConversionsUntilContextAware;
+ };
+
std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
std::defer_lock);
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(t, results)) {
- if (t.getContext()->isMultithreadingEnabled())
- cacheWriteLock.lock();
- if (!succeeded(*result)) {
- assert(results.size() == currentCount &&
- "failed type conversion should not change results");
- cachedDirectConversions.try_emplace(t, nullptr);
- return failure();
- }
- auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
- if (newTypes.size() == 1)
- cachedDirectConversions.try_emplace(t, newTypes.front());
- else
- cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ for (auto indexedConverter : llvm::enumerate(llvm::reverse(conversions))) {
+ const ConversionCallbackFn &converter = indexedConverter.value();
+ std::optional<LogicalResult> result = converter(typeOrValue, results);
+ if (!result) {
+ assert(results.size() == currentCount &&
+ "failed type conversion should not change results");
+ continue;
+ }
+ if (!isCacheable(indexedConverter.index()))
return success();
- } else {
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheWriteLock.lock();
+ if (!succeeded(*result)) {
assert(results.size() == currentCount &&
"failed type conversion should not change results");
+ cachedDirectConversions.try_emplace(t, nullptr);
+ return failure();
}
+ auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+ if (newTypes.size() == 1)
+ cachedDirectConversions.try_emplace(t, newTypes.front());
+ else
+ cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+ return success();
}
return failure();
}
-LogicalResult TypeConverter::convertType(Value v,
+LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
- assert(v && "expected non-null value");
-
- // If this type converter does not have context-aware type conversions, call
- // the type-based overload, which has caching.
- if (!hasContextAwareTypeConversions)
- return convertType(v.getType(), results);
+ return convertTypeImpl(t, results);
+}
- // Walk the added converters in reverse order to apply the most recently
- // registered first.
- for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result = converter(v, results)) {
- if (!succeeded(*result))
- return failure();
- return success();
- }
- }
- return failure();
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ return convertTypeImpl(v, results);
}
Type TypeConverter::convertType(Type t) const {
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8e79494..c983914 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -1,9 +1,5 @@
include(AddMLIRPython)
-# Specifies that all MLIR packages are co-located under the `mlir_standalone`
-# top level package (the API has been embedded in a relocatable way).
-add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
-
################################################################################
# Structural groupings.
################################################################################
@@ -27,6 +23,11 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
+
+ # The main _mlir module has submodules: include stubs from each.
+ _mlir_libs/_mlir/__init__.pyi
+ _mlir_libs/_mlir/ir.pyi
+ _mlir_libs/_mlir/passmanager.pyi
)
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
@@ -42,6 +43,7 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
ADD_TO_PARENT MLIRPythonSources
SOURCES
execution_engine.py
+ _mlir_libs/_mlirExecutionEngine.pyi
SOURCES_GLOB
runtime/*.py
)
@@ -193,6 +195,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TransformOps.td
SOURCES
dialects/transform/__init__.py
+ _mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -364,7 +367,8 @@ declare_mlir_python_sources(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
GEN_ENUM_BINDINGS
SOURCES
- dialects/quant.py)
+ dialects/quant.py
+ _mlir_libs/_mlir/dialects/quant.pyi)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -380,6 +384,7 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
+ _mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
declare_mlir_dialect_python_bindings(
@@ -505,11 +510,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
# Dialects
MLIRCAPIFunc
- GENERATE_TYPE_STUBS
- "_mlir/__init__.pyi"
- "_mlir/ir.pyi"
- "_mlir/passmanager.pyi"
- "_mlir/rewrite.pyi"
)
# This extension exposes an API to register all dialects, extensions, and passes
@@ -531,8 +531,6 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
MLIRCAPIConversion
MLIRCAPITransforms
MLIRCAPIRegisterEverything
- GENERATE_TYPE_STUBS
- "_mlirRegisterEverything.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
@@ -547,8 +545,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPILinalg
- GENERATE_TYPE_STUBS
- "_mlirDialectsLinalg.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
@@ -563,8 +559,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIGPU
- GENERATE_TYPE_STUBS
- "_mlirDialectsGPU.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
@@ -579,8 +573,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPILLVM
- GENERATE_TYPE_STUBS
- "_mlirDialectsLLVM.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
@@ -595,8 +587,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIQuant
- GENERATE_TYPE_STUBS
- "_mlirDialectsQuant.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
@@ -611,8 +601,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPINVGPU
- GENERATE_TYPE_STUBS
- "_mlirDialectsNVGPU.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
@@ -627,8 +615,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIPDL
- GENERATE_TYPE_STUBS
- "_mlirDialectsPDL.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
@@ -643,8 +629,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPISparseTensor
- GENERATE_TYPE_STUBS
- "_mlirDialectsSparseTensor.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
@@ -659,8 +643,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect
- GENERATE_TYPE_STUBS
- "_mlirDialectsTransform.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
@@ -674,8 +656,6 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIAsync
- GENERATE_TYPE_STUBS
- "_mlirAsyncPasses.pyi"
)
if(MLIR_ENABLE_EXECUTION_ENGINE)
@@ -690,8 +670,6 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIExecutionEngine
- GENERATE_TYPE_STUBS
- "_mlirExecutionEngine.pyi"
)
endif()
@@ -706,8 +684,6 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIGPU
- GENERATE_TYPE_STUBS
- "_mlirGPUPasses.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
@@ -721,8 +697,6 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPILinalg
- GENERATE_TYPE_STUBS
- "_mlirLinalgPasses.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
@@ -740,8 +714,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
MLIRCAPIIR
MLIRCAPISMT
MLIRCAPIExportSMTLIB
- GENERATE_TYPE_STUBS
- "_mlirDialectsSMT.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
@@ -755,8 +727,6 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPISparseTensor
- GENERATE_TYPE_STUBS
- "_mlirSparseTensorPasses.pyi"
)
declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
@@ -770,8 +740,6 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPITransformDialectTransforms
- GENERATE_TYPE_STUBS
- "_mlirTransformInterpreter.pyi"
)
# TODO: Figure out how to put this in the test tree.
@@ -830,8 +798,6 @@ if(MLIR_INCLUDE_TESTS)
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIPythonTestDialect
- GENERATE_TYPE_STUBS
- "_mlirPythonTestNanobind.pyi"
)
endif()
@@ -851,7 +817,7 @@ endif()
add_mlir_python_common_capi_library(MLIRPythonCAPI
INSTALL_COMPONENT MLIRPythonModules
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
- OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
+ OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs"
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_HEADERS
MLIRPythonCAPI.HeaderSources
@@ -880,7 +846,7 @@ endif()
################################################################################
add_mlir_python_modules(MLIRPythonModules
- ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
+ ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir"
INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
DECLARED_SOURCES
MLIRPythonSources
diff --git a/mlir/python/mlir/_mlir_libs/.gitignore b/mlir/python/mlir/_mlir_libs/.gitignore
deleted file mode 100644
index 8f0c82a..0000000
--- a/mlir/python/mlir/_mlir_libs/.gitignore
+++ /dev/null
@@ -1,2 +0,0 @@
-_mlir/**/*.pyi
-*.pyi
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
new file mode 100644
index 0000000..03449b7
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -0,0 +1,12 @@
+
+globals: "_Globals"
+
+class _Globals:
+ dialect_search_modules: list[str]
+ def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
+ def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
+ def append_dialect_search_prefix(self, module_name: str) -> None: ...
+ def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
+
+def register_dialect(dialect_class: type) -> type: ...
+def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
new file mode 100644
index 0000000..d12c683
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi
@@ -0,0 +1,63 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+from mlir.ir import Type, Context
+
+__all__ = [
+ 'PDLType',
+ 'AttributeType',
+ 'OperationType',
+ 'RangeType',
+ 'TypeType',
+ 'ValueType',
+]
+
+
+class PDLType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+
+class AttributeType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Context | None = None) -> AttributeType: ...
+
+
+class OperationType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Context | None = None) -> OperationType: ...
+
+
+class RangeType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(element_type: Type) -> RangeType: ...
+
+ @property
+ def element_type(self) -> Type: ...
+
+
+class TypeType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Context | None = None) -> TypeType: ...
+
+
+class ValueType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Context | None = None) -> ValueType: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
new file mode 100644
index 0000000..3f53045
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi
@@ -0,0 +1,142 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+from mlir.ir import DenseElementsAttr, Type
+
+__all__ = [
+ "QuantizedType",
+ "AnyQuantizedType",
+ "UniformQuantizedType",
+ "UniformQuantizedPerAxisType",
+ "CalibratedQuantizedType",
+]
+
+class QuantizedType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int:
+ ...
+
+ @staticmethod
+ def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int:
+ ...
+
+ @property
+ def expressed_type(self) -> Type: ...
+
+ @property
+ def flags(self) -> int: ...
+
+ @property
+ def is_signed(self) -> bool: ...
+
+ @property
+ def storage_type(self) -> Type: ...
+
+ @property
+ def storage_type_min(self) -> int: ...
+
+ @property
+ def storage_type_max(self) -> int: ...
+
+ @property
+ def storage_type_integral_width(self) -> int: ...
+
+ def is_compatible_expressed_type(self, candidate: Type) -> bool: ...
+
+ @property
+ def quantized_element_type(self) -> Type: ...
+
+ def cast_from_storage_type(self, candidate: Type) -> Type: ...
+
+ @staticmethod
+ def cast_to_storage_type(type: Type) -> Type: ...
+
+ def cast_from_expressed_type(self, candidate: Type) -> Type: ...
+
+ @staticmethod
+ def cast_to_expressed_type(type: Type) -> Type: ...
+
+ def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ...
+
+
+class AnyQuantizedType(QuantizedType):
+
+ @classmethod
+ def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+ storage_type_min: int, storage_type_max: int) -> Type:
+ ...
+
+
+class UniformQuantizedType(QuantizedType):
+
+ @classmethod
+ def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+ scale: float, zero_point: int, storage_type_min: int,
+ storage_type_max: int) -> Type: ...
+
+ @property
+ def scale(self) -> float: ...
+
+ @property
+ def zero_point(self) -> int: ...
+
+ @property
+ def is_fixed_point(self) -> bool: ...
+
+
+class UniformQuantizedPerAxisType(QuantizedType):
+
+ @classmethod
+ def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+ scales: list[float], zero_points: list[int], quantized_dimension: int,
+ storage_type_min: int, storage_type_max: int):
+ ...
+
+ @property
+ def scales(self) -> list[float]: ...
+
+ @property
+ def zero_points(self) -> list[int]: ...
+
+ @property
+ def quantized_dimension(self) -> int: ...
+
+ @property
+ def is_fixed_point(self) -> bool: ...
+
+class UniformQuantizedSubChannelType(QuantizedType):
+
+ @classmethod
+ def get(cls, flags: int, storage_type: Type, expressed_type: Type,
+ scales: DenseElementsAttr, zero_points: DenseElementsAttr,
+ quantized_dimensions: list[int], block_sizes: list[int],
+ storage_type_min: int, storage_type_max: int):
+ ...
+
+ @property
+ def quantized_dimensions(self) -> list[int]: ...
+
+ @property
+ def block_sizes(self) -> list[int]: ...
+
+ @property
+ def scales(self) -> DenseElementsAttr: ...
+
+ @property
+ def zero_points(self) -> DenseElementsAttr: ...
+
+def CalibratedQuantizedType(QuantizedType):
+
+ @classmethod
+ def get(cls, expressed_type: Type, min: float, max: float): ...
+
+ @property
+ def min(self) -> float: ...
+
+ @property
+ def max(self) -> float: ... \ No newline at end of file
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
new file mode 100644
index 0000000..a3f1b09
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
@@ -0,0 +1,25 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+
+from mlir.ir import Type, Context
+
+
+class AnyOpType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Context | None = None) -> AnyOpType: ...
+
+
+class OperationType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(operation_name: str, context: Context | None = None) -> OperationType: ...
+
+ @property
+ def operation_name(self) -> str: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
new file mode 100644
index 0000000..dcae3dd
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -0,0 +1,2846 @@
+# Originally imported via:
+# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir
+# but with the following diff (in order to remove pipes from types,
+# which we won't support until bumping minimum python to 3.10)
+#
+# --------------------- diff begins ------------------------------------
+#
+# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py
+# index 1f755aa..4924927 100644
+# --- a/pybind11_stubgen/printer.py
+# +++ b/pybind11_stubgen/printer.py
+# @@ -283,14 +283,6 @@ class Printer:
+# return split[0] + "..."
+#
+# def print_type(self, type_: ResolvedType) -> str:
+# - if (
+# - str(type_.name) == "typing.Optional"
+# - and type_.parameters is not None
+# - and len(type_.parameters) == 1
+# - ):
+# - return f"{self.print_annotation(type_.parameters[0])} | None"
+# - if str(type_.name) == "typing.Union" and type_.parameters is not None:
+# - return " | ".join(self.print_annotation(p) for p in type_.parameters)
+# if type_.parameters:
+# param_str = (
+# "["
+#
+# --------------------- diff ends ------------------------------------
+#
+# Local modifications:
+# * Rewrite references to 'mlir.ir' to local types.
+# * Drop `typing.` everywhere (top-level import instead).
+# * List -> List, dict -> Dict, Tuple -> Tuple.
+# * copy-paste Buffer type from typing_extensions.
+# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top.
+# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`.
+# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute.
+# * Local edits to signatures and types that pybind11-stubgen did not auto detect (or detected incorrectly).
+# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand.
+# * Fill in `Any`s where possible.
+# * black formatting.
+
+from __future__ import annotations
+
+import abc
+import collections
+from collections.abc import Callable, Sequence
+from pathlib import Path
+from typing import Any, BinaryIO, ClassVar, Literal, TypeVar, overload
+
+__all__ = [
+ "AffineAddExpr",
+ "AffineBinaryExpr",
+ "AffineCeilDivExpr",
+ "AffineConstantExpr",
+ "AffineDimExpr",
+ "AffineExpr",
+ "AffineExprList",
+ "AffineFloorDivExpr",
+ "AffineMap",
+ "AffineMapAttr",
+ "AffineModExpr",
+ "AffineMulExpr",
+ "AffineSymbolExpr",
+ "ArrayAttr",
+ "ArrayAttributeIterator",
+ "AsmState",
+ "AttrBuilder",
+ "Attribute",
+ "BF16Type",
+ "Block",
+ "BlockArgument",
+ "BlockArgumentList",
+ "BlockIterator",
+ "BlockList",
+ "BoolAttr",
+ "ComplexType",
+ "Context",
+ "DenseBoolArrayAttr",
+ "DenseBoolArrayIterator",
+ "DenseElementsAttr",
+ "DenseF32ArrayAttr",
+ "DenseF32ArrayIterator",
+ "DenseF64ArrayAttr",
+ "DenseF64ArrayIterator",
+ "DenseFPElementsAttr",
+ "DenseI16ArrayAttr",
+ "DenseI16ArrayIterator",
+ "DenseI32ArrayAttr",
+ "DenseI32ArrayIterator",
+ "DenseI64ArrayAttr",
+ "DenseI64ArrayIterator",
+ "DenseI8ArrayAttr",
+ "DenseI8ArrayIterator",
+ "DenseIntElementsAttr",
+ "DenseResourceElementsAttr",
+ "Diagnostic",
+ "DiagnosticHandler",
+ "DiagnosticInfo",
+ "DiagnosticSeverity",
+ "Dialect",
+ "DialectDescriptor",
+ "DialectRegistry",
+ "Dialects",
+ "DictAttr",
+ "F16Type",
+ "F32Type",
+ "F64Type",
+ "FlatSymbolRefAttr",
+ "Float4E2M1FNType",
+ "Float6E2M3FNType",
+ "Float6E3M2FNType",
+ "Float8E3M4Type",
+ "Float8E4M3B11FNUZType",
+ "Float8E4M3FNType",
+ "Float8E4M3FNUZType",
+ "Float8E4M3Type",
+ "Float8E5M2FNUZType",
+ "Float8E5M2Type",
+ "Float8E8M0FNUType",
+ "FloatAttr",
+ "FloatTF32Type",
+ "FloatType",
+ "FunctionType",
+ "IndexType",
+ "InferShapedTypeOpInterface",
+ "InferTypeOpInterface",
+ "InsertionPoint",
+ "IntegerAttr",
+ "IntegerSet",
+ "IntegerSetAttr",
+ "IntegerSetConstraint",
+ "IntegerSetConstraintList",
+ "IntegerType",
+ "Location",
+ "MemRefType",
+ "Module",
+ "NamedAttribute",
+ "NoneType",
+ "OpAttributeMap",
+ "OpOperand",
+ "OpOperandIterator",
+ "OpOperandList",
+ "OpResult",
+ "OpResultList",
+ "OpSuccessors",
+ "OpView",
+ "OpaqueAttr",
+ "OpaqueType",
+ "Operation",
+ "OperationIterator",
+ "OperationList",
+ "RankedTensorType",
+ "Region",
+ "RegionIterator",
+ "RegionSequence",
+ "ShapedType",
+ "ShapedTypeComponents",
+ "StridedLayoutAttr",
+ "StringAttr",
+ "SymbolRefAttr",
+ "SymbolTable",
+ "TupleType",
+ "Type",
+ "TypeAttr",
+ "TypeID",
+ "UnitAttr",
+ "UnrankedMemRefType",
+ "UnrankedTensorType",
+ "Value",
+ "VectorType",
+ "_GlobalDebug",
+ "_OperationBase",
+]
+
+if hasattr(collections.abc, "Buffer"):
+ Buffer = collections.abc.Buffer
+else:
+ class Buffer(abc.ABC):
+ pass
+
+class _OperationBase:
+ @overload
+ def __eq__(self, arg0: _OperationBase) -> bool: ...
+ @overload
+ def __eq__(self, arg0: _OperationBase) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the operation.
+ """
+ def clone(self, ip: InsertionPoint = None) -> OpView: ...
+ def detach_from_parent(self) -> OpView:
+ """
+ Detaches the operation from its parent block.
+ """
+
+ @property
+ def attached(self) -> bool:
+ """
+ Reports if the operation is attached to its parent block.
+ """
+
+ def erase(self) -> None: ...
+
+ @overload
+ def get_asm(
+ binary: Literal[True],
+ large_elements_limit: int | None = None,
+ large_resource_limit: int | None = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ skip_regions: bool = False,
+ ) -> bytes: ...
+ @overload
+ def get_asm(
+ self,
+ binary: bool = False,
+ large_elements_limit: int | None = None,
+ large_resource_limit: int | None = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ skip_regions: bool = False,
+ ) -> str:
+ """
+ Returns the assembly form of the operation.
+
+ See the print() method for common keyword arguments for configuring
+ the output.
+ """
+
+ def move_after(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately after the other operation in its parent block.
+ """
+ def move_before(self, other: _OperationBase) -> None:
+ """
+ Puts self immediately before the other operation in its parent block.
+ """
+ @overload
+ def print(
+ self,
+ state: AsmState,
+ file: Any | None = None,
+ binary: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ state: AsmState capturing the operation numbering and flags.
+ """
+ @overload
+ def print(
+ self,
+ large_elements_limit: int | None = None,
+ large_resource_limit: int | None = None,
+ enable_debug_info: bool = False,
+ pretty_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ use_local_scope: bool = False,
+ assume_verified: bool = False,
+ file: Any | None = None,
+ binary: bool = False,
+ skip_regions: bool = False,
+ ) -> None:
+ """
+ Prints the assembly form of the operation to a file like object.
+
+ Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ large_resource_limit: Whether to elide resource strings above this
+ number of characters. Defaults to None (no limit). If large_elements_limit
+ is set and this is None, the behavior will be to use large_elements_limit
+ as large_resource_limit.
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable).
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_Scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+ assume_verified: By default, if not printing generic form, the verifier
+ will be run and if it fails, generic form will be printed with a comment
+ about failed verification. While a reasonable default for interactive use,
+ for systematic use, it is often better for the caller to verify explicitly
+ and report failures in a more robust fashion. Set this to True if doing this
+ in order to avoid running a redundant verification. If the IR is actually
+ invalid, behavior is undefined.
+ skip_regions: Whether to skip printing regions. Defaults to False.
+ """
+ def verify(self) -> bool:
+ """
+ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
+ """
+ def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
+ """
+ Write the bytecode form of the operation to a file like object.
+
+ Args:
+ file: The file like object or path to write to.
+ desired_version: The version of bytecode to emit.
+ Returns:
+ The bytecode writer status.
+ """
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def attributes(self) -> OpAttributeMap: ...
+ @property
+ def context(self) -> Context:
+ """
+ Context that owns the Operation
+ """
+ @property
+ def location(self) -> Location:
+ """
+ Returns the source location the operation was defined or derived from.
+ """
+ @property
+ def name(self) -> str: ...
+ @property
+ def operands(self) -> OpOperandList: ...
+ @property
+ def parent(self) -> _OperationBase | None: ...
+ @property
+ def regions(self) -> RegionSequence: ...
+ @property
+ def result(self) -> OpResult:
+ """
+ Shortcut to get an op result if it has only one (throws an error otherwise).
+ """
+ @property
+ def results(self) -> OpResultList:
+ """
+ Returns the List of Operation results.
+ """
+
+_TOperation = TypeVar("_TOperation", bound=_OperationBase)
+
+class AffineExpr:
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr:
+ """
+ Gets an affine expression containing a sum of an expression and a constant.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr:
+ """
+ Gets an affine expression containing the rounded-up result of dividing an expression by a constant.
+ """
+ @staticmethod
+ def get_constant(
+ value: int, context: Context | None = None
+ ) -> AffineConstantExpr:
+ """
+ Gets a constant affine expression with the given value.
+ """
+ @staticmethod
+ def get_dim(position: int, context: Context | None = None) -> AffineDimExpr:
+ """
+ Gets an affine expression of a dimension at the given position.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr:
+ """
+ Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr:
+ """
+ Gets an affine expression containing the rounded-down result of dividing an expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets an affine expression containing the modulo of dividing one expression by another.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr:
+ """
+ Gets a semi-affine expression containing the modulo of dividing a constant by an expression.
+ """
+ @staticmethod
+ @overload
+ def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr:
+ """
+ Gets an affine expression containing the module of dividingan expression by a constant.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of two expressions.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of a constant and another expression.
+ """
+ @staticmethod
+ @overload
+ def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr:
+ """
+ Gets an affine expression containing a product of an expression and a constant.
+ """
+ @staticmethod
+ def get_symbol(
+ position: int, context: Context | None = None
+ ) -> AffineSymbolExpr:
+ """
+ Gets an affine expression of a symbol at the given position.
+ """
+ def _CAPICreate(self) -> AffineExpr: ...
+ @overload
+ def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __add__(self, arg0: int) -> AffineAddExpr: ...
+ @overload
+ def __eq__(self, arg0: AffineExpr) -> bool: ...
+ @overload
+ def __eq__(self, arg0: Any) -> bool: ...
+ def __hash__(self) -> int: ...
+ @overload
+ def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ...
+ @overload
+ def __mod__(self, arg0: int) -> AffineModExpr: ...
+ @overload
+ def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ...
+ @overload
+ def __mul__(self, arg0: int) -> AffineMulExpr: ...
+ def __radd__(self, arg0: int) -> AffineAddExpr: ...
+ def __rmod__(self, arg0: int) -> AffineModExpr: ...
+ def __rmul__(self, arg0: int) -> AffineMulExpr: ...
+ def __rsub__(self, arg0: int) -> AffineAddExpr: ...
+ @overload
+ def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ...
+ @overload
+ def __sub__(self, arg0: int) -> AffineAddExpr: ...
+ def compose(self, arg0: AffineMap) -> AffineExpr: ...
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Context: ...
+
+class Attribute:
+ @staticmethod
+ def parse(asm: str | bytes, context: Context | None = None) -> Attribute:
+ """
+ Parses an attribute from an assembly form. Raises an MLIRError on failure.
+ """
+ def _CAPICreate(self) -> Attribute: ...
+ @overload
+ def __eq__(self, arg0: Attribute) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __init__(self, cast_from_type: Attribute) -> None:
+ """
+ Casts the passed attribute to the generic Attribute
+ """
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the Attribute.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def get_named(self, arg0: str) -> NamedAttribute:
+ """
+ Binds a name to the attribute
+ """
+ def maybe_downcast(self) -> Any: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Context:
+ """
+ Context that owns the Attribute
+ """
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Type:
+ @staticmethod
+ def parse(asm: str | bytes, context: Context | None = None) -> Type:
+ """
+ Parses the assembly form of a type.
+
+ Returns a Type object or raises an MLIRError if the type cannot be parsed.
+
+ See also: https://mlir.llvm.org/docs/LangRef/#type-system
+ """
+ def _CAPICreate(self) -> Type: ...
+ @overload
+ def __eq__(self, arg0: Type) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __init__(self, cast_from_type: Type) -> None:
+ """
+ Casts the passed type to the generic Type
+ """
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the type.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def maybe_downcast(self) -> Any: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Context:
+ """
+ Context that owns the Type
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Value:
+ def _CAPICreate(self) -> Value: ...
+ @overload
+ def __eq__(self, arg0: Value) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __init__(self, value: Value) -> None: ...
+ def __str__(self) -> str:
+ """
+ Returns the string form of the value.
+
+ If the value is a block argument, this is the assembly form of its type and the
+ position in the argument List. If the value is an operation result, this is
+ equivalent to printing the operation that produced it.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ @overload
+ def get_name(self, use_local_scope: bool = False, use_name_loc_as_prefix: bool = True) -> str: ...
+ @overload
+ def get_name(self, state: AsmState) -> str:
+ """
+ Returns the string form of value as an operand (i.e., the ValueID).
+ """
+ def maybe_downcast(self) -> Any: ...
+ def replace_all_uses_with(self, arg0: Value) -> None:
+ """
+ Replace all uses of value with the new value, updating anything in
+ the IR that uses 'self' to use the other value instead.
+ """
+ def set_type(self, type: Type) -> None: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Context:
+ """
+ Context in which the value lives.
+ """
+ @property
+ def owner(self) -> _OperationBase: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def uses(self) -> OpOperandIterator: ...
+
+class AffineAddExpr(AffineBinaryExpr):
+ @staticmethod
+ def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+
+class AffineBinaryExpr(AffineExpr):
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+ @property
+ def lhs(self) -> AffineExpr: ...
+ @property
+ def rhs(self) -> AffineExpr: ...
+
+class AffineCeilDivExpr(AffineBinaryExpr):
+ @staticmethod
+ def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+
+class AffineConstantExpr(AffineExpr):
+ @staticmethod
+ def get(value: int, context: Context | None = None) -> AffineConstantExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+ @property
+ def value(self) -> int: ...
+
+class AffineDimExpr(AffineExpr):
+ @staticmethod
+ def get(position: int, context: Context | None = None) -> AffineDimExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+ @property
+ def position(self) -> int: ...
+
+class AffineExprList:
+ def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ...
+
+class AffineFloorDivExpr(AffineBinaryExpr):
+ @staticmethod
+ def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+
+class AffineMap:
+ @staticmethod
+ def compress_unused_symbols(
+ arg0: list, arg1: Context | None
+ ) -> list[AffineMap]: ...
+ @staticmethod
+ def get(
+ dim_count: int,
+ symbol_count: int,
+ exprs: list,
+ context: Context | None = None,
+ ) -> AffineMap:
+ """
+ Gets a map with the given expressions as results.
+ """
+ @staticmethod
+ def get_constant(value: int, context: Context | None = None) -> AffineMap:
+ """
+ Gets an affine map with a single constant result
+ """
+ @staticmethod
+ def get_empty(context: Context | None = None) -> AffineMap:
+ """
+ Gets an empty affine map.
+ """
+ @staticmethod
+ def get_identity(n_dims: int, context: Context | None = None) -> AffineMap:
+ """
+ Gets an identity map with the given number of dimensions.
+ """
+ @staticmethod
+ def get_minor_identity(
+ n_dims: int, n_results: int, context: Context | None = None
+ ) -> AffineMap:
+ """
+ Gets a minor identity map with the given number of dimensions and results.
+ """
+ @staticmethod
+ def get_permutation(
+ permutation: list[int], context: Context | None = None
+ ) -> AffineMap:
+ """
+ Gets an affine map that permutes its inputs.
+ """
+ def _CAPICreate(self) -> AffineMap: ...
+ @overload
+ def __eq__(self, arg0: AffineMap) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def get_major_submap(self, n_results: int) -> AffineMap: ...
+ def get_minor_submap(self, n_results: int) -> AffineMap: ...
+ def get_submap(self, result_positions: list[int]) -> AffineMap: ...
+ def replace(
+ self,
+ expr: AffineExpr,
+ replacement: AffineExpr,
+ n_result_dims: int,
+ n_result_syms: int,
+ ) -> AffineMap: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Context:
+ """
+ Context that owns the Affine Map
+ """
+ @property
+ def is_permutation(self) -> bool: ...
+ @property
+ def is_projected_permutation(self) -> bool: ...
+ @property
+ def n_dims(self) -> int: ...
+ @property
+ def n_inputs(self) -> int: ...
+ @property
+ def n_symbols(self) -> int: ...
+ @property
+ def results(self) -> AffineMapExprList: ...
+
+class AffineMapAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(affine_map: AffineMap) -> AffineMapAttr:
+ """
+ Gets an attribute wrapping an AffineMap.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class AffineModExpr(AffineBinaryExpr):
+ @staticmethod
+ def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+
+class AffineMulExpr(AffineBinaryExpr):
+ @staticmethod
+ def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+
+class AffineSymbolExpr(AffineExpr):
+ @staticmethod
+ def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ...
+ @staticmethod
+ def isinstance(other: AffineExpr) -> bool: ...
+ def __init__(self, expr: AffineExpr) -> None: ...
+ @property
+ def position(self) -> int: ...
+
+class ArrayAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(attributes: list, context: Context | None = None) -> ArrayAttr:
+ """
+ Gets a uniqued Array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> Attribute: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> ArrayAttributeIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class ArrayAttributeIterator:
+ def __iter__(self) -> ArrayAttributeIterator: ...
+ def __next__(self) -> Attribute: ...
+
+class AsmState:
+ @overload
+ def __init__(self, value: Value, use_local_scope: bool = False) -> None: ...
+ @overload
+ def __init__(self, op: _OperationBase, use_local_scope: bool = False) -> None: ...
+
+class AttrBuilder:
+ @staticmethod
+ def contains(arg0: str) -> bool: ...
+ @staticmethod
+ def get(arg0: str) -> Callable: ...
+ @staticmethod
+ def insert(
+ attribute_kind: str, attr_builder: Callable, replace: bool = False
+ ) -> None:
+ """
+ Register an attribute builder for building MLIR attributes from python values.
+ """
+
+class BF16Type(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> BF16Type:
+ """
+ Create a bf16 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Block:
+ @staticmethod
+ def create_at_start(
+ parent: Region,
+ arg_types: list[Type],
+ arg_locs: Sequence | None = None,
+ ) -> Block:
+ """
+ Creates and returns a new Block at the beginning of the given region (with given argument types and locations).
+ """
+ @overload
+ def __eq__(self, arg0: Block) -> bool: ...
+ @overload
+ def __eq__(self, arg0: Any) -> bool: ...
+ def __hash__(self) -> int: ...
+ def __iter__(self) -> OperationIterator:
+ """
+ Iterates over operations in the block.
+ """
+ def __str__(self) -> str:
+ """
+ Returns the assembly form of the block.
+ """
+ def append(self, operation: _OperationBase) -> None:
+ """
+ Appends an operation to this block. If the operation is currently in another block, it will be moved.
+ """
+ def append_to(self, arg0: Region) -> None:
+ """
+ Append this block to a region, transferring ownership if necessary
+ """
+ def create_after(self, *args, arg_locs: Sequence | None = None) -> Block:
+ """
+ Creates and returns a new Block after this block (with given argument types and locations).
+ """
+ def create_before(self, *args, arg_locs: Sequence | None = None) -> Block:
+ """
+ Creates and returns a new Block before this block (with given argument types and locations).
+ """
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def arguments(self) -> BlockArgumentList:
+ """
+ Returns a List of block arguments.
+ """
+ @property
+ def operations(self) -> OperationList:
+ """
+ Returns a forward-optimized sequence of operations.
+ """
+ @property
+ def owner(self) -> OpView:
+ """
+ Returns the owning operation of this block.
+ """
+ @property
+ def region(self) -> Region:
+ """
+ Returns the owning region of this block.
+ """
+
+class BlockArgument(Value):
+ @staticmethod
+ def isinstance(other_value: Value) -> bool: ...
+ def __init__(self, value: Value) -> None: ...
+ def maybe_downcast(self) -> Any: ...
+ def set_type(self, type: Type) -> None: ...
+ @property
+ def arg_number(self) -> int: ...
+ @property
+ def owner(self) -> Block: ...
+
+class BlockArgumentList:
+ @overload
+ def __getitem__(self, arg0: int) -> BlockArgument: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> BlockArgumentList: ...
+ def __len__(self) -> int: ...
+ def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ...
+ @property
+ def types(self) -> list[Type]: ...
+
+class BlockIterator:
+ def __iter__(self) -> BlockIterator: ...
+ def __next__(self) -> Block: ...
+
+class BlockList:
+ def __getitem__(self, arg0: int) -> Block: ...
+ def __iter__(self) -> BlockIterator: ...
+ def __len__(self) -> int: ...
+ def append(self, *args, arg_locs: Sequence | None = None) -> Block:
+ """
+ Appends a new block, with argument types as positional args.
+
+ Returns:
+ The created block.
+ """
+
+class BoolAttr(Attribute):
+ @staticmethod
+ def get(value: bool, context: Context | None = None) -> BoolAttr:
+ """
+ Gets an uniqued bool attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __bool__(self: Attribute) -> bool:
+ """
+ Converts the value of the bool attribute to a Python bool
+ """
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> bool:
+ """
+ Returns the value of the bool attribute
+ """
+
+class ComplexType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(arg0: Type) -> ComplexType:
+ """
+ Create a complex type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def element_type(self) -> Type:
+ """
+ Returns element type.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Context:
+ current: ClassVar[Context] = ... # read-only
+ allow_unregistered_dialects: bool
+ @staticmethod
+ def _get_live_count() -> int: ...
+ def _CAPICreate(self) -> object: ...
+ def __enter__(self) -> Context: ...
+ def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ...
+ def __init__(self) -> None: ...
+ def _clear_live_operations(self) -> int: ...
+ def _get_context_again(self) -> Context: ...
+ def _get_live_module_count(self) -> int: ...
+ def _get_live_operation_count(self) -> int: ...
+ def _get_live_operation_objects(self) -> list[Operation]: ...
+ def append_dialect_registry(self, registry: DialectRegistry) -> None: ...
+ def attach_diagnostic_handler(
+ self, callback: Callable[[Diagnostic], bool]
+ ) -> DiagnosticHandler:
+ """
+ Attaches a diagnostic handler that will receive callbacks
+ """
+ def enable_multithreading(self, enable: bool) -> None: ...
+ def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor:
+ """
+ Gets or loads a dialect by name, returning its descriptor object
+ """
+ def is_registered_operation(self, operation_name: str) -> bool: ...
+ def load_all_available_dialects(self) -> None: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def d(self) -> Dialects:
+ """
+ Alias for 'dialect'
+ """
+ @property
+ def dialects(self) -> Dialects:
+ """
+ Gets a container for accessing dialects by name
+ """
+
+class DenseBoolArrayAttr(Attribute):
+ @staticmethod
+ def get(
+ values: Sequence[bool], context: Context | None = None
+ ) -> DenseBoolArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseBoolArrayAttr: ...
+ def __getitem__(self, arg0: int) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseBoolArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseBoolArrayIterator:
+ def __iter__(self) -> DenseBoolArrayIterator: ...
+ def __next__(self) -> bool: ...
+
+class DenseElementsAttr(Attribute):
+ @staticmethod
+ def get(
+ array: Buffer,
+ signless: bool = True,
+ type: Type | None = None,
+ shape: list[int] | None = None,
+ context: Context | None = None,
+ ) -> DenseElementsAttr:
+ """
+ Gets a DenseElementsAttr from a Python buffer or array.
+
+ When `type` is not provided, then some limited type inferencing is done based
+ on the buffer format. Support presently exists for 8/16/32/64 signed and
+ unsigned integers and float16/float32/float64. DenseElementsAttrs of these
+ types can also be converted back to a corresponding buffer.
+
+ For conversions outside of these types, a `type=` must be explicitly provided
+ and the buffer contents must be bit-castable to the MLIR internal
+ representation:
+
+ * Integer types (except for i1): the buffer must be byte aligned to the
+ next byte boundary.
+ * Floating point types: Must be bit-castable to the given floating point
+ size.
+ * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
+ row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
+ this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
+
+ If a single element buffer is passed (or for i1, a single byte with value 0
+ or 255), then a splat will be created.
+
+ Args:
+ array: The array or buffer to convert.
+ signless: If inferring an appropriate MLIR type, use signless types for
+ integers (defaults True).
+ type: Skips inference of the MLIR element type and uses this instead. The
+ storage size must be consistent with the actual contents of the buffer.
+ shape: Overrides the shape of the buffer when constructing the MLIR
+ shaped type. This is needed when the physical and logical shape differ (as
+ for i1).
+ context: Explicit context, if not from context manager.
+
+ Returns:
+ DenseElementsAttr on success.
+
+ Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+ """
+ @staticmethod
+ def get_splat(shaped_type: Type, element_attr: Attribute) -> DenseElementsAttr:
+ """
+ Gets a DenseElementsAttr where all values are the same
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __len__(self) -> int: ...
+ def get_splat_value(self) -> Attribute: ...
+ @property
+ def is_splat(self) -> bool: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseF32ArrayAttr(Attribute):
+ @staticmethod
+ def get(
+ values: Sequence[float], context: Context | None = None
+ ) -> DenseF32ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseF32ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> float: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseF32ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseF32ArrayIterator:
+ def __iter__(self) -> DenseF32ArrayIterator: ...
+ def __next__(self) -> float: ...
+
+class DenseF64ArrayAttr(Attribute):
+ @staticmethod
+ def get(
+ values: Sequence[float], context: Context | None = None
+ ) -> DenseF64ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseF64ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> float: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseF64ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseF64ArrayIterator:
+ def __iter__(self) -> DenseF64ArrayIterator: ...
+ def __next__(self) -> float: ...
+
+class DenseFPElementsAttr(DenseElementsAttr):
+ @staticmethod
+ def get(
+ array: Buffer,
+ signless: bool = True,
+ type: Type | None = None,
+ shape: list[int] | None = None,
+ context: Context | None = None,
+ ) -> DenseFPElementsAttr: ...
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __getitem__(self, arg0: int) -> float: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseI16ArrayAttr(Attribute):
+ @staticmethod
+ def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseI16ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> int: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseI16ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseI16ArrayIterator:
+ def __iter__(self) -> DenseI16ArrayIterator: ...
+ def __next__(self) -> int: ...
+
+class DenseI32ArrayAttr(Attribute):
+ @staticmethod
+ def get(values: Sequence[int], context: Context | None = None) -> DenseI32ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseI32ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> int: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseI32ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseI32ArrayIterator:
+ def __iter__(self) -> DenseI32ArrayIterator: ...
+ def __next__(self) -> int: ...
+
+class DenseI64ArrayAttr(Attribute):
+ @staticmethod
+ def get(values: Sequence[int], context: Context | None = None) -> DenseI64ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseI64ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> int: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseI16ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseI64ArrayIterator:
+ def __iter__(self) -> DenseI64ArrayIterator: ...
+ def __next__(self) -> int: ...
+
+class DenseI8ArrayAttr(Attribute):
+ @staticmethod
+ def get(values: Sequence[int], context: Context | None = None) -> DenseI8ArrayAttr:
+ """
+ Gets a uniqued dense array attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __add__(self, arg0: list) -> DenseI8ArrayAttr: ...
+ def __getitem__(self, arg0: int) -> int: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __iter__(
+ self,
+ ) -> DenseI8ArrayIterator: ...
+ def __len__(self) -> int: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseI8ArrayIterator:
+ def __iter__(self) -> DenseI8ArrayIterator: ...
+ def __next__(self) -> int: ...
+
+class DenseIntElementsAttr(DenseElementsAttr):
+ @staticmethod
+ def get(
+ array: Buffer,
+ signless: bool = True,
+ type: Type | None = None,
+ shape: list[int] | None = None,
+ context: Context | None = None,
+ ) -> DenseIntElementsAttr: ...
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __getitem__(self, arg0: int) -> int: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class DenseResourceElementsAttr(Attribute):
+ @staticmethod
+ def get_from_buffer(
+ array: Buffer,
+ name: str,
+ type: Type,
+ alignment: int | None = None,
+ is_mutable: bool = False,
+ context: Context | None = None,
+ ) -> DenseResourceElementsAttr:
+ """
+ Gets a DenseResourceElementsAttr from a Python buffer or array.
+
+ This function does minimal validation or massaging of the data, and it is
+ up to the caller to ensure that the buffer meets the characteristics
+ implied by the shape.
+
+ The backing buffer and any user objects will be retained for the lifetime
+ of the resource blob. This is typically bounded to the context but the
+ resource can have a shorter lifespan depending on how it is used in
+ subsequent processing.
+
+ Args:
+ buffer: The array or buffer to convert.
+ name: Name to provide to the resource (may be changed upon collision).
+ type: The explicit ShapedType to construct the attribute with.
+ context: Explicit context, if not from context manager.
+
+ Returns:
+ DenseResourceElementsAttr on success.
+
+ Raises:
+ ValueError: If the type of the buffer or array cannot be matched to an MLIR
+ type or if the buffer does not meet expectations.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Diagnostic:
+ @property
+ def location(self) -> Location: ...
+ @property
+ def message(self) -> str: ...
+ @property
+ def notes(self) -> tuple[Diagnostic]: ...
+ @property
+ def severity(self) -> DiagnosticSeverity: ...
+
+class DiagnosticHandler:
+ def __enter__(self) -> DiagnosticHandler: ...
+ def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
+ def detach(self) -> None: ...
+ @property
+ def attached(self) -> bool: ...
+ @property
+ def had_error(self) -> bool: ...
+
+class DiagnosticInfo:
+ def __init__(self, arg0: Diagnostic) -> None: ...
+ @property
+ def location(self) -> Location: ...
+ @property
+ def message(self) -> str: ...
+ @property
+ def notes(self) -> list[DiagnosticInfo]: ...
+ @property
+ def severity(self) -> DiagnosticSeverity: ...
+
+class DiagnosticSeverity:
+ """
+ Members:
+
+ ERROR
+
+ WARNING
+
+ NOTE
+
+ REMARK
+ """
+
+ ERROR: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.ERROR: 0>
+ NOTE: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.NOTE: 2>
+ REMARK: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.REMARK: 3>
+ WARNING: ClassVar[DiagnosticSeverity] # value = <DiagnosticSeverity.WARNING: 1>
+ __members__: ClassVar[
+ dict[str, DiagnosticSeverity]
+ ] # value = {'ERROR': <DiagnosticSeverity.ERROR: 0>, 'WARNING': <DiagnosticSeverity.WARNING: 1>, 'NOTE': <DiagnosticSeverity.NOTE: 2>, 'REMARK': <DiagnosticSeverity.REMARK: 3>}
+ def __eq__(self, other: Any) -> bool: ...
+ def __getstate__(self) -> int: ...
+ def __hash__(self) -> int: ...
+ def __index__(self) -> int: ...
+ def __init__(self, value: int) -> None: ...
+ def __int__(self) -> int: ...
+ def __ne__(self, other: Any) -> bool: ...
+ def __setstate__(self, state: int) -> None: ...
+ @property
+ def name(self) -> str: ...
+ @property
+ def value(self) -> int: ...
+
+class Dialect:
+ def __init__(self, descriptor: DialectDescriptor) -> None: ...
+ @property
+ def descriptor(self) -> DialectDescriptor: ...
+
+class DialectDescriptor:
+ @property
+ def namespace(self) -> str: ...
+
+class DialectRegistry:
+ def _CAPICreate(self) -> DialectRegistry: ...
+ def __init__(self) -> None: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+
+class Dialects:
+ def __getattr__(self, arg0: str) -> Dialect: ...
+ def __getitem__(self, arg0: str) -> Dialect: ...
+
+class DictAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(value: dict = {}, context: Context | None = None) -> DictAttr:
+ """
+ Gets an uniqued Dict attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __contains__(self, arg0: str) -> bool: ...
+ @overload
+ def __getitem__(self, arg0: str) -> Attribute: ...
+ @overload
+ def __getitem__(self, arg0: int) -> NamedAttribute: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __len__(self) -> int: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class FloatType(Type):
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def width(self) -> int:
+ """
+ Returns the width of the floating-point type.
+ """
+
+class F16Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> F16Type:
+ """
+ Create a f16 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class F32Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> F32Type:
+ """
+ Create a f32 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class F64Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> F64Type:
+ """
+ Create a f64 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class FlatSymbolRefAttr(Attribute):
+ @staticmethod
+ def get(value: str, context: Context | None = None) -> FlatSymbolRefAttr:
+ """
+ Gets a uniqued FlatSymbolRef attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> str:
+ """
+ Returns the value of the FlatSymbolRef attribute as a string
+ """
+
+class Float4E2M1FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float4E2M1FNType:
+ """
+ Create a float4_e2m1fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float6E2M3FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float6E2M3FNType:
+ """
+ Create a float6_e2m3fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float6E3M2FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float6E3M2FNType:
+ """
+ Create a float6_e3m2fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E3M4Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E3M4Type:
+ """
+ Create a float8_e3m4 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E4M3B11FNUZType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E4M3B11FNUZType:
+ """
+ Create a float8_e4m3b11fnuz type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E4M3FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E4M3FNType:
+ """
+ Create a float8_e4m3fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E4M3FNUZType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E4M3FNUZType:
+ """
+ Create a float8_e4m3fnuz type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E4M3Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E4M3Type:
+ """
+ Create a float8_e4m3 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E5M2FNUZType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E5M2FNUZType:
+ """
+ Create a float8_e5m2fnuz type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E5M2Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E5M2Type:
+ """
+ Create a float8_e5m2 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Float8E8M0FNUType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> Float8E8M0FNUType:
+ """
+ Create a float8_e8m0fnu type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class FloatAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(type: Type, value: float, loc: Location | None = None) -> FloatAttr:
+ """
+ Gets an uniqued float point attribute associated to a type
+ """
+ @staticmethod
+ def get_f32(value: float, context: Context | None = None) -> FloatAttr:
+ """
+ Gets an uniqued float point attribute associated to a f32 type
+ """
+ @staticmethod
+ def get_f64(value: float, context: Context | None = None) -> FloatAttr:
+ """
+ Gets an uniqued float point attribute associated to a f64 type
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __float__(self: Attribute) -> float:
+ """
+ Converts the value of the float attribute to a Python float
+ """
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> float:
+ """
+ Returns the value of the float attribute
+ """
+
+class FloatTF32Type(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> FloatTF32Type:
+ """
+ Create a tf32 type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class FunctionType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ inputs: list[Type], results: list[Type], context: Context | None = None
+ ) -> FunctionType:
+ """
+ Gets a FunctionType from a List of input and result types
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def inputs(self) -> list:
+ """
+ Returns the List of input types in the FunctionType.
+ """
+ @property
+ def results(self) -> list:
+ """
+ Returns the List of result types in the FunctionType.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class IndexType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> IndexType:
+ """
+ Create a index type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class InferShapedTypeOpInterface:
+ def __init__(self, object: object, context: Context | None = None) -> None:
+ """
+ Creates an interface from a given operation/opview object or from a
+ subclass of OpView. Raises ValueError if the operation does not implement the
+ interface.
+ """
+ def inferReturnTypeComponents(
+ self,
+ operands: list | None = None,
+ attributes: Attribute | None = None,
+ properties=None,
+ regions: list[Region] | None = None,
+ context: Context | None = None,
+ loc: Location | None = None,
+ ) -> list[ShapedTypeComponents]:
+ """
+ Given the arguments required to build an operation, attempts to infer
+ its return shaped type components. Raises ValueError on failure.
+ """
+ @property
+ def operation(self) -> Operation:
+ """
+ Returns an Operation for which the interface was constructed.
+ """
+ @property
+ def opview(self) -> OpView:
+ """
+ Returns an OpView subclass _instance_ for which the interface was
+ constructed
+ """
+
+class InferTypeOpInterface:
+ def __init__(self, object: object, context: Context | None = None) -> None:
+ """
+ Creates an interface from a given operation/opview object or from a
+ subclass of OpView. Raises ValueError if the operation does not implement the
+ interface.
+ """
+ def inferReturnTypes(
+ self,
+ operands: list | None = None,
+ attributes: Attribute | None = None,
+ properties=None,
+ regions: list[Region] | None = None,
+ context: Context | None = None,
+ loc: Location | None = None,
+ ) -> list[Type]:
+ """
+ Given the arguments required to build an operation, attempts to infer
+ its return types. Raises ValueError on failure.
+ """
+ @property
+ def operation(self) -> Operation:
+ """
+ Returns an Operation for which the interface was constructed.
+ """
+ @property
+ def opview(self) -> OpView:
+ """
+ Returns an OpView subclass _instance_ for which the interface was
+ constructed
+ """
+
+class InsertionPoint:
+ current: ClassVar[InsertionPoint] = ... # read-only
+ @staticmethod
+ def at_block_begin(block: Block) -> InsertionPoint:
+ """
+ Inserts at the beginning of the block.
+ """
+ @staticmethod
+ def at_block_terminator(block: Block) -> InsertionPoint:
+ """
+ Inserts before the block terminator.
+ """
+ def __enter__(self) -> InsertionPoint: ...
+ def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ...
+ @overload
+ def __init__(self, block: Block) -> None:
+ """
+ Inserts after the last operation but still inside the block.
+ """
+ @overload
+ def __init__(self, beforeOperation: _OperationBase) -> None:
+ """
+ Inserts before a referenced operation.
+ """
+ def insert(self, operation: _OperationBase) -> None:
+ """
+ Inserts an operation.
+ """
+ @property
+ def block(self) -> Block:
+ """
+ Returns the block that this InsertionPoint points to.
+ """
+ @property
+ def ref_operation(self) -> _OperationBase | None:
+ """
+ The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block
+ """
+
+class IntegerAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(type: Type, value: int) -> IntegerAttr:
+ """
+ Gets an uniqued integer attribute associated to a type
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ def __int__(self) -> int:
+ """
+ Converts the value of the integer attribute to a Python int
+ """
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> int:
+ """
+ Returns the value of the integer attribute
+ """
+
+class IntegerSet:
+ @staticmethod
+ def get(
+ num_dims: int,
+ num_symbols: int,
+ exprs: list,
+ eq_flags: list[bool],
+ context: Context | None = None,
+ ) -> IntegerSet: ...
+ @staticmethod
+ def get_empty(
+ num_dims: int, num_symbols: int, context: Context | None = None
+ ) -> IntegerSet: ...
+ def _CAPICreate(self) -> IntegerSet: ...
+ @overload
+ def __eq__(self, arg0: IntegerSet) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ def get_replaced(
+ self,
+ dim_exprs: list,
+ symbol_exprs: list,
+ num_result_dims: int,
+ num_result_symbols: int,
+ ) -> IntegerSet: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def constraints(self) -> IntegerSetConstraintList: ...
+ @property
+ def context(self) -> Context: ...
+ @property
+ def is_canonical_empty(self) -> bool: ...
+ @property
+ def n_dims(self) -> int: ...
+ @property
+ def n_equalities(self) -> int: ...
+ @property
+ def n_inequalities(self) -> int: ...
+ @property
+ def n_inputs(self) -> int: ...
+ @property
+ def n_symbols(self) -> int: ...
+
+class IntegerSetAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(integer_set) -> IntegerSetAttr:
+ """
+ Gets an attribute wrapping an IntegerSet.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class IntegerSetConstraint:
+ def __init__(self, *args, **kwargs) -> None: ...
+ @property
+ def expr(self) -> AffineExpr: ...
+ @property
+ def is_eq(self) -> bool: ...
+
+class IntegerSetConstraintList:
+ def __init__(self, *args, **kwargs) -> None: ...
+ def __add__(self, arg0: IntegerSetConstraintList) -> list[IntegerSetConstraint]: ...
+ @overload
+ def __getitem__(self, arg0: int) -> IntegerSetConstraint: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ...
+ def __len__(self) -> int: ...
+
+class IntegerType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get_signed(width: int, context: Context | None = None) -> IntegerType:
+ """
+ Create a signed integer type
+ """
+ @staticmethod
+ def get_signless(width: int, context: Context | None = None) -> IntegerType:
+ """
+ Create a signless integer type
+ """
+ @staticmethod
+ def get_unsigned(width: int, context: Context | None = None) -> IntegerType:
+ """
+ Create an unsigned integer type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def is_signed(self) -> bool:
+ """
+ Returns whether this is a signed integer
+ """
+ @property
+ def is_signless(self) -> bool:
+ """
+ Returns whether this is a signless integer
+ """
+ @property
+ def is_unsigned(self) -> bool:
+ """
+ Returns whether this is an unsigned integer
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def width(self) -> int:
+ """
+ Returns the width of the integer type
+ """
+
+class Location:
+ current: ClassVar[Location] = ... # read-only
+ __hash__: ClassVar[None] = None
+ @staticmethod
+ def callsite(
+ callee: Location, frames: Sequence[Location], context: Context | None = None
+ ) -> Location:
+ """
+ Gets a Location representing a caller and callsite
+ """
+ @staticmethod
+ def file(
+ filename: str, line: int, col: int, context: Context | None = None
+ ) -> Location:
+ """
+ Gets a Location representing a file, line and column
+ """
+ @staticmethod
+ def from_attr(attribute: Attribute, context: Context | None = None) -> Location:
+ """
+ Gets a Location from a LocationAttr
+ """
+ @staticmethod
+ def fused(
+ locations: Sequence[Location],
+ metadata: Attribute | None = None,
+ context: Context | None = None,
+ ) -> Location:
+ """
+ Gets a Location representing a fused location with optional metadata
+ """
+ @staticmethod
+ def name(
+ name: str,
+ childLoc: Location | None = None,
+ context: Context | None = None,
+ ) -> Location:
+ """
+ Gets a Location representing a named location with optional child location
+ """
+ @staticmethod
+ def unknown(context: Context | None = None) -> Location:
+ """
+ Gets a Location representing an unknown location
+ """
+ def _CAPICreate(self) -> Location: ...
+ def __enter__(self) -> Location: ...
+ @overload
+ def __eq__(self, arg0: Location) -> bool: ...
+ @overload
+ def __eq__(self, arg0: Location) -> bool: ...
+ def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
+ def emit_error(self, message: str) -> None:
+ """
+ Emits an error at this location
+ """
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def attr(self) -> Attribute:
+ """
+ Get the underlying LocationAttr
+ """
+ @property
+ def context(self) -> Context:
+ """
+ Context that owns the Location
+ """
+
+class MemRefType(ShapedType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ shape: list[int],
+ element_type: Type,
+ layout: Attribute = None,
+ memory_space: Attribute = None,
+ loc: Location | None = None,
+ ) -> MemRefType:
+ """
+ Create a memref type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def affine_map(self) -> AffineMap:
+ """
+ The layout of the MemRef type as an affine map.
+ """
+ @property
+ def layout(self) -> Attribute:
+ """
+ The layout of the MemRef type.
+ """
+ @property
+ def memory_space(self) -> Attribute | None:
+ """
+ Returns the memory space of the given MemRef type.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+ def get_strides_and_offset(self) -> tuple[list[int], int]:
+ """
+ The strides and offset of the MemRef type.
+ """
+
+class Module:
+ @staticmethod
+ def create(loc: Location | None = None) -> Module:
+ """
+ Creates an empty module
+ """
+ @staticmethod
+ def parse(asm: str | bytes, context: Context | None = None) -> Module:
+ """
+ Parses a module's assembly format from a string.
+
+ Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+ See also: https://mlir.llvm.org/docs/LangRef/
+ """
+ @staticmethod
+ def parseFile(path: str, context: Context | None = None) -> Module:
+ """
+ Parses a module's assembly format from file.
+
+ Returns a new MlirModule or raises an MLIRError if the parsing fails.
+
+ See also: https://mlir.llvm.org/docs/LangRef/
+ """
+ def _CAPICreate(self) -> Any: ...
+ def __str__(self) -> str:
+ """
+ Gets the assembly form of the operation with default options.
+
+ If more advanced control over the assembly formatting or I/O options is needed,
+ use the dedicated print or get_asm method, which supports keyword arguments to
+ customize behavior.
+ """
+ def dump(self) -> None:
+ """
+ Dumps a debug representation of the object to stderr.
+ """
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def body(self) -> Block:
+ """
+ Return the block for this module
+ """
+ @property
+ def context(self) -> Context:
+ """
+ Context that created the Module
+ """
+ @property
+ def operation(self) -> Operation:
+ """
+ Accesses the module as an operation
+ """
+
+class MLIRError(Exception):
+ def __init__(
+ self, message: str, error_diagnostics: list[DiagnosticInfo]
+ ) -> None: ...
+
+class NamedAttribute:
+ @property
+ def attr(self) -> Attribute:
+ """
+ The underlying generic attribute of the NamedAttribute binding
+ """
+ @property
+ def name(self) -> str:
+ """
+ The name of the NamedAttribute binding
+ """
+
+class NoneType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> NoneType:
+ """
+ Create a none type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class OpAttributeMap:
+ def __contains__(self, arg0: str) -> bool: ...
+ def __delitem__(self, arg0: str) -> None: ...
+ @overload
+ def __getitem__(self, arg0: str) -> Attribute: ...
+ @overload
+ def __getitem__(self, arg0: int) -> NamedAttribute: ...
+ def __len__(self) -> int: ...
+ def __setitem__(self, arg0: str, arg1: Attribute) -> None: ...
+
+class OpOperand:
+ @property
+ def operand_number(self) -> int: ...
+ @property
+ def owner(self) -> _OperationBase: ...
+
+class OpOperandIterator:
+ def __iter__(self) -> OpOperandIterator: ...
+ def __next__(self) -> OpOperand: ...
+
+class OpOperandList:
+ def __add__(self, arg0: OpOperandList) -> list[Value]: ...
+ @overload
+ def __getitem__(self, arg0: int) -> Value: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> OpOperandList: ...
+ def __len__(self) -> int: ...
+ def __setitem__(self, arg0: int, arg1: Value) -> None: ...
+
+class OpResult(Value):
+ @staticmethod
+ def isinstance(other_value: Value) -> bool: ...
+ def __init__(self, value: Value) -> None: ...
+ @staticmethod
+ def isinstance(arg: Any) -> bool: ...
+ @property
+ def owner(self) -> _OperationBase: ...
+ @property
+ def result_number(self) -> int: ...
+
+class OpResultList:
+ def __add__(self, arg0: OpResultList) -> list[OpResult]: ...
+ @overload
+ def __getitem__(self, arg0: int) -> OpResult: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> OpResultList: ...
+ def __len__(self) -> int: ...
+ @property
+ def owner(self) -> _OperationBase: ...
+ @property
+ def types(self) -> list[Type]: ...
+
+class OpSuccessors:
+ def __add__(self, arg0: OpSuccessors) -> list[Block]: ...
+ @overload
+ def __getitem__(self, arg0: int) -> Block: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> OpSuccessors: ...
+ def __setitem__(self, arg0: int, arg1: Block) -> None: ...
+ def __len__(self) -> int: ...
+
+class OpView(_OperationBase):
+ _ODS_OPERAND_SEGMENTS: ClassVar[None] = ...
+ _ODS_REGIONS: ClassVar[tuple] = ...
+ _ODS_RESULT_SEGMENTS: ClassVar[None] = ...
+ def __init__(self, operation: _OperationBase) -> None: ...
+ @classmethod
+ def build_generic(
+ cls: type[_TOperation],
+ results: Sequence[Type] | None = None,
+ operands: Sequence[Value] | None = None,
+ attributes: dict[str, Attribute] | None = None,
+ successors: Sequence[Block] | None = None,
+ regions: int | None = None,
+ loc: Location | None = None,
+ ip: InsertionPoint | None = None,
+ ) -> _TOperation:
+ """
+ Builds a specific, generated OpView based on class level attributes.
+ """
+ @classmethod
+ def parse(
+ cls: type[_TOperation],
+ source: str | bytes,
+ *,
+ source_name: str = "",
+ context: Context | None = None,
+ ) -> _TOperation:
+ """
+ Parses a specific, generated OpView based on class level attributes
+ """
+ def __init__(self, operation: _OperationBase) -> None: ...
+ @property
+ def operation(self) -> _OperationBase: ...
+ @property
+ def opview(self) -> OpView: ...
+ @property
+ def successors(self) -> OpSuccessors:
+ """
+ Returns the List of Operation successors.
+ """
+
+class OpaqueAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ dialect_namespace: str,
+ buffer: Buffer,
+ type: Type,
+ context: Context | None = None,
+ ) -> OpaqueAttr:
+ """
+ Gets an Opaque attribute.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def data(self) -> bytes:
+ """
+ Returns the data for the Opaqued attributes as `bytes`
+ """
+ @property
+ def dialect_namespace(self) -> str:
+ """
+ Returns the dialect namespace for the Opaque attribute as a string
+ """
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class OpaqueType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ dialect_namespace: str, buffer: str, context: Context | None = None
+ ) -> OpaqueType:
+ """
+ Create an unregistered (opaque) dialect type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def data(self) -> str:
+ """
+ Returns the data for the Opaque type as a string.
+ """
+ @property
+ def dialect_namespace(self) -> str:
+ """
+ Returns the dialect namespace for the Opaque type as a string.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Operation(_OperationBase):
+ def _CAPICreate(self) -> object: ...
+ @staticmethod
+ def create(
+ name: str,
+ results: Sequence[Type] | None = None,
+ operands: Sequence[Value] | None = None,
+ attributes: dict[str, Attribute] | None = None,
+ successors: Sequence[Block] | None = None,
+ regions: int = 0,
+ loc: Location | None = None,
+ ip: InsertionPoint | None = None,
+ infer_type: bool = False,
+ ) -> Operation:
+ """
+ Creates a new operation.
+
+ Args:
+ name: Operation name (e.g. "dialect.operation").
+ results: Sequence of Type representing op result types.
+ attributes: Dict of str:Attribute.
+ successors: List of Block for the operation's successors.
+ regions: Number of regions to create.
+ loc: A Location object (defaults to resolve from context manager).
+ ip: An InsertionPoint (defaults to resolve from context manager or set to
+ False to disable insertion, even with an insertion point set in the
+ context manager).
+ infer_type: Whether to infer result types.
+ Returns:
+ A new "detached" Operation object. Detached operations can be added
+ to blocks, which causes them to become "attached."
+ """
+ @staticmethod
+ def parse(
+ source: str | bytes, *, source_name: str = "", context: Context | None = None
+ ) -> Operation:
+ """
+ Parses an operation. Supports both text assembly format and binary bytecode format.
+ """
+ def _CAPICreate(self) -> object: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def operation(self) -> Operation: ...
+ @property
+ def opview(self) -> OpView: ...
+ @property
+ def successors(self) -> OpSuccessors:
+ """
+ Returns the List of Operation successors.
+ """
+
+class OperationIterator:
+ def __iter__(self) -> OperationIterator: ...
+ def __next__(self) -> OpView: ...
+
+class OperationList:
+ def __getitem__(self, arg0: int) -> OpView: ...
+ def __iter__(self) -> OperationIterator: ...
+ def __len__(self) -> int: ...
+
+class RankedTensorType(ShapedType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ shape: list[int],
+ element_type: Type,
+ encoding: Attribute | None = None,
+ loc: Location | None = None,
+ ) -> RankedTensorType:
+ """
+ Create a ranked tensor type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def encoding(self) -> Attribute | None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class Region:
+ __hash__: ClassVar[None] = None
+ @overload
+ def __eq__(self, arg0: Region) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __iter__(self) -> BlockIterator:
+ """
+ Iterates over blocks in the region.
+ """
+ @property
+ def blocks(self) -> BlockList:
+ """
+ Returns a forward-optimized sequence of blocks.
+ """
+ @property
+ def owner(self) -> OpView:
+ """
+ Returns the operation owning this region.
+ """
+
+class RegionIterator:
+ def __iter__(self) -> RegionIterator: ...
+ def __next__(self) -> Region: ...
+
+class RegionSequence:
+ @overload
+ def __getitem__(self, arg0: int) -> Region: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
+ def __iter__(self) -> RegionIterator: ...
+ def __len__(self) -> int: ...
+
+class ShapedType(Type):
+ @staticmethod
+ def get_dynamic_size() -> int:
+ """
+ Returns the value used to indicate dynamic dimensions in shaped types.
+ """
+ @staticmethod
+ def get_dynamic_stride_or_offset() -> int:
+ """
+ Returns the value used to indicate dynamic strides or offsets in shaped types.
+ """
+ @staticmethod
+ def is_dynamic_size(dim_size: int) -> bool:
+ """
+ Returns whether the given dimension size indicates a dynamic dimension.
+ """
+ @staticmethod
+ def is_static_size(dim_size: int) -> bool:
+ """
+ Returns whether the given dimension size indicates a static dimension.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ def get_dim_size(self, dim: int) -> int:
+ """
+ Returns the dim-th dimension of the given ranked shaped type.
+ """
+ def is_dynamic_dim(self, dim: int) -> bool:
+ """
+ Returns whether the dim-th dimension of the given shaped type is dynamic.
+ """
+ def is_static_dim(self, dim: int) -> bool:
+ """
+ Returns whether the dim-th dimension of the given shaped type is static.
+ """
+ def is_dynamic_stride_or_offset(self, dim_size: int) -> bool:
+ """
+ Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types.
+ """
+ def is_static_stride_or_offset(self, dim_size: int) -> bool:
+ """
+ Returns whether the given shaped type stride or offset value is statically-sized.
+ """
+ @property
+ def element_type(self) -> Type:
+ """
+ Returns the element type of the shaped type.
+ """
+ @property
+ def has_rank(self) -> bool:
+ """
+ Returns whether the given shaped type is ranked.
+ """
+ @property
+ def has_static_shape(self) -> bool:
+ """
+ Returns whether the given shaped type has a static shape.
+ """
+ @property
+ def rank(self) -> int:
+ """
+ Returns the rank of the given ranked shaped type.
+ """
+ @property
+ def shape(self) -> list[int]:
+ """
+ Returns the shape of the ranked shaped type as a List of integers.
+ """
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class ShapedTypeComponents:
+ @staticmethod
+ @overload
+ def get(element_type: Type) -> ShapedTypeComponents:
+ """
+ Create an shaped type components object with only the element type.
+ """
+ @staticmethod
+ @overload
+ def get(shape: list, element_type: Type) -> ShapedTypeComponents:
+ """
+ Create a ranked shaped type components object.
+ """
+ @staticmethod
+ @overload
+ def get(
+ shape: list, element_type: Type, attribute: Attribute
+ ) -> ShapedTypeComponents:
+ """
+ Create a ranked shaped type components object with attribute.
+ """
+ @property
+ def element_type(self) -> Type:
+ """
+ Returns the element type of the shaped type components.
+ """
+ @property
+ def has_rank(self) -> bool:
+ """
+ Returns whether the given shaped type component is ranked.
+ """
+ @property
+ def rank(self) -> int:
+ """
+ Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned.
+ """
+ @property
+ def shape(self) -> list[int]:
+ """
+ Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank.
+ """
+
+class StridedLayoutAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ offset: int, strides: list[int], context: Context | None = None
+ ) -> StridedLayoutAttr:
+ """
+ Gets a strided layout attribute.
+ """
+ @staticmethod
+ def get_fully_dynamic(
+ rank: int, context: Context | None = None
+ ) -> StridedLayoutAttr:
+ """
+ Gets a strided layout attribute with dynamic offset and strides of a given rank.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def offset(self) -> int:
+ """
+ Returns the value of the float point attribute
+ """
+ @property
+ def strides(self) -> list[int]:
+ """
+ Returns the value of the float point attribute
+ """
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class StringAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(value: str | bytes, context: Context | None = None) -> StringAttr:
+ """
+ Gets a uniqued string attribute
+ """
+ @staticmethod
+ def get_typed(type: Type, value: str) -> StringAttr:
+ """
+ Gets a uniqued string attribute associated to a type
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> str:
+ """
+ Returns the value of the string attribute
+ """
+ @property
+ def value_bytes(self) -> bytes:
+ """
+ Returns the value of the string attribute as `bytes`
+ """
+
+class SymbolRefAttr(Attribute):
+ @staticmethod
+ def get(symbols: list[str], context: Context | None = None) -> Attribute:
+ """
+ Gets a uniqued SymbolRef attribute from a List of symbol names
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def static_typeid(self) -> TypeID: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> list[str]:
+ """
+ Returns the value of the SymbolRef attribute as a List[str]
+ """
+
+class SymbolTable:
+ @staticmethod
+ def get_symbol_name(symbol: _OperationBase) -> Attribute: ...
+ @staticmethod
+ def get_visibility(symbol: _OperationBase) -> Attribute: ...
+ @staticmethod
+ def replace_all_symbol_uses(
+ old_symbol: str, new_symbol: str, from_op: _OperationBase
+ ) -> None: ...
+ @staticmethod
+ def set_symbol_name(symbol: _OperationBase, name: str) -> None: ...
+ @staticmethod
+ def set_visibility(symbol: _OperationBase, visibility: str) -> None: ...
+ @staticmethod
+ def walk_symbol_tables(
+ from_op: _OperationBase,
+ all_sym_uses_visible: bool,
+ callback: Callable[[_OperationBase, bool], None],
+ ) -> None: ...
+ def __contains__(self, arg0: str) -> bool: ...
+ def __delitem__(self, arg0: str) -> None: ...
+ def __getitem__(self, arg0: str) -> OpView: ...
+ def __init__(self, arg0: _OperationBase) -> None: ...
+ def erase(self, operation: _OperationBase) -> None: ...
+ def insert(self, operation: _OperationBase) -> Attribute: ...
+
+class TupleType(Type):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get_tuple(elements: list[Type], context: Context | None = None) -> TupleType:
+ """
+ Create a Tuple type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ def get_type(self, pos: int) -> Type:
+ """
+ Returns the pos-th type in the Tuple type.
+ """
+ @property
+ def num_types(self) -> int:
+ """
+ Returns the number of types contained in a Tuple.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class TypeAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(value: Type, context: Context | None = None) -> TypeAttr:
+ """
+ Gets a uniqued Type attribute
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+ @property
+ def value(self) -> Type: ...
+
+class TypeID:
+ def _CAPICreate(self) -> TypeID: ...
+ @overload
+ def __eq__(self, arg0: TypeID) -> bool: ...
+ @overload
+ def __eq__(self, arg0: Any) -> bool: ...
+ def __hash__(self) -> int: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+
+class UnitAttr(Attribute):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Context | None = None) -> UnitAttr:
+ """
+ Create a Unit attribute.
+ """
+ @staticmethod
+ def isinstance(other: Attribute) -> bool: ...
+ def __init__(self, cast_from_attr: Attribute) -> None: ...
+ @property
+ def type(self) -> Type: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class UnrankedMemRefType(ShapedType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ element_type: Type, memory_space: Attribute, loc: Location | None = None
+ ) -> UnrankedMemRefType:
+ """
+ Create a unranked memref type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def memory_space(self) -> Attribute | None:
+ """
+ Returns the memory space of the given Unranked MemRef type.
+ """
+ @property
+ def typeid(self) -> TypeID: ...
+
+class UnrankedTensorType(ShapedType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(element_type: Type, loc: Location | None = None) -> UnrankedTensorType:
+ """
+ Create a unranked tensor type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class VectorType(ShapedType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(
+ shape: list[int],
+ element_type: Type,
+ *,
+ scalable: list | None = None,
+ scalable_dims: list[int] | None = None,
+ loc: Location | None = None,
+ ) -> VectorType:
+ """
+ Create a vector type
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def scalable(self) -> bool: ...
+ @property
+ def scalable_dims(self) -> list[bool]: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
+class _GlobalDebug:
+ flag: ClassVar[bool] = False
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
new file mode 100644
index 0000000..1010dad
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
@@ -0,0 +1,36 @@
+# Originally imported via:
+# stubgen {...} -m mlir._mlir_libs._mlir.passmanager
+# Local modifications:
+# * Relative imports for cross-module references.
+# * Add __all__
+
+
+from . import ir as _ir
+
+__all__ = [
+ "PassManager",
+]
+
+class PassManager:
+ def __init__(self, context: _ir.Context | None = None) -> None: ...
+ def _CAPICreate(self) -> object: ...
+ def _testing_release(self) -> None: ...
+ def enable_ir_printing(
+ self,
+ print_before_all: bool = False,
+ print_after_all: bool = True,
+ print_module_scope: bool = False,
+ print_after_change: bool = False,
+ print_after_failure: bool = False,
+ large_elements_limit: int | None = None,
+ large_resource_limit: int | None = None,
+ enable_debug_info: bool = False,
+ print_generic_op_form: bool = False,
+ tree_printing_dir_path: str | None = None,
+ ) -> None: ...
+ def enable_verifier(self, enable: bool) -> None: ...
+ @staticmethod
+ def parse(pipeline: str, context: _ir.Context | None = None) -> PassManager: ...
+ def run(self, module: _ir._OperationBase) -> None: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
new file mode 100644
index 0000000..4b82c78
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi
@@ -0,0 +1,24 @@
+# Originally imported via:
+# stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine
+# Local modifications:
+# * Relative imports for cross-module references.
+# * Add __all__
+
+from collections.abc import Sequence
+
+from ._mlir import ir as _ir
+
+__all__ = [
+ "ExecutionEngine",
+]
+
+class ExecutionEngine:
+ def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ...
+ def _CAPICreate(self) -> object: ...
+ def _testing_release(self) -> None: ...
+ def dump_to_object_file(self, file_name: str) -> None: ...
+ def raw_lookup(self, func_name: str) -> int: ...
+ def raw_register_runtime(self, name: str, callback: object) -> None: ...
+ def init() -> None: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 6f37266..7ddc70a 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -2,9 +2,18 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from __future__ import annotations
+
+from collections.abc import Iterable
+from contextlib import contextmanager
+
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
-from ._mlir_libs._mlir import register_type_caster, register_value_caster
+from ._mlir_libs._mlir import (
+ register_type_caster,
+ register_value_caster,
+ globals,
+)
from ._mlir_libs import (
get_dialect_registry,
append_load_on_create_dialect,
@@ -12,6 +21,30 @@ from ._mlir_libs import (
)
+@contextmanager
+def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]:
+ """Enables automatic traceback-based locations for MLIR operations.
+
+ Operations created within this context will have their location
+ automatically set based on the Python call stack.
+
+ Args:
+ max_depth: Maximum number of frames to include in the location.
+ If None, the default limit is used.
+ """
+ old_enabled = globals.loc_tracebacks_enabled()
+ old_limit = globals.loc_tracebacks_frame_limit()
+ try:
+ globals.set_loc_tracebacks_frame_limit(max_depth)
+ if not old_enabled:
+ globals.set_loc_tracebacks_enabled(True)
+ yield
+ finally:
+ if not old_enabled:
+ globals.set_loc_tracebacks_enabled(False)
+ globals.set_loc_tracebacks_frame_limit(old_limit)
+
+
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind, replace=False):
def decorator_builder(func):
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index abe0925..1a0075e 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,7 +1,6 @@
-nanobind>=2.9, <3.0
+nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
-ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
-typing_extensions>=4.12.2
+ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" \ No newline at end of file
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index a722acb..d362bb6 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -6,7 +6,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop {
- // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+ // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) collapse(2) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d3351..2d33888 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We unroll multi-dimensional to_elements ops with pattern
+// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
+// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index ed664a7..d6e36fa 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
- // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
- // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
- // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
- // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
- // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
- // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
- // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- %c8 = arith.constant 8 : index
- // CHECK: %[[C16:.*]] = arith.constant 16 : index
- %c16 = arith.constant 16 : index
- // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
- // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
- // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
- // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
- // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
- // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
- // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
- // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
- %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0f67dc2..0b150e9 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -1,239 +1,73 @@
// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
gpu.module @test {
-// CHECK-LABEL: @load_gather_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
- %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
- // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32>
- // CHECK: scf.yield %[[VAR9]] : vector<2xf32>
- // CHECK: } else {
- // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
- // CHECK: scf.yield %[[CST_1]] : vector<2xf32>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
- %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
- // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32>
- // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
- // CHECK: scf.yield %[[VAR9]] : f32
- // CHECK: } else {
- // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
- // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
- // CHECK: scf.yield %[[VAR8]] : f32
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @load_gather_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
// CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
- // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
- %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex>
- -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
- // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) {
- // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16>
- // CHECK: scf.yield %[[VAR8]] : vector<8xf16>
- // CHECK: } else {
- // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
- // CHECK: scf.yield %[[CST_0]] : vector<8xf16>
- %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
- %2 = arith.constant dense<2.9>: vector<2xf32>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
- %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
- // CHECK: scf.if %[[VAR4]] {
- // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
- // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1>
- xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
- : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
- %2 = arith.constant dense<2.9>: vector<2xf16>
- // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
- // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
- %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
- // CHECK: scf.if %[[VAR2]] {
- // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
- // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1>
- xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
- : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
+ // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
+ // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
+ // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
+ // CHECK: scf.yield %[[VAR8]] : f16
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
+ // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
+ // CHECK: scf.yield %[[VAR7]] : f16
+ // CHECK: }
+ %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
gpu.return
}
}
// -----
gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @store_scatter_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
// CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
- // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
+ // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
%2 = arith.constant dense<2.9>: vector<1xf32>
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
- %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64
// CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
// CHECK: scf.if %[[VAR2]] {
- // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
- // CHECK-SAME: : f32, !llvm.ptr<1>
- xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
- : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
+ // CHECK: llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
+ // CHECK: }
+ xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<1xf32>, i64, vector<1xindex>, vector<1xi1>
gpu.return
}
}
// -----
gpu.module @test {
-// CHECK-LABEL: @prefetch_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64
- %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
- // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
- xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @prefetch_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+// CHECK-LABEL: @prefetch_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- %0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
- // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
- %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
- // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
- xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64
+ // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR4]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %src[%offset] <{offset_align_byte=4, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : i64, vector<1xindex>
gpu.return
}
}
@@ -250,12 +84,10 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
- %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
// CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
// CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
- xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ xegpu.prefetch %src[%offset] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<256xf32>, vector<1xindex>
gpu.return
}
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index b28a8c2..2a2b99f 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -9,9 +9,9 @@ gpu.module @materializecast {
gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
// CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- %offset = arith.constant dense<0> : vector<1xindex>
- %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ %offset = arith.constant 0 : index
+ %mask = arith.constant 1 : i1
+ %val = xegpu.load %src[%offset], %mask : memref<128xf32>, index, i1 -> f32
gpu.return
}
}
@@ -23,9 +23,9 @@ gpu.module @materializecast {
gpu.func @materialize_ui64(%src: ui64) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- %offset = arith.constant dense<0> : vector<1xindex>
- %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ %offset = arith.constant 0 : index
+ %mask = arith.constant 1 : i1
+ %val = xegpu.load %src[%offset], %mask : ui64, index, i1 -> vector<1xf32>
gpu.return
}
}
@@ -37,9 +37,9 @@ gpu.module @materializecast {
gpu.func @materialize_ui32(%src: ui32) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32
- %offset = arith.constant dense<0> : vector<1xindex>
- %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ %offset = arith.constant 0 : index
+ %mask = arith.constant 1 : i1
+ %val = xegpu.load %src[%offset], %mask : ui32, index, i1 -> vector<1xf32>
gpu.return
}
}
@@ -52,24 +52,12 @@ gpu.module @materializecast {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64
+ // CHECK: %[[CST1:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR3:.*]] = vector.extract %[[CST1]][0] : i1 from vector<1xi1>
%offset = arith.constant dense<0> : vector<1xindex>
- %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ %mask = arith.constant dense<1> : vector<1xi1>
+ %val = xegpu.load %src[%offset], %mask : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1xf32>
gpu.return
}
}
-// -----
-gpu.module @materializecast {
- // CHECK-LABEL: gpu.func @materialize_single_elem_vector
- // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
- gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
- // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
- %mask = arith.constant dense<1>: vector<1xi1>
- %offset = arith.constant dense<0> : vector<1xindex>
- %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
- gpu.return
- }
-}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
deleted file mode 100644
index 6e59414..0000000
--- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
+++ /dev/null
@@ -1,25 +0,0 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-
-gpu.module @update_offset {
- // CHECK-LABEL: gpu.func @update_offset
- // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
- gpu.func @update_offset(%src: memref<128xf32>) kernel {
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
- // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
- %offset = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
- // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
- %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
- // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
- %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
- , vector<1xindex>
- gpu.return
- }
-}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index bdbb12b..8f60a07 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -242,3 +242,22 @@ llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
llvm.return
}
+// -----
+// CHECK-LABEL: llvm.func @llvm.load
+llvm.func @llvm.load(%a: !llvm.ptr<1>) -> i32 {
+ // CHECK: xevm.DecorationCacheControl =
+ // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
+ // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+ %val = llvm.load %a {cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>} : !llvm.ptr<1> -> i32
+ llvm.return %val : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @llvm.store
+llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) {
+ // CHECK: xevm.DecorationCacheControl =
+ // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32
+ // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32
+ llvm.store %val, %a {cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>} : i32, !llvm.ptr<1>
+ llvm.return
+}
diff --git a/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir
new file mode 100644
index 0000000..937c855
--- /dev/null
+++ b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s
+
+// Test cases specifically for sibling fusion. Note that sibling fusion test
+// cases also exist in loop-fusion*.mlir.
+
+// CHECK-LABEL: func @disjoint_stores
+func.func @disjoint_stores(%0: memref<8xf32>) {
+ %alloc_1 = memref.alloc() : memref<16xf32>
+ // The affine stores below are to different parts of the memrefs. Sibling
+ // fusion helps improve reuse and is valid.
+ affine.for %arg2 = 0 to 8 {
+ %2 = affine.load %0[%arg2] : memref<8xf32>
+ affine.store %2, %alloc_1[%arg2] : memref<16xf32>
+ }
+ affine.for %arg2 = 0 to 8 {
+ %2 = affine.load %0[%arg2] : memref<8xf32>
+ %3 = arith.negf %2 : f32
+ affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32>
+ }
+ // CHECK: affine.for
+ // CHECK-NOT: affine.for
+ return
+}
diff --git a/mlir/test/Dialect/LLVMIR/mmra.mlir b/mlir/test/Dialect/LLVMIR/mmra.mlir
new file mode 100644
index 0000000..95da966
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/mmra.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file --verify-roundtrip --mlir-print-local-scope | FileCheck %s
+
+// CHECK-LABEL: llvm.func @native
+// CHECK: llvm.load
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
+// CHECK: llvm.fence
+// CHECK-SAME: llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #llvm.mmra_tag<"foo":"bar">]
+// CHECK: llvm.store
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"foo":"bar">
+
+#mmra_tag = #llvm.mmra_tag<"foo":"bar">
+
+llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
+ %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
+ llvm.fence syncscope("workgroup-one-as") release
+ {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
+ llvm.store %0, %y {llvm.mmra = #llvm.mmra_tag<"foo":"bar">} : i32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @foreign_op
+// CHECK: rocdl.load.to.lds
+// CHECK-SAME: llvm.mmra = #llvm.mmra_tag<"fake":"example">
+llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
+ rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
+ llvm.return
+}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 986c384..763f41c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -159,6 +159,29 @@ func.func @no_loops(%lb : index, %ub : index, %step : index) {
// -----
+func.func @collapse_size(%lb : index, %ub : index, %step : index) {
+ omp.wsloop {
+ // expected-error@+1 {{collapse value is larger than the number of loops}}
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) collapse(4) {
+ omp.yield
+ }
+ }
+}
+
+// -----
+
+func.func @tiles_length(%lb : index, %ub : index, %step : index) {
+ omp.wsloop {
+ // expected-error@+1 {{op too few canonical loops for tile dimensions}}
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) tiles(2, 4) {
+ omp.yield
+ }
+ }
+}
+
+
+// -----
+
func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
// expected-error @below {{expected '{'}}
omp.wsloop nowait inclusive {
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3c2e0a3..60b1f61 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -376,6 +376,60 @@ func.func @omp_loop_nest_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32,
return
}
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse
+func.func @omp_loop_nest_pretty_multiple_collapse(%lb1 : i32, %ub1 : i32, %step1 : i32,
+ %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+ omp.wsloop {
+ // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2)
+ omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
+ %1 = "test.payload"(%iv1) : (i32) -> (index)
+ %2 = "test.payload"(%iv2) : (i32) -> (index)
+ memref.store %iv1, %data1[%1] : memref<?xi32>
+ memref.store %iv2, %data1[%2] : memref<?xi32>
+ omp.yield
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_tiles
+func.func @omp_loop_nest_pretty_multiple_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32,
+ %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+ omp.wsloop {
+ // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) tiles(5, 10)
+ omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) tiles(5, 10) {
+ %1 = "test.payload"(%iv1) : (i32) -> (index)
+ %2 = "test.payload"(%iv2) : (i32) -> (index)
+ memref.store %iv1, %data1[%1] : memref<?xi32>
+ memref.store %iv2, %data1[%2] : memref<?xi32>
+ omp.yield
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty_multiple_collapse_tiles
+func.func @omp_loop_nest_pretty_multiple_collapse_tiles(%lb1 : i32, %ub1 : i32, %step1 : i32,
+ %lb2 : i32, %ub2 : i32, %step2 : i32, %data1 : memref<?xi32>) -> () {
+
+ omp.wsloop {
+ // CHECK: omp.loop_nest (%{{.*}}, %{{.*}}) : i32 = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) collapse(2) tiles(5, 10)
+ omp.loop_nest (%iv1, %iv2) : i32 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) tiles(5, 10) {
+ %1 = "test.payload"(%iv1) : (i32) -> (index)
+ %2 = "test.payload"(%iv2) : (i32) -> (index)
+ memref.store %iv1, %data1[%1] : memref<?xi32>
+ memref.store %iv2, %data1[%2] : memref<?xi32>
+ omp.yield
+ }
+ }
+
+ return
+}
+
// CHECK-LABEL: omp_wsloop
func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32>, %linear_var : i32, %chunk_var : i32) -> () {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e8bfd0..fe697c8 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -538,3 +538,26 @@ func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3:
%1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
return %1 : vector<2x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK: %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/lit.local.cfg b/mlir/test/Dialect/Vector/lit.local.cfg
new file mode 100644
index 0000000..3e9e8f8
--- /dev/null
+++ b/mlir/test/Dialect/Vector/lit.local.cfg
@@ -0,0 +1,2 @@
+# Skip the directory with input TD sequences.
+config.excludes = ["td"]
diff --git a/mlir/test/Dialect/Vector/td/unroll-elements.mlir b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
new file mode 100644
index 0000000..40a90a3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/td/unroll-elements.mlir
@@ -0,0 +1,11 @@
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.unroll_to_elements
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000..9ec0d76
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
+// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
+// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK: return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+ %0:2 = vector.to_elements %arg0 : vector<2xf32>
+ return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
+// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
+// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8750582..bb76392 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1856,3 +1856,72 @@ func.func @negative_warp_step_more_than_warp_size(%laneid: index, %buffer: memre
// CHECK-PROP-LABEL: @negative_warp_step_more_than_warp_size
// CHECK-PROP-NOT: vector.broadcast
// CHECK-PROP: vector.step : vector<64xindex>
+
+// -----
+
+func.func @warp_scf_if_no_yield_distribute(%buffer: memref<128xindex>, %pred : i1) {
+ %laneid = gpu.lane_id
+ %c0 = arith.constant 0 : index
+
+ gpu.warp_execute_on_lane_0(%laneid)[32] {
+ %seq = vector.step : vector<32xindex>
+ scf.if %pred {
+ vector.store %seq, %buffer[%c0] : memref<128xindex>, vector<32xindex>
+ }
+ gpu.yield
+ }
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_no_yield_distribute(
+// CHECK-PROP-SAME: %[[ARG0:.+]]: memref<128xindex>, %[[ARG1:.+]]: i1
+// CHECK-PROP: scf.if %[[ARG1]] {
+// CHECK-PROP: gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<1xindex>) {
+// CHECK-PROP: ^bb0(%[[ARG2:.+]]: vector<32xindex>):
+// CHECK-PROP: vector.store %[[ARG2]], %[[ARG0]][%{{.*}}] : memref<128xindex>, vector<32xindex>
+
+// -----
+
+func.func @warp_scf_if_distribute(%pred : i1) {
+ %laneid = gpu.lane_id
+ %c0 = arith.constant 0 : index
+
+ %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> vector<1xf32> {
+ %seq1 = vector.step : vector<32xindex>
+ %seq2 = arith.constant dense<2> : vector<32xindex>
+ %0 = scf.if %pred -> (vector<32xf32>) {
+ %1 = "some_op"(%seq1) : (vector<32xindex>) -> (vector<32xf32>)
+ scf.yield %1 : vector<32xf32>
+ } else {
+ %2 = "other_op"(%seq2) : (vector<32xindex>) -> (vector<32xf32>)
+ scf.yield %2 : vector<32xf32>
+ }
+ gpu.yield %0 : vector<32xf32>
+ }
+ "some_use"(%0) : (vector<1xf32>) -> ()
+
+ return
+}
+
+// CHECK-PROP-LABEL: func.func @warp_scf_if_distribute(
+// CHECK-PROP-SAME: %[[ARG0:.+]]: i1
+// CHECK-PROP: %[[SEQ2:.+]] = arith.constant dense<2> : vector<32xindex>
+// CHECK-PROP: %[[LANE_ID:.+]] = gpu.lane_id
+// CHECK-PROP: %[[SEQ1:.+]] = vector.broadcast %[[LANE_ID]] : index to vector<1xindex>
+// CHECK-PROP: %[[IF_YIELD_DIST:.+]] = scf.if %[[ARG0]] -> (vector<1xf32>) {
+// CHECK-PROP: %[[THEN_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] args(%[[SEQ1]] : vector<1xindex>) -> (vector<1xf32>) {
+// CHECK-PROP: ^bb0(%[[ARG1:.+]]: vector<32xindex>):
+// CHECK-PROP: %{{.*}} = "some_op"(%[[ARG1]]) : (vector<32xindex>) -> vector<32xf32>
+// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[THEN_DIST]] : vector<1xf32>
+// CHECK-PROP: } else {
+// CHECK-PROP: %[[ELSE_DIST:.+]] = gpu.warp_execute_on_lane_0(%[[LANE_ID]])[32] -> (vector<1xf32>) {
+// CHECK-PROP: %{{.*}} = "other_op"(%[[SEQ2]]) : (vector<32xindex>) -> vector<32xf32>
+// CHECK-PROP: gpu.yield %{{.*}} : vector<32xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[ELSE_DIST]] : vector<1xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
+// CHECK-PROP: return
+// CHECK-PROP: }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index a39aa90..60acea0 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -339,6 +339,63 @@ gpu.module @test {
}
// -----
+// CHECK-LABEL: gpu.func @scatter_ops_scf_yield({{.*}},
+// CHECK-SAME: %[[PREDICATE:.*]]: i1) {
+// CHECK: %[[DEFAULT:.*]] = arith.constant dense<1.200000e+01> : vector<8xf16>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[PREDICATED_LOAD:.*]] = scf.if %[[PREDICATE]] -> (vector<8xf16>) {
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: scf.yield %[[LOADED]] : vector<8xf16>
+// CHECK-NEXT: } else {
+// CHECK-NEXT: scf.yield %[[DEFAULT]] : vector<8xf16>
+// CHECK-NEXT: }
+// CHECK-NEXT: xegpu.store %[[PREDICATED_LOAD]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops_scf_yield(%src: memref<256xf16>, %pred : i1) {
+ %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+ %loaded = scf.if %pred -> (vector<16x8xf16>) {
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ scf.yield %3 : vector<16x8xf16>
+ } else {
+ %3 = arith.constant {
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ } dense<12.> : vector<16x8xf16>
+ scf.yield %3 : vector<16x8xf16>
+ } { layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> }
+ xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) {
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1
+// CHECK: scf.if %[[PREDICATE]] {
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-NEXT: }
+gpu.module @test {
+ gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) {
+ %pred = llvm.mlir.poison : i1
+ %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+ scf.if %pred {
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ }
+ gpu.return
+ }
+}
+
+// -----
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
new file mode 100644
index 0000000..5e1ed37
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/metadata-mmra.ll
@@ -0,0 +1,22 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-DAG: #[[$MMRA0:.+]] = #llvm.mmra_tag<"foo":"bar">
+; CHECK-DAG: #[[$MMRA1:.+]] = #llvm.mmra_tag<"amdgpu-synchronize-as":"local">
+
+; CHECK-LABEL: llvm.func @native
+define void @native(ptr %x, ptr %y) {
+ ; CHECK: llvm.load
+ ; CHECK-SAME: llvm.mmra = #[[$MMRA0]]
+ %v = load i32, ptr %x, align 4, !mmra !0
+ ; CHECK: llvm.fence
+ ; CHECK-SAME: llvm.mmra = [#[[$MMRA1]], #[[$MMRA0]]]
+ fence syncscope("workgroup-one-as") release, !mmra !2
+ ; CHECK: llvm.store {{.*}}, !llvm.ptr{{$}}
+ store i32 %v, ptr %y, align 4, !mmra !3
+ ret void
+}
+
+!0 = !{!"foo", !"bar"}
+!1 = !{!"amdgpu-synchronize-as", !"local"}
+!2 = !{!1, !0}
+!3 = !{}
diff --git a/mlir/test/Target/LLVMIR/mmra.mlir b/mlir/test/Target/LLVMIR/mmra.mlir
new file mode 100644
index 0000000..5864e0e
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/mmra.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: define void @native
+// CHECK: load
+// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
+// CHECK: fence
+// CHECK-SAME: !mmra ![[MMRA1:[0-9]+]]
+// CHECK: store {{.*}}, align 4{{$}}
+
+#mmra_tag = #llvm.mmra_tag<"foo":"bar">
+
+llvm.func @native(%x: !llvm.ptr, %y: !llvm.ptr) {
+ %0 = llvm.load %x {llvm.mmra = #mmra_tag} : !llvm.ptr -> i32
+ llvm.fence syncscope("workgroup-one-as") release
+ {llvm.mmra = [#llvm.mmra_tag<"amdgpu-synchronize-as":"local">, #mmra_tag]}
+ llvm.store %0, %y {llvm.mmra = []} : i32, !llvm.ptr
+ llvm.return
+}
+
+// Actual MMRA metadata
+// CHECK-DAG: ![[MMRA0]] = !{!"foo", !"bar"}
+// CHECK-DAG: ![[MMRA_PART0:[0-9]+]] = !{!"amdgpu-synchronize-as", !"local"}
+// CHECK-DAG: ![[MMRA1]] = !{![[MMRA_PART0]], ![[MMRA0]]}
+
+// -----
+
+// CHECK-LABEL: define void @foreign_op
+// CHECK: call void @llvm.amdgcn.load.to.lds
+// CHECK-SAME: !mmra ![[MMRA0:[0-9]+]]
+llvm.func @foreign_op(%g: !llvm.ptr<1>, %l: !llvm.ptr<3>) {
+ rocdl.load.to.lds %g, %l, 4, 0, 0 {llvm.mmra = #llvm.mmra_tag<"fake":"example">} : !llvm.ptr<1>
+ llvm.return
+}
+
+// CHECK: ![[MMRA0]] = !{!"fake", !"example"}
diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
index b42e387..d84641f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir
@@ -9,7 +9,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
%loop_lb = llvm.mlir.constant(0 : i32) : i32
%loop_step = llvm.mlir.constant(1 : index) : i32
omp.wsloop {
- omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
+ omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) collapse(2) {
%1 = llvm.add %arg1, %arg2 : i32
%2 = llvm.mul %arg2, %loop_ub overflow<nsw> : i32
%3 = llvm.add %arg1, %2 :i32
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 3f4dcd5..27210bc 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -698,7 +698,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) {
// CHECK-LABEL: @simd_simple_multiple
llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.simd {
- omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) {
+ omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) collapse(2) {
%3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
// The form of the emitted IR is controlled by OpenMPIRBuilder and
// tested there. Just check that the right metadata is added and collapsed
@@ -736,7 +736,7 @@ llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64
// CHECK-LABEL: @simd_simple_multiple_simdlen
llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.simd simdlen(2) {
- omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
%3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
// The form of the emitted IR is controlled by OpenMPIRBuilder and
// tested there. Just check that the right metadata is added.
@@ -760,7 +760,7 @@ llvm.func @simd_simple_multiple_simdlen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
// CHECK-LABEL: @simd_simple_multiple_safelen
llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.simd safelen(2) {
- omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
%3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
%4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -779,7 +779,7 @@ llvm.func @simd_simple_multiple_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %l
// CHECK-LABEL: @simd_simple_multiple_simdlen_safelen
llvm.func @simd_simple_multiple_simdlen_safelen(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.simd simdlen(1) safelen(2) {
- omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) collapse(2) {
%3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
%4 = llvm.getelementptr %arg0[%iv1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%5 = llvm.getelementptr %arg1[%iv2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -1177,7 +1177,7 @@ llvm.func @collapse_wsloop(
// CHECK: store i32 %[[TOTAL_SUB_1]], ptr
// CHECK: call void @__kmpc_for_static_init_4u
omp.wsloop {
- omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) {
+ omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
%31 = llvm.load %20 : !llvm.ptr -> i32
%32 = llvm.add %31, %arg0 : i32
%33 = llvm.add %32, %arg1 : i32
@@ -1239,7 +1239,7 @@ llvm.func @collapse_wsloop_dynamic(
// CHECK: store i32 %[[TOTAL]], ptr
// CHECK: call void @__kmpc_dispatch_init_4u
omp.wsloop schedule(dynamic) {
- omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) {
+ omp.loop_nest (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
%31 = llvm.load %20 : !llvm.ptr -> i32
%32 = llvm.add %31, %arg0 : i32
%33 = llvm.add %32, %arg1 : i32
diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir
index a3dd0b6..112d923 100644
--- a/mlir/test/Target/LLVMIR/xevm.mlir
+++ b/mlir/test/Target/LLVMIR/xevm.mlir
@@ -19,3 +19,35 @@ module {
// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+// -----
+module {
+ // CHECK-LABEL: define i32 @load(ptr addrspace(1)
+ // CHECK-SAME: %[[ARG0:.*]]) {
+ llvm.func @load(%arg0: !llvm.ptr<1>) -> i32 {
+ // CHECK: load i32, ptr addrspace(1) %[[ARG0]], align 4,
+ // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+ %0 = llvm.load %arg0 {xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} : !llvm.ptr<1> -> i32
+ llvm.return %0 : i32
+ }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+
+// -----
+module {
+ // CHECK-LABEL: define void @store(ptr addrspace(1)
+ // CHECK-SAME: %[[ARG0:.*]], i32 %[[ARG1:.*]]) {
+ llvm.func @store(%arg0: !llvm.ptr<1>, %arg1: i32) {
+ // CHECK: store i32 %[[ARG1]], ptr addrspace(1) %[[ARG0]], align 4,
+ // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+ llvm.store %arg1, %arg0 {xevm.DecorationCacheControl = [[6443 : i32, 0 : i32, 2 : i32, 0 : i32], [6443 : i32, 1 : i32, 2 : i32, 0 : i32]]} : i32, !llvm.ptr<1>
+ llvm.return
+ }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6443, i32 0, i32 2, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6443, i32 1, i32 2, i32 0}
+
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index c05e15f..f2adca6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -35,7 +35,6 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index b414b47..4201ade 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -33,7 +33,6 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 231400e..5564264 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -23,7 +23,6 @@ include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -2322,10 +2321,10 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
}
//===----------------------------------------------------------------------===//
-// Test CopyOpInterface
+// Copy Operation Test
//===----------------------------------------------------------------------===//
-def CopyOp : TEST_Op<"copy", [CopyOpInterface]> {
+def CopyOp : TEST_Op<"copy", []> {
let description = [{
Represents a copy operation.
}];
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598e..d6596cd 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
}
};
+struct TestUnrollVectorToElements
+ : public PassWrapper<TestUnrollVectorToElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-to-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for to_elements ops";
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorToElementsLoweringPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestUnrollVectorFromElements>();
+ PassRegistration<TestUnrollVectorToElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index f213f50..87b41f9 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -543,3 +543,12 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
// REDUCE_EXC-NOT: NS::AOp declarations
// REDUCE_EXC-LABEL: NS::BOp declarations
+
+// CHECK-LABEL: _TypeInferredPropOp declarations
+def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [
+ AllTypesMatch<["value", "result"]>
+ ]> {
+ let arguments = (ins Property<"unsigned">:$prop, AnyType:$value);
+ let results = (outs AnyType:$result);
+ let hasCustomAssemblyFormat = 1;
+}
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 5a648fe..28902b0 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -48,6 +48,8 @@ def non_configurable_patterns():
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_from_elements
vector.ApplyUnrollFromElementsPatternsOp()
+ # CHECK: transform.apply_patterns.vector.unroll_to_elements
+ vector.ApplyUnrollToElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e..50c4210 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
"""
)
- def custom_pass_1(op):
+ def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
- def __call__(self, m):
- apply_patterns_and_fold_greedily(m, frozen)
+ def __call__(self, op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
custom_pass_2 = CustomPass2()
@@ -86,3 +86,17 @@ def testCustomPass():
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ def custom_pass_that_fails(op, pass_):
+ print("hello from pass that fails")
+ pass_.signal_pass_failure()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")
diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
index 10d602f..712237b 100644
--- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
+++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
@@ -10,8 +10,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
+#include "llvm/Support/LSP/Protocol.h"
using namespace mlir;
@@ -37,8 +37,8 @@ int main(int argc, char **argv) {
// Returns the registry, except in testing mode when the URI contains
// "-disable-lsp-registration". Testing for/example of registering dialects
// based on URI.
- auto registryFn = [&registry,
- &empty](const lsp::URIForFile &uri) -> DialectRegistry & {
+ auto registryFn = [&registry, &empty](
+ const llvm::lsp::URIForFile &uri) -> DialectRegistry & {
(void)empty;
#ifdef MLIR_INCLUDE_TESTS
if (uri.uri().contains("-disable-lsp-registration"))
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 8ea4eb7..4fdde76 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3849,9 +3849,9 @@ void OpEmitter::genTypeInterfaceMethods() {
const InferredResultType &infer = op.getInferredResultType(i);
if (!infer.isArg())
continue;
- Operator::OperandOrAttribute arg =
- op.getArgToOperandOrAttribute(infer.getIndex());
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+ Operator::OperandAttrOrProp arg =
+ op.getArgToOperandAttrOrProp(infer.getIndex());
+ if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
maxAccessedIndex =
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
}
@@ -3877,17 +3877,16 @@ void OpEmitter::genTypeInterfaceMethods() {
if (infer.isArg()) {
// If this is an operand, just index into operand list to access the
// type.
- Operator::OperandOrAttribute arg =
- op.getArgToOperandOrAttribute(infer.getIndex());
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+ Operator::OperandAttrOrProp arg =
+ op.getArgToOperandAttrOrProp(infer.getIndex());
+ if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
"].getType()")
.str();
// If this is an attribute, index into the attribute dictionary.
- } else {
- auto *attr =
- cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
+ } else if (auto *attr = dyn_cast<NamedAttribute *>(
+ op.getArg(arg.operandOrAttributeIndex()))) {
body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
<< " = ";
if (op.getDialect().usePropertiesForAttributes()) {
@@ -3907,6 +3906,9 @@ void OpEmitter::genTypeInterfaceMethods() {
typeStr =
("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
.str();
+ } else {
+ llvm::PrintFatalError(&op.getDef(),
+ "Properties cannot be used for type inference");
}
} else if (std::optional<StringRef> builder =
op.getResult(infer.getResultIndex())
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index 9ea6560..d7b442f 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -72,6 +72,8 @@ TEST(Bytecode, MultiModuleWithResource) {
ASSERT_TRUE(module);
// Write the module to bytecode.
+ // Ensure that reserveExtraSpace is called with the size needed to write the
+ // bytecode buffer.
MockOstream ostream;
EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
ostream.buffer = std::make_unique<std::byte[]>(space);
@@ -128,31 +130,28 @@ TEST(Bytecode, AlignmentFailure) {
ASSERT_TRUE(module);
// Write the module to bytecode.
- MockOstream ostream;
- EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
- ostream.buffer = std::make_unique<std::byte[]>(space);
- ostream.size = space;
- });
+ std::string serializedBytecode;
+ llvm::raw_string_ostream ostream(serializedBytecode);
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
// Create copy of buffer which is not aligned to requested resource alignment.
- std::string buffer((char *)ostream.buffer.get(),
- (char *)ostream.buffer.get() + ostream.size);
+ std::string buffer(serializedBytecode);
size_t bufferSize = buffer.size();
- // Increment into the buffer until we get to a power of 2 alignment that is
- // not 32 bit aligned.
+ // Increment into the buffer until we get to an address that is 2 byte aligned
+ // but not 32 byte aligned.
size_t pad = 0;
while (true) {
- if (llvm::isAddrAligned(Align(2), &buffer[pad]) &&
- !llvm::isAddrAligned(Align(32), &buffer[pad]))
+ if (llvm::isAddrAligned(Align(2), buffer.data() + pad) &&
+ !llvm::isAddrAligned(Align(32), buffer.data() + pad))
break;
pad++;
- buffer.reserve(bufferSize + pad);
+ // Pad the beginning of the buffer to push the start point to an unaligned
+ // value.
+ buffer.insert(0, 1, ' ');
}
- buffer.insert(0, pad, ' ');
StringRef alignedBuffer(buffer.data() + pad, bufferSize);
// Attach a diagnostic handler to get the error message.
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index c5f0d7e..89332bc 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -18,7 +18,6 @@ add_subdirectory(Support)
add_subdirectory(Rewrite)
add_subdirectory(TableGen)
add_subdirectory(Target)
-add_subdirectory(Tools)
add_subdirectory(Transforms)
if(MLIR_ENABLE_EXECUTION_ENGINE)
diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt
deleted file mode 100644
index a97588d..0000000
--- a/mlir/unittests/Tools/CMakeLists.txt
+++ /dev/null
@@ -1 +0,0 @@
-add_subdirectory(lsp-server-support)
diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt
deleted file mode 100644
index c539c9b..0000000
--- a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-add_mlir_unittest(MLIRLspServerSupportTests
- Protocol.cpp
- Transport.cpp
-)
-mlir_target_link_libraries(MLIRLspServerSupportTests
- PRIVATE
- MLIRLspServerSupportLib)
diff --git a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp b/mlir/unittests/Tools/lsp-server-support/Protocol.cpp
deleted file mode 100644
index 04d7b2f..0000000
--- a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-//===- Protocol.cpp - LSP JSON protocol unit tests ------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-
-#include "gtest/gtest.h"
-
-using namespace mlir;
-using namespace mlir::lsp;
-using namespace testing;
-
-namespace {
-
-TEST(ProtocolTest, DiagnosticTagPresent) {
- Diagnostic diagnostic;
- diagnostic.tags.push_back(DiagnosticTag::Unnecessary);
-
- llvm::json::Value json = toJSON(diagnostic);
- const llvm::json::Object *o = json.getAsObject();
- const llvm::json::Array *v = o->get("tags")->getAsArray();
- EXPECT_EQ(*v, llvm::json::Array{1});
-
- Diagnostic parsed;
- llvm::json::Path::Root root = llvm::json::Path::Root();
- bool success = fromJSON(json, parsed, llvm::json::Path(root));
- EXPECT_TRUE(success);
- ASSERT_EQ(parsed.tags.size(), (size_t)1);
- EXPECT_EQ(parsed.tags.at(0), DiagnosticTag::Unnecessary);
-}
-
-TEST(ProtocolTest, DiagnosticTagNotPresent) {
- Diagnostic diagnostic;
-
- llvm::json::Value json = toJSON(diagnostic);
- const llvm::json::Object *o = json.getAsObject();
- const llvm::json::Value *v = o->get("tags");
- EXPECT_EQ(v, nullptr);
-
- Diagnostic parsed;
- llvm::json::Path::Root root = llvm::json::Path::Root();
- bool success = fromJSON(json, parsed, llvm::json::Path(root));
- EXPECT_TRUE(success);
- EXPECT_TRUE(parsed.tags.empty());
-}
-
-} // namespace
diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/mlir/unittests/Tools/lsp-server-support/Transport.cpp
deleted file mode 100644
index 92581bd..0000000
--- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp
+++ /dev/null
@@ -1,205 +0,0 @@
-//===- Transport.cpp - LSP JSON transport unit tests ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Tools/lsp-server-support/Transport.h"
-#include "mlir/Tools/lsp-server-support/Logging.h"
-#include "mlir/Tools/lsp-server-support/Protocol.h"
-#include "llvm/Support/FileSystem.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using namespace mlir;
-using namespace mlir::lsp;
-using namespace testing;
-
-namespace {
-
-TEST(TransportTest, SendReply) {
- std::string out;
- llvm::raw_string_ostream os(out);
- JSONTransport transport(nullptr, os);
- MessageHandler handler(transport);
-
- transport.reply(1989, nullptr);
- EXPECT_THAT(out, HasSubstr("\"id\":1989"));
- EXPECT_THAT(out, HasSubstr("\"result\":null"));
-}
-
-class TransportInputTest : public Test {
- llvm::SmallVector<char> inputPath;
- std::FILE *in = nullptr;
- std::string output = "";
- llvm::raw_string_ostream os;
- std::optional<JSONTransport> transport = std::nullopt;
- std::optional<MessageHandler> messageHandler = std::nullopt;
-
-protected:
- TransportInputTest() : os(output) {}
-
- void SetUp() override {
- std::error_code ec =
- llvm::sys::fs::createTemporaryFile("lsp-unittest", "json", inputPath);
- ASSERT_FALSE(ec) << "Could not create temporary file: " << ec.message();
-
- in = std::fopen(inputPath.data(), "r");
- ASSERT_TRUE(in) << "Could not open temporary file: "
- << std::strerror(errno);
- transport.emplace(in, os, JSONStreamStyle::Delimited);
- messageHandler.emplace(*transport);
- }
-
- void TearDown() override {
- EXPECT_EQ(std::fclose(in), 0)
- << "Could not close temporary file FD: " << std::strerror(errno);
- std::error_code ec =
- llvm::sys::fs::remove(inputPath, /*IgnoreNonExisting=*/false);
- EXPECT_FALSE(ec) << "Could not remove temporary file '" << inputPath.data()
- << "': " << ec.message();
- }
-
- void writeInput(StringRef buffer) {
- std::error_code ec;
- llvm::raw_fd_ostream os(inputPath.data(), ec);
- ASSERT_FALSE(ec) << "Could not write to '" << inputPath.data()
- << "': " << ec.message();
- os << buffer;
- os.close();
- }
-
- StringRef getOutput() const { return output; }
- MessageHandler &getMessageHandler() { return *messageHandler; }
-
- void runTransport() {
- bool gotEOF = false;
- llvm::Error err = llvm::handleErrors(
- transport->run(*messageHandler), [&](const llvm::ECError &ecErr) {
- gotEOF = ecErr.convertToErrorCode() == std::errc::io_error;
- });
- llvm::consumeError(std::move(err));
- EXPECT_TRUE(gotEOF);
- }
-};
-
-TEST_F(TransportInputTest, RequestWithInvalidParams) {
- struct Handler {
- void onMethod(const TextDocumentItem &params,
- mlir::lsp::Callback<TextDocumentIdentifier> callback) {}
- } handler;
- getMessageHandler().method("invalid-params-request", &handler,
- &Handler::onMethod);
-
- writeInput("{\"jsonrpc\":\"2.0\",\"id\":92,"
- "\"method\":\"invalid-params-request\",\"params\":{}}\n");
- runTransport();
- EXPECT_THAT(getOutput(), HasSubstr("error"));
- EXPECT_THAT(getOutput(), HasSubstr("missing value at (root).uri"));
-}
-
-TEST_F(TransportInputTest, NotificationWithInvalidParams) {
- // JSON parsing errors are only reported via error logging. As a result, this
- // test can't make any expectations -- but it prints the output anyway, by way
- // of demonstration.
- Logger::setLogLevel(Logger::Level::Error);
-
- struct Handler {
- void onNotification(const TextDocumentItem &params) {}
- } handler;
- getMessageHandler().notification("invalid-params-notification", &handler,
- &Handler::onNotification);
-
- writeInput("{\"jsonrpc\":\"2.0\",\"method\":\"invalid-params-notification\","
- "\"params\":{}}\n");
- runTransport();
-}
-
-TEST_F(TransportInputTest, MethodNotFound) {
- writeInput("{\"jsonrpc\":\"2.0\",\"id\":29,\"method\":\"ack\"}\n");
- runTransport();
- EXPECT_THAT(getOutput(), HasSubstr("\"id\":29"));
- EXPECT_THAT(getOutput(), HasSubstr("\"error\""));
- EXPECT_THAT(getOutput(), HasSubstr("\"message\":\"method not found: ack\""));
-}
-
-TEST_F(TransportInputTest, OutgoingNotification) {
- auto notifyFn = getMessageHandler().outgoingNotification<CompletionList>(
- "outgoing-notification");
- notifyFn(CompletionList{});
- EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
-}
-
-TEST_F(TransportInputTest, ResponseHandlerNotFound) {
- // Unhandled responses are only reported via error logging. As a result, this
- // test can't make any expectations -- but it prints the output anyway, by way
- // of demonstration.
- Logger::setLogLevel(Logger::Level::Error);
- writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
- runTransport();
-}
-
-TEST_F(TransportInputTest, OutgoingRequest) {
- // Make some outgoing requests.
- int responseCallbackInvoked = 0;
- auto callFn =
- getMessageHandler().outgoingRequest<CompletionList, CompletionContext>(
- "outgoing-request",
- [&responseCallbackInvoked](const llvm::json::Value &id,
- llvm::Expected<CompletionContext> result) {
- // Make expectations on the expected response.
- EXPECT_EQ(id, 83);
- ASSERT_TRUE((bool)result);
- EXPECT_EQ(result->triggerKind, CompletionTriggerKind::Invoked);
- responseCallbackInvoked += 1;
- });
- callFn({}, 82);
- callFn({}, 83);
- callFn({}, 84);
- EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
- EXPECT_EQ(responseCallbackInvoked, 0);
-
- // One of the requests receives a response. The message handler handles this
- // response by invoking the callback from above. Subsequent responses with the
- // same ID are ignored.
- writeInput(
- "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":1}}\n"
- "// -----\n"
- "{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"triggerKind\":3}}\n");
- runTransport();
- EXPECT_EQ(responseCallbackInvoked, 1);
-}
-
-TEST_F(TransportInputTest, OutgoingRequestJSONParseFailure) {
- // Make an outgoing request that expects a failure response.
- bool responseCallbackInvoked = false;
- auto callFn = getMessageHandler().outgoingRequest<CompletionList, Position>(
- "outgoing-request-json-parse-failure",
- [&responseCallbackInvoked](const llvm::json::Value &id,
- llvm::Expected<Position> result) {
- llvm::Error err = result.takeError();
- EXPECT_EQ(id, 109);
- ASSERT_TRUE((bool)err);
- EXPECT_THAT(debugString(err),
- HasSubstr("failed to decode "
- "reply:outgoing-request-json-parse-failure(109) "
- "response: missing value at (root).character"));
- llvm::consumeError(std::move(err));
- responseCallbackInvoked += 1;
- });
- callFn({}, 109);
- EXPECT_EQ(responseCallbackInvoked, 0);
-
- // The request receives multiple responses, but only the first one triggers
- // the response callback. The first response has erroneous JSON that causes a
- // parse failure.
- writeInput("{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":7}}\n"
- "// -----\n"
- "{\"jsonrpc\":\"2.0\",\"id\":109,\"result\":{\"line\":3,"
- "\"character\":2}}\n");
- runTransport();
- EXPECT_EQ(responseCallbackInvoked, 1);
-}
-} // namespace