diff options
Diffstat (limited to 'mlir')
47 files changed, 1461 insertions, 253 deletions
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index ea34f94..fa6aec8 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -123,12 +123,12 @@ function(mlir_generate_type_stubs) "IMPORT_PATHS;DEPENDS_TARGETS;OUTPUTS;DEPENDS_TARGET_SRC_DEPS" ${ARGN}) - # for people installing a distro (e.g., pip install) of nanobind + # for people doing find_package(nanobind) if(EXISTS ${nanobind_DIR}/../src/stubgen.py) set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py") elseif(EXISTS ${nanobind_DIR}/../stubgen.py) set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py") - # for people using nanobind git source tree (e.g., FetchContent_Declare and FetchContent_MakeAvailable) + # for people using FetchContent_Declare and FetchContent_MakeAvailable elseif(EXISTS ${nanobind_SOURCE_DIR}/src/stubgen.py) set(NB_STUBGEN "${nanobind_SOURCE_DIR}/src/stubgen.py") elseif(EXISTS ${nanobind_SOURCE_DIR}/stubgen.py) @@ -226,10 +226,11 @@ endfunction() # EMBED_CAPI_LINK_LIBS: Dependent CAPI libraries that this extension depends # on. These will be collected for all extensions and put into an # aggregate dylib that is linked against. +# PYTHON_BINDINGS_LIBRARY: Either pybind11 or nanobind. function(declare_mlir_python_extension name) cmake_parse_arguments(ARG "" - "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT" + "ROOT_DIR;MODULE_NAME;ADD_TO_PARENT;PYTHON_BINDINGS_LIBRARY" "SOURCES;PRIVATE_LINK_LIBS;EMBED_CAPI_LINK_LIBS" ${ARGN}) @@ -238,15 +239,20 @@ function(declare_mlir_python_extension name) endif() set(_install_destination "src/python/${name}") + if(NOT ARG_PYTHON_BINDINGS_LIBRARY) + set(ARG_PYTHON_BINDINGS_LIBRARY "pybind11") + endif() + add_library(${name} INTERFACE) set_target_properties(${name} PROPERTIES # Yes: Leading-lowercase property names are load bearing and the recommended # way to do this: https://gitlab.kitware.com/cmake/cmake/-/issues/19261 - EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS" + EXPORT_PROPERTIES "mlir_python_SOURCES_TYPE;mlir_python_EXTENSION_MODULE_NAME;mlir_python_EMBED_CAPI_LINK_LIBS;mlir_python_DEPENDS;mlir_python_BINDINGS_LIBRARY" mlir_python_SOURCES_TYPE extension mlir_python_EXTENSION_MODULE_NAME "${ARG_MODULE_NAME}" mlir_python_EMBED_CAPI_LINK_LIBS "${ARG_EMBED_CAPI_LINK_LIBS}" mlir_python_DEPENDS "" + mlir_python_BINDINGS_LIBRARY "${ARG_PYTHON_BINDINGS_LIBRARY}" ) # Set the interface source and link_libs properties of the target @@ -335,12 +341,14 @@ function(add_mlir_python_modules name) elseif(_source_type STREQUAL "extension") # Native CPP extension. get_target_property(_module_name ${sources_target} mlir_python_EXTENSION_MODULE_NAME) + get_target_property(_bindings_library ${sources_target} mlir_python_BINDINGS_LIBRARY) # Transform relative source to based on root dir. set(_extension_target "${modules_target}.extension.${_module_name}.dso") add_mlir_python_extension(${_extension_target} "${_module_name}" INSTALL_COMPONENT ${modules_target} INSTALL_DIR "${ARG_INSTALL_PREFIX}/_mlir_libs" OUTPUT_DIRECTORY "${ARG_ROOT_PREFIX}/_mlir_libs" + PYTHON_BINDINGS_LIBRARY ${_bindings_library} LINK_LIBS PRIVATE ${sources_target} ${ARG_COMMON_CAPI_LINK_LIBS} @@ -745,7 +753,7 @@ endfunction() function(add_mlir_python_extension libname extname) cmake_parse_arguments(ARG "" - "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY" + "INSTALL_COMPONENT;INSTALL_DIR;OUTPUT_DIRECTORY;PYTHON_BINDINGS_LIBRARY" "SOURCES;LINK_LIBS" ${ARGN}) if(ARG_UNPARSED_ARGUMENTS) @@ -753,7 +761,7 @@ function(add_mlir_python_extension libname extname) endif() # The extension itself must be compiled with RTTI and exceptions enabled. - # Also, some warning classes triggered by nanobind are disabled. + # Also, some warning classes triggered by pybind11 are disabled. set(eh_rtti_enable) if (MSVC) set(eh_rtti_enable /EHsc /GR) @@ -761,53 +769,62 @@ function(add_mlir_python_extension libname extname) set(eh_rtti_enable -frtti -fexceptions) endif () - nanobind_add_module(${libname} - NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN} - FREE_THREADED - ${ARG_SOURCES} - ) + # The actual extension library produces a shared-object or DLL and has + # sources that must be compiled in accordance with pybind11 needs (RTTI and + # exceptions). + if(NOT DEFINED ARG_PYTHON_BINDINGS_LIBRARY OR ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "pybind11") + pybind11_add_module(${libname} + ${ARG_SOURCES} + ) + elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind") + nanobind_add_module(${libname} + NB_DOMAIN ${MLIR_BINDINGS_PYTHON_NB_DOMAIN} + FREE_THREADED + ${ARG_SOURCES} + ) - if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES - AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)) - # Avoid some warnings from upstream nanobind. - # If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let - # the super project handle compile options as it wishes. - get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES) - target_compile_options(${NB_LIBRARY_TARGET_NAME} - PRIVATE - -Wall -Wextra -Wpedantic - -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - ${eh_rtti_enable}) - - target_compile_options(${libname} - PRIVATE - -Wall -Wextra -Wpedantic - -Wno-c++98-compat-extra-semi - -Wno-cast-qual - -Wno-covered-switch-default - -Wno-deprecated-literal-operator - -Wno-nested-anon-types - -Wno-unused-parameter - -Wno-zero-length-array - ${eh_rtti_enable}) - endif() + if (NOT MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES + AND (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)) + # Avoid some warnings from upstream nanobind. + # If a superproject set MLIR_DISABLE_CONFIGURE_PYTHON_DEV_PACKAGES, let + # the super project handle compile options as it wishes. + get_property(NB_LIBRARY_TARGET_NAME TARGET ${libname} PROPERTY LINK_LIBRARIES) + target_compile_options(${NB_LIBRARY_TARGET_NAME} + PRIVATE + -Wall -Wextra -Wpedantic + -Wno-c++98-compat-extra-semi + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-deprecated-literal-operator + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + ${eh_rtti_enable}) + + target_compile_options(${libname} + PRIVATE + -Wall -Wextra -Wpedantic + -Wno-c++98-compat-extra-semi + -Wno-cast-qual + -Wno-covered-switch-default + -Wno-deprecated-literal-operator + -Wno-nested-anon-types + -Wno-unused-parameter + -Wno-zero-length-array + ${eh_rtti_enable}) + endif() - if(APPLE) - # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind - # doesn't declare this API as undefined in its linker flags. So we need to declare it as such - # for downstream users that do not do something like `-undefined dynamic_lookup`. - # Same for the rest. - target_link_options(${libname} PUBLIC - "LINKER:-U,_PyClassMethod_New" - "LINKER:-U,_PyCode_Addr2Location" - "LINKER:-U,_PyFrame_GetLasti" - ) + if(APPLE) + # NanobindAdaptors.h uses PyClassMethod_New to build `pure_subclass`es but nanobind + # doesn't declare this API as undefined in its linker flags. So we need to declare it as such + # for downstream users that do not do something like `-undefined dynamic_lookup`. + # Same for the rest. + target_link_options(${libname} PUBLIC + "LINKER:-U,_PyClassMethod_New" + "LINKER:-U,_PyCode_Addr2Location" + "LINKER:-U,_PyFrame_GetLasti" + ) + endif() endif() target_compile_options(${libname} PRIVATE ${eh_rtti_enable}) @@ -845,11 +862,11 @@ function(add_mlir_python_extension libname extname) if(WIN32) # On Windows, pyconfig.h (and by extension python.h) hardcode the version of the # python library which will be used for linkage depending on the flavor of the build. - # nanobind has a workaround which depends on the definition of Py_DEBUG (if Py_DEBUG - # is not passed in as a compile definition, nanobind undefs _DEBUG when including + # pybind11 has a workaround which depends on the definition of Py_DEBUG (if Py_DEBUG + # is not passed in as a compile definition, pybind11 undefs _DEBUG when including # python.h, so that the release python library would be used). - # Since mlir uses nanobind, we can leverage their workaround by never directly - # pyconfig.h or python.h and instead relying on the nanobind headers to include the + # Since mlir uses pybind11, we can leverage their workaround by never directly + # pyconfig.h or python.h and instead relying on the pybind11 headers to include the # necessary python headers. This results in mlir always linking against the # release python library via the (undocumented) cmake property Python3_LIBRARY_RELEASE. target_link_libraries(${libname} PRIVATE ${Python3_LIBRARY_RELEASE}) diff --git a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake index edbad2e..d18f8c0 100644 --- a/mlir/cmake/modules/MLIRDetectPythonEnv.cmake +++ b/mlir/cmake/modules/MLIRDetectPythonEnv.cmake @@ -46,20 +46,81 @@ macro(mlir_configure_python_dev_packages) message(STATUS "Found python include dirs: ${Python3_INCLUDE_DIRS}") message(STATUS "Found python libraries: ${Python3_LIBRARIES}") message(STATUS "Found numpy v${Python3_NumPy_VERSION}: ${Python3_NumPy_INCLUDE_DIRS}") - message(STATUS "Python extension suffix for modules: '${Python3_SOABI}'") - if(nanobind_DIR) - message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)") - find_package(nanobind 2.9 CONFIG REQUIRED) - else() - include(FetchContent) - FetchContent_Declare( - nanobind - GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG v2.9.0 - GIT_SHALLOW TRUE - ) - FetchContent_MakeAvailable(nanobind) - endif() - message(STATUS "Found nanobind: ${NB_DIR}") + mlir_detect_pybind11_install() + find_package(pybind11 2.10 CONFIG REQUIRED) + message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}") + message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', " + "suffix = '${PYTHON_MODULE_SUFFIX}', " + "extension = '${PYTHON_MODULE_EXTENSION}") + + mlir_detect_nanobind_install() + find_package(nanobind 2.9 CONFIG REQUIRED) + message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}") + message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', " + "suffix = '${PYTHON_MODULE_SUFFIX}', " + "extension = '${PYTHON_MODULE_EXTENSION}") endif() endmacro() + +# Detects a pybind11 package installed in the current python environment +# and sets variables to allow it to be found. This allows pybind11 to be +# installed via pip, which typically yields a much more recent version than +# the OS install, which will be available otherwise. +function(mlir_detect_pybind11_install) + if(pybind11_DIR) + message(STATUS "Using explicit pybind11 cmake directory: ${pybind11_DIR} (-Dpybind11_DIR to change)") + else() + message(STATUS "Checking for pybind11 in python path...") + execute_process( + COMMAND "${Python3_EXECUTABLE}" + -c "import pybind11;print(pybind11.get_cmake_dir(), end='')" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE STATUS + OUTPUT_VARIABLE PACKAGE_DIR + ERROR_QUIET) + if(NOT STATUS EQUAL "0") + message(STATUS "not found (install via 'pip install pybind11' or set pybind11_DIR)") + return() + endif() + message(STATUS "found (${PACKAGE_DIR})") + set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE) + endif() +endfunction() + + +# Detects a nanobind package installed in the current python environment +# and sets variables to allow it to be found. This allows nanobind to be +# installed via pip, which typically yields a much more recent version than +# the OS install, which will be available otherwise. +function(mlir_detect_nanobind_install) + if(nanobind_DIR) + message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR} (-Dnanobind_DIR to change)") + else() + message(STATUS "Checking for nanobind in python path...") + execute_process( + COMMAND "${Python3_EXECUTABLE}" + -c "import nanobind;print(nanobind.cmake_dir(), end='')" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE STATUS + OUTPUT_VARIABLE PACKAGE_DIR + ERROR_QUIET) + if(NOT STATUS EQUAL "0") + message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)") + return() + endif() + message(STATUS "found (${PACKAGE_DIR})") + set(nanobind_DIR "${PACKAGE_DIR}" PARENT_SCOPE) + execute_process( + COMMAND "${Python3_EXECUTABLE}" + -c "import nanobind;print(nanobind.include_dir(), end='')" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE STATUS + OUTPUT_VARIABLE PACKAGE_DIR + ERROR_QUIET) + if(NOT STATUS EQUAL "0") + message(STATUS "not found (install via 'pip install nanobind' or set nanobind_DIR)") + return() + endif() + set(nanobind_INCLUDE_DIR "${PACKAGE_DIR}" PARENT_SCOPE) + endif() +endfunction() diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index 5d7e274..b892bbe 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -16,7 +16,7 @@ corresponding `linalg.generic` IR for the composition. ## Basic usage The tool is bundled with the MLIR Python bindings. To use from the CMake build -tree, MLIR must be built with Python bindings enabled +tree, MLIR must be build with Python bindings enabled (`-DMLIR_ENABLE_BINDINGS_PYTHON=ON`). Then add the `python` directory in the build tree to your `PYTHONPATH` environment variable (i.e. `export PYTHONPATH=$PWD/build/tools/mlir/python_packages/mlir_core`). Optionally, use an @@ -24,7 +24,7 @@ installed MLIR package, if available, to avoid building. ```shell # Dump the `core_named_ops.py` module as YAML. -python -m mlir.dialects.linalg.opdsl.dump_oplib.ops.core_named_ops +python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops ``` Alternatively, run the `$PWD/build/bin/update_core_linalg_named_ops.sh` script, diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md index 7164cb7..2133b81 100644 --- a/mlir/docs/Dialects/Transform.md +++ b/mlir/docs/Dialects/Transform.md @@ -415,10 +415,14 @@ ops rather than having the methods directly act on the payload IR. [include "Dialects/TransformOps.md"] -## Tuning Extension Operaiton +## Tune Extension Operations [include "Dialects/TuneExtensionOps.md"] +## SMT Extension Operations + +[include "Dialects/SMTExtensionOps.md"] + ## Affine Transform Operations [include "Dialects/AffineLoopTransformOps.md"] diff --git a/mlir/examples/standalone/pyproject.toml b/mlir/examples/standalone/pyproject.toml index 75e2153..5a1e6e8 100644 --- a/mlir/examples/standalone/pyproject.toml +++ b/mlir/examples/standalone/pyproject.toml @@ -23,7 +23,9 @@ Discussions = "https://discourse.llvm.org/" [build-system] requires = [ "scikit-build-core>=0.10.7", - "typing_extensions>=4.12.2" + "typing_extensions>=4.12.2", + "nanobind>=2.9, <3.0", + "pybind11>=2.10.0, <=2.13.6", ] build-backend = "scikit_build_core.build" diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt index 108c343..905c9449 100644 --- a/mlir/examples/standalone/python/CMakeLists.txt +++ b/mlir/examples/standalone/python/CMakeLists.txt @@ -16,10 +16,27 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir_standalone" TD_FILE dialects/StandaloneOps.td SOURCES + dialects/standalone_pybind11.py dialects/standalone_nanobind.py _mlir_libs/_standaloneDialectsNanobind/py.typed DIALECT_NAME standalone) + +declare_mlir_python_extension(StandalonePythonSources.Pybind11Extension + MODULE_NAME _standaloneDialectsPybind11 + ADD_TO_PARENT StandalonePythonSources + SOURCES + StandaloneExtensionPybind11.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIArith + MLIRCAPITransforms + StandaloneCAPI + PYTHON_BINDINGS_LIBRARY pybind11 +) + declare_mlir_python_extension(StandalonePythonSources.NanobindExtension MODULE_NAME _standaloneDialectsNanobind ADD_TO_PARENT StandalonePythonSources @@ -32,6 +49,7 @@ declare_mlir_python_extension(StandalonePythonSources.NanobindExtension MLIRCAPIArith MLIRCAPITransforms StandaloneCAPI + PYTHON_BINDINGS_LIBRARY nanobind ) diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp new file mode 100644 index 0000000..da8c216 --- /dev/null +++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp @@ -0,0 +1,38 @@ +//===- StandaloneExtensionPybind11.cpp - Extension module -----------------===// +// +// This is the pybind11 version of the example module. There is also a nanobind +// example in StandaloneExtensionNanobind.cpp. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Standalone-c/Dialects.h" +#include "mlir-c/Dialect/Arith.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +using namespace mlir::python::adaptors; + +PYBIND11_MODULE(_standaloneDialectsPybind11, m) { + //===--------------------------------------------------------------------===// + // standalone dialect + //===--------------------------------------------------------------------===// + auto standaloneM = m.def_submodule("standalone"); + + standaloneM.def( + "register_dialects", + [](MlirContext context, bool load) { + MlirDialectHandle arithHandle = mlirGetDialectHandle__arith__(); + MlirDialectHandle standaloneHandle = + mlirGetDialectHandle__standalone__(); + mlirDialectHandleRegisterDialect(arithHandle, context); + mlirDialectHandleRegisterDialect(standaloneHandle, context); + if (load) { + mlirDialectHandleLoadDialect(arithHandle, context); + mlirDialectHandleRegisterDialect(standaloneHandle, context); + } + }, + py::arg("context") = py::none(), py::arg("load") = true); +} diff --git a/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py new file mode 100644 index 0000000..bfb98e40 --- /dev/null +++ b/mlir/examples/standalone/python/mlir_standalone/dialects/standalone_pybind11.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._standalone_ops_gen import * +from .._mlir_libs._standaloneDialectsPybind11.standalone import * diff --git a/mlir/examples/standalone/test/python/smoketest.py b/mlir/examples/standalone/test/python/smoketest.py index f881984..26d84fd 100644 --- a/mlir/examples/standalone/test/python/smoketest.py +++ b/mlir/examples/standalone/test/python/smoketest.py @@ -1,7 +1,16 @@ +# RUN: %python %s pybind11 | FileCheck %s # RUN: %python %s nanobind | FileCheck %s +import sys from mlir_standalone.ir import * -from mlir_standalone.dialects import standalone_nanobind as standalone_d + +if sys.argv[1] == "pybind11": + from mlir_standalone.dialects import standalone_pybind11 as standalone_d +elif sys.argv[1] == "nanobind": + from mlir_standalone.dialects import standalone_nanobind as standalone_d +else: + raise ValueError("Expected either pybind11 or nanobind as arguments") + with Context(): standalone_d.register_dialects() diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h new file mode 100644 index 0000000..edc6977 --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -0,0 +1,616 @@ +//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types, using pybind11. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H + +#include <pybind11/functional.h> +#include <pybind11/pybind11.h> +#include <pybind11/pytypes.h> +#include <pybind11/stl.h> + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" + +#include "llvm/ADT/Twine.h" + +namespace py = pybind11; +using namespace py::literals; + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace pybind11 { +namespace detail { + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). +static py::object mlirApiObjectToCapsule(py::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return py::reinterpret_borrow<py::object>(apiObject); + if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + auto repr = py::repr(apiObject).cast<std::string>(); + throw py::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); + } + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + +// Note: Currently all of the following support cast from py::object to the +// Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster<MlirAffineMap> { + PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(value)) { + return false; + } + return !mlirAffineMapIsNull(value); + } + static handle cast(MlirAffineMap v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster<MlirAttribute> { + PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAttribute(capsule.ptr()); + return !mlirAttributeIsNull(value); + } + static handle cast(MlirAttribute v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts object -> MlirBlock. +template <> +struct type_caster<MlirBlock> { + PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToBlock(capsule.ptr()); + return !mlirBlockIsNull(value); + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster<MlirContext> { + PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); + bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Context") + .attr("current"); + } + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule.ptr()); + return !mlirContextIsNull(value); + } +}; + +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster<MlirDialectRegistry> { + PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle cast(MlirDialectRegistry v, return_value_policy, handle) { + py::object capsule = py::reinterpret_steal<py::object>( + mlirPythonDialectRegistryToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirLocation. +template <> +struct type_caster<MlirLocation> { + PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); + bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToLocation(capsule.ptr()); + return !mlirLocationIsNull(value); + } + static handle cast(MlirLocation v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster<MlirModule> { + PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToModule(capsule.ptr()); + return !mlirModuleIsNull(value); + } + static handle cast(MlirModule v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster<MlirFrozenRewritePatternSet> { + PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, + _("MlirFrozenRewritePatternSet")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, + handle) { + py::object capsule = py::reinterpret_steal<py::object>( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster<MlirOperation> { + PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToOperation(capsule.ptr()); + return !mlirOperationIsNull(value); + } + static handle cast(MlirOperation v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirValue. +template <> +struct type_caster<MlirValue> { + PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToValue(capsule.ptr()); + return !mlirValueIsNull(value); + } + static handle cast(MlirValue v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster<MlirPassManager> { + PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + return !mlirPassManagerIsNull(value); + } +}; + +/// Casts object <-> MlirTypeID. +template <> +struct type_caster<MlirTypeID> { + PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle cast(MlirTypeID v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster<MlirType> { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + return !mlirTypeIsNull(value); + } + static handle cast(MlirType t, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(t)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +} // namespace detail +} // namespace pybind11 + +namespace mlir { +namespace python { +namespace adaptors { + +/// Provides a facility like py::class_ for defining a new class in a scope, +/// but this allows extension of an arbitrary Python class, defining methods +/// on it is a similar way. Classes defined in this way are very similar to +/// if defined in Python in the usual way but use Pybind11 machinery to do +/// it. These are not "real" Pybind11 classes but pure Python classes with no +/// relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in pybind11. +class pure_subclass { +public: + pure_subclass(py::handle scope, const char *derivedClassName, + const py::object &superClass) { + py::object pyType = + py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); + py::object metaclass = pyType(superClass); + py::dict attributes; + + thisClass = + metaclass(derivedClassName, py::make_tuple(superClass), attributes); + scope.attr(derivedClassName) = thisClass; + } + + template <typename Func, typename... Extra> + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + py::cpp_function cf( + std::forward<Func>(f), py::name(name), py::is_method(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + thisClass.attr(cf.name()) = cf; + return *this; + } + + template <typename Func, typename... Extra> + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + py::cpp_function cf( + std::forward<Func>(f), py::name(name), py::is_method(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + auto builtinProperty = + py::reinterpret_borrow<py::object>((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template <typename Func, typename... Extra> + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer<Func>::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf(std::forward<Func>(f), py::name(name), + py::scope(thisClass), extra...); + thisClass.attr(cf.name()) = py::staticmethod(cf); + return *this; + } + + template <typename Func, typename... Extra> + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer<Func>::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf(std::forward<Func>(f), py::name(name), + py::scope(thisClass), extra...); + thisClass.attr(cf.name()) = + py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr())); + return *this; + } + + py::object get_class() const { return thisClass; } + +protected: + py::object superClass; + py::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(py::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherAttribute) { + MlirAttribute rawAttribute = py::cast<MlirAttribute>(otherAttribute); + if (!isaFunction(rawAttribute)) { + auto origRepr = py::repr(otherAttribute).cast<std::string>(); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + py::arg("other_attribute")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherType) { + MlirType rawType = py::cast<MlirType>(otherType); + if (!isaFunction(rawType)) { + auto origRepr = py::repr(otherType).cast<std::string>(); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherType); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + py::arg("other_type")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { + } + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureValueName](py::object cls, + py::object otherValue) { + MlirValue rawValue = py::cast<MlirValue>(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = py::repr(otherValue).cast<std::string>(); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + py::arg("other_value")); + } +}; + +} // namespace adaptors + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index f482385..ab9b9f2 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -39,8 +39,7 @@ void addTosaToLinalgPasses( TosaToLinalgNamedOptions(), // Note: Default to 'none' level unless otherwise specified. std::optional<tosa::TosaValidationOptions> validationOptions = - tosa::TosaValidationOptions{ - {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None}); + tosa::TosaValidationOptions{false, false}); /// Populates TOSA to linalg pipelines /// Currently, this includes only the "tosa-to-linalg-pipeline". diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 29001e2..db1b7e3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -1029,6 +1029,24 @@ foreach smallT = [ attr-dict $src `,` $scale `:` type($res) }]; } + + + def ROCDL_CvtScaleF32SrPk8 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk8." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name # " with stochastic rounding"; + let description = [{ + Convert 8 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, multiplying by the exponent part of `scale` + before doing so and apply stochastic rounding. This op is for gfx1250+ arch. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `:` type($res) + }]; + } } // foreach largeT } // foreach smallTOp diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td index 4f7a842..2dd6121 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td @@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr def XeVM_BlockLoadOp : XeVM_Op<"blockload">, - Results<( - outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>, + Results<(outs AnyTypeOf< + [XeVM_1DBlockElemType, + FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>, Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr, OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> { let summary = "subgroup block load"; @@ -228,7 +229,9 @@ def XeVM_BlockLoadOp def XeVM_BlockStoreOp : XeVM_Op<"blockstore">, Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr, - FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val, + AnyTypeOf<[XeVM_1DBlockElemType, + FixedVectorOfRankAndType<[1], + [XeVM_1DBlockElemType]>]>:$val, OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> { let summary = "subgroup block store"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h index 9ee5079..10491f6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h @@ -20,24 +20,67 @@ namespace mlir { namespace tosa { +struct TosaLevel { + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + int32_t MAX_LOG2_SIZE = 0; + int32_t MAX_NESTING = 0; + int32_t MAX_TENSOR_LIST_SIZE = 0; + + bool operator==(const TosaLevel &rhs) { + return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && + MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && + MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && + MAX_NESTING == rhs.MAX_NESTING && + MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; + } +}; + +static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; +static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, + 63, 256, 256}; + +TargetEnvAttr lookupTargetEnv(Operation *op); +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); + +/// Queries the target environment recursively from enclosing symbol table ops +/// containing the given `op` or returns the default target environment as +/// returned by getDefaultTargetEnv() if not provided. +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); + /// This class represents the capability enabled in the target implementation -/// such as profile, extension, and level. +/// such as profile, extension, and level. It's a wrapper class around +/// tosa::TargetEnvAttr. class TargetEnv { public: TargetEnv() {} - explicit TargetEnv(const SmallVectorImpl<Profile> &profiles, - const SmallVectorImpl<Extension> &extensions) { + explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles, + const ArrayRef<Extension> &extensions) + : level(level) { enabledProfiles.insert_range(profiles); - enabledExtensions.insert_range(extensions); } + explicit TargetEnv(TargetEnvAttr targetAttr) + : TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(), + targetAttr.getExtensions()) {} + void addProfile(Profile p) { enabledProfiles.insert(p); } void addExtension(Extension e) { enabledExtensions.insert(e); } // TODO implement the following utilities. // Version getSpecVersion() const; - // TosaLevel getLevel() const; + + TosaLevel getLevel() const { + if (level == Level::eightK) + return TOSA_LEVEL_EIGHTK; + else if (level == Level::none) + return TOSA_LEVEL_NONE; + else + llvm_unreachable("Unknown TOSA level"); + }; // Returns true if the given profile is allowed. bool allows(Profile prof) const { return enabledProfiles.count(prof) != 0; } @@ -62,8 +105,9 @@ public: } private: + Level level; llvm::SmallSet<Profile, 3> enabledProfiles; - llvm::SmallSet<Extension, 8> enabledExtensions; + llvm::SmallSet<Extension, 13> enabledExtensions; }; } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 80337fc..38cb293 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -245,6 +245,19 @@ def Tosa_NONE : I32EnumAttrCase<"none", 0>; def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; +def Tosa_ProfileAttr + : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", + [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]> { + let extraClassDeclaration = [{ + static llvm::SmallVector<Profile, 2> getAllValues() { + return {Profile::pro_int, Profile::pro_fp}; + } + }]; +} + +def Tosa_ProfileArrayAttr + : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">; + def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>; def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; @@ -264,17 +277,27 @@ def Tosa_ExtensionAttr Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_DYNAMIC - ]>; + ]> { + let extraClassDeclaration = [{ + static llvm::SmallVector<Extension, 11> getAllValues() { + return { + Extension::int16, Extension::int4, Extension::bf16, + Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, + Extension::variable, Extension::controlflow, Extension::doubleround, + Extension::inexactround, Extension::dynamic + }; + } + }]; +} def Tosa_ExtensionArrayAttr : TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">; -def Tosa_ProfileAttr - : Tosa_I32EnumAttr<"Profile", "supported TOSA profiles", "prof", - [Tosa_PRO_INT, Tosa_PRO_FP, Tosa_NONE]>; +def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">; -def Tosa_ProfileArrayAttr - : TypedArrayAttrBase<Tosa_ProfileAttr, "TOSA profile array attribute">; +def Tosa_LevelAttr + : Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>; // The base class for defining op availability dimensions. class Availability { @@ -382,6 +405,21 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability { } //===----------------------------------------------------------------------===// +// TOSA target environment. +//===----------------------------------------------------------------------===// +def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> { + let summary = "Target environment information."; + let parameters = ( ins + "Level": $level, + ArrayRefParameter<"Profile">: $profiles, + ArrayRefParameter<"Extension">: $extensions + ); + + let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` " + "`extensions` `=` `[` $extensions `]` `>`"; +} + +//===----------------------------------------------------------------------===// // Iterable attributes. //===----------------------------------------------------------------------===// // Defined in `section 3. Enumerations` of the TOSA specification. diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt index 7484473..f52b82a 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,7 +1,5 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) -mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) -mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_mlir_dialect_tablegen_target(MLIRTosaPassIncGen) add_mlir_doc(Passes TosaPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 306e4b1..ba99d2f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -15,7 +15,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index b966828..6ae19d8 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -65,14 +65,6 @@ def TosaOptionalDecompositionsPass }]; } -def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", - [ - I32EnumAttrCase<"None", 0, "none">, - I32EnumAttrCase<"EightK", 1, "8k">, - ]>{ - let cppNamespace = "mlir::tosa"; -} - def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let summary = "Validates TOSA dialect"; let description = [{ @@ -81,10 +73,6 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { }]; let options = [ - ListOption<"profile", "profile", "std::string", - "Validate if operations match for the given profile set">, - ListOption<"extension", "extension", "std::string", - "Validate if operations match for the given extension set">, Option<"strictOpSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, @@ -92,17 +80,7 @@ def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { /*default=*/"false", "Disable checks for operations that are determined to be invalid due to their " "operand/result datatypes not aligning with the 'Supported Data Types' " - "sections of the specifciation">, - Option<"level", "level", "mlir::tosa::TosaLevelEnum", - /*default=*/"mlir::tosa::TosaLevelEnum::EightK", - "Validate if operator parameters are within specfication for the given level", - [{::llvm::cl::values( - clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k", - "Ranges are expected to be sufficient for applications with frame sizes up to 8K."), - clEnumValN(mlir::tosa::TosaLevelEnum::None, "none", - "Allows the full range of arguments specified by the operations according " - "to the operation data types.") - )}]> + "sections of the specifciation"> ]; } @@ -141,4 +119,44 @@ def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signle }]; } +def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { + let summary = "Attach tosa.target_env information to the given module."; + + let description = [{ + This pass allows the user to specify a TOSA target environment consisting of + the following components: level, profiles and extensions. + + The target environment is attached to the module as an attribute, allowing other + transformations to query the selected target and adapt their behaviour based on + this information. + }]; + + let dependentDialects = [ + "func::FuncDialect", + "tosa::TosaDialect", + ]; + + let options = [ + Option<"level", "level", "mlir::tosa::Level", + /*default=*/"mlir::tosa::Level::eightK", + "The TOSA level that operators should conform to. A TOSA level defines " + "operator argument ranges that an implementation shall support.", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::Level::eightK, "8k", + "Ranges are expected to be sufficient for applications with frame " + "sizes up to 8K."), + clEnumValN(mlir::tosa::Level::none, "none", + "Allows the full range of arguments specified by the operations according " + "to the operation data types.") + )}]>, + ListOption<"profiles", "profiles", "std::string", + "The TOSA profile(s) that operators should conform to. TOSA profiles " + "enable efficient implementation on different classes of device. Each " + "profile is an independent set of operations and data type combinations.">, + ListOption<"extensions", "extensions", "std::string", + "The TOSA extension(s) that operators should conform to. TOSA profile " + "extensions define optional operation and data type combinations."> + ]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index c6a3ba9..e7602b4 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -115,11 +115,8 @@ void mlir::tosa::registerTosaToLinalgPipelines() { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; TosaValidationOptions validationOptions; - validationOptions.profile = {"none"}; - validationOptions.extension = {"none"}; validationOptions.strictOpSpecAlignment = false; validationOptions.allowInvalidOpDatatypeCombinations = false; - validationOptions.level = tosa::TosaLevelEnum::EightK; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, validationOptions); diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 8295492..04e8836 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() { template <typename OpType, typename = std::enable_if_t<llvm::is_one_of< OpType, BlockLoadOp, BlockStoreOp>::value>> LogicalResult verify1DBlockArg(OpType op) { - VectorType vTy; + Type srcOrDstTy; if constexpr (std::is_same_v<OpType, BlockLoadOp>) - vTy = op.getResult().getType(); + srcOrDstTy = op.getResult().getType(); else - vTy = op.getVal().getType(); + srcOrDstTy = op.getVal().getType(); + VectorType vTy = dyn_cast<VectorType>(srcOrDstTy); + // scalar case is always valid + if (!vTy) + return success(); int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8; if (elemTySize == 1) { - llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16}; + llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type"); + "vector size must be 2, 4, 8 or 16 for 8-bit element type"); } else { - llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8}; + llvm::SmallSet<int, 3> validSizes{2, 4, 8}; if (validSizes.contains(vTy.getNumElements())) return success(); else return op.emitOpError( - "vector size must be 1, 2, 4 or 8 for element type > 8 bits"); + "vector size must be 2, 4 or 8 for element type > 8 bits"); } } diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index c6a438d..a95906a 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTosaDialect IR/TosaOps.cpp IR/TosaCanonicalizations.cpp + IR/TargetEnv.cpp Utils/ConversionUtils.cpp Utils/QuantUtils.cpp diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp new file mode 100644 index 0000000..5aad671 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -0,0 +1,42 @@ +//===-------------- TosaTarget.cpp - TOSA Target utilities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" + +namespace mlir { +namespace tosa { + +TargetEnvAttr lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) + break; + + if (auto attr = op->getAttrOfType<TargetEnvAttr>(TargetEnvAttr::name)) + return attr; + + op = op->getParentOp(); + } + + return {}; +} + +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) { + return TargetEnvAttr::get(context, Level::eightK, + {Profile::pro_int, Profile::pro_fp}, {}); +} + +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = lookupTargetEnv(op)) + return attr; + + return getDefaultTargetEnv(op->getContext()); +} + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 803993b..41b338d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaAttachTarget.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp new file mode 100644 index 0000000..bcb880a --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -0,0 +1,87 @@ +//===- TosaAttachTarget.cpp +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Attach target information to a TOSA module. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TargetEnv.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace tosa { + +#define GEN_PASS_DEF_TOSAATTACHTARGET +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +namespace { + +class TosaAttachTarget + : public tosa::impl::TosaAttachTargetBase<TosaAttachTarget> { + using Base::Base; + +public: + void runOnOperation() override { + llvm::SmallVector<Profile, 2> selectedProfiles; + if (!profiles.empty()) { + for (const std::string &prof : profiles) { + std::optional<Profile> profSymbol = symbolizeProfile(prof); + if (!profSymbol) { + llvm::SmallVector<Profile> allProfiles = ProfileAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allProfiles, + "profile", prof); + return signalPassFailure(); + } + selectedProfiles.push_back(profSymbol.value()); + } + } + + llvm::SmallVector<Extension, 10> selectedExtensions; + if (!extensions.empty()) { + for (const std::string &ext : extensions) { + std::optional<Extension> extSymbol = symbolizeExtension(ext); + if (!extSymbol) { + llvm::SmallVector<Extension> allExtensions = + ExtensionAttr::getAllValues(); + llvm::errs() << buildUnkownParameterErrorMessage(allExtensions, + "extension", ext); + return signalPassFailure(); + } + selectedExtensions.push_back(extSymbol.value()); + } + } + + ModuleOp mod = getOperation(); + MLIRContext *ctx = &getContext(); + const auto targetEnvAttr = + TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); + } + +private: + template <typename T> + std::string buildUnkownParameterErrorMessage(llvm::SmallVector<T> &enumValues, + std::string enumName, + std::string unknownArgument) { + std::string message; + llvm::raw_string_ostream os(message); + os << "Unknown TOSA " << enumName << " name passed in '" << unknownArgument + << "', supported " << enumName << "s are: "; + llvm::interleaveComma(enumValues, os); + os << "\n"; + return message; + } +}; + +} // namespace + +} // namespace tosa +} // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4fc7ce8..82f2f7e 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Tosa/IR/TargetEnv.h" #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" #include <string> @@ -130,28 +129,6 @@ static LogicalResult checkConstantOperandNegate(Operation *op, return success(); } -struct TosaLevel { - int32_t MAX_RANK = 0; - int32_t MAX_KERNEL = 0; - int32_t MAX_STRIDE = 0; - int32_t MAX_SCALE = 0; - int32_t MAX_LOG2_SIZE = 0; - int32_t MAX_NESTING = 0; - int32_t MAX_TENSOR_LIST_SIZE = 0; - - bool operator==(const TosaLevel &rhs) { - return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && - MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE && - MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE && - MAX_NESTING == rhs.MAX_NESTING && - MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE; - } -}; - -static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64}; -static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048, - 63, 256, 256}; - //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -162,12 +139,9 @@ public: explicit TosaValidation(const TosaValidationOptions &options) : TosaValidation() { - this->profile = options.profile; - this->extension = options.extension; this->strictOpSpecAlignment = options.strictOpSpecAlignment; this->allowInvalidOpDatatypeCombinations = options.allowInvalidOpDatatypeCombinations; - this->level = options.level; } void runOnOperation() final; @@ -207,28 +181,28 @@ private: LogicalResult levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_KERNEL) + if (v > targetEnv.getLevel().MAX_KERNEL) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_STRIDE) + if (v > targetEnv.getLevel().MAX_STRIDE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_SCALE) + if (v > targetEnv.getLevel().MAX_SCALE) return op->emitOpError() << "failed level check: " << checkDesc; return success(); } LogicalResult levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) { - if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) + if (v > targetEnv.getLevel().MAX_TENSOR_LIST_SIZE) return op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: " << checkDesc; return success(); @@ -285,6 +259,7 @@ private: template <typename T> LogicalResult levelCheckRanks(T tosaOp) { auto op = tosaOp.getOperation(); + const TosaLevel tosaLevel = targetEnv.getLevel(); for (auto v : op->getOperands()) { if (failed(levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))) return failure(); @@ -466,7 +441,7 @@ private: int32_t maxNestedDepth = 0; getMaxNestedDepth(op, maxNestedDepth); - if (maxNestedDepth >= tosaLevel.MAX_NESTING) { + if (maxNestedDepth >= targetEnv.getLevel().MAX_NESTING) { op->emitOpError() << "failed level check: " << maxNestedDepth << " >= MAX_NESTING"; return failure(); @@ -523,43 +498,6 @@ private: return success(); } - // configure profile and level values from pass options profileName and - // levelName - void configLevelAndProfile() { - tosaLevel = TOSA_LEVEL_NONE; - if (level == TosaLevelEnum::EightK) { - tosaLevel = TOSA_LEVEL_EIGHTK; - } - - if (!profile.empty()) { - for (std::string &prof : profile) { - auto profSymbol = symbolizeProfile(prof); - if (profSymbol) { - targetEnv.addProfile(profSymbol.value()); - } else { - llvm::errs() << "unknown TOSA profile name passed in: " << prof - << ", supported profiles are `pro_int` and `pro_fp`\n"; - return signalPassFailure(); - } - } - } - - if (!extension.empty()) { - for (std::string &ext : extension) { - auto extSymbol = symbolizeExtension(ext); - if (extSymbol) { - targetEnv.addExtension(extSymbol.value()); - } else { - llvm::errs() << "unknown TOSA extension name passed in: " << ext - << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, variable, controlflow, " - << "doubleround, inexactround and dynamic\n"; - return signalPassFailure(); - } - } - } - } - LogicalResult CheckVariable(Operation *op); LogicalResult CheckVariableReadOrWrite(Operation *op); bool isValidElementType(Type type, const bool allowUnsigned = false); @@ -567,7 +505,6 @@ private: SmallVector< std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>> constCheckers; - TosaLevel tosaLevel; DenseMap<StringAttr, mlir::Type> variablesMap; TosaProfileCompliance profileComp; tosa::TargetEnv targetEnv; @@ -576,13 +513,13 @@ private: template <> LogicalResult TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) { auto *op = tosaOp.getOperation(); - if (failed( - levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))) + if (failed(levelCheckRank(op, tosaOp.getInput(), "operand", + targetEnv.getLevel().MAX_RANK))) return failure(); // rank(output) = rank(input) - 1 if (failed(levelCheckRank(op, tosaOp.getOutput(), "result", - tosaLevel.MAX_RANK - 1))) + targetEnv.getLevel().MAX_RANK - 1))) return failure(); return success(); @@ -594,7 +531,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { // Only the condition input has rank limitation. if (failed(levelCheckRank(op, tosaOp.getCondition(), "operand", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -605,7 +542,7 @@ LogicalResult TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { auto *op = tosaOp.getOperation(); auto variableType = getVariableType(tosaOp); if (failed(levelCheckRank(op, variableType, "variable type", - tosaLevel.MAX_RANK))) + targetEnv.getLevel().MAX_RANK))) return failure(); return success(); @@ -762,7 +699,8 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, // defined in 1.7. Levels. // For each tensor, the number of tensor elements multiplied by the // element size in bytes must be representable as a tensor_size_t. - const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1; + const int64_t max_size = + (INT64_C(1) << targetEnv.getLevel().MAX_LOG2_SIZE) - 1; if (size > max_size) return op->emitOpError() << "failed level check: " << operandOrResult @@ -772,7 +710,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, } LogicalResult TosaValidation::applyLevelCheck(Operation *op) { - if (tosaLevel == TOSA_LEVEL_NONE) { + if (targetEnv.getLevel() == TOSA_LEVEL_NONE) { // no need to do level checks return success(); } @@ -1282,12 +1220,12 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { - configLevelAndProfile(); - TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; + targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); + getOperation().walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index cea5b25..9f5246d 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -440,11 +440,11 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME smt) declare_mlir_dialect_python_bindings( - ADD_TO_PARENT MLIRPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - TD_FILE dialects/SPIRVOps.td - SOURCES dialects/spirv.py - DIALECT_NAME spirv) + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SPIRVOps.td + SOURCES dialects/spirv.py + DIALECT_NAME spirv) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -501,6 +501,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES MainModule.cpp IRAffine.cpp @@ -539,6 +540,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything MODULE_NAME _mlirRegisterEverything ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES RegisterEverything.cpp PRIVATE_LINK_LIBS @@ -549,10 +551,11 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything MLIRCAPIRegisterEverything ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectLinalg.cpp PRIVATE_LINK_LIBS @@ -562,10 +565,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind MLIRCAPILinalg ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind MODULE_NAME _mlirDialectsGPU ADD_TO_PARENT MLIRPythonSources.Dialects.gpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectGPU.cpp PRIVATE_LINK_LIBS @@ -575,10 +579,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind MLIRCAPIGPU ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind MODULE_NAME _mlirDialectsLLVM ADD_TO_PARENT MLIRPythonSources.Dialects.llvm ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectLLVM.cpp PRIVATE_LINK_LIBS @@ -588,10 +593,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind MLIRCAPILLVM ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind MODULE_NAME _mlirDialectsQuant ADD_TO_PARENT MLIRPythonSources.Dialects.quant ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectQuant.cpp PRIVATE_LINK_LIBS @@ -601,10 +607,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind MLIRCAPIQuant ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind MODULE_NAME _mlirDialectsNVGPU ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectNVGPU.cpp PRIVATE_LINK_LIBS @@ -614,10 +621,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind MLIRCAPINVGPU ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind MODULE_NAME _mlirDialectsPDL ADD_TO_PARENT MLIRPythonSources.Dialects.pdl ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectPDL.cpp PRIVATE_LINK_LIBS @@ -627,10 +635,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind MLIRCAPIPDL ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MODULE_NAME _mlirDialectsSparseTensor ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectSparseTensor.cpp PRIVATE_LINK_LIBS @@ -640,10 +649,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind MODULE_NAME _mlirDialectsTransform ADD_TO_PARENT MLIRPythonSources.Dialects.transform ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectTransform.cpp PRIVATE_LINK_LIBS @@ -653,10 +663,11 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind MLIRCAPITransformDialect ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind MODULE_NAME _mlirDialectsIRDL ADD_TO_PARENT MLIRPythonSources.Dialects.irdl ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectIRDL.cpp PRIVATE_LINK_LIBS @@ -670,6 +681,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES AsyncPasses.cpp PRIVATE_LINK_LIBS @@ -683,6 +695,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE) MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES ExecutionEngineModule.cpp PRIVATE_LINK_LIBS @@ -696,6 +709,7 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES GPUPasses.cpp PRIVATE_LINK_LIBS @@ -708,6 +722,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MODULE_NAME _mlirLinalgPasses ADD_TO_PARENT MLIRPythonSources.Dialects.linalg ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES LinalgPasses.cpp PRIVATE_LINK_LIBS @@ -716,10 +731,11 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MLIRCAPILinalg ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind MODULE_NAME _mlirDialectsSMT ADD_TO_PARENT MLIRPythonSources.Dialects.smt ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectSMT.cpp # Headers must be included explicitly so they are installed. @@ -736,6 +752,7 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES SparseTensorPasses.cpp PRIVATE_LINK_LIBS @@ -748,6 +765,7 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter MODULE_NAME _mlirTransformInterpreter ADD_TO_PARENT MLIRPythonSources.Dialects.transform ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES TransformInterpreter.cpp PRIVATE_LINK_LIBS @@ -789,10 +807,23 @@ if(MLIR_INCLUDE_TESTS) ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest SOURCES "dialects/_python_test_ops_gen.py") + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11 + MODULE_NAME _mlirPythonTestPybind11 + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY pybind11 + SOURCES + PythonTestModulePybind11.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind MODULE_NAME _mlirPythonTestNanobind ADD_TO_PARENT MLIRPythonTestSources.Dialects ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES PythonTestModuleNanobind.cpp PRIVATE_LINK_LIBS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 56d3c0f..9380896 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -5,7 +5,12 @@ from ._python_test_ops_gen import * -def register_python_test_dialect(registry): - from .._mlir_libs import _mlirPythonTestNanobind +def register_python_test_dialect(registry, use_nanobind): + if use_nanobind: + from .._mlir_libs import _mlirPythonTestNanobind - _mlirPythonTestNanobind.register_dialect(registry) + _mlirPythonTestNanobind.register_dialect(registry) + else: + from .._mlir_libs import _mlirPythonTestPybind11 + + _mlirPythonTestPybind11.register_dialect(registry) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 7ddc70a..11477d0 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -12,7 +12,7 @@ from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import ( register_type_caster, register_value_caster, - globals, + globals as _globals, ) from ._mlir_libs import ( get_dialect_registry, @@ -32,17 +32,17 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]: max_depth: Maximum number of frames to include in the location. If None, the default limit is used. """ - old_enabled = globals.loc_tracebacks_enabled() - old_limit = globals.loc_tracebacks_frame_limit() + old_enabled = _globals.loc_tracebacks_enabled() + old_limit = _globals.loc_tracebacks_frame_limit() try: - globals.set_loc_tracebacks_frame_limit(max_depth) + _globals.set_loc_tracebacks_frame_limit(max_depth) if not old_enabled: - globals.set_loc_tracebacks_enabled(True) + _globals.set_loc_tracebacks_enabled(True) yield finally: if not old_enabled: - globals.set_loc_tracebacks_enabled(False) - globals.set_loc_tracebacks_frame_limit(old_limit) + _globals.set_loc_tracebacks_enabled(False) + _globals.set_loc_tracebacks_frame_limit(old_limit) # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 5ff9500..abe0925 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,6 @@ +nanobind>=2.9, <3.0 numpy>=1.19.5, <=2.1.2 +pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index b7ca71a..aaf9f80 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1973,14 +1973,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) { // ----- llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) { - // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}} + // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}} %0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16> llvm.return } // ----- llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) { - // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}} + // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}} xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>) llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 6134695..a88b59a 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, // ----- +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32 + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32 + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32 + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16 + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16 + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16 + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16 + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16 + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16 + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + +// ----- + // CHECK-LABEL: rocdl.cvt.scale.pk16 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index aaf8371..60b70b8 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 2f9421c..334f52a 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e60f1c9b..2a3985c 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> { diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 1daabe9..e5c9402 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 5bf2dbb8..8cc357e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 225b962..09e96ec 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 58a73d6..7ff8065 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index a5784b3..48e79e4 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir new file mode 100644 index 0000000..d6c886c --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K +// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT + +// ----- + +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-LABEL: test_simple +func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir index f05ae7f..8e0ad0a 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index 88ec027..663159e 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s // ----- diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 00ee6b7..1c0c2eb 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16> llvm.return } +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + + // CHECK-LABEL: @rocdl.cvt.scale.pk16 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 5a9acc7..1194e32 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s pybind11 | FileCheck %s +# RUN: %PYTHON %s nanobind | FileCheck %s import sys import typing from typing import Union, Optional @@ -9,14 +10,26 @@ import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor import mlir.dialects.arith as arith -from mlir._mlir_libs._mlirPythonTestNanobind import ( - TestAttr, - TestType, - TestTensorValue, - TestIntegerRankedTensorType, -) - -test.register_python_test_dialect(get_dialect_registry()) +if sys.argv[1] == "pybind11": + from mlir._mlir_libs._mlirPythonTestPybind11 import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False) +elif sys.argv[1] == "nanobind": + from mlir._mlir_libs._mlirPythonTestNanobind import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True) +else: + raise ValueError("Expected pybind11 or nanobind as argument") def run(f): diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt index f51a7b4..9a813da 100644 --- a/mlir/test/python/lib/CMakeLists.txt +++ b/mlir/test/python/lib/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES PythonTestCAPI.cpp PythonTestDialect.cpp + PythonTestModulePybind11.cpp PythonTestModuleNanobind.cpp ) diff --git a/mlir/test/python/lib/PythonTestModulePybind11.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp new file mode 100644 index 0000000..94a5f51 --- /dev/null +++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp @@ -0,0 +1,118 @@ +//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This is the pybind11 edition of the PythonTest dialect module. +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace mlir::python::adaptors; +using namespace pybind11::literals; + +static bool mlirTypeIsARankedIntegerTensor(MlirType t) { + return mlirTypeIsARankedTensor(t) && + mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); +} + +PYBIND11_MODULE(_mlirPythonTestPybind11, m) { + m.def( + "register_python_test_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleRegisterDialect(pythonTestDialect, context); + if (load) { + mlirDialectHandleLoadDialect(pythonTestDialect, context); + } + }, + py::arg("context"), py::arg("load") = true); + + m.def( + "register_dialect", + [](MlirDialectRegistry registry) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleInsertDialect(pythonTestDialect, registry); + }, + py::arg("registry")); + + mlir_attribute_subclass(m, "TestAttr", + mlirAttributeIsAPythonTestTestAttribute, + mlirPythonTestTestAttributeGetTypeID) + .def_classmethod( + "get", + [](const py::object &cls, MlirContext ctx) { + return cls(mlirPythonTestTestAttributeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); + + mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, + mlirPythonTestTestTypeGetTypeID) + .def_classmethod( + "get", + [](const py::object &cls, MlirContext ctx) { + return cls(mlirPythonTestTestTypeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); + + auto typeCls = + mlir_type_subclass(m, "TestIntegerRankedTensorType", + mlirTypeIsARankedIntegerTensor, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("RankedTensorType")) + .def_classmethod( + "get", + [](const py::object &cls, std::vector<int64_t> shape, + unsigned width, MlirContext ctx) { + MlirAttribute encoding = mlirAttributeGetNull(); + return cls(mlirRankedTensorTypeGet( + shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), + encoding)); + }, + "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); + + assert(py::hasattr(typeCls.get_class(), "static_typeid") && + "TestIntegerRankedTensorType has no static_typeid"); + + MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, + "replace"_a = true)( + pybind11::cpp_function([typeCls](const py::object &mlirType) { + return typeCls.get_class()(mlirType); + })); + + auto valueCls = mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("is_null", [](MlirValue &self) { + return mlirValueIsNull(self); + }); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( + mlirRankedTensorTypeID)( + pybind11::cpp_function([valueCls](const py::object &valueObj) { + py::object capsule = mlirApiObjectToCapsule(valueObj); + MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); + MlirType t = mlirValueGetType(v); + // This is hyper-specific in order to exercise/test registering a + // value caster from cpp (but only for a single test case; see + // testTensorValue python_test.py). + if (mlirShapedTypeHasStaticShape(t) && + mlirShapedTypeGetDimSize(t, 0) == 1 && + mlirShapedTypeGetDimSize(t, 1) == 2 && + mlirShapedTypeGetDimSize(t, 2) == 3) + return valueCls.get_class()(valueObj); + return valueObj; + })); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in b/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in index 0bb6a20..da4db39 100755 --- a/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in +++ b/mlir/tools/mlir-linalg-ods-gen/update_core_linalg_named_ops.sh.in @@ -26,7 +26,7 @@ export PYTHONPATH="$python_package_dir" OUTPUT="$( echo "### AUTOGENERATED from core_named_ops.py" && \ echo "### To regenerate, run: bin/update_core_linalg_named_ops.sh" && \ - "$python_exe" -m mlir.dialects.linalg.opdsl.dump_oplib.ops.core_named_ops \ + "$python_exe" -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops \ )" echo "$OUTPUT" > "$dest_file" echo "Success." |