aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp')
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0b3f4b9..24719fe 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -32,6 +32,17 @@ void XeGPUDialect::initialize() {
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescAttr
//===----------------------------------------------------------------------===//
+TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
+ xegpu::MemoryScope memory_scope,
+ int array_length, bool boundary_check,
+ bool scattered) {
+ auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
+ auto lengthAttr =
+ IntegerAttr::get(IntegerType::get(context, 64), array_length);
+ auto boundaryAttr = BoolAttr::get(context, boundary_check);
+ auto scatteredAttr = BoolAttr::get(context, scattered);
+ return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
+}
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
@@ -96,6 +107,16 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
printer << ">";
}
+TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
+ mlir::Type elementType, bool scattered,
+ int array_length, MemoryScope memory_scope,
+ bool boundary_check) {
+ auto context = elementType.getContext();
+ auto attr = TensorDescAttr::get(context, memory_scope, array_length,
+ boundary_check, scattered);
+ return Base::get(context, shape, elementType, attr);
+}
+
} // namespace xegpu
} // namespace mlir