diff options
author | Christian Ulmann <christianulmann@gmail.com> | 2024-03-22 08:31:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 08:31:17 +0100 |
commit | 0289ae51aa375fd297f1d03d27ff517223e5e998 (patch) | |
tree | 606d7f4db5f899a392f97d4a8fa488b0a7c2411e /mlir/lib | |
parent | 90454a609894ab278a87be2b9f5c49714caba8df (diff) | |
download | llvm-0289ae51aa375fd297f1d03d27ff517223e5e998.zip llvm-0289ae51aa375fd297f1d03d27ff517223e5e998.tar.gz llvm-0289ae51aa375fd297f1d03d27ff517223e5e998.tar.bz2 |
[MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813)
This commit relaxes the assumption of type consistency for LLVM dialect
load and store operations in SROA. Instead, there is now a check that
loads and stores are in the bounds specified by the sub-slot they
access.
This commit additionally removes the corresponding patterns from the
type consistency pass, as they are no longer necessary.
Note: It will be necessary to extend Mem2Reg with the logic for
differently sized accesses as well. This is non-the-less a strict
upgrade for productive flows, as the type consistency pass can produce
invalid IR for some odd cases.
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 108 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp | 101 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 7 |
3 files changed, 97 insertions, 119 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 00b4559..0ef1d10 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -13,10 +13,8 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "llvm/ADT/STLExtras.h" @@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() { if (!destructuredType) return {}; - DenseMap<Attribute, Type> allocaTypeMap; - for (Attribute index : llvm::make_first_range(destructuredType.value())) - allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); - - return { - DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}}; + return {DestructurableMemorySlot{{getResult(), getElemType()}, + *destructuredType}}; } DenseMap<Attribute, MemorySlot> @@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( return DeletionKind::Delete; } +/// Checks if `slot` can be accessed through the provided access type. +static bool isValidAccessType(const MemorySlot &slot, Type accessType, + const DataLayout &dataLayout) { + return dataLayout.getTypeSize(accessType) <= + dataLayout.getTypeSize(slot.elemType); +} + LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses( const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, const DataLayout &dataLayout) { - return success(getAddr() != slot.ptr || getType() == slot.elemType); + return success(getAddr() != slot.ptr || + isValidAccessType(slot, getType(), dataLayout)); } LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, const DataLayout &dataLayout) { return success(getAddr() != slot.ptr || - getValue().getType() == slot.elemType); + isValidAccessType(slot, getValue().getType(), dataLayout)); +} + +/// Returns the subslot's type at the requested index. +static Type getTypeAtIndex(const DestructurableMemorySlot &slot, + Attribute index) { + auto subelementIndexMap = + slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap(); + if (!subelementIndexMap) + return {}; + assert(!subelementIndexMap->empty()); + + // Note: Returns a null-type when no entry was found. + return subelementIndexMap->lookup(index); +} + +bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl<Attribute> &usedIndices, + SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (getVolatile_()) + return false; + + // A load always accesses the first element of the destructured slot. + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + Type subslotType = getTypeAtIndex(slot, index); + if (!subslotType) + return false; + + // The access can only be replaced when the subslot is read within its bounds. + if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType)) + return false; + + usedIndices.insert(index); + return true; +} + +DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot, + DenseMap<Attribute, MemorySlot> &subslots, + RewriterBase &rewriter, + const DataLayout &dataLayout) { + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + auto it = subslots.find(index); + assert(it != subslots.end()); + + rewriter.modifyOpInPlace( + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); + return DeletionKind::Keep; +} + +bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl<Attribute> &usedIndices, + SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (getVolatile_()) + return false; + + // A store always accesses the first element of the destructured slot. + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + Type subslotType = getTypeAtIndex(slot, index); + if (!subslotType) + return false; + + // The access can only be replaced when the subslot is read within its bounds. + if (dataLayout.getTypeSize(getValue().getType()) > + dataLayout.getTypeSize(subslotType)) + return false; + + usedIndices.insert(index); + return true; +} + +DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot, + DenseMap<Attribute, MemorySlot> &subslots, + RewriterBase &rewriter, + const DataLayout &dataLayout) { + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + auto it = subslots.find(index); + assert(it != subslots.end()); + + rewriter.modifyOpInPlace( + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); + return DeletionKind::Keep; } //===----------------------------------------------------------------------===// @@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]); if (!firstLevelIndex) return false; - assert(slot.elementPtrs.contains(firstLevelIndex)); - if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex))) - return false; mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType}); + assert(slot.elementPtrs.contains(firstLevelIndex)); usedIndices.insert(firstLevelIndex); return true; } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp index b25c831..3d700fe 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -50,104 +50,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) { } //===----------------------------------------------------------------------===// -// AddFieldGetterToStructDirectUse -//===----------------------------------------------------------------------===// - -/// Gets the type of the first subelement of `type` if `type` is destructurable, -/// nullptr otherwise. -static Type getFirstSubelementType(Type type) { - auto destructurable = dyn_cast<DestructurableTypeInterface>(type); - if (!destructurable) - return nullptr; - - Type subelementType = destructurable.getTypeAtIndex( - IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0)); - if (subelementType) - return subelementType; - - return nullptr; -} - -/// Extracts a pointer to the first field of an `elemType` from the address -/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer -/// instead. -template <class MemOp> -static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, - Type elemType) { - PatternRewriter::InsertionGuard guard(rewriter); - - rewriter.setInsertionPointAfterValue(op.getAddr()); - SmallVector<GEPArg> firstTypeIndices{0, 0}; - - Value properPtr = rewriter.create<GEPOp>( - op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType, - op.getAddr(), firstTypeIndices); - - rewriter.modifyOpInPlace(op, - [&]() { op.getAddrMutable().assign(properPtr); }); -} - -template <> -LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite( - LoadOp load, PatternRewriter &rewriter) const { - PatternRewriter::InsertionGuard guard(rewriter); - - Type inconsistentElementType = - isElementTypeInconsistent(load.getAddr(), load.getType()); - if (!inconsistentElementType) - return failure(); - Type firstType = getFirstSubelementType(inconsistentElementType); - if (!firstType) - return failure(); - DataLayout layout = DataLayout::closest(load); - if (!areBitcastCompatible(layout, firstType, load.getResult().getType())) - return failure(); - - insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType); - - // If the load does not use the first type but a type that can be casted from - // it, add a bitcast and change the load type. - if (firstType != load.getResult().getType()) { - rewriter.setInsertionPointAfterValue(load.getResult()); - BitcastOp bitcast = rewriter.create<BitcastOp>( - load->getLoc(), load.getResult().getType(), load.getResult()); - rewriter.modifyOpInPlace(load, - [&]() { load.getResult().setType(firstType); }); - rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(), - bitcast); - } - - return success(); -} - -template <> -LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite( - StoreOp store, PatternRewriter &rewriter) const { - PatternRewriter::InsertionGuard guard(rewriter); - - Type inconsistentElementType = - isElementTypeInconsistent(store.getAddr(), store.getValue().getType()); - if (!inconsistentElementType) - return failure(); - Type firstType = getFirstSubelementType(inconsistentElementType); - if (!firstType) - return failure(); - - DataLayout layout = DataLayout::closest(store); - // Check that the first field has the right type or can at least be bitcast - // to the right type. - if (!areBitcastCompatible(layout, firstType, store.getValue().getType())) - return failure(); - - insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType); - - rewriter.modifyOpInPlace( - store, [&]() { store.getValueMutable().assign(store.getValue()); }); - - return success(); -} - -//===----------------------------------------------------------------------===// // CanonicalizeAlignedGep //===----------------------------------------------------------------------===// @@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> { void runOnOperation() override { RewritePatternSet rewritePatterns(&getContext()); - rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext()); - rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>( - &getContext()); rewritePatterns.add<CanonicalizeAlignedGep>(&getContext()); rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize); rewritePatterns.add<BitcastStores>(&getContext()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index 7be4056..6c5250d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() { if (!destructuredType) return {}; - DenseMap<Attribute, Type> indexMap; - for (auto const &[index, type] : *destructuredType) - indexMap.insert({index, MemRefType::get({}, type)}); - - return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}}; + return { + DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}}; } DenseMap<Attribute, MemorySlot> |