aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorChristian Ulmann <christianulmann@gmail.com>2024-03-22 08:31:17 +0100
committerGitHub <noreply@github.com>2024-03-22 08:31:17 +0100
commit0289ae51aa375fd297f1d03d27ff517223e5e998 (patch)
tree606d7f4db5f899a392f97d4a8fa488b0a7c2411e /mlir/lib
parent90454a609894ab278a87be2b9f5c49714caba8df (diff)
downloadllvm-0289ae51aa375fd297f1d03d27ff517223e5e998.zip
llvm-0289ae51aa375fd297f1d03d27ff517223e5e998.tar.gz
llvm-0289ae51aa375fd297f1d03d27ff517223e5e998.tar.bz2
[MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813)
This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary. Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases.
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp108
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp101
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp7
3 files changed, 97 insertions, 119 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 00b4559..0ef1d10 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,8 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> allocaTypeMap;
- for (Attribute index : llvm::make_first_range(destructuredType.value()))
- allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
-
- return {
- DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
+ return {DestructurableMemorySlot{{getResult(), getElemType()},
+ *destructuredType}};
}
DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
return DeletionKind::Delete;
}
+/// Checks if `slot` can be accessed through the provided access type.
+static bool isValidAccessType(const MemorySlot &slot, Type accessType,
+ const DataLayout &dataLayout) {
+ return dataLayout.getTypeSize(accessType) <=
+ dataLayout.getTypeSize(slot.elemType);
+}
+
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
- return success(getAddr() != slot.ptr || getType() == slot.elemType);
+ return success(getAddr() != slot.ptr ||
+ isValidAccessType(slot, getType(), dataLayout));
}
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
- getValue().getType() == slot.elemType);
+ isValidAccessType(slot, getValue().getType(), dataLayout));
+}
+
+/// Returns the subslot's type at the requested index.
+static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
+ Attribute index) {
+ auto subelementIndexMap =
+ slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+ if (!subelementIndexMap)
+ return {};
+ assert(!subelementIndexMap->empty());
+
+ // Note: Returns a null-type when no entry was found.
+ return subelementIndexMap->lookup(index);
+}
+
+bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A load always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
+}
+
+bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A store always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getValue().getType()) >
+ dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
}
//===----------------------------------------------------------------------===//
@@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
if (!firstLevelIndex)
return false;
- assert(slot.elementPtrs.contains(firstLevelIndex));
- if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
- return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ assert(slot.elementPtrs.contains(firstLevelIndex));
usedIndices.insert(firstLevelIndex);
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b25c831..3d700fe 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -50,104 +50,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
}
//===----------------------------------------------------------------------===//
-// AddFieldGetterToStructDirectUse
-//===----------------------------------------------------------------------===//
-
-/// Gets the type of the first subelement of `type` if `type` is destructurable,
-/// nullptr otherwise.
-static Type getFirstSubelementType(Type type) {
- auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
- if (!destructurable)
- return nullptr;
-
- Type subelementType = destructurable.getTypeAtIndex(
- IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
- if (subelementType)
- return subelementType;
-
- return nullptr;
-}
-
-/// Extracts a pointer to the first field of an `elemType` from the address
-/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
-/// instead.
-template <class MemOp>
-static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
- Type elemType) {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- rewriter.setInsertionPointAfterValue(op.getAddr());
- SmallVector<GEPArg> firstTypeIndices{0, 0};
-
- Value properPtr = rewriter.create<GEPOp>(
- op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
- op.getAddr(), firstTypeIndices);
-
- rewriter.modifyOpInPlace(op,
- [&]() { op.getAddrMutable().assign(properPtr); });
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
- LoadOp load, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(load.getAddr(), load.getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
- DataLayout layout = DataLayout::closest(load);
- if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
- return failure();
-
- insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
-
- // If the load does not use the first type but a type that can be casted from
- // it, add a bitcast and change the load type.
- if (firstType != load.getResult().getType()) {
- rewriter.setInsertionPointAfterValue(load.getResult());
- BitcastOp bitcast = rewriter.create<BitcastOp>(
- load->getLoc(), load.getResult().getType(), load.getResult());
- rewriter.modifyOpInPlace(load,
- [&]() { load.getResult().setType(firstType); });
- rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
- bitcast);
- }
-
- return success();
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
- StoreOp store, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
-
- DataLayout layout = DataLayout::closest(store);
- // Check that the first field has the right type or can at least be bitcast
- // to the right type.
- if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
- return failure();
-
- insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
-
- rewriter.modifyOpInPlace(
- store, [&]() { store.getValueMutable().assign(store.getValue()); });
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// CanonicalizeAlignedGep
//===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
: public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
void runOnOperation() override {
RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
- &getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
rewritePatterns.add<BitcastStores>(&getContext());
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 7be4056..6c5250d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> indexMap;
- for (auto const &[index, type] : *destructuredType)
- indexMap.insert({index, MemRefType::get({}, type)});
-
- return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+ return {
+ DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
}
DenseMap<Attribute, MemorySlot>