diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef')
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 95 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp | 2 |
3 files changed, 97 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index e25a012..1382c7ac 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRMemRefDialect ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/IR DEPENDS MLIRMemRefOpsIncGen @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRDialectUtils MLIRInferIntRangeCommon MLIRInferIntRangeInterface + MLIRInferStridedMetadataInterface MLIRInferTypeOpInterface MLIRIR MLIRMemOpInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda..94947b7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,45 @@ public: return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + unsigned srcStaticCount = llvm::count_if( + llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); }); + + SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()}; + SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes(); + SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides(); + + // TODO: Using counting comparison instead of direct comparison because + // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns + // IntegerAttrs, while constifyIndexValues (and therefore + // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs. + if (srcStaticCount == + llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides), + [](OpFoldResult ofr) { return isa<Attribute>(ofr); })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> @@ -3437,6 +3471,65 @@ SubViewOp::bubbleDownCasts(OpBuilder &builder) { return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); } +void SubViewOp::inferStridedMetadataRanges( + ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange, + SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) { + auto isUninitialized = + +[](IntegerValueRange range) { return range.isUninitialized(); }; + + // Bail early if any of the operands metadata is not ready: + SmallVector<IntegerValueRange> offsetOperands = + getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth); + if (llvm::any_of(offsetOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> sizeOperands = + getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth); + if (llvm::any_of(sizeOperands, isUninitialized)) + return; + + SmallVector<IntegerValueRange> stridesOperands = + getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth); + if (llvm::any_of(stridesOperands, isUninitialized)) + return; + + StridedMetadataRange sourceRange = + ranges[getSourceMutable().getOperandNumber()]; + if (sourceRange.isUninitialized()) + return; + + ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides(); + + // Get the dropped dims. + llvm::SmallBitVector droppedDims = getDroppedDims(); + + // Compute the new offset, strides and sizes. + ConstantIntRanges offset = sourceRange.getOffsets()[0]; + SmallVector<ConstantIntRanges> strides, sizes; + + for (size_t i = 0, e = droppedDims.size(); i < e; ++i) { + bool dropped = droppedDims.test(i); + // Compute the new offset. + ConstantIntRanges off = + intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]}); + offset = intrange::inferAdd({offset, off}); + + // Skip dropped dimensions. + if (dropped) + continue; + // Multiply the strides. + strides.push_back( + intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]})); + // Get the sizes. + sizes.push_back(sizeOperands[i].getValue()); + } + + setMetadata(getResult(), + StridedMetadataRange::getRanked( + SmallVector<ConstantIntRanges>({std::move(offset)}), + std::move(sizes), std::move(strides))); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 49b7162..6f815ae 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -121,7 +121,7 @@ struct EmulateWideIntPass final [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); - // Add common pattenrs to support contants, functions, etc. + // Add common patterns to support contants, functions, etc. arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); |