diff options
Diffstat (limited to 'mlir/test/python/dialects/python_test.py')
-rw-r--r-- | mlir/test/python/dialects/python_test.py | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 5a9acc7..1194e32 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s pybind11 | FileCheck %s +# RUN: %PYTHON %s nanobind | FileCheck %s import sys import typing from typing import Union, Optional @@ -9,14 +10,26 @@ import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor import mlir.dialects.arith as arith -from mlir._mlir_libs._mlirPythonTestNanobind import ( - TestAttr, - TestType, - TestTensorValue, - TestIntegerRankedTensorType, -) - -test.register_python_test_dialect(get_dialect_registry()) +if sys.argv[1] == "pybind11": + from mlir._mlir_libs._mlirPythonTestPybind11 import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False) +elif sys.argv[1] == "nanobind": + from mlir._mlir_libs._mlirPythonTestNanobind import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True) +else: + raise ValueError("Expected pybind11 or nanobind as argument") def run(f): |