diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 52 |
1 files changed, 49 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index cc6314c..a6f816a 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -465,6 +465,51 @@ struct AssumeAlignmentOpLowering } }; +struct DistinctObjectsOpLowering + : public ConvertOpToLLVMPattern<memref::DistinctObjectsOp> { + using ConvertOpToLLVMPattern< + memref::DistinctObjectsOp>::ConvertOpToLLVMPattern; + explicit DistinctObjectsOpLowering(const LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern<memref::DistinctObjectsOp>(converter) {} + + LogicalResult + matchAndRewrite(memref::DistinctObjectsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + if (operands.size() <= 1) { + // Fast path. + rewriter.replaceOp(op, operands); + return success(); + } + + Location loc = op.getLoc(); + SmallVector<Value> ptrs; + for (auto [origOperand, newOperand] : + llvm::zip_equal(op.getOperands(), operands)) { + auto memrefType = cast<MemRefType>(origOperand.getType()); + MemRefDescriptor memRefDescriptor(newOperand); + Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), + memrefType); + ptrs.push_back(ptr); + } + + auto cond = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), 1); + // Generate separate_storage assumptions for each pair of pointers. + for (auto i : llvm::seq<size_t>(ptrs.size() - 1)) { + for (auto j : llvm::seq<size_t>(i + 1, ptrs.size())) { + Value ptr1 = ptrs[i]; + Value ptr2 = ptrs[j]; + LLVM::AssumeOp::create(rewriter, loc, cond, + LLVM::AssumeSeparateStorageTag{}, ptr1, ptr2); + } + } + + rewriter.replaceOp(op, operands); + return success(); + } +}; + // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. @@ -1997,22 +2042,23 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns( patterns.add< AllocaOpLowering, AllocaScopeOpLowering, - AtomicRMWOpLowering, AssumeAlignmentOpLowering, + AtomicRMWOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + DistinctObjectsOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, + MemorySpaceCastOpLowering, PrefetchOpLowering, RankOpLowering, - ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, + ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, |