aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
blob: e5a969fffe9402e79ae162fc97dfbad4bcf3779c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"

namespace nb = nanobind;
using namespace llvm;
using namespace mlir::python::nanobind_adaptors;

namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace amdgpu {
struct TDMBaseType : PyConcreteType<TDMBaseType> {
  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAMDGPUTDMBaseType;
  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
      mlirAMDGPUTDMBaseTypeGetTypeID;
  static constexpr const char *pyClassName = "TDMBaseType";
  static inline const MlirStringRef name = mlirAMDGPUTDMBaseTypeGetName();
  using Base::Base;

  static void bindDerived(ClassTy &c) {
    c.def_static(
        "get",
        [](const PyType &elementType, DefaultingPyMlirContext context) {
          return TDMBaseType(
              context->getRef(),
              mlirAMDGPUTDMBaseTypeGet(context.get()->get(), elementType));
        },
        "Gets an instance of TDMBaseType in the same context",
        nb::arg("element_type"), nb::arg("context").none() = nb::none());
  }
};

struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
  static constexpr IsAFunctionTy isaFunction =
      mlirTypeIsAAMDGPUTDMDescriptorType;
  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
      mlirAMDGPUTDMDescriptorTypeGetTypeID;
  static constexpr const char *pyClassName = "TDMDescriptorType";
  static inline const MlirStringRef name = mlirAMDGPUTDMDescriptorTypeGetName();
  using Base::Base;

  static void bindDerived(ClassTy &c) {
    c.def_static(
        "get",
        [](DefaultingPyMlirContext context) {
          return TDMDescriptorType(
              context->getRef(),
              mlirAMDGPUTDMDescriptorTypeGet(context.get()->get()));
        },
        "Gets an instance of TDMDescriptorType in the same context",
        nb::arg("context").none() = nb::none());
  }
};

struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
  static constexpr IsAFunctionTy isaFunction =
      mlirTypeIsAAMDGPUTDMGatherBaseType;
  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
      mlirAMDGPUTDMGatherBaseTypeGetTypeID;
  static constexpr const char *pyClassName = "TDMGatherBaseType";
  static inline const MlirStringRef name = mlirAMDGPUTDMGatherBaseTypeGetName();
  using Base::Base;

  static void bindDerived(ClassTy &c) {
    c.def_static(
        "get",
        [](const PyType &elementType, const PyType &indexType,
           DefaultingPyMlirContext context) {
          return TDMGatherBaseType(
              context->getRef(),
              mlirAMDGPUTDMGatherBaseTypeGet(context.get()->get(), elementType,
                                             indexType));
        },
        "Gets an instance of TDMGatherBaseType in the same context",
        nb::arg("element_type"), nb::arg("index_type"),
        nb::arg("context").none() = nb::none());
  }
};

static void populateDialectAMDGPUSubmodule(nb::module_ &m) {
  TDMBaseType::bind(m);
  TDMDescriptorType::bind(m);
  TDMGatherBaseType::bind(m);
}
} // namespace amdgpu
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir

NB_MODULE(_mlirDialectsAMDGPU, m) {
  m.doc() = "MLIR AMDGPU dialect.";

  mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::amdgpu::
      populateDialectAMDGPUSubmodule(m);
}