diff options
author | Fabian Mora <fmora.dev@gmail.com> | 2024-06-27 07:14:34 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-27 07:14:34 -0500 |
commit | e035ef0e7423c1a4c78e922508da817dbd5b6a02 (patch) | |
tree | 001989d2b7fec85d70fb173cf0fd653b11f85c17 /mlir/lib | |
parent | 0f5fa3558eb36823c16ba81a4c6e6e23a5f9df24 (diff) | |
download | llvm-e035ef0e7423c1a4c78e922508da817dbd5b6a02.zip llvm-e035ef0e7423c1a4c78e922508da817dbd5b6a02.tar.gz llvm-e035ef0e7423c1a4c78e922508da817dbd5b6a02.tar.bz2 |
[mlir][Ptr] Init the Ptr dialect with the `!ptr.ptr` type. (#86860)
This patch initializes the `ptr` dialect directories and base files,
adding the `!ptr.ptr` type and the `#ptr.spec<...>` data layout spec
attribute.
The `!ptr.ptr` type is an opaque pointer type optionally parameterized
by a memory space. This type typically represents a handle to an object
in memory or target-dependent values like `nullptr`.
The implementation of the `DataLayoutTypeInterface` interface for
`!ptr.ptr` was adapted from `!llvm.ptr`'s implementation. This
implementation uses the `#ptr.spec<...>` attribute for defining the data
layout specification.
See [[RFC] `ptr` dialect & modularizing ptr ops in the LLVM
dialect](https://discourse.llvm.org/t/rfc-ptr-dialect-modularizing-ptr-ops-in-the-llvm-dialect/75142)
for rationale and roadmap.
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 16 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp | 40 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 55 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 144 |
6 files changed, 257 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index a324ce7..80b0ef0 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -29,6 +29,7 @@ add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Polynomial) +add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000..f33061b2 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt new file mode 100644 index 0000000..9cf3643 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library( + MLIRPtrDialect + PtrAttrs.cpp + PtrTypes.cpp + PtrDialect.cpp + + DEPENDS + MLIRPtrOpsAttributesIncGen + MLIRPtrOpsIncGen + + LINK_LIBS + PUBLIC + MLIRIR + MLIRDataLayoutInterfaces + MLIRMemorySlotInterfaces +) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp new file mode 100644 index 0000000..f8ce820 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -0,0 +1,40 @@ +//===- PtrAttrs.cpp - Pointer dialect attributes ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect attributes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +constexpr const static unsigned kBitsInByte = 8; + +//===----------------------------------------------------------------------===// +// SpecAttr +//===----------------------------------------------------------------------===// + +LogicalResult SpecAttr::verify(function_ref<InFlightDiagnostic()> emitError, + uint32_t size, uint32_t abi, uint32_t preferred, + uint32_t index) { + if (size % kBitsInByte != 0) + return emitError() << "size entry must be divisible by 8"; + if (abi % kBitsInByte != 0) + return emitError() << "abi entry must be divisible by 8"; + if (preferred % kBitsInByte != 0) + return emitError() << "preferred entry must be divisible by 8"; + if (index != kOptionalSpecValue && index % kBitsInByte != 0) + return emitError() << "index entry must be divisible by 8"; + if (abi > preferred) + return emitError() << "preferred alignment is expected to be at least " + "as large as ABI alignment"; + return success(); +} diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp new file mode 100644 index 0000000..7830ffe --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -0,0 +1,55 @@ +//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pointer dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer dialect +//===----------------------------------------------------------------------===// + +void PtrDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Pointer API. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp new file mode 100644 index 0000000..2866d4e --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -0,0 +1,144 @@ +//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +constexpr const static unsigned kDefaultPointerSizeBits = 64; +constexpr const static unsigned kBitsInByte = 8; +constexpr const static unsigned kDefaultPointerAlignment = 8; + +static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; } + +/// Searches the data layout for the pointer spec, returns nullptr if it is not +/// found. +static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { + for (DataLayoutEntryInterface entry : params) { + if (!entry.isTypeEntry()) + continue; + if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() == + type.getMemorySpace()) { + if (auto spec = dyn_cast<SpecAttr>(entry.getValue())) + return spec; + } + } + // If not found, and this is the pointer to the default memory space, assume + // 64-bit pointers. + if (type.getMemorySpace() == getDefaultMemorySpace(type)) + return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, + kDefaultPointerAlignment, kDefaultPointerAlignment, + kDefaultPointerSizeBits); + return nullptr; +} + +bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, + DataLayoutEntryListRef newLayout) const { + for (DataLayoutEntryInterface newEntry : newLayout) { + if (!newEntry.isTypeEntry()) + continue; + uint32_t size = kDefaultPointerSizeBits; + uint32_t abi = kDefaultPointerAlignment; + auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>()); + const auto *it = + llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { + return llvm::cast<PtrType>(type).getMemorySpace() == + newType.getMemorySpace(); + } + return false; + }); + if (it == oldLayout.end()) { + it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { + auto ptrTy = llvm::cast<PtrType>(type); + return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy); + } + return false; + }); + } + if (it != oldLayout.end()) { + auto spec = llvm::cast<SpecAttr>(*it); + size = spec.getSize(); + abi = spec.getAbi(); + } + + auto newSpec = llvm::cast<SpecAttr>(newEntry.getValue()); + uint32_t newSize = newSpec.getSize(); + uint32_t newAbi = newSpec.getAbi(); + if (size != newSize || abi < newAbi || abi % newAbi != 0) + return false; + } + return true; +} + +uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getAbi() / kBitsInByte; + + return dataLayout.getTypeABIAlignment( + get(getContext(), getDefaultMemorySpace(*this))); +} + +std::optional<uint64_t> +PtrType::getIndexBitwidth(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) { + return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() + : spec.getIndex(); + } + + return dataLayout.getTypeIndexBitwidth( + get(getContext(), getDefaultMemorySpace(*this))); +} + +llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return llvm::TypeSize::getFixed(spec.getSize()); + + // For other memory spaces, use the size of the pointer to the default memory + // space. + return dataLayout.getTypeSizeInBits( + get(getContext(), getDefaultMemorySpace(*this))); +} + +uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getPreferred() / kBitsInByte; + + return dataLayout.getTypePreferredAlignment( + get(getContext(), getDefaultMemorySpace(*this))); +} + +LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, + Location loc) const { + for (DataLayoutEntryInterface entry : entries) { + if (!entry.isTypeEntry()) + continue; + auto key = entry.getKey().get<Type>(); + if (!llvm::isa<SpecAttr>(entry.getValue())) { + return emitError(loc) << "expected layout attribute for " << key + << " to be a #ptr.spec attribute"; + } + } + return success(); +} |