diff options
Diffstat (limited to 'mlir/test')
-rw-r--r-- | mlir/test/Dialect/LLVMIR/invalid.mlir | 4 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/rocdl.mlir | 33 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/dynamic_extension.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/error_if_check.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/invalid.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/invalid_extension.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/level_check.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/profile_all_unsupported.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/tosa-attach-target.mlir | 14 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Tosa/tosa-validation-valid.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/rocdl.mlir | 33 | ||||
-rw-r--r-- | mlir/test/python/dialects/python_test.py | 31 | ||||
-rw-r--r-- | mlir/test/python/lib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/test/python/lib/PythonTestModulePybind11.cpp | 118 |
17 files changed, 233 insertions, 21 deletions
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index b7ca71a..aaf9f80 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1973,14 +1973,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) { // ----- llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) { - // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}} + // expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}} %0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16> llvm.return } // ----- llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) { - // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}} + // expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}} xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>) llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 6134695..a88b59a 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1100,6 +1100,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, // ----- +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f32 + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f32 + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f32 + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.f16 + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.f16 + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.f16 + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp8.bf16 + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.bf8.bf16 + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: rocdl.cvt.scalef32.sr.pk8.fp4.bf16 + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + +// ----- + // CHECK-LABEL: rocdl.cvt.scale.pk16 llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/Dialect/Tosa/dynamic_extension.mlir b/mlir/test/Dialect/Tosa/dynamic_extension.mlir index aaf8371..60b70b8 100644 --- a/mlir/test/Dialect/Tosa/dynamic_extension.mlir +++ b/mlir/test/Dialect/Tosa/dynamic_extension.mlir @@ -2,7 +2,7 @@ // Check operations when the dynamic extension is enabled. //-------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp extension=dynamic allow-invalid-op-datatype-combinations" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate="strict-op-spec-alignment allow-invalid-op-datatype-combinations" // ----- diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index 2f9421c..334f52a 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="level=none profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="level=none profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic" -tosa-validate="strict-op-spec-alignment" // ----- diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e60f1c9b..2a3985c 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> { diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 1daabe9..e5c9402 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_argmax(%arg0: tensor<14x19xbf16>) -> tensor<14xi32> { diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 5bf2dbb8..8cc357e 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="extension=dynamic" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 225b962..09e96ec 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 58a73d6..7ff8065 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index a5784b3..48e79e4 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir new file mode 100644 index 0000000..d6c886c --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL +// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K +// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT + +// ----- + +// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>} +// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>} +// CHECK-LABEL: test_simple +func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir index f05ae7f..8e0ad0a 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid-strict.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" --tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir index 88ec027..663159e 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate | FileCheck %s // ----- diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 00ee6b7..1c0c2eb 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1368,6 +1368,39 @@ llvm.func @rocdl.cvt.scalef32.pk8(%v8xf32: vector<8xf32>, %v8xf16: vector<8xf16> llvm.return } +// CHECK-LABEL: rocdl.cvt.scalef32.sr.pk8 +// CHECK-SAME:(<8 x float> %[[V8F32:.+]], <8 x half> %[[V8F16:.+]], <8 x bfloat> %[[V8BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl.cvt.scalef32.sr.pk8(%v8xf32: vector<8xf32>, + %v8xf16: vector<8xf16>, + %v8xbf16: vector<8xbf16>, + %seed: i32, + %scale: f32) { + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %0 = rocdl.cvt.scalef32.sr.pk8.fp8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %1 = rocdl.cvt.scalef32.sr.pk8.bf8.f32 %v8xf32, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f32(<8 x float> %[[V8F32]], i32 %[[SEED]], float %[[SCALE]]) + %2 = rocdl.cvt.scalef32.sr.pk8.fp4.f32 %v8xf32, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %3 = rocdl.cvt.scalef32.sr.pk8.fp8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %4 = rocdl.cvt.scalef32.sr.pk8.bf8.f16 %v8xf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.f16(<8 x half> %[[V8F16]], i32 %[[SEED]], float %[[SCALE]]) + %5 = rocdl.cvt.scalef32.sr.pk8.fp4.f16 %v8xf16, %seed, %scale : i32 + + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.fp8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %6 = rocdl.cvt.scalef32.sr.pk8.fp8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call <2 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk8.bf8.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %7 = rocdl.cvt.scalef32.sr.pk8.bf8.bf16 %v8xbf16, %seed, %scale : vector<2xi32> + // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk8.fp4.bf16(<8 x bfloat> %[[V8BF16]], i32 %[[SEED]], float %[[SCALE]]) + %8 = rocdl.cvt.scalef32.sr.pk8.fp4.bf16 %v8xbf16, %seed, %scale : i32 + + llvm.return +} + + // CHECK-LABEL: @rocdl.cvt.scale.pk16 // CHECK-SAME:(<3 x i32> %[[SRC0:.+]], i32 %[[SCALE:.+]]) llvm.func @rocdl.cvt.scale.pk16(%v3xi32: vector<3xi32>, %scale:i32) { diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 5a9acc7..1194e32 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,4 +1,5 @@ -# RUN: %PYTHON %s | FileCheck %s +# RUN: %PYTHON %s pybind11 | FileCheck %s +# RUN: %PYTHON %s nanobind | FileCheck %s import sys import typing from typing import Union, Optional @@ -9,14 +10,26 @@ import mlir.dialects.python_test as test import mlir.dialects.tensor as tensor import mlir.dialects.arith as arith -from mlir._mlir_libs._mlirPythonTestNanobind import ( - TestAttr, - TestType, - TestTensorValue, - TestIntegerRankedTensorType, -) - -test.register_python_test_dialect(get_dialect_registry()) +if sys.argv[1] == "pybind11": + from mlir._mlir_libs._mlirPythonTestPybind11 import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=False) +elif sys.argv[1] == "nanobind": + from mlir._mlir_libs._mlirPythonTestNanobind import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, + ) + + test.register_python_test_dialect(get_dialect_registry(), use_nanobind=True) +else: + raise ValueError("Expected pybind11 or nanobind as argument") def run(f): diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt index f51a7b4..9a813da 100644 --- a/mlir/test/python/lib/CMakeLists.txt +++ b/mlir/test/python/lib/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES PythonTestCAPI.cpp PythonTestDialect.cpp + PythonTestModulePybind11.cpp PythonTestModuleNanobind.cpp ) diff --git a/mlir/test/python/lib/PythonTestModulePybind11.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp new file mode 100644 index 0000000..94a5f51 --- /dev/null +++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp @@ -0,0 +1,118 @@ +//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This is the pybind11 edition of the PythonTest dialect module. +//===----------------------------------------------------------------------===// + +#include "PythonTestCAPI.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace mlir::python::adaptors; +using namespace pybind11::literals; + +static bool mlirTypeIsARankedIntegerTensor(MlirType t) { + return mlirTypeIsARankedTensor(t) && + mlirTypeIsAInteger(mlirShapedTypeGetElementType(t)); +} + +PYBIND11_MODULE(_mlirPythonTestPybind11, m) { + m.def( + "register_python_test_dialect", + [](MlirContext context, bool load) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleRegisterDialect(pythonTestDialect, context); + if (load) { + mlirDialectHandleLoadDialect(pythonTestDialect, context); + } + }, + py::arg("context"), py::arg("load") = true); + + m.def( + "register_dialect", + [](MlirDialectRegistry registry) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleInsertDialect(pythonTestDialect, registry); + }, + py::arg("registry")); + + mlir_attribute_subclass(m, "TestAttr", + mlirAttributeIsAPythonTestTestAttribute, + mlirPythonTestTestAttributeGetTypeID) + .def_classmethod( + "get", + [](const py::object &cls, MlirContext ctx) { + return cls(mlirPythonTestTestAttributeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); + + mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, + mlirPythonTestTestTypeGetTypeID) + .def_classmethod( + "get", + [](const py::object &cls, MlirContext ctx) { + return cls(mlirPythonTestTestTypeGet(ctx)); + }, + py::arg("cls"), py::arg("context") = py::none()); + + auto typeCls = + mlir_type_subclass(m, "TestIntegerRankedTensorType", + mlirTypeIsARankedIntegerTensor, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("RankedTensorType")) + .def_classmethod( + "get", + [](const py::object &cls, std::vector<int64_t> shape, + unsigned width, MlirContext ctx) { + MlirAttribute encoding = mlirAttributeGetNull(); + return cls(mlirRankedTensorTypeGet( + shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), + encoding)); + }, + "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); + + assert(py::hasattr(typeCls.get_class(), "static_typeid") && + "TestIntegerRankedTensorType has no static_typeid"); + + MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, + "replace"_a = true)( + pybind11::cpp_function([typeCls](const py::object &mlirType) { + return typeCls.get_class()(mlirType); + })); + + auto valueCls = mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("is_null", [](MlirValue &self) { + return mlirValueIsNull(self); + }); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( + mlirRankedTensorTypeID)( + pybind11::cpp_function([valueCls](const py::object &valueObj) { + py::object capsule = mlirApiObjectToCapsule(valueObj); + MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); + MlirType t = mlirValueGetType(v); + // This is hyper-specific in order to exercise/test registering a + // value caster from cpp (but only for a single test case; see + // testTensorValue python_test.py). + if (mlirShapedTypeHasStaticShape(t) && + mlirShapedTypeGetDimSize(t, 0) == 1 && + mlirShapedTypeGetDimSize(t, 1) == 2 && + mlirShapedTypeGetDimSize(t, 2) == 3) + return valueCls.get_class()(valueObj); + return valueObj; + })); +} |