aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/MemRef/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR')
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp1
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp19
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp9
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp25
5 files changed, 29 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 1382c7ac..d358362 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 6ff63df..a1e3f10 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index dfa2e4e..5404238 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
// Interfaces for AllocaOp
//===----------------------------------------------------------------------===//
-static bool isSupportedElementType(Type type) {
- return llvm::isa<MemRefType>(type) ||
- OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
MemRefType type = getType();
- if (!isSupportedElementType(type.getElementType()))
- return {};
if (!type.hasStaticShape())
return {};
// Make sure the memref contains only a single element.
@@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
- assert(isSupportedElementType(slot.elemType));
- // TODO: support more types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](MemRefType t) {
- return memref::AllocaOp::create(builder, getLoc(), t);
- })
- .Default([&](Type t) {
- return arith::ConstantOp::create(builder, getLoc(), t,
- builder.getZeroAttr(t));
- });
+ return ub::PoisonOp::create(builder, getLoc(), slot.elemType);
}
std::optional<PromotableAllocationOpInterface>
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c551fba..1035d7c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return subview.getDynamicSize(sourceIndex);
}
- if (auto sizeInterface =
- dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
- assert(sizeInterface.isDynamicSize(unsignedIndex) &&
- "Expected dynamic subview size");
- return sizeInterface.getDynamicSize(unsignedIndex);
- }
-
// dim(memrefcast) -> dim
if (succeeded(foldMemRefCast(*this)))
return getResult();
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf89..69afbca 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
ValueBoundsConstraintSet &cstr) const {
auto expandOp = cast<memref::ExpandShapeOp>(op);
assert(value == expandOp.getResult() && "invalid value");
- cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
}
};
@@ -98,6 +98,27 @@ struct RankOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ memref::CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+ *ctx);
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);