aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/MemRef/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR')
-rw-r--r--mlir/lib/Dialect/MemRef/IR/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp95
2 files changed, 96 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index e25a012..1382c7ac 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
- ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
+ ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR
DEPENDS
MLIRMemRefOpsIncGen
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
+ MLIRInferStridedMetadataInterface
MLIRInferTypeOpInterface
MLIRIR
MLIRMemOpInterfaces
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e9bdcda..94947b7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2158,11 +2158,45 @@ public:
return success();
}
};
+
+struct ReinterpretCastOpConstantFolder
+ : public OpRewritePattern<ReinterpretCastOp> {
+public:
+ using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReinterpretCastOp op,
+ PatternRewriter &rewriter) const override {
+ unsigned srcStaticCount = llvm::count_if(
+ llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
+ op.getMixedStrides()),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
+
+ SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
+ SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
+
+ // TODO: Using counting comparison instead of direct comparison because
+ // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
+ // IntegerAttrs, while constifyIndexValues (and therefore
+ // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
+ if (srcStaticCount ==
+ llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
+ [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
+ return failure();
+
+ auto newReinterpretCast = ReinterpretCastOp::create(
+ rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
+
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
+ return success();
+ }
+};
} // namespace
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
+ results.add<ReinterpretCastOpExtractStridedMetadataFolder,
+ ReinterpretCastOpConstantFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
@@ -3437,6 +3471,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
}
+void SubViewOp::inferStridedMetadataRanges(
+ ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
+ SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
+ auto isUninitialized =
+ +[](IntegerValueRange range) { return range.isUninitialized(); };
+
+ // Bail early if any of the operands metadata is not ready:
+ SmallVector<IntegerValueRange> offsetOperands =
+ getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
+ if (llvm::any_of(offsetOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> sizeOperands =
+ getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
+ if (llvm::any_of(sizeOperands, isUninitialized))
+ return;
+
+ SmallVector<IntegerValueRange> stridesOperands =
+ getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
+ if (llvm::any_of(stridesOperands, isUninitialized))
+ return;
+
+ StridedMetadataRange sourceRange =
+ ranges[getSourceMutable().getOperandNumber()];
+ if (sourceRange.isUninitialized())
+ return;
+
+ ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
+
+ // Get the dropped dims.
+ llvm::SmallBitVector droppedDims = getDroppedDims();
+
+ // Compute the new offset, strides and sizes.
+ ConstantIntRanges offset = sourceRange.getOffsets()[0];
+ SmallVector<ConstantIntRanges> strides, sizes;
+
+ for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
+ bool dropped = droppedDims.test(i);
+ // Compute the new offset.
+ ConstantIntRanges off =
+ intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
+ offset = intrange::inferAdd({offset, off});
+
+ // Skip dropped dimensions.
+ if (dropped)
+ continue;
+ // Multiply the strides.
+ strides.push_back(
+ intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
+ // Get the sizes.
+ sizes.push_back(sizeOperands[i].getValue());
+ }
+
+ setMetadata(getResult(),
+ StridedMetadataRange::getRanked(
+ SmallVector<ConstantIntRanges>({std::move(offset)}),
+ std::move(sizes), std::move(strides)));
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//