aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/cmake/modules/AddMLIRPython.cmake127
-rw-r--r--mlir/cmake/modules/MLIRDetectPythonEnv.cmake91
-rw-r--r--mlir/docs/Dialects/Linalg/OpDSL.md4
-rw-r--r--mlir/docs/Dialects/Transform.md6
-rw-r--r--mlir/examples/standalone/pyproject.toml4
-rw-r--r--mlir/examples/standalone/python/CMakeLists.txt18
-rw-r--r--mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp38
-rw-r--r--mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py6
-rw-r--r--mlir/examples/standalone/test/python/smoketest.py11
-rw-r--r--mlir/include/mlir/Bindings/Python/PybindAdaptors.h616
-rw-r--r--mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h3
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td18
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td9
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h56
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td50
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h1
-rw-r--r--mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td64
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp3
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp18
-rw-r--r--mlir/lib/Dialect/Tosa/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp42
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp87
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp94
-rw-r--r--mlir/python/CMakeLists.txt61
-rw-r--r--mlir/python/mlir/dialects/python_test.py11
-rw-r--r--mlir/python/mlir/ir.py14
-rw-r--r--mlir/python/requirements.txt2
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid.mlir4
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir33
-rw-r--r--mlir/test/Dialect/Tosa/dynamic_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_all_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-attach-target.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir33
-rw-r--r--mlir/test/python/dialects/python_test.py31
-rw-r--r--mlir/test/python/lib/CMakeLists.txt1
-rw-r--r--mlir/test/python/lib/PythonTestModulePybind11.cpp118
-rwxr-xr-xmlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in2
47 files changed, 1461 insertions, 253 deletions
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index ea34f94..fa6aec8 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -123,12 +123,12 @@ function(mlir_generate_type_stubs)
"IMPORT_PATHS;DEPENDS_TARGETS;OUTPUTS;DEPENDS_TARGET_SRC_DEPS"
${ARGN})
- # for people installing a distro (e.g., pip install) of nanobind
+ # for people doing find_package(nanobind)
if(EXISTS ${nanobind_DIR}/../src/stubgen.py)
set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py")
elseif(EXISTS ${nanobind_DIR}/../stubgen.py)
set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py")
- # for people using nanobind git source tree (e.g., FetchContent_Declare and FetchContent_MakeAvailable)
+ # for people using FetchContent_Declare and FetchContent_MakeAvailable
elseif(EXISTS ${nanobind_SOURCE_DIR}/src/stubgen.py)
set(NB_STUBGEN "${nanobind_SOURCE_DIR}/src/stubgen.py")
elseif(EXISTS ${nanobind_SOURCE_DIR}/stubgen.py)
@@ -226,10 +226,11 @@ endfunction()
# EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends
# on. These will be collected for all extensions and put into an
# aggregate dylib that is linked against.
+# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind.
function(declare_mlir_python_extension name)
cmake_parse_arguments(ARG
""
- "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT"
+ "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY"
"SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS"
${ARGN})
@@ -238,15 +239,20 @@ function(declare_mlir_python_extension name)
endif()
set(_install_destination "src/python/${name}")
+ if(NOT ARG_PYTHON_BINDINGS_LIBRARY)
+ set(ARG_PYTHON_BINDINGS_LIBRARY "pybind11")
+ endif()
+
add_library(${name} INTERFACE)
set_target_properties(${name} PROPERTIES
# Yes: Leading-lowercase property names are load bearing and the recommended
# way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261
- EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS"
+ EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY"
mlir_python_SOURCES_TYPE extension
mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}"
mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}"
mlir_python_DEPENDS ""
+ mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}"
)
# Set the interface source and link_libs properties of the target
@@ -335,12 +341,14 @@ function(add_mlir_python_modules name)
elseif(_source_type STREQUAL "extension")
# Native CPP extension.
get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME)
+ get_target_property(_bindings_library ${sources_target} mlir_python_BINDINGS_LIBRARY)
# Transform relative source to based on root dir.
set(_extension_target "${modules_target}.extension.${_module_name}.dso")
add_mlir_python_extension(${_extension_target} "${_module_name}"
INSTALL_COMPONENT ${modules_target}
INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs"
OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs"
+ PYTHON_BINDINGS_LIBRARY ${_bindings_library}
LINK_LIBS PRIVATE
${sources_target}
${ARG_COMMON_CAPI_LINK_LIBS}
@@ -745,7 +753,7 @@ endfunction()
function(add_mlir_python_extension libname extname)
cmake_parse_arguments(ARG
""
- "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY"
+ "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;PYTHON_BINDINGS_LIBRARY"
"SOURCES;LINK_LIBS"
${ARGN})
if(ARG_UNPARSED_ARGUMENTS)
@@ -753,7 +761,7 @@ function(add_mlir_python_extension libname extname)
endif()
# The extension itself must be compiled with RTTI and exceptions enabled.
- # Also, some warning classes triggered by nanobind are disabled.
+ # Also, some warning classes triggered by pybind11 are disabled.
set(eh_rtti_enable)
if (MSVC)
set(eh_rtti_enable /EHsc /GR)
@@ -761,53 +769,62 @@ function(add_mlir_python_extension libname extname)
set(eh_rtti_enable -frtti -fexceptions)
endif ()
- nanobind_add_module(${libname}
- NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
- FREE_THREADED
- ${ARG_SOURCES}
- )
+ # The actual extension library produces a shared-object or DLL and has
+ # sources that must be compiled in accordance with pybind11 needs (RTTI and
+ # exceptions).
+ if(NOT DEFINED ARG_PYTHON_BINDINGS_LIBRARY OR ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "pybind11")
+ pybind11_add_module(${libname}
+ ${ARG_SOURCES}
+ )
+ elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
+ nanobind_add_module(${libname}
+ NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN}
+ FREE_THREADED
+ ${ARG_SOURCES}
+ )
- if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
- AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL))
- # Avoid some warnings from upstream nanobind.
- # If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let
- # the super project handle compile options as it wishes.
- get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
- target_compile_options(${NB_LIBRARY_TARGET_NAME}
- PRIVATE
- -Wall -Wextra -Wpedantic
- -Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- ${eh_rtti_enable})
-
- target_compile_options(${libname}
- PRIVATE
- -Wall -Wextra -Wpedantic
- -Wno-c++98-compat-extra-semi
- -Wno-cast-qual
- -Wno-covered-switch-default
- -Wno-deprecated-literal-operator
- -Wno-nested-anon-types
- -Wno-unused-parameter
- -Wno-zero-length-array
- ${eh_rtti_enable})
- endif()
+ if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES
+ AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL))
+ # Avoid some warnings from upstream nanobind.
+ # If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let
+ # the super project handle compile options as it wishes.
+ get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES)
+ target_compile_options(${NB_LIBRARY_TARGET_NAME}
+ PRIVATE
+ -Wall -Wextra -Wpedantic
+ -Wno-c++98-compat-extra-semi
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ ${eh_rtti_enable})
+
+ target_compile_options(${libname}
+ PRIVATE
+ -Wall -Wextra -Wpedantic
+ -Wno-c++98-compat-extra-semi
+ -Wno-cast-qual
+ -Wno-covered-switch-default
+ -Wno-deprecated-literal-operator
+ -Wno-nested-anon-types
+ -Wno-unused-parameter
+ -Wno-zero-length-array
+ ${eh_rtti_enable})
+ endif()
- if(APPLE)
- # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
- # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
- # for downstream users that do not do something like `-undefined dynamic_lookup`.
- # Same for the rest.
- target_link_options(${libname} PUBLIC
- "LINKER:-U,_PyClassMethod_New"
- "LINKER:-U,_PyCode_Addr2Location"
- "LINKER:-U,_PyFrame_GetLasti"
- )
+ if(APPLE)
+ # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind
+ # doesn't declare this API as undefined in its linker flags. So we need to declare it as such
+ # for downstream users that do not do something like `-undefined dynamic_lookup`.
+ # Same for the rest.
+ target_link_options(${libname} PUBLIC
+ "LINKER:-U,_PyClassMethod_New"
+ "LINKER:-U,_PyCode_Addr2Location"
+ "LINKER:-U,_PyFrame_GetLasti"
+ )
+ endif()
endif()
target_compile_options(${libname} PRIVATE ${eh_rtti_enable})
@@ -845,11 +862,11 @@ function(add_mlir_python_extension libname extname)
if(WIN32)
# On Windows, pyconfig.h (and by extension python.h) hardcode the version of the
# python library which will be used for linkage depending on the flavor of the build.
- # nanobind has a workaround which depends on the definition of Py_DEBUG (if Py_DEBUG
- # is not passed in as a compile definition, nanobind undefs _DEBUG when including
+ # pybind11 has a workaround which depends on the definition of Py_DEBUG (if Py_DEBUG
+ # is not passed in as a compile definition, pybind11 undefs _DEBUG when including
# python.h, so that the release python library would be used).
- # Since mlir uses nanobind, we can leverage their workaround by never directly
- # pyconfig.h or python.h and instead relying on the nanobind headers to include the
+ # Since mlir uses pybind11, we can leverage their workaround by never directly
+ # pyconfig.h or python.h and instead relying on the pybind11 headers to include the
# necessary python headers. This results in mlir always linking against the
# release python library via the (undocumented) cmake property Python3_LIBRARY_RELEASE.
target_link_libraries(${libname} PRIVATE ${Python3_LIBRARY_RELEASE})
diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
index edbad2e..d18f8c0 100644
--- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
+++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake
@@ -46,20 +46,81 @@ macro(mlir_configure_python_dev_packages)
message(STATUS "Found python include dirs: ${Python3_INCLUDE_DIRS}")
message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}")
- message(STATUS "Python extension suffix for modules: '${Python3_SOABI}'")
- if(nanobind_DIR)
- message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)")
- find_package(nanobind 2.9 CONFIG REQUIRED)
- else()
- include(FetchContent)
- FetchContent_Declare(
- nanobind
- GIT_REPOSITORY https://github.com/wjakob/nanobind.git
- GIT_TAG v2.9.0
- GIT_SHALLOW TRUE
- )
- FetchContent_MakeAvailable(nanobind)
- endif()
- message(STATUS "Found nanobind: ${NB_DIR}")
+ mlir_detect_pybind11_install()
+ find_package(pybind11 2.10 CONFIG REQUIRED)
+ message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
+ message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
+ "suffix = '${PYTHON_MODULE_SUFFIX}', "
+ "extension = '${PYTHON_MODULE_EXTENSION}")
+
+ mlir_detect_nanobind_install()
+ find_package(nanobind 2.9 CONFIG REQUIRED)
+ message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
+ message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
+ "suffix = '${PYTHON_MODULE_SUFFIX}', "
+ "extension = '${PYTHON_MODULE_EXTENSION}")
endif()
endmacro()
+
+# Detects a pybind11 package installed in the current python environment
+# and sets variables to allow it to be found. This allows pybind11 to be
+# installed via pip, which typically yields a much more recent version than
+# the OS install, which will be available otherwise.
+function(mlir_detect_pybind11_install)
+ if(pybind11_DIR)
+ message(STATUS "Using explicit pybind11 cmake directory: ${pybind11_DIR} (-Dpybind11_DIR to change)")
+ else()
+ message(STATUS "Checking for pybind11 in python path...")
+ execute_process(
+ COMMAND "${Python3_EXECUTABLE}"
+ -c "import pybind11;print(pybind11.get_cmake_dir(), end='')"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE PACKAGE_DIR
+ ERROR_QUIET)
+ if(NOT STATUS EQUAL "0")
+ message(STATUS "not found (install via 'pip install pybind11' or set pybind11_DIR)")
+ return()
+ endif()
+ message(STATUS "found (${PACKAGE_DIR})")
+ set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
+ endif()
+endfunction()
+
+
+# Detects a nanobind package installed in the current python environment
+# and sets variables to allow it to be found. This allows nanobind to be
+# installed via pip, which typically yields a much more recent version than
+# the OS install, which will be available otherwise.
+function(mlir_detect_nanobind_install)
+ if(nanobind_DIR)
+ message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)")
+ else()
+ message(STATUS "Checking for nanobind in python path...")
+ execute_process(
+ COMMAND "${Python3_EXECUTABLE}"
+ -c "import nanobind;print(nanobind.cmake_dir(), end='')"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE PACKAGE_DIR
+ ERROR_QUIET)
+ if(NOT STATUS EQUAL "0")
+ message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
+ return()
+ endif()
+ message(STATUS "found (${PACKAGE_DIR})")
+ set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
+ execute_process(
+ COMMAND "${Python3_EXECUTABLE}"
+ -c "import nanobind;print(nanobind.include_dir(), end='')"
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE PACKAGE_DIR
+ ERROR_QUIET)
+ if(NOT STATUS EQUAL "0")
+ message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)")
+ return()
+ endif()
+ set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
+ endif()
+endfunction()
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 5d7e274..b892bbe 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -16,7 +16,7 @@ corresponding `linalg.generic` IR for the composition.
## Basic usage
The tool is bundled with the MLIR Python bindings. To use from the CMake build
-tree, MLIR must be built with Python bindings enabled
+tree, MLIR must be build with Python bindings enabled
(`-DMLIR_ENABLE_BINDINGS_PYTHON=ON`). Then add the `python` directory in the
build tree to your `PYTHONPATH` environment variable (i.e. `export
PYTHONPATH=$PWD/build/tools/mlir/python_packages/mlir_core`). Optionally, use an
@@ -24,7 +24,7 @@ installed MLIR package, if available, to avoid building.
```shell
# Dump the `core_named_ops.py` module as YAML.
-python -m mlir.dialects.linalg.opdsl.dump_oplib.ops.core_named_ops
+python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops
```
Alternatively, run the `$PWD/build/bin/update_core_linalg_named_ops.sh` script,
diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md
index 7164cb7..2133b81 100644
--- a/mlir/docs/Dialects/Transform.md
+++ b/mlir/docs/Dialects/Transform.md
@@ -415,10 +415,14 @@ ops rather than having the methods directly act on the payload IR.
[include "Dialects/TransformOps.md"]
-## Tuning Extension Operaiton
+## Tune Extension Operations
[include "Dialects/TuneExtensionOps.md"]
+## SMT Extension Operations
+
+[include "Dialects/SMTExtensionOps.md"]
+
## Affine Transform Operations
[include "Dialects/AffineLoopTransformOps.md"]
diff --git a/mlir/examples/standalone/pyproject.toml b/mlir/examples/standalone/pyproject.toml
index 75e2153..5a1e6e8 100644
--- a/mlir/examples/standalone/pyproject.toml
+++ b/mlir/examples/standalone/pyproject.toml
@@ -23,7 +23,9 @@ Discussions = "https://discourse.llvm.org/"
[build-system]
requires = [
"scikit-build-core>=0.10.7",
- "typing_extensions>=4.12.2"
+ "typing_extensions>=4.12.2",
+ "nanobind>=2.9, <3.0",
+ "pybind11>=2.10.0, <=2.13.6",
]
build-backend = "scikit_build_core.build"
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index 108c343..905c9449 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -16,10 +16,27 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir_standalone"
TD_FILE dialects/StandaloneOps.td
SOURCES
+ dialects/standalone_pybind11.py
dialects/standalone_nanobind.py
_mlir_libs/_standaloneDialectsNanobind/py.typed
DIALECT_NAME standalone)
+
+declare_mlir_python_extension(StandalonePythonSources.Pybind11Extension
+ MODULE_NAME _standaloneDialectsPybind11
+ ADD_TO_PARENT StandalonePythonSources
+ SOURCES
+ StandaloneExtensionPybind11.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIArith
+ MLIRCAPITransforms
+ StandaloneCAPI
+ PYTHON_BINDINGS_LIBRARY pybind11
+)
+
declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
MODULE_NAME _standaloneDialectsNanobind
ADD_TO_PARENT StandalonePythonSources
@@ -32,6 +49,7 @@ declare_mlir_python_extension(StandalonePythonSources.NanobindExtension
MLIRCAPIArith
MLIRCAPITransforms
StandaloneCAPI
+ PYTHON_BINDINGS_LIBRARY nanobind
)
diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
new file mode 100644
index 0000000..da8c216
--- /dev/null
+++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
@@ -0,0 +1,38 @@
+//===- StandaloneExtensionPybind11.cpp - Extension module -----------------===//
+//
+// This is the pybind11 version of the example module. There is also a nanobind
+// example in StandaloneExtensionNanobind.cpp.
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Standalone-c/Dialects.h"
+#include "mlir-c/Dialect/Arith.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+using namespace mlir::python::adaptors;
+
+PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
+ //===--------------------------------------------------------------------===//
+ // standalone dialect
+ //===--------------------------------------------------------------------===//
+ auto standaloneM = m.def_submodule("standalone");
+
+ standaloneM.def(
+ "register_dialects",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle arithHandle = mlirGetDialectHandle__arith__();
+ MlirDialectHandle standaloneHandle =
+ mlirGetDialectHandle__standalone__();
+ mlirDialectHandleRegisterDialect(arithHandle, context);
+ mlirDialectHandleRegisterDialect(standaloneHandle, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(arithHandle, context);
+ mlirDialectHandleRegisterDialect(standaloneHandle, context);
+ }
+ },
+ py::arg("context") = py::none(), py::arg("load") = true);
+}
diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py
new file mode 100644
index 0000000..bfb98e40
--- /dev/null
+++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._standalone_ops_gen import *
+from .._mlir_libs._standaloneDialectsPybind11.standalone import *
diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py
index f881984..26d84fd 100644
--- a/mlir/examples/standalone/test/python/smoketest.py
+++ b/mlir/examples/standalone/test/python/smoketest.py
@@ -1,7 +1,16 @@
+# RUN: %python %s pybind11 | FileCheck %s
# RUN: %python %s nanobind | FileCheck %s
+import sys
from mlir_standalone.ir import *
-from mlir_standalone.dialects import standalone_nanobind as standalone_d
+
+if sys.argv[1] == "pybind11":
+ from mlir_standalone.dialects import standalone_pybind11 as standalone_d
+elif sys.argv[1] == "nanobind":
+ from mlir_standalone.dialects import standalone_nanobind as standalone_d
+else:
+ raise ValueError("Expected either pybind11 or nanobind as arguments")
+
with Context():
standalone_d.register_dialects()
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
new file mode 100644
index 0000000..edc6977
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -0,0 +1,616 @@
+//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains adaptors for clients of the core MLIR Python APIs to
+// interop via MLIR CAPI types, using pybind11. The facilities here do not
+// depend on implementation details of the MLIR Python API and do not introduce
+// C++-level dependencies with it (requiring only Python and CAPI-level
+// dependencies).
+//
+// It is encouraged to be used both in-tree and out-of-tree. For in-tree use
+// cases, it should be used for dialect implementations (versus relying on
+// Pybind-based internals of the core libraries).
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
+#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
+
+#include <pybind11/functional.h>
+#include <pybind11/pybind11.h>
+#include <pybind11/pytypes.h>
+#include <pybind11/stl.h>
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/IR.h"
+
+#include "llvm/ADT/Twine.h"
+
+namespace py = pybind11;
+using namespace py::literals;
+
+// Raw CAPI type casters need to be declared before use, so always include them
+// first.
+namespace pybind11 {
+namespace detail {
+
+/// Helper to convert a presumed MLIR API object to a capsule, accepting either
+/// an explicit Capsule (which can happen when two C APIs are communicating
+/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR
+/// attribute (through which supported MLIR Python API objects export their
+/// contained API pointer as a capsule). Throws a type error if the object is
+/// neither. This is intended to be used from type casters, which are invoked
+/// with a raw handle (unowned). The returned object's lifetime may not extend
+/// beyond the apiObject handle without explicitly having its refcount increased
+/// (i.e. on return).
+static py::object mlirApiObjectToCapsule(py::handle apiObject) {
+ if (PyCapsule_CheckExact(apiObject.ptr()))
+ return py::reinterpret_borrow<py::object>(apiObject);
+ if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) {
+ auto repr = py::repr(apiObject).cast<std::string>();
+ throw py::type_error(
+ (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str());
+ }
+ return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
+}
+
+// Note: Currently all of the following support cast from py::object to the
+// Mlir* C-API type, but only a few light-weight, context-bound ones
+// implicitly cast the other way because the use case has not yet emerged and
+// ownership is unclear.
+
+/// Casts object <-> MlirAffineMap.
+template <>
+struct type_caster<MlirAffineMap> {
+ PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToAffineMap(capsule.ptr());
+ if (mlirAffineMapIsNull(value)) {
+ return false;
+ }
+ return !mlirAffineMapIsNull(value);
+ }
+ static handle cast(MlirAffineMap v, return_value_policy, handle) {
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("AffineMap")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirAttribute.
+template <>
+struct type_caster<MlirAttribute> {
+ PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToAttribute(capsule.ptr());
+ return !mlirAttributeIsNull(value);
+ }
+ static handle cast(MlirAttribute v, return_value_policy, handle) {
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Attribute")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ }
+};
+
+/// Casts object -> MlirBlock.
+template <>
+struct type_caster<MlirBlock> {
+ PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToBlock(capsule.ptr());
+ return !mlirBlockIsNull(value);
+ }
+};
+
+/// Casts object -> MlirContext.
+template <>
+struct type_caster<MlirContext> {
+ PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext"));
+ bool load(handle src, bool) {
+ if (src.is_none()) {
+ // Gets the current thread-bound context.
+ // TODO: This raises an error of "No current context" currently.
+ // Update the implementation to pretty-print the helpful error that the
+ // core implementations print in this case.
+ src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Context")
+ .attr("current");
+ }
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToContext(capsule.ptr());
+ return !mlirContextIsNull(value);
+ }
+};
+
+/// Casts object <-> MlirDialectRegistry.
+template <>
+struct type_caster<MlirDialectRegistry> {
+ PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToDialectRegistry(capsule.ptr());
+ return !mlirDialectRegistryIsNull(value);
+ }
+ static handle cast(MlirDialectRegistry v, return_value_policy, handle) {
+ py::object capsule = py::reinterpret_steal<py::object>(
+ mlirPythonDialectRegistryToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("DialectRegistry")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirLocation.
+template <>
+struct type_caster<MlirLocation> {
+ PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation"));
+ bool load(handle src, bool) {
+ if (src.is_none()) {
+ // Gets the current thread-bound context.
+ src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Location")
+ .attr("current");
+ }
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToLocation(capsule.ptr());
+ return !mlirLocationIsNull(value);
+ }
+ static handle cast(MlirLocation v, return_value_policy, handle) {
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Location")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ }
+};
+
+/// Casts object <-> MlirModule.
+template <>
+struct type_caster<MlirModule> {
+ PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToModule(capsule.ptr());
+ return !mlirModuleIsNull(value);
+ }
+ static handle cast(MlirModule v, return_value_policy, handle) {
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Module")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirFrozenRewritePatternSet.
+template <>
+struct type_caster<MlirFrozenRewritePatternSet> {
+ PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
+ _("MlirFrozenRewritePatternSet"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
+ return value.ptr != nullptr;
+ }
+ static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
+ handle) {
+ py::object capsule = py::reinterpret_steal<py::object>(
+ mlirPythonFrozenRewritePatternSetToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
+ .attr("FrozenRewritePatternSet")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirOperation.
+template <>
+struct type_caster<MlirOperation> {
+ PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToOperation(capsule.ptr());
+ return !mlirOperationIsNull(value);
+ }
+ static handle cast(MlirOperation v, return_value_policy, handle) {
+ if (v.ptr == nullptr)
+ return py::none();
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Operation")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirValue.
+template <>
+struct type_caster<MlirValue> {
+ PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToValue(capsule.ptr());
+ return !mlirValueIsNull(value);
+ }
+ static handle cast(MlirValue v, return_value_policy, handle) {
+ if (v.ptr == nullptr)
+ return py::none();
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Value")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ };
+};
+
+/// Casts object -> MlirPassManager.
+template <>
+struct type_caster<MlirPassManager> {
+ PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToPassManager(capsule.ptr());
+ return !mlirPassManagerIsNull(value);
+ }
+};
+
+/// Casts object <-> MlirTypeID.
+template <>
+struct type_caster<MlirTypeID> {
+ PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToTypeID(capsule.ptr());
+ return !mlirTypeIDIsNull(value);
+ }
+ static handle cast(MlirTypeID v, return_value_policy, handle) {
+ if (v.ptr == nullptr)
+ return py::none();
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("TypeID")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .release();
+ };
+};
+
+/// Casts object <-> MlirType.
+template <>
+struct type_caster<MlirType> {
+ PYBIND11_TYPE_CASTER(MlirType, _("MlirType"));
+ bool load(handle src, bool) {
+ py::object capsule = mlirApiObjectToCapsule(src);
+ value = mlirPythonCapsuleToType(capsule.ptr());
+ return !mlirTypeIsNull(value);
+ }
+ static handle cast(MlirType t, return_value_policy, handle) {
+ py::object capsule =
+ py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t));
+ return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Type")
+ .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
+ .release();
+ }
+};
+
+} // namespace detail
+} // namespace pybind11
+
+namespace mlir {
+namespace python {
+namespace adaptors {
+
+/// Provides a facility like py::class_ for defining a new class in a scope,
+/// but this allows extension of an arbitrary Python class, defining methods
+/// on it is a similar way. Classes defined in this way are very similar to
+/// if defined in Python in the usual way but use Pybind11 machinery to do
+/// it. These are not "real" Pybind11 classes but pure Python classes with no
+/// relation to a concrete C++ class.
+///
+/// Derived from a discussion upstream:
+/// https://github.com/pybind/pybind11/issues/1193
+/// (plus a fair amount of extra curricular poking)
+/// TODO: If this proves useful, see about including it in pybind11.
+class pure_subclass {
+public:
+ pure_subclass(py::handle scope, const char *derivedClassName,
+ const py::object &superClass) {
+ py::object pyType =
+ py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
+ py::object metaclass = pyType(superClass);
+ py::dict attributes;
+
+ thisClass =
+ metaclass(derivedClassName, py::make_tuple(superClass), attributes);
+ scope.attr(derivedClassName) = thisClass;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
+ py::cpp_function cf(
+ std::forward<Func>(f), py::name(name), py::is_method(thisClass),
+ py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ thisClass.attr(cf.name()) = cf;
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_property_readonly(const char *name, Func &&f,
+ const Extra &...extra) {
+ py::cpp_function cf(
+ std::forward<Func>(f), py::name(name), py::is_method(thisClass),
+ py::sibling(py::getattr(thisClass, name, py::none())), extra...);
+ auto builtinProperty =
+ py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type);
+ thisClass.attr(name) = builtinProperty(cf);
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_staticmethod(const char *name, Func &&f,
+ const Extra &...extra) {
+ static_assert(!std::is_member_function_pointer<Func>::value,
+ "def_staticmethod(...) called with a non-static member "
+ "function pointer");
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
+ thisClass.attr(cf.name()) = py::staticmethod(cf);
+ return *this;
+ }
+
+ template <typename Func, typename... Extra>
+ pure_subclass &def_classmethod(const char *name, Func &&f,
+ const Extra &...extra) {
+ static_assert(!std::is_member_function_pointer<Func>::value,
+ "def_classmethod(...) called with a non-static member "
+ "function pointer");
+ py::cpp_function cf(std::forward<Func>(f), py::name(name),
+ py::scope(thisClass), extra...);
+ thisClass.attr(cf.name()) =
+ py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
+ return *this;
+ }
+
+ py::object get_class() const { return thisClass; }
+
+protected:
+ py::object superClass;
+ py::object thisClass;
+};
+
+/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting
+/// constructor and type checking methods.
+class mlir_attribute_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirAttribute);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_attribute_subclass(py::handle scope, const char *attrClassName,
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : mlir_attribute_subclass(
+ scope, attrClassName, isaFunction,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("Attribute"),
+ getTypeIDFunction) {}
+
+ /// Subclasses with a provided mlir.ir.Attribute super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_attribute_subclass(py::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : pure_subclass(scope, typeClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in pybind11 due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureTypeName(
+ typeClassName); // As string in case if typeClassName is not static.
+ py::cpp_function newCf(
+ [superCls, isaFunction, captureTypeName](py::object cls,
+ py::object otherAttribute) {
+ MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute);
+ if (!isaFunction(rawAttribute)) {
+ auto origRepr = py::repr(otherAttribute).cast<std::string>();
+ throw std::invalid_argument(
+ (llvm::Twine("Cannot cast attribute to ") + captureTypeName +
+ " (from " + origRepr + ")")
+ .str());
+ }
+ py::object self = superCls.attr("__new__")(cls, otherAttribute);
+ return self;
+ },
+ py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirAttribute other) { return isaFunction(other); },
+ py::arg("other_attribute"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirAttribute) {
+ return thisClass(mlirAttribute);
+ }));
+ }
+ }
+};
+
+/// Creates a custom subclass of mlir.ir.Type, implementing a casting
+/// constructor and type checking methods.
+class mlir_type_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_type_subclass(py::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : mlir_type_subclass(
+ scope, typeClassName, isaFunction,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
+ getTypeIDFunction) {}
+
+ /// Subclasses with a provided mlir.ir.Type super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_type_subclass(py::handle scope, const char *typeClassName,
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
+ : pure_subclass(scope, typeClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in pybind11 due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureTypeName(
+ typeClassName); // As string in case if typeClassName is not static.
+ py::cpp_function newCf(
+ [superCls, isaFunction, captureTypeName](py::object cls,
+ py::object otherType) {
+ MlirType rawType = py::cast<MlirType>(otherType);
+ if (!isaFunction(rawType)) {
+ auto origRepr = py::repr(otherType).cast<std::string>();
+ throw std::invalid_argument((llvm::Twine("Cannot cast type to ") +
+ captureTypeName + " (from " +
+ origRepr + ")")
+ .str());
+ }
+ py::object self = superCls.attr("__new__")(cls, otherType);
+ return self;
+ },
+ py::name("__new__"), py::arg("cls"), py::arg("cast_from_type"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirType other) { return isaFunction(other); },
+ py::arg("other_type"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ // 'get_static_typeid' method.
+ // This is modeled as a static method instead of a static property because
+ // `def_property_readonly_static` is not available in `pure_subclass` and
+ // we do not want to introduce the complexity that pybind uses to
+ // implement it.
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirType) {
+ return thisClass(mlirType);
+ }));
+ }
+ }
+};
+
+/// Creates a custom subclass of mlir.ir.Value, implementing a casting
+/// constructor and type checking methods.
+class mlir_value_subclass : public pure_subclass {
+public:
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ /// Subclasses by looking up the super-class dynamically.
+ mlir_value_subclass(py::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction)
+ : mlir_value_subclass(
+ scope, valueClassName, isaFunction,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
+ }
+
+ /// Subclasses with a provided mlir.ir.Value super-class. This must
+ /// be used if the subclass is being defined in the same extension module
+ /// as the mlir.ir class (otherwise, it will trigger a recursive
+ /// initialization).
+ mlir_value_subclass(py::handle scope, const char *valueClassName,
+ IsAFunctionTy isaFunction, const py::object &superCls)
+ : pure_subclass(scope, valueClassName, superCls) {
+ // Casting constructor. Note that it hard, if not impossible, to properly
+ // call chain to parent `__init__` in pybind11 due to its special handling
+ // for init functions that don't have a fully constructed self-reference,
+ // which makes it impossible to forward it to `__init__` of a superclass.
+ // Instead, provide a custom `__new__` and call that of a superclass, which
+ // eventually calls `__init__` of the superclass. Since attribute subclasses
+ // have no additional members, we can just return the instance thus created
+ // without amending it.
+ std::string captureValueName(
+ valueClassName); // As string in case if valueClassName is not static.
+ py::cpp_function newCf(
+ [superCls, isaFunction, captureValueName](py::object cls,
+ py::object otherValue) {
+ MlirValue rawValue = py::cast<MlirValue>(otherValue);
+ if (!isaFunction(rawValue)) {
+ auto origRepr = py::repr(otherValue).cast<std::string>();
+ throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
+ captureValueName + " (from " +
+ origRepr + ")")
+ .str());
+ }
+ py::object self = superCls.attr("__new__")(cls, otherValue);
+ return self;
+ },
+ py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
+ thisClass.attr("__new__") = newCf;
+
+ // 'isinstance' method.
+ def_staticmethod(
+ "isinstance",
+ [isaFunction](MlirValue other) { return isaFunction(other); },
+ py::arg("other_value"));
+ }
+};
+
+} // namespace adaptors
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index f482385..ab9b9f2 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -39,8 +39,7 @@ void addTosaToLinalgPasses(
TosaToLinalgNamedOptions(),
// Note: Default to 'none' level unless otherwise specified.
std::optional<tosa::TosaValidationOptions> validationOptions =
- tosa::TosaValidationOptions{
- {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None});
+ tosa::TosaValidationOptions{false, false});
/// Populates TOSA to linalg pipelines
/// Currently, this includes only the "tosa-to-linalg-pipeline".
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 29001e2..db1b7e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1029,6 +1029,24 @@ foreach smallT = [
attr-dict $src `,` $scale `:` type($res)
}];
}
+
+
+ def ROCDL_CvtScaleF32SrPk8 # smallT.nameForOp # largeT.nameForOp # Op :
+ ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk8." # smallT.name # "." # largeT.name,
+ [Pure], 1>,
+ Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
+ let results = (outs smallT.type:$res);
+ let summary = "Scale and convert packed "
+ # largeT.name # " to packed " # smallT.name # " with stochastic rounding";
+ let description = [{
+ Convert 8 packed }] # largeT.name # [{ values to packed }]
+ # smallT.name # [{, multiplying by the exponent part of `scale`
+ before doing so and apply stochastic rounding. This op is for gfx1250+ arch.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `,` $seed `,` $scale `:` type($res)
+ }];
+ }
} // foreach largeT
} // foreach smallTOp
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 4f7a842..2dd6121 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr
def XeVM_BlockLoadOp
: XeVM_Op<"blockload">,
- Results<(
- outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
+ Results<(outs AnyTypeOf<
+ [XeVM_1DBlockElemType,
+ FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
let summary = "subgroup block load";
@@ -228,7 +229,9 @@ def XeVM_BlockLoadOp
def XeVM_BlockStoreOp
: XeVM_Op<"blockstore">,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
- FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
+ AnyTypeOf<[XeVM_1DBlockElemType,
+ FixedVectorOfRankAndType<[1],
+ [XeVM_1DBlockElemType]>]>:$val,
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
let summary = "subgroup block store";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
index 9ee5079..10491f6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
@@ -20,24 +20,67 @@
namespace mlir {
namespace tosa {
+struct TosaLevel {
+ int32_t MAX_RANK = 0;
+ int32_t MAX_KERNEL = 0;
+ int32_t MAX_STRIDE = 0;
+ int32_t MAX_SCALE = 0;
+ int32_t MAX_LOG2_SIZE = 0;
+ int32_t MAX_NESTING = 0;
+ int32_t MAX_TENSOR_LIST_SIZE = 0;
+
+ bool operator==(const TosaLevel &rhs) {
+ return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
+ MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
+ MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
+ MAX_NESTING == rhs.MAX_NESTING &&
+ MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
+ }
+};
+
+static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
+static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
+ 63, 256, 256};
+
+TargetEnvAttr lookupTargetEnv(Operation *op);
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
+
+/// Queries the target environment recursively from enclosing symbol table ops
+/// containing the given `op` or returns the default target environment as
+/// returned by getDefaultTargetEnv() if not provided.
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
+
/// This class represents the capability enabled in the target implementation
-/// such as profile, extension, and level.
+/// such as profile, extension, and level. It's a wrapper class around
+/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
- explicit TargetEnv(const SmallVectorImpl<Profile> &profiles,
- const SmallVectorImpl<Extension> &extensions) {
+ explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
+ const ArrayRef<Extension> &extensions)
+ : level(level) {
enabledProfiles.insert_range(profiles);
-
enabledExtensions.insert_range(extensions);
}
+ explicit TargetEnv(TargetEnvAttr targetAttr)
+ : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
+ targetAttr.getExtensions()) {}
+
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
- // TosaLevel getLevel() const;
+
+ TosaLevel getLevel() const {
+ if (level == Level::eightK)
+ return TOSA_LEVEL_EIGHTK;
+ else if (level == Level::none)
+ return TOSA_LEVEL_NONE;
+ else
+ llvm_unreachable("Unknown TOSA level");
+ };
// Returns true if the given profile is allowed.
bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; }
@@ -62,8 +105,9 @@ public:
}
private:
+ Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
- llvm::SmallSet<Extension, 8> enabledExtensions;
+ llvm::SmallSet<Extension, 13> enabledExtensions;
};
} // namespace tosa
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 80337fc..38cb293 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>;
def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>;
+def Tosa_ProfileAttr
+ : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
+ [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Profile, 2> getAllValues() {
+ return {Profile::pro_int, Profile::pro_fp};
+ }
+ }];
+}
+
+def Tosa_ProfileArrayAttr
+ : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+
def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>;
def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>;
@@ -264,17 +277,27 @@ def Tosa_ExtensionAttr
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
- ]>;
+ ]> {
+ let extraClassDeclaration = [{
+ static llvm::SmallVector<Extension, 11> getAllValues() {
+ return {
+ Extension::int16, Extension::int4, Extension::bf16,
+ Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
+ Extension::variable, Extension::controlflow, Extension::doubleround,
+ Extension::inexactround, Extension::dynamic
+ };
+ }
+ }];
+}
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
-def Tosa_ProfileAttr
- : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof",
- [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>;
+def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
+def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
-def Tosa_ProfileArrayAttr
- : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">;
+def Tosa_LevelAttr
+ : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
@@ -382,6 +405,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
}
//===----------------------------------------------------------------------===//
+// TOSA target environment.
+//===----------------------------------------------------------------------===//
+def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
+ let summary = "Target environment information.";
+ let parameters = ( ins
+ "Level": $level,
+ ArrayRefParameter<"Profile">: $profiles,
+ ArrayRefParameter<"Extension">: $extensions
+ );
+
+ let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
+ "`extensions` `=` `[` $extensions `]` `>`";
+}
+
+//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
index 7484473..f52b82a 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,7 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt)
-mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
-mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen)
add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 306e4b1..ba99d2f 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index b966828..6ae19d8 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass
}];
}
-def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
- [
- I32EnumAttrCase<"None", 0, "none">,
- I32EnumAttrCase<"EightK", 1, "8k">,
- ]>{
- let cppNamespace = "mlir::tosa";
-}
-
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
let summary = "Validates TOSA dialect";
let description = [{
@@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
}];
let options = [
- ListOption<"profile", "profile", "std::string",
- "Validate if operations match for the given profile set">,
- ListOption<"extension", "extension", "std::string",
- "Validate if operations match for the given extension set">,
Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool",
/*default=*/"false",
"Verify if the properties of certain operations align the spec requirement">,
@@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
/*default=*/"false",
"Disable checks for operations that are determined to be invalid due to their "
"operand/result datatypes not aligning with the 'Supported Data Types' "
- "sections of the specifciation">,
- Option<"level", "level", "mlir::tosa::TosaLevelEnum",
- /*default=*/"mlir::tosa::TosaLevelEnum::EightK",
- "Validate if operator parameters are within specfication for the given level",
- [{::llvm::cl::values(
- clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k",
- "Ranges are expected to be sufficient for applications with frame sizes up to 8K."),
- clEnumValN(mlir::tosa::TosaLevelEnum::None, "none",
- "Allows the full range of arguments specified by the operations according "
- "to the operation data types.")
- )}]>
+ "sections of the specifciation">
];
}
@@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle
}];
}
+def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
+ let summary = "Attach tosa.target_env information to the given module.";
+
+ let description = [{
+ This pass allows the user to specify a TOSA target environment consisting of
+ the following components: level, profiles and extensions.
+
+ The target environment is attached to the module as an attribute, allowing other
+ transformations to query the selected target and adapt their behaviour based on
+ this information.
+ }];
+
+ let dependentDialects = [
+ "func::FuncDialect",
+ "tosa::TosaDialect",
+ ];
+
+ let options = [
+ Option<"level", "level", "mlir::tosa::Level",
+ /*default=*/"mlir::tosa::Level::eightK",
+ "The TOSA level that operators should conform to. A TOSA level defines "
+ "operator argument ranges that an implementation shall support.",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::tosa::Level::eightK, "8k",
+ "Ranges are expected to be sufficient for applications with frame "
+ "sizes up to 8K."),
+ clEnumValN(mlir::tosa::Level::none, "none",
+ "Allows the full range of arguments specified by the operations according "
+ "to the operation data types.")
+ )}]>,
+ ListOption<"profiles", "profiles", "std::string",
+ "The TOSA profile(s) that operators should conform to. TOSA profiles "
+ "enable efficient implementation on different classes of device. Each "
+ "profile is an independent set of operations and data type combinations.">,
+ ListOption<"extensions", "extensions", "std::string",
+ "The TOSA extension(s) that operators should conform to. TOSA profile "
+ "extensions define optional operation and data type combinations.">
+ ];
+}
+
#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index c6a3ba9..e7602b4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() {
TosaToLinalgOptions tosaToLinalgOptions;
TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
TosaValidationOptions validationOptions;
- validationOptions.profile = {"none"};
- validationOptions.extension = {"none"};
validationOptions.strictOpSpecAlignment = false;
validationOptions.allowInvalidOpDatatypeCombinations = false;
- validationOptions.level = tosa::TosaLevelEnum::EightK;
tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
tosaToLinalgNamedOptions,
validationOptions);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
index 8295492..04e8836 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
- VectorType vTy;
+ Type srcOrDstTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
- vTy = op.getResult().getType();
+ srcOrDstTy = op.getResult().getType();
else
- vTy = op.getVal().getType();
+ srcOrDstTy = op.getVal().getType();
+ VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
+ // scalar case is always valid
+ if (!vTy)
+ return success();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
- llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
+ llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
+ "vector size must be 2, 4, 8 or 16 for 8-bit element type");
} else {
- llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
+ llvm::SmallSet<int, 3> validSizes{2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
- "vector size must be 1, 2, 4 or 8 for element type > 8 bits");
+ "vector size must be 2, 4 or 8 for element type > 8 bits");
}
}
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index c6a438d..a95906a 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTosaDialect
IR/TosaOps.cpp
IR/TosaCanonicalizations.cpp
+ IR/TargetEnv.cpp
Utils/ConversionUtils.cpp
Utils/QuantUtils.cpp
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
new file mode 100644
index 0000000..5aad671
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -0,0 +1,42 @@
+//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+
+namespace mlir {
+namespace tosa {
+
+TargetEnvAttr lookupTargetEnv(Operation *op) {
+ while (op) {
+ op = SymbolTable::getNearestSymbolTable(op);
+ if (!op)
+ break;
+
+ if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name))
+ return attr;
+
+ op = op->getParentOp();
+ }
+
+ return {};
+}
+
+TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
+ return TargetEnvAttr::get(context, Level::eightK,
+ {Profile::pro_int, Profile::pro_fp}, {});
+}
+
+TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
+ if (auto attr = lookupTargetEnv(op))
+ return attr;
+
+ return getDefaultTargetEnv(op->getContext());
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 803993b..41b338d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
+ TosaAttachTarget.cpp
TosaConvertIntegerTypeToSignless.cpp
TosaDecomposeTransposeConv.cpp
TosaDecomposeDepthwise.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
new file mode 100644
index 0000000..bcb880a
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp
@@ -0,0 +1,87 @@
+//===- TosaAttachTarget.cpp
+//------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Attach target information to a TOSA module.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSAATTACHTARGET
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+
+class TosaAttachTarget
+ : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> {
+ using Base::Base;
+
+public:
+ void runOnOperation() override {
+ llvm::SmallVector<Profile, 2> selectedProfiles;
+ if (!profiles.empty()) {
+ for (const std::string &prof : profiles) {
+ std::optional<Profile> profSymbol = symbolizeProfile(prof);
+ if (!profSymbol) {
+ llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allProfiles,
+ "profile", prof);
+ return signalPassFailure();
+ }
+ selectedProfiles.push_back(profSymbol.value());
+ }
+ }
+
+ llvm::SmallVector<Extension, 10> selectedExtensions;
+ if (!extensions.empty()) {
+ for (const std::string &ext : extensions) {
+ std::optional<Extension> extSymbol = symbolizeExtension(ext);
+ if (!extSymbol) {
+ llvm::SmallVector<Extension> allExtensions =
+ ExtensionAttr::getAllValues();
+ llvm::errs() << buildUnkownParameterErrorMessage(allExtensions,
+ "extension", ext);
+ return signalPassFailure();
+ }
+ selectedExtensions.push_back(extSymbol.value());
+ }
+ }
+
+ ModuleOp mod = getOperation();
+ MLIRContext *ctx = &getContext();
+ const auto targetEnvAttr =
+ TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
+ mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
+ }
+
+private:
+ template <typename T>
+ std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues,
+ std::string enumName,
+ std::string unknownArgument) {
+ std::string message;
+ llvm::raw_string_ostream os(message);
+ os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument
+ << "', supported " << enumName << "s are: ";
+ llvm::interleaveComma(enumValues, os);
+ os << "\n";
+ return message;
+ }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4fc7ce8..82f2f7e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
@@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op,
return success();
}
-struct TosaLevel {
- int32_t MAX_RANK = 0;
- int32_t MAX_KERNEL = 0;
- int32_t MAX_STRIDE = 0;
- int32_t MAX_SCALE = 0;
- int32_t MAX_LOG2_SIZE = 0;
- int32_t MAX_NESTING = 0;
- int32_t MAX_TENSOR_LIST_SIZE = 0;
-
- bool operator==(const TosaLevel &rhs) {
- return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
- MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
- MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
- MAX_NESTING == rhs.MAX_NESTING &&
- MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
- }
-};
-
-static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
-static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
- 63, 256, 256};
-
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
@@ -162,12 +139,9 @@ public:
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
- this->profile = options.profile;
- this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
- this->level = options.level;
}
void runOnOperation() final;
@@ -207,28 +181,28 @@ private:
LogicalResult levelCheckKernel(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_KERNEL)
+ if (v > targetEnv.getLevel().MAX_KERNEL)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckStride(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_STRIDE)
+ if (v > targetEnv.getLevel().MAX_STRIDE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckScale(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_SCALE)
+ if (v > targetEnv.getLevel().MAX_SCALE)
return op->emitOpError() << "failed level check: " << checkDesc;
return success();
}
LogicalResult levelCheckListSize(Operation *op, int32_t v,
const StringRef checkDesc) {
- if (v > tosaLevel.MAX_TENSOR_LIST_SIZE)
+ if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE)
return op->emitOpError()
<< "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc;
return success();
@@ -285,6 +259,7 @@ private:
template <typename T>
LogicalResult levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
+ const TosaLevel tosaLevel = targetEnv.getLevel();
for (auto v : op->getOperands()) {
if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK)))
return failure();
@@ -466,7 +441,7 @@ private:
int32_t maxNestedDepth = 0;
getMaxNestedDepth(op, maxNestedDepth);
- if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+ if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) {
op->emitOpError() << "failed level check: " << maxNestedDepth
<< " >= MAX_NESTING";
return failure();
@@ -523,43 +498,6 @@ private:
return success();
}
- // configure profile and level values from pass options profileName and
- // levelName
- void configLevelAndProfile() {
- tosaLevel = TOSA_LEVEL_NONE;
- if (level == TosaLevelEnum::EightK) {
- tosaLevel = TOSA_LEVEL_EIGHTK;
- }
-
- if (!profile.empty()) {
- for (std::string &prof : profile) {
- auto profSymbol = symbolizeProfile(prof);
- if (profSymbol) {
- targetEnv.addProfile(profSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA profile name passed in: " << prof
- << ", supported profiles are `pro_int` and `pro_fp`\n";
- return signalPassFailure();
- }
- }
- }
-
- if (!extension.empty()) {
- for (std::string &ext : extension) {
- auto extSymbol = symbolizeExtension(ext);
- if (extSymbol) {
- targetEnv.addExtension(extSymbol.value());
- } else {
- llvm::errs() << "unknown TOSA extension name passed in: " << ext
- << ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
- << "doubleround, inexactround and dynamic\n";
- return signalPassFailure();
- }
- }
- }
- }
-
LogicalResult CheckVariable(Operation *op);
LogicalResult CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type, const bool allowUnsigned = false);
@@ -567,7 +505,6 @@ private:
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
- TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
@@ -576,13 +513,13 @@ private:
template <>
LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto *op = tosaOp.getOperation();
- if (failed(
- levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK)))
+ if (failed(levelCheckRank(op, tosaOp.getInput(), "operand",
+ targetEnv.getLevel().MAX_RANK)))
return failure();
// rank(output) = rank(input) - 1
if (failed(levelCheckRank(op, tosaOp.getOutput(), "result",
- tosaLevel.MAX_RANK - 1)))
+ targetEnv.getLevel().MAX_RANK - 1)))
return failure();
return success();
@@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
// Only the condition input has rank limitation.
if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) {
auto *op = tosaOp.getOperation();
auto variableType = getVariableType(tosaOp);
if (failed(levelCheckRank(op, variableType, "variable type",
- tosaLevel.MAX_RANK)))
+ targetEnv.getLevel().MAX_RANK)))
return failure();
return success();
@@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
- const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
+ const int64_t max_size =
+ (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1;
if (size > max_size)
return op->emitOpError()
<< "failed level check: " << operandOrResult
@@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
- if (tosaLevel == TOSA_LEVEL_NONE) {
+ if (targetEnv.getLevel() == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
@@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
- configLevelAndProfile();
-
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
+ targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
+
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index cea5b25..9f5246d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -440,11 +440,11 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME smt)
declare_mlir_dialect_python_bindings(
- ADD_TO_PARENT MLIRPythonSources.Dialects
- ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
- TD_FILE dialects/SPIRVOps.td
- SOURCES dialects/spirv.py
- DIALECT_NAME spirv)
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/SPIRVOps.td
+ SOURCES dialects/spirv.py
+ DIALECT_NAME spirv)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -501,6 +501,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
MainModule.cpp
IRAffine.cpp
@@ -539,6 +540,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
MODULE_NAME _mlirRegisterEverything
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
RegisterEverything.cpp
PRIVATE_LINK_LIBS
@@ -549,10 +551,11 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
MLIRCAPIRegisterEverything
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MODULE_NAME _mlirDialectsLinalg
ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectLinalg.cpp
PRIVATE_LINK_LIBS
@@ -562,10 +565,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind
MLIRCAPILinalg
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
MODULE_NAME _mlirDialectsGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectGPU.cpp
PRIVATE_LINK_LIBS
@@ -575,10 +579,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind
MLIRCAPIGPU
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
MODULE_NAME _mlirDialectsLLVM
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectLLVM.cpp
PRIVATE_LINK_LIBS
@@ -588,10 +593,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind
MLIRCAPILLVM
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MODULE_NAME _mlirDialectsQuant
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectQuant.cpp
PRIVATE_LINK_LIBS
@@ -601,10 +607,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind
MLIRCAPIQuant
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
MODULE_NAME _mlirDialectsNVGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectNVGPU.cpp
PRIVATE_LINK_LIBS
@@ -614,10 +621,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind
MLIRCAPINVGPU
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
MODULE_NAME _mlirDialectsPDL
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectPDL.cpp
PRIVATE_LINK_LIBS
@@ -627,10 +635,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind
MLIRCAPIPDL
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MODULE_NAME _mlirDialectsSparseTensor
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectSparseTensor.cpp
PRIVATE_LINK_LIBS
@@ -640,10 +649,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind
MLIRCAPISparseTensor
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MODULE_NAME _mlirDialectsTransform
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectTransform.cpp
PRIVATE_LINK_LIBS
@@ -653,10 +663,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
MLIRCAPITransformDialect
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
MODULE_NAME _mlirDialectsIRDL
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectIRDL.cpp
PRIVATE_LINK_LIBS
@@ -670,6 +681,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
AsyncPasses.cpp
PRIVATE_LINK_LIBS
@@ -683,6 +695,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE)
MODULE_NAME _mlirExecutionEngine
ADD_TO_PARENT MLIRPythonSources.ExecutionEngine
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
ExecutionEngineModule.cpp
PRIVATE_LINK_LIBS
@@ -696,6 +709,7 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses
MODULE_NAME _mlirGPUPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
GPUPasses.cpp
PRIVATE_LINK_LIBS
@@ -708,6 +722,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
MODULE_NAME _mlirLinalgPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
LinalgPasses.cpp
PRIVATE_LINK_LIBS
@@ -716,10 +731,11 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
MLIRCAPILinalg
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
MODULE_NAME _mlirDialectsSMT
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectSMT.cpp
# Headers must be included explicitly so they are installed.
@@ -736,6 +752,7 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MODULE_NAME _mlirSparseTensorPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
SparseTensorPasses.cpp
PRIVATE_LINK_LIBS
@@ -748,6 +765,7 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
MODULE_NAME _mlirTransformInterpreter
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
TransformInterpreter.cpp
PRIVATE_LINK_LIBS
@@ -789,10 +807,23 @@ if(MLIR_INCLUDE_TESTS)
ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest
SOURCES "dialects/_python_test_ops_gen.py")
+ declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11
+ MODULE_NAME _mlirPythonTestPybind11
+ ADD_TO_PARENT MLIRPythonTestSources.Dialects
+ ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
+ PYTHON_BINDINGS_LIBRARY pybind11
+ SOURCES
+ PythonTestModulePybind11.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIPythonTestDialect
+ )
declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind
MODULE_NAME _mlirPythonTestNanobind
ADD_TO_PARENT MLIRPythonTestSources.Dialects
ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
+ PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
PythonTestModuleNanobind.cpp
PRIVATE_LINK_LIBS
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 56d3c0f..9380896 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -5,7 +5,12 @@
from ._python_test_ops_gen import *
-def register_python_test_dialect(registry):
- from .._mlir_libs import _mlirPythonTestNanobind
+def register_python_test_dialect(registry, use_nanobind):
+ if use_nanobind:
+ from .._mlir_libs import _mlirPythonTestNanobind
- _mlirPythonTestNanobind.register_dialect(registry)
+ _mlirPythonTestNanobind.register_dialect(registry)
+ else:
+ from .._mlir_libs import _mlirPythonTestPybind11
+
+ _mlirPythonTestPybind11.register_dialect(registry)
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 7ddc70a..11477d0 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -12,7 +12,7 @@ from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import (
register_type_caster,
register_value_caster,
- globals,
+ globals as _globals,
)
from ._mlir_libs import (
get_dialect_registry,
@@ -32,17 +32,17 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]:
max_depth: Maximum number of frames to include in the location.
If None, the default limit is used.
"""
- old_enabled = globals.loc_tracebacks_enabled()
- old_limit = globals.loc_tracebacks_frame_limit()
+ old_enabled = _globals.loc_tracebacks_enabled()
+ old_limit = _globals.loc_tracebacks_frame_limit()
try:
- globals.set_loc_tracebacks_frame_limit(max_depth)
+ _globals.set_loc_tracebacks_frame_limit(max_depth)
if not old_enabled:
- globals.set_loc_tracebacks_enabled(True)
+ _globals.set_loc_tracebacks_enabled(True)
yield
finally:
if not old_enabled:
- globals.set_loc_tracebacks_enabled(False)
- globals.set_loc_tracebacks_frame_limit(old_limit)
+ _globals.set_loc_tracebacks_enabled(False)
+ _globals.set_loc_tracebacks_frame_limit(old_limit)
# Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 5ff9500..abe0925 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,6 @@
+nanobind>=2.9, <3.0
numpy>=1.19.5, <=2.1.2
+pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index b7ca71a..aaf9f80 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1973,14 +1973,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
// -----
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
- // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
+ // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
llvm.return
}
// -----
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
- // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
+ // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 6134695..a88b59a 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
// -----
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: rocdl.cvt.scale.pk16
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index aaf8371..60b70b8 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 2f9421c..334f52a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment"
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e60f1c9b..2a3985c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1daabe9..e5c9402 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5bf2dbb8..8cc357e 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 225b962..09e96ec 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 58a73d6..7ff8065 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index a5784b3..48e79e4 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
new file mode 100644
index 0000000..d6c886c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+
+// -----
+
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-LABEL: test_simple
+func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
+ return %1 : tensor<1x1x1x1xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
index f05ae7f..8e0ad0a 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 88ec027..663159e 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s
// -----
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 00ee6b7..1c0c2eb 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>
llvm.return
}
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 5a9acc7..1194e32 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,4 +1,5 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s pybind11 | FileCheck %s
+# RUN: %PYTHON %s nanobind | FileCheck %s
import sys
import typing
from typing import Union, Optional
@@ -9,14 +10,26 @@ import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
-from mlir._mlir_libs._mlirPythonTestNanobind import (
- TestAttr,
- TestType,
- TestTensorValue,
- TestIntegerRankedTensorType,
-)
-
-test.register_python_test_dialect(get_dialect_registry())
+if sys.argv[1] == "pybind11":
+ from mlir._mlir_libs._mlirPythonTestPybind11 import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False)
+elif sys.argv[1] == "nanobind":
+ from mlir._mlir_libs._mlirPythonTestNanobind import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True)
+else:
+ raise ValueError("Expected pybind11 or nanobind as argument")
def run(f):
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index f51a7b4..9a813da 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
PythonTestCAPI.cpp
PythonTestDialect.cpp
+ PythonTestModulePybind11.cpp
PythonTestModuleNanobind.cpp
)
diff --git a/mlir/test/python/lib/PythonTestModulePybind11.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp
new file mode 100644
index 0000000..94a5f51
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp
@@ -0,0 +1,118 @@
+//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This is the pybind11 edition of the PythonTest dialect module.
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+using namespace mlir::python::adaptors;
+using namespace pybind11::literals;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+ return mlirTypeIsARankedTensor(t) &&
+ mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
+
+PYBIND11_MODULE(_mlirPythonTestPybind11, m) {
+ m.def(
+ "register_python_test_dialect",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ }
+ },
+ py::arg("context"), py::arg("load") = true);
+
+ m.def(
+ "register_dialect",
+ [](MlirDialectRegistry registry) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+ },
+ py::arg("registry"));
+
+ mlir_attribute_subclass(m, "TestAttr",
+ mlirAttributeIsAPythonTestTestAttribute,
+ mlirPythonTestTestAttributeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestAttributeGet(ctx));
+ },
+ py::arg("cls"), py::arg("context") = py::none());
+
+ mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+ mlirPythonTestTestTypeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestTypeGet(ctx));
+ },
+ py::arg("cls"), py::arg("context") = py::none());
+
+ auto typeCls =
+ mlir_type_subclass(m, "TestIntegerRankedTensorType",
+ mlirTypeIsARankedIntegerTensor,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"))
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, std::vector<int64_t> shape,
+ unsigned width, MlirContext ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return cls(mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+ encoding));
+ },
+ "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
+
+ assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
+ "TestIntegerRankedTensorType has no static_typeid");
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
+ "replace"_a = true)(
+ pybind11::cpp_function([typeCls](const py::object &mlirType) {
+ return typeCls.get_class()(mlirType);
+ }));
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ pybind11::cpp_function([valueCls](const py::object &valueObj) {
+ py::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ // This is hyper-specific in order to exercise/test registering a
+ // value caster from cpp (but only for a single test case; see
+ // testTensorValue python_test.py).
+ if (mlirShapedTypeHasStaticShape(t) &&
+ mlirShapedTypeGetDimSize(t, 0) == 1 &&
+ mlirShapedTypeGetDimSize(t, 1) == 2 &&
+ mlirShapedTypeGetDimSize(t, 2) == 3)
+ return valueCls.get_class()(valueObj);
+ return valueObj;
+ }));
+}
diff --git a/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in b/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in
index 0bb6a20..da4db39 100755
--- a/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in
+++ b/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in
@@ -26,7 +26,7 @@ export PYTHONPATH="$python_package_dir"
OUTPUT="$(
echo "### AUTOGENERATED from core_named_ops.py" && \
echo "### To regenerate, run: bin/update_core_linalg_named_ops.sh" && \
- "$python_exe" -m mlir.dialects.linalg.opdsl.dump_oplib.ops.core_named_ops \
+ "$python_exe" -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops \
)"
echo "$OUTPUT" > "$dest_file"
echo "Success."