aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorChristian Ulmann <christianulmann@gmail.com>2024-03-20 14:21:53 +0100
committerGitHub <noreply@github.com>2024-03-20 14:21:53 +0100
commit98c6bc531d091215896087b94e4e047c67f892c2 (patch)
tree0b6e0adae180dac5b081b31195548c3075bfad33 /mlir/lib
parent3eb806373e3164b242db65f8c900e4adb5a2eddf (diff)
downloadllvm-98c6bc531d091215896087b94e4e047c67f892c2.zip
llvm-98c6bc531d091215896087b94e4e047c67f892c2.tar.gz
llvm-98c6bc531d091215896087b94e4e047c67f892c2.tar.bz2
[MLIR][SROA][Mem2Reg] Add data layout to interface methods (#85644)
This commit expends the Mem2Reg and SROA interface methods with passed in handles to a `DataLayout` structure. This is done to avoid superfluous retreiving of data layouts during each conversion of intrinsics. This change, additionally, enables subsequent changes to make the LLVM dialect implementation of these interfaces type agnostic.
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp153
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp18
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp24
-rw-r--r--mlir/lib/Transforms/SROA.cpp30
4 files changed, 140 insertions, 85 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 81766b3..00b4559 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -132,7 +132,8 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
bool LLVM::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
@@ -155,7 +156,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
bool LLVM::StoreOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
@@ -181,12 +183,14 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
}
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr || getType() == slot.elemType);
}
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
getValue().getType() == slot.elemType);
}
@@ -206,7 +210,8 @@ static bool forwardToUsers(Operation *op,
bool LLVM::BitcastOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return forwardToUsers(*this, newBlockingUses);
}
@@ -217,7 +222,8 @@ DeletionKind LLVM::BitcastOp::removeBlockingUses(
bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return forwardToUsers(*this, newBlockingUses);
}
@@ -228,7 +234,8 @@ DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
bool LLVM::LifetimeStartOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return true;
}
@@ -239,7 +246,8 @@ DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
bool LLVM::LifetimeEndOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return true;
}
@@ -250,7 +258,8 @@ DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
bool LLVM::InvariantStartOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return true;
}
@@ -261,7 +270,8 @@ DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
bool LLVM::InvariantEndOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return true;
}
@@ -272,7 +282,8 @@ DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
bool LLVM::DbgDeclareOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
return true;
}
@@ -283,7 +294,8 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
bool LLVM::DbgValueOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
// There is only one operand that we can remove the use of.
if (blockingUses.size() != 1)
return false;
@@ -318,7 +330,8 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
bool LLVM::GEPOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
// GEP can be removed as long as it is a no-op and its users can be removed.
if (!hasAllZeroIndices(*this))
return false;
@@ -337,7 +350,8 @@ static bool isFirstIndexZero(LLVM::GEPOp gep) {
}
LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
if (getBase() != slot.ptr)
return success();
if (slot.elemType != getElemType())
@@ -357,9 +371,9 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
- if (!basePtrType)
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
return false;
if (getBase() != slot.ptr || slot.elemType != getElemType())
@@ -386,7 +400,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
IntegerAttr firstLevelIndex =
llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
@@ -455,7 +470,7 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
/// of the bounds of the given slot, on a best-effort basis.
template <class MemIntr>
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
- DataLayout &dataLayout) {
+ const DataLayout &dataLayout) {
if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
op.getDst() != slot.ptr)
return false;
@@ -520,7 +535,8 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
bool LLVM::MemsetOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
@@ -534,8 +550,7 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
if (getIsVolatile())
return false;
- DataLayout layout = DataLayout::closest(*this);
- return getStaticMemIntrLen(*this) == layout.getTypeSize(slot.elemType);
+ return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
}
DeletionKind LLVM::MemsetOp::removeBlockingUses(
@@ -545,14 +560,15 @@ DeletionKind LLVM::MemsetOp::removeBlockingUses(
}
LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- DataLayout dataLayout = DataLayout::closest(*this);
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
}
bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
if (&slot.elemType.getDialect() != getOperation()->getDialect())
return false;
@@ -566,13 +582,13 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
if (!areAllIndicesI32(slot))
return false;
- DataLayout dataLayout = DataLayout::closest(*this);
return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
}
DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
@@ -587,7 +603,6 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
packed = structType.isPacked();
Type i32 = IntegerType::get(getContext(), 32);
- DataLayout dataLayout = DataLayout::closest(*this);
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
uint64_t covered = 0;
for (size_t i = 0; i < types->size(); i++) {
@@ -650,7 +665,8 @@ template <class MemcpyLike>
static bool
memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
// If source and destination are the same, memcpy behavior is undefined and
// memmove is a no-op. Because there is no memory change happening here,
// simplifying such operations is left to canonicalization.
@@ -660,8 +676,7 @@ memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
if (op.getIsVolatile())
return false;
- DataLayout layout = DataLayout::closest(op);
- return getStaticMemIntrLen(op) == layout.getTypeSize(slot.elemType);
+ return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
}
template <class MemcpyLike>
@@ -689,7 +704,8 @@ memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
template <class MemcpyLike>
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
if (op.getIsVolatile())
return false;
@@ -701,7 +717,6 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
return false;
// Only full copies are supported.
- DataLayout dataLayout = DataLayout::closest(op);
if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
return false;
@@ -741,15 +756,13 @@ void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
/// Rewires a memcpy-like operation. Only copies to or from the full slot are
/// supported.
template <class MemcpyLike>
-static DeletionKind memcpyRewire(MemcpyLike op,
- const DestructurableMemorySlot &slot,
- DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
+static DeletionKind
+memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
if (subslots.empty())
return DeletionKind::Delete;
- DataLayout layout = DataLayout::closest(op);
-
assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
bool isDst = slot.ptr == op.getDst();
@@ -780,7 +793,7 @@ static DeletionKind memcpyRewire(MemcpyLike op,
isDst ? op.getSrc() : op.getDst(), gepIndices);
// Then create a new memcpy out of this source pointer.
- createMemcpyLikeToReplace(rewriter, layout, op,
+ createMemcpyLikeToReplace(rewriter, dataLayout, op,
isDst ? subslot.ptr : subslotPtrInOther,
isDst ? subslotPtrInOther : subslot.ptr,
subslot.elemType, op.getIsVolatile());
@@ -806,8 +819,10 @@ Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
bool LLVM::MemcpyOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
- return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+ dataLayout);
}
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
@@ -818,20 +833,24 @@ DeletionKind LLVM::MemcpyOp::removeBlockingUses(
}
LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+ dataLayout);
}
DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
- return memcpyRewire(*this, slot, subslots, rewriter);
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
}
bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
@@ -849,8 +868,10 @@ Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
- return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+ dataLayout);
}
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
@@ -861,22 +882,26 @@ DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
}
LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemcpyInlineOp::canRewire(
const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+ dataLayout);
}
DeletionKind
LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
- return memcpyRewire(*this, slot, subslots, rewriter);
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
}
bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
@@ -894,8 +919,10 @@ Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
bool LLVM::MemmoveOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
- return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
+ return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+ dataLayout);
}
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
@@ -906,20 +933,24 @@ DeletionKind LLVM::MemmoveOp::removeBlockingUses(
}
LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
- const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+ dataLayout);
}
DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
- return memcpyRewire(*this, slot, subslots, rewriter);
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 561b861..7be4056 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -170,7 +170,8 @@ Value memref::LoadOp::getStored(const MemorySlot &slot,
bool memref::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
@@ -210,7 +211,8 @@ static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
if (slot.ptr != getMemRef())
return false;
Attribute index = getAttributeIndexFromIndexOperands(
@@ -223,7 +225,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
@@ -247,7 +250,8 @@ Value memref::StoreOp::getStored(const MemorySlot &slot,
bool memref::StoreOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ SmallVectorImpl<OpOperand *> &newBlockingUses,
+ const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
@@ -263,7 +267,8 @@ DeletionKind memref::StoreOp::removeBlockingUses(
bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
- SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
if (slot.ptr != getMemRef() || getValue() == slot.ptr)
return false;
Attribute index = getAttributeIndexFromIndexOperands(
@@ -276,7 +281,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
Attribute index = getAttributeIndexFromIndexOperands(
getContext(), getIndices(), getMemRefType());
const MemorySlot &memorySlot = subslots.at(index);
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 84ac69b..80e3b79 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Mem2Reg.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
@@ -117,8 +118,9 @@ struct MemorySlotPromotionInfo {
/// promotion. This does not mutate IR.
class MemorySlotPromotionAnalyzer {
public:
- MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
- : slot(slot), dominance(dominance) {}
+ MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
+ const DataLayout &dataLayout)
+ : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
/// Computes the information for slot promotion if promotion is possible,
/// returns nothing otherwise.
@@ -153,6 +155,7 @@ private:
MemorySlot slot;
DominanceInfo &dominance;
+ const DataLayout &dataLayout;
};
/// The MemorySlotPromoter handles the state of promoting a memory slot. It
@@ -267,10 +270,12 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
// If the operation decides it cannot deal with removing the blocking uses,
// promotion must fail.
if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
- if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
+ if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
+ dataLayout))
return failure();
} else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
- if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
+ if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
+ dataLayout))
return failure();
} else {
// An operation that has blocking uses must be promoted. If it is not
@@ -610,7 +615,8 @@ void MemorySlotPromoter::promoteSlot() {
LogicalResult mlir::tryToPromoteMemorySlots(
ArrayRef<PromotableAllocationOpInterface> allocators,
- RewriterBase &rewriter, Mem2RegStatistics statistics) {
+ RewriterBase &rewriter, const DataLayout &dataLayout,
+ Mem2RegStatistics statistics) {
bool promotedAny = false;
for (PromotableAllocationOpInterface allocator : allocators) {
@@ -619,7 +625,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
continue;
DominanceInfo dominance;
- MemorySlotPromotionAnalyzer analyzer(slot, dominance);
+ MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
MemorySlotPromoter(slot, allocator, rewriter, dominance,
@@ -661,8 +667,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
allocators.emplace_back(allocator);
});
+ auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
+
// Attempt promoting until no promotion succeeds.
- if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+ if (failed(tryToPromoteMemorySlots(allocators, rewriter, dataLayout,
+ statistics)))
break;
changed = true;
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 6111489..f24cbb7 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/SROA.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
@@ -42,7 +43,8 @@ struct MemorySlotDestructuringInfo {
/// nothing if the slot cannot be destructured or if there is no useful work to
/// be done.
static std::optional<MemorySlotDestructuringInfo>
-computeDestructuringInfo(DestructurableMemorySlot &slot) {
+computeDestructuringInfo(DestructurableMemorySlot &slot,
+ const DataLayout &dataLayout) {
assert(isa<DestructurableTypeInterface>(slot.elemType));
if (slot.ptr.use_empty())
@@ -62,7 +64,8 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
for (OpOperand &use : slot.ptr.getUses()) {
if (auto accessor =
dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
- if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) {
+ if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
+ dataLayout)) {
info.accessors.push_back(accessor);
continue;
}
@@ -82,8 +85,8 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
Operation *subslotUser = subslotUse.getOwner();
if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
- if (succeeded(memOp.ensureOnlySafeAccesses(mustBeUsedSafely,
- usedSafelyWorklist)))
+ if (succeeded(memOp.ensureOnlySafeAccesses(
+ mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
continue;
// If it cannot be shown that the operation uses the slot safely, maybe it
@@ -110,7 +113,7 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
SmallVector<OpOperand *> newBlockingUses;
// If the operation decides it cannot deal with removing the blocking uses,
// destructuring must fail.
- if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
+ if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
return {};
// Then, register any new blocking uses for coming operations.
@@ -132,6 +135,7 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
static void destructureSlot(DestructurableMemorySlot &slot,
DestructurableAllocationOpInterface allocator,
RewriterBase &rewriter,
+ const DataLayout &dataLayout,
MemorySlotDestructuringInfo &info,
const SROAStatistics &statistics) {
RewriterBase::InsertionGuard guard(rewriter);
@@ -158,7 +162,8 @@ static void destructureSlot(DestructurableMemorySlot &slot,
for (Operation *toRewire : llvm::reverse(usersToRewire)) {
rewriter.setInsertionPointAfter(toRewire);
if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
- if (accessor.rewire(slot, subslots, rewriter) == DeletionKind::Delete)
+ if (accessor.rewire(slot, subslots, rewriter, dataLayout) ==
+ DeletionKind::Delete)
toErase.push_back(accessor);
continue;
}
@@ -186,17 +191,18 @@ static void destructureSlot(DestructurableMemorySlot &slot,
LogicalResult mlir::tryToDestructureMemorySlots(
ArrayRef<DestructurableAllocationOpInterface> allocators,
- RewriterBase &rewriter, SROAStatistics statistics) {
+ RewriterBase &rewriter, const DataLayout &dataLayout,
+ SROAStatistics statistics) {
bool destructuredAny = false;
for (DestructurableAllocationOpInterface allocator : allocators) {
for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
std::optional<MemorySlotDestructuringInfo> info =
- computeDestructuringInfo(slot);
+ computeDestructuringInfo(slot, dataLayout);
if (!info)
continue;
- destructureSlot(slot, allocator, rewriter, *info, statistics);
+ destructureSlot(slot, allocator, rewriter, dataLayout, *info, statistics);
destructuredAny = true;
}
}
@@ -215,6 +221,8 @@ struct SROA : public impl::SROABase<SROA> {
SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
&maxSubelementAmount};
+ auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
bool changed = false;
for (Region &region : scopeOp->getRegions()) {
@@ -235,8 +243,8 @@ struct SROA : public impl::SROABase<SROA> {
allocators.emplace_back(allocator);
});
- if (failed(
- tryToDestructureMemorySlots(allocators, rewriter, statistics)))
+ if (failed(tryToDestructureMemorySlots(allocators, rewriter, dataLayout,
+ statistics)))
break;
changed = true;