aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp')
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp11
1 files changed, 9 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b7..c551fba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
@@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,