diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef/IR')
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 59 |
2 files changed, 61 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..507597b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3437,6 +3437,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// |