aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/python/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python/dialects')
-rw-r--r--mlir/test/python/dialects/python_test.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 5a9acc7..1194e32 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,4 +1,5 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s pybind11 | FileCheck %s
+# RUN: %PYTHON %s nanobind | FileCheck %s
import sys
import typing
from typing import Union, Optional
@@ -9,14 +10,26 @@ import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
-from mlir._mlir_libs._mlirPythonTestNanobind import (
- TestAttr,
- TestType,
- TestTensorValue,
- TestIntegerRankedTensorType,
-)
-
-test.register_python_test_dialect(get_dialect_registry())
+if sys.argv[1] == "pybind11":
+ from mlir._mlir_libs._mlirPythonTestPybind11 import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False)
+elif sys.argv[1] == "nanobind":
+ from mlir._mlir_libs._mlirPythonTestNanobind import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True)
+else:
+ raise ValueError("Expected pybind11 or nanobind as argument")
def run(f):