aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Ptr/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Ptr/IR/CMakeLists.txt16
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp40
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp55
-rw-r--r--mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp144
6 files changed, 257 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index a324ce7..80b0ef0 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -29,6 +29,7 @@ add_subdirectory(OpenMP)
add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Polynomial)
+add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt
new file mode 100644
index 0000000..f33061b2
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
new file mode 100644
index 0000000..9cf3643
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(
+ MLIRPtrDialect
+ PtrAttrs.cpp
+ PtrTypes.cpp
+ PtrDialect.cpp
+
+ DEPENDS
+ MLIRPtrOpsAttributesIncGen
+ MLIRPtrOpsIncGen
+
+ LINK_LIBS
+ PUBLIC
+ MLIRIR
+ MLIRDataLayoutInterfaces
+ MLIRMemorySlotInterfaces
+)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
new file mode 100644
index 0000000..f8ce820
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp
@@ -0,0 +1,40 @@
+//===- PtrAttrs.cpp - Pointer dialect attributes ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect attributes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+constexpr const static unsigned kBitsInByte = 8;
+
+//===----------------------------------------------------------------------===//
+// SpecAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult SpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ uint32_t size, uint32_t abi, uint32_t preferred,
+ uint32_t index) {
+ if (size % kBitsInByte != 0)
+ return emitError() << "size entry must be divisible by 8";
+ if (abi % kBitsInByte != 0)
+ return emitError() << "abi entry must be divisible by 8";
+ if (preferred % kBitsInByte != 0)
+ return emitError() << "preferred entry must be divisible by 8";
+ if (index != kOptionalSpecValue && index % kBitsInByte != 0)
+ return emitError() << "index entry must be divisible by 8";
+ if (abi > preferred)
+ return emitError() << "preferred alignment is expected to be at least "
+ "as large as ABI alignment";
+ return success();
+}
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
new file mode 100644
index 0000000..7830ffe
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -0,0 +1,55 @@
+//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Pointer dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOps.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer dialect
+//===----------------------------------------------------------------------===//
+
+void PtrDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// Pointer API.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc"
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
new file mode 100644
index 0000000..2866d4e
--- /dev/null
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -0,0 +1,144 @@
+//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Ptr dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::ptr;
+
+//===----------------------------------------------------------------------===//
+// Pointer type
+//===----------------------------------------------------------------------===//
+
+constexpr const static unsigned kDefaultPointerSizeBits = 64;
+constexpr const static unsigned kBitsInByte = 8;
+constexpr const static unsigned kDefaultPointerAlignment = 8;
+
+static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; }
+
+/// Searches the data layout for the pointer spec, returns nullptr if it is not
+/// found.
+static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) {
+ for (DataLayoutEntryInterface entry : params) {
+ if (!entry.isTypeEntry())
+ continue;
+ if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() ==
+ type.getMemorySpace()) {
+ if (auto spec = dyn_cast<SpecAttr>(entry.getValue()))
+ return spec;
+ }
+ }
+ // If not found, and this is the pointer to the default memory space, assume
+ // 64-bit pointers.
+ if (type.getMemorySpace() == getDefaultMemorySpace(type))
+ return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits,
+ kDefaultPointerAlignment, kDefaultPointerAlignment,
+ kDefaultPointerSizeBits);
+ return nullptr;
+}
+
+bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
+ DataLayoutEntryListRef newLayout) const {
+ for (DataLayoutEntryInterface newEntry : newLayout) {
+ if (!newEntry.isTypeEntry())
+ continue;
+ uint32_t size = kDefaultPointerSizeBits;
+ uint32_t abi = kDefaultPointerAlignment;
+ auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+ const auto *it =
+ llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ return llvm::cast<PtrType>(type).getMemorySpace() ==
+ newType.getMemorySpace();
+ }
+ return false;
+ });
+ if (it == oldLayout.end()) {
+ it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
+ auto ptrTy = llvm::cast<PtrType>(type);
+ return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy);
+ }
+ return false;
+ });
+ }
+ if (it != oldLayout.end()) {
+ auto spec = llvm::cast<SpecAttr>(*it);
+ size = spec.getSize();
+ abi = spec.getAbi();
+ }
+
+ auto newSpec = llvm::cast<SpecAttr>(newEntry.getValue());
+ uint32_t newSize = newSpec.getSize();
+ uint32_t newAbi = newSpec.getAbi();
+ if (size != newSize || abi < newAbi || abi % newAbi != 0)
+ return false;
+ }
+ return true;
+}
+
+uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (SpecAttr spec = getPointerSpec(params, *this))
+ return spec.getAbi() / kBitsInByte;
+
+ return dataLayout.getTypeABIAlignment(
+ get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+std::optional<uint64_t>
+PtrType::getIndexBitwidth(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (SpecAttr spec = getPointerSpec(params, *this)) {
+ return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize()
+ : spec.getIndex();
+ }
+
+ return dataLayout.getTypeIndexBitwidth(
+ get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (SpecAttr spec = getPointerSpec(params, *this))
+ return llvm::TypeSize::getFixed(spec.getSize());
+
+ // For other memory spaces, use the size of the pointer to the default memory
+ // space.
+ return dataLayout.getTypeSizeInBits(
+ get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout,
+ DataLayoutEntryListRef params) const {
+ if (SpecAttr spec = getPointerSpec(params, *this))
+ return spec.getPreferred() / kBitsInByte;
+
+ return dataLayout.getTypePreferredAlignment(
+ get(getContext(), getDefaultMemorySpace(*this)));
+}
+
+LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
+ Location loc) const {
+ for (DataLayoutEntryInterface entry : entries) {
+ if (!entry.isTypeEntry())
+ continue;
+ auto key = entry.getKey().get<Type>();
+ if (!llvm::isa<SpecAttr>(entry.getValue())) {
+ return emitError(loc) << "expected layout attribute for " << key
+ << " to be a #ptr.spec attribute";
+ }
+ }
+ return success();
+}