diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | 16 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp | 40 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | 55 | ||||
-rw-r--r-- | mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp | 144 |
6 files changed, 257 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index a324ce7..80b0ef0 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -29,6 +29,7 @@ add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Polynomial) +add_subdirectory(Ptr) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/lib/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000..f33061b2 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt new file mode 100644 index 0000000..9cf3643 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library( + MLIRPtrDialect + PtrAttrs.cpp + PtrTypes.cpp + PtrDialect.cpp + + DEPENDS + MLIRPtrOpsAttributesIncGen + MLIRPtrOpsIncGen + + LINK_LIBS + PUBLIC + MLIRIR + MLIRDataLayoutInterfaces + MLIRMemorySlotInterfaces +) diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp new file mode 100644 index 0000000..f8ce820 --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -0,0 +1,40 @@ +//===- PtrAttrs.cpp - Pointer dialect attributes ----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect attributes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +constexpr const static unsigned kBitsInByte = 8; + +//===----------------------------------------------------------------------===// +// SpecAttr +//===----------------------------------------------------------------------===// + +LogicalResult SpecAttr::verify(function_ref<InFlightDiagnostic()> emitError, + uint32_t size, uint32_t abi, uint32_t preferred, + uint32_t index) { + if (size % kBitsInByte != 0) + return emitError() << "size entry must be divisible by 8"; + if (abi % kBitsInByte != 0) + return emitError() << "abi entry must be divisible by 8"; + if (preferred % kBitsInByte != 0) + return emitError() << "preferred entry must be divisible by 8"; + if (index != kOptionalSpecValue && index % kBitsInByte != 0) + return emitError() << "index entry must be divisible by 8"; + if (abi > preferred) + return emitError() << "preferred alignment is expected to be at least " + "as large as ABI alignment"; + return success(); +} diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp new file mode 100644 index 0000000..7830ffe --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -0,0 +1,55 @@ +//===- PtrDialect.cpp - Pointer dialect ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Pointer dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer dialect +//===----------------------------------------------------------------------===// + +void PtrDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Pointer API. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Ptr/IR/PtrOps.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp new file mode 100644 index 0000000..2866d4e --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp @@ -0,0 +1,144 @@ +//===- PtrTypes.cpp - Pointer dialect types ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Ptr dialect types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::ptr; + +//===----------------------------------------------------------------------===// +// Pointer type +//===----------------------------------------------------------------------===// + +constexpr const static unsigned kDefaultPointerSizeBits = 64; +constexpr const static unsigned kBitsInByte = 8; +constexpr const static unsigned kDefaultPointerAlignment = 8; + +static Attribute getDefaultMemorySpace(PtrType ptr) { return nullptr; } + +/// Searches the data layout for the pointer spec, returns nullptr if it is not +/// found. +static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) { + for (DataLayoutEntryInterface entry : params) { + if (!entry.isTypeEntry()) + continue; + if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() == + type.getMemorySpace()) { + if (auto spec = dyn_cast<SpecAttr>(entry.getValue())) + return spec; + } + } + // If not found, and this is the pointer to the default memory space, assume + // 64-bit pointers. + if (type.getMemorySpace() == getDefaultMemorySpace(type)) + return SpecAttr::get(type.getContext(), kDefaultPointerSizeBits, + kDefaultPointerAlignment, kDefaultPointerAlignment, + kDefaultPointerSizeBits); + return nullptr; +} + +bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout, + DataLayoutEntryListRef newLayout) const { + for (DataLayoutEntryInterface newEntry : newLayout) { + if (!newEntry.isTypeEntry()) + continue; + uint32_t size = kDefaultPointerSizeBits; + uint32_t abi = kDefaultPointerAlignment; + auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>()); + const auto *it = + llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { + return llvm::cast<PtrType>(type).getMemorySpace() == + newType.getMemorySpace(); + } + return false; + }); + if (it == oldLayout.end()) { + it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { + if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) { + auto ptrTy = llvm::cast<PtrType>(type); + return ptrTy.getMemorySpace() == getDefaultMemorySpace(ptrTy); + } + return false; + }); + } + if (it != oldLayout.end()) { + auto spec = llvm::cast<SpecAttr>(*it); + size = spec.getSize(); + abi = spec.getAbi(); + } + + auto newSpec = llvm::cast<SpecAttr>(newEntry.getValue()); + uint32_t newSize = newSpec.getSize(); + uint32_t newAbi = newSpec.getAbi(); + if (size != newSize || abi < newAbi || abi % newAbi != 0) + return false; + } + return true; +} + +uint64_t PtrType::getABIAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getAbi() / kBitsInByte; + + return dataLayout.getTypeABIAlignment( + get(getContext(), getDefaultMemorySpace(*this))); +} + +std::optional<uint64_t> +PtrType::getIndexBitwidth(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) { + return spec.getIndex() == SpecAttr::kOptionalSpecValue ? spec.getSize() + : spec.getIndex(); + } + + return dataLayout.getTypeIndexBitwidth( + get(getContext(), getDefaultMemorySpace(*this))); +} + +llvm::TypeSize PtrType::getTypeSizeInBits(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return llvm::TypeSize::getFixed(spec.getSize()); + + // For other memory spaces, use the size of the pointer to the default memory + // space. + return dataLayout.getTypeSizeInBits( + get(getContext(), getDefaultMemorySpace(*this))); +} + +uint64_t PtrType::getPreferredAlignment(const DataLayout &dataLayout, + DataLayoutEntryListRef params) const { + if (SpecAttr spec = getPointerSpec(params, *this)) + return spec.getPreferred() / kBitsInByte; + + return dataLayout.getTypePreferredAlignment( + get(getContext(), getDefaultMemorySpace(*this))); +} + +LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries, + Location loc) const { + for (DataLayoutEntryInterface entry : entries) { + if (!entry.isTypeEntry()) + continue; + auto key = entry.getKey().get<Type>(); + if (!llvm::isa<SpecAttr>(entry.getValue())) { + return emitError(loc) << "expected layout attribute for " << key + << " to be a #ptr.spec attribute"; + } + } + return success(); +} |