From 0289ae51aa375fd297f1d03d27ff517223e5e998 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Fri, 22 Mar 2024 08:31:17 +0100 Subject: [MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813) This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary. Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases. --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 108 ++++++++++++++++++--- .../Dialect/LLVMIR/Transforms/TypeConsistency.cpp | 101 ------------------- mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 7 +- 3 files changed, 97 insertions(+), 119 deletions(-) (limited to 'mlir/lib') diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 00b4559..0ef1d10 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -13,10 +13,8 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "llvm/ADT/STLExtras.h" @@ -71,12 +69,8 @@ SmallVector LLVM::AllocaOp::getDestructurableSlots() { if (!destructuredType) return {}; - DenseMap allocaTypeMap; - for (Attribute index : llvm::make_first_range(destructuredType.value())) - allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); - - return { - DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}}; + return {DestructurableMemorySlot{{getResult(), getElemType()}, + *destructuredType}}; } DenseMap @@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses( return DeletionKind::Delete; } +/// Checks if `slot` can be accessed through the provided access type. +static bool isValidAccessType(const MemorySlot &slot, Type accessType, + const DataLayout &dataLayout) { + return dataLayout.getTypeSize(accessType) <= + dataLayout.getTypeSize(slot.elemType); +} + LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed, const DataLayout &dataLayout) { - return success(getAddr() != slot.ptr || getType() == slot.elemType); + return success(getAddr() != slot.ptr || + isValidAccessType(slot, getType(), dataLayout)); } LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed, const DataLayout &dataLayout) { return success(getAddr() != slot.ptr || - getValue().getType() == slot.elemType); + isValidAccessType(slot, getValue().getType(), dataLayout)); +} + +/// Returns the subslot's type at the requested index. +static Type getTypeAtIndex(const DestructurableMemorySlot &slot, + Attribute index) { + auto subelementIndexMap = + slot.elemType.cast().getSubelementIndexMap(); + if (!subelementIndexMap) + return {}; + assert(!subelementIndexMap->empty()); + + // Note: Returns a null-type when no entry was found. + return subelementIndexMap->lookup(index); +} + +bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (getVolatile_()) + return false; + + // A load always accesses the first element of the destructured slot. + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + Type subslotType = getTypeAtIndex(slot, index); + if (!subslotType) + return false; + + // The access can only be replaced when the subslot is read within its bounds. + if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType)) + return false; + + usedIndices.insert(index); + return true; +} + +DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + RewriterBase &rewriter, + const DataLayout &dataLayout) { + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + auto it = subslots.find(index); + assert(it != subslots.end()); + + rewriter.modifyOpInPlace( + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); + return DeletionKind::Keep; +} + +bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed, + const DataLayout &dataLayout) { + if (getVolatile_()) + return false; + + // A store always accesses the first element of the destructured slot. + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + Type subslotType = getTypeAtIndex(slot, index); + if (!subslotType) + return false; + + // The access can only be replaced when the subslot is read within its bounds. + if (dataLayout.getTypeSize(getValue().getType()) > + dataLayout.getTypeSize(subslotType)) + return false; + + usedIndices.insert(index); + return true; +} + +DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + RewriterBase &rewriter, + const DataLayout &dataLayout) { + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); + auto it = subslots.find(index); + assert(it != subslots.end()); + + rewriter.modifyOpInPlace( + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); + return DeletionKind::Keep; } //===----------------------------------------------------------------------===// @@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, auto firstLevelIndex = dyn_cast(getIndices()[1]); if (!firstLevelIndex) return false; - assert(slot.elementPtrs.contains(firstLevelIndex)); - if (!llvm::isa(slot.elementPtrs.at(firstLevelIndex))) - return false; mustBeSafelyUsed.emplace_back({getResult(), reachedType}); + assert(slot.elementPtrs.contains(firstLevelIndex)); usedIndices.insert(firstLevelIndex); return true; } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp index b25c831..3d700fe 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -50,104 +50,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) { } //===----------------------------------------------------------------------===// -// AddFieldGetterToStructDirectUse -//===----------------------------------------------------------------------===// - -/// Gets the type of the first subelement of `type` if `type` is destructurable, -/// nullptr otherwise. -static Type getFirstSubelementType(Type type) { - auto destructurable = dyn_cast(type); - if (!destructurable) - return nullptr; - - Type subelementType = destructurable.getTypeAtIndex( - IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0)); - if (subelementType) - return subelementType; - - return nullptr; -} - -/// Extracts a pointer to the first field of an `elemType` from the address -/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer -/// instead. -template -static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, - Type elemType) { - PatternRewriter::InsertionGuard guard(rewriter); - - rewriter.setInsertionPointAfterValue(op.getAddr()); - SmallVector firstTypeIndices{0, 0}; - - Value properPtr = rewriter.create( - op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType, - op.getAddr(), firstTypeIndices); - - rewriter.modifyOpInPlace(op, - [&]() { op.getAddrMutable().assign(properPtr); }); -} - -template <> -LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( - LoadOp load, PatternRewriter &rewriter) const { - PatternRewriter::InsertionGuard guard(rewriter); - - Type inconsistentElementType = - isElementTypeInconsistent(load.getAddr(), load.getType()); - if (!inconsistentElementType) - return failure(); - Type firstType = getFirstSubelementType(inconsistentElementType); - if (!firstType) - return failure(); - DataLayout layout = DataLayout::closest(load); - if (!areBitcastCompatible(layout, firstType, load.getResult().getType())) - return failure(); - - insertFieldIndirection(load, rewriter, inconsistentElementType); - - // If the load does not use the first type but a type that can be casted from - // it, add a bitcast and change the load type. - if (firstType != load.getResult().getType()) { - rewriter.setInsertionPointAfterValue(load.getResult()); - BitcastOp bitcast = rewriter.create( - load->getLoc(), load.getResult().getType(), load.getResult()); - rewriter.modifyOpInPlace(load, - [&]() { load.getResult().setType(firstType); }); - rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(), - bitcast); - } - - return success(); -} - -template <> -LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( - StoreOp store, PatternRewriter &rewriter) const { - PatternRewriter::InsertionGuard guard(rewriter); - - Type inconsistentElementType = - isElementTypeInconsistent(store.getAddr(), store.getValue().getType()); - if (!inconsistentElementType) - return failure(); - Type firstType = getFirstSubelementType(inconsistentElementType); - if (!firstType) - return failure(); - - DataLayout layout = DataLayout::closest(store); - // Check that the first field has the right type or can at least be bitcast - // to the right type. - if (!areBitcastCompatible(layout, firstType, store.getValue().getType())) - return failure(); - - insertFieldIndirection(store, rewriter, inconsistentElementType); - - rewriter.modifyOpInPlace( - store, [&]() { store.getValueMutable().assign(store.getValue()); }); - - return success(); -} - -//===----------------------------------------------------------------------===// // CanonicalizeAlignedGep //===----------------------------------------------------------------------===// @@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass : public LLVM::impl::LLVMTypeConsistencyBase { void runOnOperation() override { RewritePatternSet rewritePatterns(&getContext()); - rewritePatterns.add>(&getContext()); - rewritePatterns.add>( - &getContext()); rewritePatterns.add(&getContext()); rewritePatterns.add(&getContext(), maxVectorSplitSize); rewritePatterns.add(&getContext()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index 7be4056..6c5250d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() { if (!destructuredType) return {}; - DenseMap indexMap; - for (auto const &[index, type] : *destructuredType) - indexMap.insert({index, MemRefType::get({}, type)}); - - return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}}; + return { + DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}}; } DenseMap -- cgit v1.1