diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR')
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 19 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 25 |
5 files changed, 29 insertions, 26 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 1382c7ac..d358362 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp index 6ff63df..a1e3f10 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index dfa2e4e..5404238 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -61,15 +62,8 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape, // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static bool isSupportedElementType(Type type) { - return llvm::isa<MemRefType>(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; if (!type.hasStaticShape()) return {}; // Make sure the memref contains only a single element. @@ -81,16 +75,7 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() { Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch<Type, Value>(slot.elemType) - .Case([&](MemRefType t) { - return memref::AllocaOp::create(builder, getLoc(), t); - }) - .Default([&](Type t) { - return arith::ConstantOp::create(builder, getLoc(), t, - builder.getZeroAttr(t)); - }); + return ub::PoisonOp::create(builder, getLoc(), slot.elemType); } std::optional<PromotableAllocationOpInterface> diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c551fba..1035d7c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -1074,13 +1074,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) { return subview.getDynamicSize(sourceIndex); } - if (auto sizeInterface = - dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) { - assert(sizeInterface.isDynamicSize(unsignedIndex) && - "Expected dynamic subview size"); - return sizeInterface.getDynamicSize(unsignedIndex); - } - // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf89..69afbca 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ struct ExpandShapeOpInterface ValueBoundsConstraintSet &cstr) const { auto expandOp = cast<memref::ExpandShapeOp>(op); assert(value == expandOp.getResult() && "invalid value"); - cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim]; } }; @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface, + memref::CollapseShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast<memref::CollapseShapeOp>(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>( + *ctx); memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); |
