aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Dialect/LLVMIR/invalid.mlir4
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir33
-rw-r--r--mlir/test/Dialect/Tosa/dynamic_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/error_if_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/invalid_extension.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/level_check.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_all_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-attach-target.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir2
-rw-r--r--mlir/test/Dialect/Tosa/tosa-validation-valid.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir33
-rw-r--r--mlir/test/python/dialects/python_test.py31
-rw-r--r--mlir/test/python/lib/CMakeLists.txt1
-rw-r--r--mlir/test/python/lib/PythonTestModulePybind11.cpp118
17 files changed, 233 insertions, 21 deletions
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index b7ca71a..aaf9f80 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1973,14 +1973,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
// -----
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
- // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
+ // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
llvm.return
}
// -----
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
- // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
+ // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
llvm.return
}
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 6134695..a88b59a 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>,
// -----
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: rocdl.cvt.scale.pk16
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
index aaf8371..60b70b8 100644
--- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir
+++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir
@@ -2,7 +2,7 @@
// Check operations when the dynamic extension is enabled.
//--------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations"
// -----
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 2f9421c..334f52a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment"
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e60f1c9b..2a3985c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 1daabe9..e5c9402 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -2,7 +2,7 @@
// Enable all supported profiles to focus the verification of expected extension requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 5bf2dbb8..8cc357e 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 225b962..09e96ec 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 58a73d6..7ff8065 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index a5784b3..48e79e4 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
new file mode 100644
index 0000000..d6c886c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
+// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
+
+// -----
+
+// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
+// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
+// CHECK-LABEL: test_simple
+func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
+ %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
+ return %1 : tensor<1x1x1x1xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
index f05ae7f..8e0ad0a 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
index 88ec027..663159e 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s
// -----
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 00ee6b7..1c0c2eb 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16>
llvm.return
}
+// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8
+// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
+llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>,
+ %v8xf16: vector<8xf16>,
+ %v8xbf16: vector<8xbf16>,
+ %seed: i32,
+ %scale: f32) {
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]])
+ %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]])
+ %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32
+
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32>
+ // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]])
+ %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32
+
+ llvm.return
+}
+
+
// CHECK-LABEL: @rocdl.cvt.scale.pk16
// CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]])
llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) {
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 5a9acc7..1194e32 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,4 +1,5 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s pybind11 | FileCheck %s
+# RUN: %PYTHON %s nanobind | FileCheck %s
import sys
import typing
from typing import Union, Optional
@@ -9,14 +10,26 @@ import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
-from mlir._mlir_libs._mlirPythonTestNanobind import (
- TestAttr,
- TestType,
- TestTensorValue,
- TestIntegerRankedTensorType,
-)
-
-test.register_python_test_dialect(get_dialect_registry())
+if sys.argv[1] == "pybind11":
+ from mlir._mlir_libs._mlirPythonTestPybind11 import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False)
+elif sys.argv[1] == "nanobind":
+ from mlir._mlir_libs._mlirPythonTestNanobind import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+ )
+
+ test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True)
+else:
+ raise ValueError("Expected pybind11 or nanobind as argument")
def run(f):
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index f51a7b4..9a813da 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
PythonTestCAPI.cpp
PythonTestDialect.cpp
+ PythonTestModulePybind11.cpp
PythonTestModuleNanobind.cpp
)
diff --git a/mlir/test/python/lib/PythonTestModulePybind11.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp
new file mode 100644
index 0000000..94a5f51
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp
@@ -0,0 +1,118 @@
+//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This is the pybind11 edition of the PythonTest dialect module.
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+using namespace mlir::python::adaptors;
+using namespace pybind11::literals;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+ return mlirTypeIsARankedTensor(t) &&
+ mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
+
+PYBIND11_MODULE(_mlirPythonTestPybind11, m) {
+ m.def(
+ "register_python_test_dialect",
+ [](MlirContext context, bool load) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+ if (load) {
+ mlirDialectHandleLoadDialect(pythonTestDialect, context);
+ }
+ },
+ py::arg("context"), py::arg("load") = true);
+
+ m.def(
+ "register_dialect",
+ [](MlirDialectRegistry registry) {
+ MlirDialectHandle pythonTestDialect =
+ mlirGetDialectHandle__python_test__();
+ mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+ },
+ py::arg("registry"));
+
+ mlir_attribute_subclass(m, "TestAttr",
+ mlirAttributeIsAPythonTestTestAttribute,
+ mlirPythonTestTestAttributeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestAttributeGet(ctx));
+ },
+ py::arg("cls"), py::arg("context") = py::none());
+
+ mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+ mlirPythonTestTestTypeGetTypeID)
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, MlirContext ctx) {
+ return cls(mlirPythonTestTestTypeGet(ctx));
+ },
+ py::arg("cls"), py::arg("context") = py::none());
+
+ auto typeCls =
+ mlir_type_subclass(m, "TestIntegerRankedTensorType",
+ mlirTypeIsARankedIntegerTensor,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"))
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, std::vector<int64_t> shape,
+ unsigned width, MlirContext ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return cls(mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+ encoding));
+ },
+ "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
+
+ assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
+ "TestIntegerRankedTensorType has no static_typeid");
+
+ MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
+ "replace"_a = true)(
+ pybind11::cpp_function([typeCls](const py::object &mlirType) {
+ return typeCls.get_class()(mlirType);
+ }));
+
+ auto valueCls = mlir_value_subclass(m, "TestTensorValue",
+ mlirTypeIsAPythonTestTestTensorValue)
+ .def("is_null", [](MlirValue &self) {
+ return mlirValueIsNull(self);
+ });
+
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
+ mlirRankedTensorTypeID)(
+ pybind11::cpp_function([valueCls](const py::object &valueObj) {
+ py::object capsule = mlirApiObjectToCapsule(valueObj);
+ MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
+ MlirType t = mlirValueGetType(v);
+ // This is hyper-specific in order to exercise/test registering a
+ // value caster from cpp (but only for a single test case; see
+ // testTensorValue python_test.py).
+ if (mlirShapedTypeHasStaticShape(t) &&
+ mlirShapedTypeGetDimSize(t, 0) == 1 &&
+ mlirShapedTypeGetDimSize(t, 1) == 2 &&
+ mlirShapedTypeGetDimSize(t, 2) == 3)
+ return valueCls.get_class()(valueObj);
+ return valueObj;
+ }));
+}