//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/MathExtras.h" #include #define DEBUG_TYPE "memref-to-llvm" #define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags = LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw; namespace { static bool isStaticStrideOrOffset(int64_t strideOrOffset) { return ShapedType::isStatic(strideOrOffset); } static FailureOr getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module, SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables); return LLVM::lookupOrCreateFreeFn(b, module, symbolTables); } static FailureOr getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, Operation *module, Type indexType, SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType, symbolTables); return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables); } static FailureOr getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, Operation *module, Type indexType, SymbolTableCollection *symbolTables) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType, symbolTables); return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables); } /// Computes the aligned value for 'input' as follows: /// bumped = input + alignement - 1 /// aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(), rewriter.getIndexAttr(1)); Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one); Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump); Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment); return LLVM::SubOp::create(rewriter, loc, bumped, mod); } /// Computes the byte size for the MemRef element type. static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter, MemRefType memRefType, Operation *op, const DataLayout *defaultLayout) { const DataLayout *layout = defaultLayout; if (const DataLayoutAnalysis *analysis = typeConverter->getDataLayoutAnalysis()) { layout = &analysis->getAbove(op); } Type elementType = memRefType.getElementType(); if (auto memRefElementType = dyn_cast(elementType)) return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout); if (auto memRefElementType = dyn_cast(elementType)) return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType, *layout); return layout->getTypeSize(elementType); } static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, Location loc, Value allocatedPtr, MemRefType memRefType, Type elementPtrType, const LLVMTypeConverter &typeConverter) { auto allocatedPtrTy = cast(allocatedPtr.getType()); FailureOr maybeMemrefAddrSpace = typeConverter.getMemRefAddressSpace(memRefType); assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space"); unsigned memrefAddrSpace = *maybeMemrefAddrSpace; if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace) allocatedPtr = LLVM::AddrSpaceCastOp::create( rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace), allocatedPtr); return allocatedPtr; } class AllocOpLowering : public ConvertOpToLLVMPattern { SymbolTableCollection *symbolTables = nullptr; public: explicit AllocOpLowering(const LLVMTypeConverter &typeConverter, SymbolTableCollection *symbolTables = nullptr, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MemRefType memRefType = op.getType(); if (!isConvertibleAndHasIdentityMaps(memRefType)) return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get or insert alloc function into the module. FailureOr allocFuncOp = getNotalignedAllocFn(rewriter, getTypeConverter(), op->getParentWithTrait(), getIndexType(), symbolTables); if (failed(allocFuncOp)) return failure(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value sizeBytes; this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides, sizeBytes, true); Value alignment = getAlignment(rewriter, loc, op); if (alignment) { // Adjust the allocation size to consider alignment. sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment); } // Allocate the underlying buffer. Type elementPtrType = this->getElementPtrType(memRefType); assert(elementPtrType && "could not compute element ptr type"); auto results = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, elementPtrType, *getTypeConverter()); Value alignedPtr = allocatedPtr; if (alignment) { // Compute the aligned pointer. Value allocatedInt = LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt); } // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); return success(); } /// Computes the alignment for the given memory allocation op. template Value getAlignment(ConversionPatternRewriter &rewriter, Location loc, OpType op) const { MemRefType memRefType = op.getType(); Value alignment; if (auto alignmentAttr = op.getAlignment()) { Type indexType = getIndexType(); alignment = createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the // biggest scalar on a target HW. For non-scalars, use the natural // alignment of the LLVM type given by the LLVM DataLayout. alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); } return alignment; } }; class AlignedAllocOpLowering : public ConvertOpToLLVMPattern { SymbolTableCollection *symbolTables = nullptr; public: explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter, SymbolTableCollection *symbolTables = nullptr, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MemRefType memRefType = op.getType(); if (!isConvertibleAndHasIdentityMaps(memRefType)) return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get or insert alloc function into module. FailureOr allocFuncOp = getAlignedAllocFn(rewriter, getTypeConverter(), op->getParentWithTrait(), getIndexType(), symbolTables); if (failed(allocFuncOp)) return failure(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value sizeBytes; this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides, sizeBytes, !false); int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout); Value allocAlignment = createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); // Function aligned_alloc requires size to be a multiple of alignment; we // pad the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout)) sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); auto results = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); Value ptr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, elementPtrType, *getTypeConverter()); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, ptr, ptr, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); return success(); } /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; /// Computes the alignment for aligned_alloc used to allocate the buffer for /// the memory allocation op. /// /// Aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of the alignment. int64_t alignedAllocationGetAlignment(memref::AllocOp op, const DataLayout *defaultLayout) const { if (std::optional alignment = op.getAlignment()) return *alignment; // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it isn't. unsigned eltSizeBytes = getMemRefEltSizeInBytes( getTypeConverter(), op.getType(), op, defaultLayout); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor. bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op, const DataLayout *defaultLayout) const { uint64_t sizeDivisor = getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamicDim(i)) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } private: /// Default layout to use in absence of the corresponding analysis. DataLayout defaultLayout; }; struct AllocaOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). LogicalResult matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MemRefType memRefType = op.getType(); if (!isConvertibleAndHasIdentityMaps(memRefType)) return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value size; this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides, size, !true); // With alloca, one gets a pointer to the element type right away. // For stack allocations. auto elementType = typeConverter->convertType(op.getType().getElementType()); FailureOr maybeAddressSpace = getTypeConverter()->getMemRefAddressSpace(op.getType()); assert(succeeded(maybeAddressSpace) && "unsupported address space"); unsigned addrSpace = *maybeAddressSpace; auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); auto allocatedElementPtr = LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size, op.getAlignment().value_or(0)); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); return success(); } }; struct AllocaScopeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { OpBuilder::InsertionGuard guard(rewriter); Location loc = allocaScopeOp.getLoc(); // Split the current block before the AllocaScopeOp to create the inlining // point. auto *currentBlock = rewriter.getInsertionBlock(); auto *remainingOpsBlock = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *continueBlock; if (allocaScopeOp.getNumResults() == 0) { continueBlock = remainingOpsBlock; } else { continueBlock = rewriter.createBlock( remainingOpsBlock, allocaScopeOp.getResultTypes(), SmallVector(allocaScopeOp->getNumResults(), allocaScopeOp.getLoc())); LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock); } // Inline body region. Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); Block *afterBody = &allocaScopeOp.getBodyRegion().back(); rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody); // Replace the alloca_scope return with a branch that jumps out of the body. // Stack restore before leaving the body region. rewriter.setInsertionPointToEnd(afterBody); auto returnOp = cast(afterBody->getTerminator()); auto branchOp = rewriter.replaceOpWithNewOp( returnOp, returnOp.getResults(), continueBlock); // Insert stack restore before jumping out the body of the region. rewriter.setInsertionPoint(branchOp); LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); // Replace the op with values return from the body region. rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value memref = adaptor.getMemref(); unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); auto srcMemRefType = cast(op.getMemref().getType()); Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref, /*indices=*/{}); // Emit llvm.assume(true) ["align"(memref, alignment)]. // This is more direct than ptrtoint-based checks, is explicitly supported, // and works with non-integral address spaces. Value trueCond = LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); Value alignmentConst = createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr, alignmentConst); rewriter.replaceOp(op, memref); return success(); } }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. class DeallocOpLowering : public ConvertOpToLLVMPattern { SymbolTableCollection *symbolTables = nullptr; public: explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter, SymbolTableCollection *symbolTables = nullptr, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. FailureOr freeFunc = getFreeFn(rewriter, getTypeConverter(), op->getParentOfType(), symbolTables); if (failed(freeFunc)) return failure(); Value allocatedPtr; if (auto unrankedTy = llvm::dyn_cast(op.getMemref().getType())) { auto elementPtrTy = LLVM::LLVMPointerType::get( rewriter.getContext(), unrankedTy.getMemorySpaceAsInt()); allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( rewriter, op.getLoc(), UnrankedMemRefDescriptor(adaptor.getMemref()) .memRefDescPtr(rewriter, op.getLoc()), elementPtrTy); } else { allocatedPtr = MemRefDescriptor(adaptor.getMemref()) .allocatedPtr(rewriter, op.getLoc()); } rewriter.replaceOpWithNewOp(op, freeFunc.value(), allocatedPtr); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); if (isa(operandType)) { FailureOr extractedSize = extractSizeOfUnrankedMemRef( operandType, dimOp, adaptor.getOperands(), rewriter); if (failed(extractedSize)) return failure(); rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } if (isa(operandType)) { rewriter.replaceOp( dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, adaptor.getOperands(), rewriter)}); return success(); } llvm_unreachable("expected MemRefType or UnrankedMemRefType"); } private: FailureOr extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); auto unrankedMemRefType = cast(operandType); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); FailureOr maybeAddressSpace = getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType); if (failed(maybeAddressSpace)) { dimOp.emitOpError("memref memory space must be convertible to an integer " "address space"); return failure(); } unsigned addressSpace = *maybeAddressSpace; // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Type elementType = typeConverter->convertType(scalarMemRefType); // Get pointer to offset field of memref descriptor. auto indexPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); Value offsetPtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType, underlyingRankedDesc, ArrayRef{0, 2}); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = LLVM::AddOp::create( rewriter, loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), adaptor.getIndex()); Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); return LLVM::LoadOp::create(rewriter, loc, getTypeConverter()->getIndexType(), sizePtr) .getResult(); } std::optional getConstantDimIndex(memref::DimOp dimOp) const { if (auto idx = dimOp.getConstantIndex()) return idx; if (auto constantOp = dimOp.getIndex().getDefiningOp()) return cast(constantOp.getValue()).getValue().getSExtValue(); return std::nullopt; } Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); // Take advantage if index is constant. MemRefType memRefType = cast(operandType); Type indexType = getIndexType(); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(adaptor.getSource()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexAttrConstant(rewriter, loc, indexType, dimSize); } } Value index = adaptor.getIndex(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(adaptor.getSource()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; /// Common base for load and store operations on MemRefs. Restricts the match /// to supported MemRef types. Provides functionality to emit code accessing a /// specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern::isConvertibleAndHasIdentityMaps; using Base = LoadStoreOpLowering; }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | cf.br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cf.cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct GenericAtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp)); loopBlock->addArgument(valueType, loc); auto *endBlock = rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = cast(atomicOp.getMemref().getType()); auto dataPtr = getStridedElementPtr( rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices()); Value init = LLVM::LoadOp::create( rewriter, loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); LLVM::BrOp::create(rewriter, loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); // Clone the GenericAtomicRMWOp region and extract the result. auto loopArgument = loopBlock->getArgument(0); IRMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); for (auto &nestedOp : entryBlock.without_terminator()) { Operation *clone = rewriter.clone(nestedOp, mapping); mapping.map(nestedOp.getResults(), clone->getResults()); } Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto cmpxchg = LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0); Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1); // Conditionally branch to the end or back to the loop depending on %ok. LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(atomicOp, {newLoaded}); return success(); } }; /// Returns the LLVM type of the global variable given the memref type `type`. static Type convertGlobalMemrefTypeToLLVM(MemRefType type, const LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for memref.global's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. Type elementType = typeConverter.convertType(type.getElementType()); Type arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); return arrayTy; } /// GlobalMemrefOp is lowered to a LLVM Global Variable. class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { SymbolTableCollection *symbolTables = nullptr; public: explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter, SymbolTableCollection *symbolTables = nullptr, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), symbolTables(symbolTables) {} LogicalResult matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MemRefType type = global.getType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; bool isExternal = global.isExternal(); bool isUninitialized = global.isUninitialized(); Attribute initialValue = nullptr; if (!isExternal && !isUninitialized) { auto elementsAttr = llvm::cast(*global.getInitialValue()); initialValue = elementsAttr; // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) initialValue = elementsAttr.getSplatValue(); } uint64_t alignment = global.getAlignment().value_or(0); FailureOr addressSpace = getTypeConverter()->getMemRefAddressSpace(type); if (failed(addressSpace)) return global.emitOpError( "memory space cannot be converted to an integer address space"); // Remove old operation from symbol table. SymbolTable *symbolTable = nullptr; if (symbolTables) { Operation *symbolTableOp = global->getParentWithTrait(); symbolTable = &symbolTables->getSymbolTable(symbolTableOp); symbolTable->remove(global); } // Create new operation. auto newGlobal = rewriter.replaceOpWithNewOp( global, arrayTy, global.getConstant(), linkage, global.getSymName(), initialValue, alignment, *addressSpace); // Insert new operation into symbol table. if (symbolTable) symbolTable->insert(newGlobal, rewriter.getInsertionPoint()); if (!isExternal && isUninitialized) { rewriter.createBlock(&newGlobal.getInitializerRegion()); Value undef[] = { LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)}; LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef); } return success(); } }; /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. struct GetGlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Buffer "allocation" for memref.get_global op is getting the address of /// the global variable referenced. LogicalResult matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MemRefType memRefType = op.getType(); if (!isConvertibleAndHasIdentityMaps(memRefType)) return rewriter.notifyMatchFailure(op, "incompatible memref type"); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value sizeBytes; this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(), rewriter, sizes, strides, sizeBytes, !false); MemRefType type = cast(op.getResult().getType()); // This is called after a type conversion, which would have failed if this // call fails. FailureOr maybeAddressSpace = getTypeConverter()->getMemRefAddressSpace(type); assert(succeeded(maybeAddressSpace) && "unsupported address space"); unsigned memSpace = *maybeAddressSpace; Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); auto addressOf = LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. auto gep = LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf, SmallVector(type.getRank() + 1, 0)); // We do not expect the memref obtained using `memref.get_global` to be // ever deallocated. Set the allocated pointer to be known bad value to // help debug if that ever happens. auto intPtrType = getIntPtrType(memSpace); Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); return success(); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); // Per memref.load spec, the indices must be in-bounds: // 0 <= idx < dim_size, and additionally all offsets are non-negative, // hence inbounds and nuw are used when lowering to llvm.getelementptr. Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp( loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false, loadOp.getNontemporal()); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); // Per memref.store spec, the indices must be in-bounds: // 0 <= idx < dim_size, and additionally all offsets are non-negative, // hence inbounds and nuw are used when lowering to llvm.getelementptr. Value dataPtr = getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, 0, false, op.getNontemporal()); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); Value dataPtr = getStridedElementPtr( rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices()); // Replace with llvm.prefetch. IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); IntegerAttr localityHint = prefetchOp.getLocalityHintAttr(); IntegerAttr isData = rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()); rewriter.replaceOpWithNewOp(prefetchOp, dataPtr, isWrite, localityHint, isData); return success(); } }; struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); if (isa(operandType)) { UnrankedMemRefDescriptor desc(adaptor.getMemref()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } if (auto rankedMemRefType = dyn_cast(operandType)) { Type indexType = getIndexType(); rewriter.replaceOp(op, {createIndexAttrConstant(rewriter, loc, indexType, rankedMemRefType.getRank())}); return success(); } return failure(); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); // memref::CastOp reduce to bitcast in the ranked MemRef case and can be // used for type erasure. For now they must preserve underlying element type // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (isa(srcType) && isa(dstType)) if (typeConverter->convertType(srcType) != typeConverter->convertType(dstType)) return failure(); // Unranked to unranked cast is disallowed if (isa(srcType) && isa(dstType)) return failure(); auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. if (isa(srcType) && isa(dstType)) { rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); return success(); } if (isa(srcType) && isa(dstType)) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = cast(srcType); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( loc, adaptor.getSource(), rewriter); // rank = ConstantOp srcRank auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(rank)); // poison = PoisonOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType); // d1 = InsertValueOp poison, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, ptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); } else if (isa(srcType) && isa(dstType)) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // struct = LoadOp ptr auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } return success(); } }; /// Pattern to lower a `memref.copy` to llvm. /// /// For memrefs with identity layouts, the copy is lowered to the llvm /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call /// to the generic `MemrefCopyFn`. class MemRefCopyOpLowering : public ConvertOpToLLVMPattern { SymbolTableCollection *symbolTables = nullptr; public: explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter, SymbolTableCollection *symbolTables = nullptr, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), symbolTables(symbolTables) {} LogicalResult lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = dyn_cast(op.getSource().getType()); MemRefDescriptor srcDesc(adaptor.getSource()); // Compute number of elements. Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1)); for (int pos = 0; pos < srcType.getRank(); ++pos) { auto size = srcDesc.size(rewriter, loc, pos); numElements = LLVM::MulOp::create(rewriter, loc, numElements, size); } // Get element size. auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); // Compute total. Value totalSize = LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes); Type elementType = typeConverter->convertType(srcType.getElementType()); Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc); Value srcOffset = srcDesc.offset(rewriter, loc); Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); MemRefDescriptor targetDesc(adaptor.getTarget()); Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc); Value targetOffset = targetDesc.offset(rewriter, loc); Value targetPtr = LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize, /*isVolatile=*/false); rewriter.eraseOp(op); return success(); } LogicalResult lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = cast(op.getSource().getType()); auto targetType = cast(op.getTarget().getType()); // First make sure we have an unranked memref descriptor representation. auto makeUnranked = [&, this](Value ranked, MemRefType type) { auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), type.getRank()); auto *typeConverter = getTypeConverter(); auto ptr = typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); auto unrankedType = UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); return UnrankedMemRefDescriptor::pack( rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); }; // Save stack position before promoting descriptors auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType()); auto srcMemRefType = dyn_cast(srcType); Value unrankedSource = srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) : adaptor.getSource(); auto targetMemRefType = dyn_cast(targetType); Value unrankedTarget = targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) : adaptor.getTarget(); // Now promote the unranked descriptors to the stack. auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), rewriter.getIndexAttr(1)); auto promote = [&](Value desc) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); auto allocated = LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one); LLVM::StoreOp::create(rewriter, loc, desc, allocated); return allocated; }; auto sourcePtr = promote(unrankedSource); auto targetPtr = promote(unrankedTarget); // Derive size from llvm.getelementptr which will account for any // potential alignment auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( rewriter, op->getParentOfType(), getIndexType(), sourcePtr.getType(), symbolTables); if (failed(copyFn)) return failure(); LLVM::CallOp::create(rewriter, loc, copyFn.value(), ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp); rewriter.eraseOp(op); return success(); } LogicalResult matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = cast(op.getSource().getType()); auto targetType = cast(op.getTarget().getType()); auto isContiguousMemrefType = [&](BaseMemRefType type) { auto memrefType = dyn_cast(type); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a // special case handled by memrefCopy. return memrefType && (memrefType.getLayout().isIdentity() || (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && memref::isStaticShapeAndContiguousRowMajor(memrefType))); }; if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) return lowerToMemCopyIntrinsic(op, adaptor, rewriter); return lowerToMemCopyFunctionCall(op, adaptor, rewriter); } }; struct MemorySpaceCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = op.getDest().getType(); if (auto resultTypeR = dyn_cast(resultType)) { auto resultDescType = cast(typeConverter->convertType(resultTypeR)); Type newPtrType = resultDescType.getBody()[0]; SmallVector descVals; MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, descVals); descVals[0] = LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]); descVals[1] = LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]); Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), resultTypeR, descVals); rewriter.replaceOp(op, result); return success(); } if (auto resultTypeU = dyn_cast(resultType)) { // Since the type converter won't be doing this for us, get the address // space. auto sourceType = cast(op.getSource().getType()); FailureOr maybeSourceAddrSpace = getTypeConverter()->getMemRefAddressSpace(sourceType); if (failed(maybeSourceAddrSpace)) return rewriter.notifyMatchFailure(loc, "non-integer source address space"); unsigned sourceAddrSpace = *maybeSourceAddrSpace; FailureOr maybeResultAddrSpace = getTypeConverter()->getMemRefAddressSpace(resultTypeU); if (failed(maybeResultAddrSpace)) return rewriter.notifyMatchFailure(loc, "non-integer result address space"); unsigned resultAddrSpace = *maybeResultAddrSpace; UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); Value rank = sourceDesc.rank(rewriter, loc); Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); // Create and allocate storage for new memref descriptor. auto result = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(resultTypeU)); result.setRank(rewriter, loc, rank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), result, resultAddrSpace, sizes); Value resultUnderlyingSize = sizes.front(); Value resultUnderlyingDesc = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); // Copy pointers, performing address space casts. auto sourceElemPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace); auto resultElemPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace); Value allocatedPtr = sourceDesc.allocatedPtr( rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType); Value alignedPtr = sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(), sourceUnderlyingDesc, sourceElemPtrType); allocatedPtr = LLVM::AddrSpaceCastOp::create( rewriter, loc, resultElemPtrType, allocatedPtr); alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc, resultElemPtrType, alignedPtr); result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, resultElemPtrType, allocatedPtr); result.setAlignedPtr(rewriter, loc, *getTypeConverter(), resultUnderlyingDesc, resultElemPtrType, alignedPtr); // Copy all the index-valued operands. Value sourceIndexVals = sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(), sourceUnderlyingDesc, sourceElemPtrType); Value resultIndexVals = result.offsetBasePtr(rewriter, loc, *getTypeConverter(), resultUnderlyingDesc, resultElemPtrType); int64_t bytesToSkip = 2 * llvm::divideCeil( getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); Value bytesToSkipConst = LLVM::ConstantOp::create( rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); Value copySize = LLVM::SubOp::create(rewriter, loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals, copySize, /*isVolatile=*/false); rewriter.replaceOp(op, ValueRange{result}); return success(); } return rewriter.notifyMatchFailure(loc, "unexpected memref type"); } }; /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Value originalOperand, Value convertedOperand, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); if (isa(operandType)) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); if (offset != nullptr) *offset = desc.offset(rewriter, loc); return; } // These will all cause assert()s on unconvertible types. unsigned memorySpace = *typeConverter.getMemRefAddressSpace( cast(operandType)); auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. UnrankedMemRefDescriptor unrankedDesc(convertedOperand); Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( rewriter, loc, underlyingDescPtr, elementPtrType); *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); if (offset != nullptr) { *offset = UnrankedMemRefDescriptor::offset( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType); } } struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = castOp.getSource().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(castOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor( ConversionPatternRewriter &rewriter, Type srcType, memref::ReinterpretCastOp castOp, memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = cast(castOp.getResult().getType()); auto llvmTargetDescriptorTy = dyn_cast_or_null( typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); // Create descriptor. Location loc = castOp.getLoc(); auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), castOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Set offset. if (castOp.isDynamicOffset(0)) desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); // Set sizes and strides. unsigned dynSizeId = 0; unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } *descriptor = desc; return success(); } }; struct MemRefReshapeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = reshapeOp.getSource().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(reshapeOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { auto shapeMemRefType = cast(reshapeOp.getShape().getType()); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = cast(reshapeOp.getResult().getType()); auto llvmTargetDescriptorTy = dyn_cast_or_null( typeConverter->convertType(targetMemRefType)); if (!llvmTargetDescriptorTy) return failure(); // Create descriptor. Location loc = reshapeOp.getLoc(); auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), reshapeOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; if (failed(targetMemRefType.getStridesAndOffset(strides, offset))) return rewriter.notifyMatchFailure( reshapeOp, "failed to get stride and offset exprs"); if (!isStaticStrideOrOffset(offset)) return rewriter.notifyMatchFailure(reshapeOp, "dynamic offset is unsupported"); desc.setConstantOffset(rewriter, loc, offset); assert(targetMemRefType.getLayout().isIdentity() && "Identity layout map is a precondition of a valid reshape op"); Type indexType = getIndexType(); Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { if (ShapedType::isStatic(strides[i])) { // If the stride for this dimension is dynamic, then use the product // of the sizes of the inner dimensions. stride = createIndexAttrConstant(rewriter, loc, indexType, strides[i]); } else if (!stride) { // `stride` is null only in the first iteration of the loop. However, // since the target memref has an identity layout, we can safely set // the innermost stride to 1. stride = createIndexAttrConstant(rewriter, loc, indexType, 1); } Value dimSize; // If the size of this dimension is dynamic, then load it at runtime // from the shape operand. if (!targetMemRefType.isDynamicDim(i)) { dimSize = createIndexAttrConstant(rewriter, loc, indexType, targetMemRefType.getDimSize(i)); } else { Value shapeOp = reshapeOp.getShape(); Value index = createIndexAttrConstant(rewriter, loc, indexType, i); dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) dimSize = typeConverter->materializeTargetConversion( rewriter, loc, indexType, dimSize); assert(dimSize && "Invalid memref element type"); } desc.setSize(rewriter, loc, i, dimSize); desc.setStride(rewriter, loc, i, stride); // Prepare the stride value for the next dimension. stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize); } *descriptor = desc; return success(); } // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); MemRefDescriptor shapeDesc(adaptor.getShape()); Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. auto targetType = cast(reshapeOp.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, addressSpace, sizes); Value underlyingDescPtr = LLVM::AllocaOp::create( rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), sizes.front()); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), reshapeOp.getSource(), adaptor.getSource(), &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType, alignedPtr); UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); Value resultRankMinusOne = LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); Type indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, {indexType, indexType}, {loc, loc}); // Move the remaining initBlock ops to condBlock. Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt); rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange()); rewriter.setInsertionPointToEnd(initBlock); LLVM::BrOp::create(rewriter, loc, ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); Value pred = LLVM::ICmpOp::create( rewriter, loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); Value sizeLoadGep = LLVM::GEPOp::create( rewriter, loc, llvmIndexPtrType, typeConverter->convertType(shapeMemRefType.getElementType()), shapeOperandPtr, indexArg); Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size); // Decrement loop counter and branch back. Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex); LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}), condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(), remainder, ValueRange()); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); *descriptor = targetDesc; return success(); } }; /// RessociatingReshapeOp must be expanded before we reach this stage. /// Report that information. template class ReassociatingReshapeOpConversion : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; LogicalResult matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { return rewriter.notifyMatchFailure( reshapeOp, "reassociation operations should have been expanded beforehand"); } }; /// Subviews must be expanded before we reach this stage. /// Report that information. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return rewriter.notifyMatchFailure( subViewOp, "subview operations should have been expanded beforehand"); } }; /// Conversion pattern that transforms a transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The transpose op is replaced by the alloca'ed pointer. class TransposeOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = transposeOp.getLoc(); MemRefDescriptor viewMemRef(adaptor.getIn()); // No permutation, early exit. if (transposeOp.getPermutation().isIdentity()) return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(transposeOp.getIn().getType())); // Copy the base and aligned pointers from the old descriptor to the new // one. targetMemRef.setAllocatedPtr(rewriter, loc, viewMemRef.allocatedPtr(rewriter, loc)); targetMemRef.setAlignedPtr(rewriter, loc, viewMemRef.alignedPtr(rewriter, loc)); // Copy the offset pointer from the old descriptor to the new one. targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); // Iterate over the dimensions and apply size/stride permutation: // When enumerating the results of the permutation map, the enumeration // index is the index into the target dimensions and the DimExpr points to // the dimension of the source memref. for (const auto &en : llvm::enumerate(transposeOp.getPermutation().getResults())) { int targetPos = en.index(); int sourcePos = cast(en.value()).getPosition(); targetMemRef.setSize(rewriter, loc, targetPos, viewMemRef.size(rewriter, loc, sourcePos)); targetMemRef.setStride(rewriter, loc, targetPos, viewMemRef.stride(rewriter, loc, sourcePos)); } rewriter.replaceOp(transposeOp, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ValueRange dynamicSizes, unsigned idx, Type indexType) const { assert(idx < shape.size()); if (ShapedType::isStatic(shape[idx])) return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx, Type indexType) const { assert(idx < strides.size()); if (ShapedType::isStatic(strides[idx])) return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); if (nextSize) return runningStride ? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexAttrConstant(rewriter, loc, indexType, 1); } LogicalResult matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = viewOp.getLoc(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); auto targetDescTy = typeConverter->convertType(viewMemRefType); if (!targetDescTy || !targetElementTy || !LLVM::isCompatibleType(targetElementTy) || !LLVM::isCompatibleType(targetDescTy)) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset); if (failed(successStrides)) return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Target memref must be contiguous in memory (innermost stride is 1), or // empty (special case when at least one of the memref dimensions is 0). if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) return viewOp.emitWarning("cannot cast to non-contiguous shape"), failure(); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.getSource()); auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = cast(viewOp.getSource().getType()); targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr); // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); alignedPtr = LLVM::GEPOp::create( rewriter, loc, alignedPtr.getType(), typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, adaptor.getByteShift()); targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr); Type indexType = getIndexType(); // Field 3: The offset in the resulting type must be 0. This is // because of the type change: an offset on srcType* may not be // expressible as an offset on dstType*. targetMemRef.setOffset( rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(viewOp, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.getSizes(), i, indexType); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i, indexType); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(viewOp, {targetMemRef}); return success(); } }; //===----------------------------------------------------------------------===// // AtomicRMWOpLowering //===----------------------------------------------------------------------===// /// Try to match the kind of a memref.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static std::optional matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { switch (atomicOp.getKind()) { case arith::AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case arith::AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case arith::AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " "from fmax to fmaximum, expect more NaNs"); return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; case arith::AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case arith::AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " "from fmin to fminimum, expect more NaNs"); return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; case arith::AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case arith::AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; case arith::AtomicRMWKind::ori: return LLVM::AtomicBinOp::_or; case arith::AtomicRMWKind::andi: return LLVM::AtomicBinOp::_and; default: return std::nullopt; } llvm_unreachable("Invalid AtomicRMWKind"); } struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); auto memRefType = atomicOp.getMemRefType(); SmallVector strides; int64_t offset; if (failed(memRefType.getStridesAndOffset(strides, offset))) return failure(); auto dataPtr = getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType, adaptor.getMemref(), adaptor.getIndices()); rewriter.replaceOpWithNewOp( atomicOp, *maybeKind, dataPtr, adaptor.getValue(), LLVM::AtomicOrdering::acq_rel); return success(); } }; /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. class ConvertExtractAlignedPointerAsIndex : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern< memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { BaseMemRefType sourceTy = extractOp.getSource().getType(); Value alignedPtr; if (sourceTy.hasRank()) { MemRefDescriptor desc(adaptor.getSource()); alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc()); } else { auto elementPtrTy = LLVM::LLVMPointerType::get( rewriter.getContext(), sourceTy.getMemorySpaceAsInt()); UnrankedMemRefDescriptor desc(adaptor.getSource()); Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc()); alignedPtr = UnrankedMemRefDescriptor::alignedPtr( rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr, elementPtrTy); } rewriter.replaceOpWithNewOp( extractOp, getTypeConverter()->getIndexType(), alignedPtr); return success(); } }; /// Materialize the MemRef descriptor represented by the results of /// ExtractStridedMetadataOp. class ExtractStridedMetadataOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern< memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) return failure(); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.getSource()); Location loc = extractStridedMetadataOp.getLoc(); Value source = extractStridedMetadataOp.getSource(); auto sourceMemRefType = cast(source.getType()); int64_t rank = sourceMemRefType.getRank(); SmallVector results; results.reserve(2 + rank * 2); // Base buffer. Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), cast(extractStridedMetadataOp.getBaseBuffer().getType()), baseBuffer, alignedBuffer); results.push_back((Value)dstMemRef); // Offset. results.push_back(sourceMemRef.offset(rewriter, loc)); // Sizes. for (unsigned i = 0; i < rank; ++i) results.push_back(sourceMemRef.size(rewriter, loc, i)); // Strides. for (unsigned i = 0; i < rank; ++i) results.push_back(sourceMemRef.stride(rewriter, loc, i)); rewriter.replaceOp(extractStridedMetadataOp, results); return success(); } }; } // namespace void mlir::populateFinalizeMemRefToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns, SymbolTableCollection *symbolTables) { // clang-format off patterns.add< AllocaOpLowering, AllocaScopeOpLowering, AtomicRMWOpLowering, AssumeAlignmentOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, PrefetchOpLowering, RankOpLowering, ReassociatingReshapeOpConversion, ReassociatingReshapeOpConversion, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, ViewOpLowering>(converter); // clang-format on patterns.add(converter, symbolTables); auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) patterns.add(converter, symbolTables); else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) patterns.add(converter, symbolTables); } namespace { struct FinalizeMemRefToLLVMConversionPass : public impl::FinalizeMemRefToLLVMConversionPassBase< FinalizeMemRefToLLVMConversionPass> { using FinalizeMemRefToLLVMConversionPassBase:: FinalizeMemRefToLLVMConversionPassBase; void runOnOperation() override { Operation *op = getOperation(); const auto &dataLayoutAnalysis = getAnalysis(); LowerToLLVMOptions options(&getContext(), dataLayoutAnalysis.getAtOrAbove(op)); options.allocLowering = (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc : LowerToLLVMOptions::AllocLowering::Malloc); options.useGenericFunctions = useGenericFunctions; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); SymbolTableCollection symbolTables; populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns, &symbolTables); LLVMConversionTarget target(getContext()); target.addLegalOp(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } }; /// Implement the interface to convert MemRef to LLVM. struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); } }; } // namespace void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { dialect->addInterfaces(); }); }