diff options
author | Matthias Springer <mspringer@nvidia.com> | 2025-03-08 12:21:15 +0100 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2025-04-01 01:52:53 +0200 |
commit | 4e7246ae3ac7166b40432828dcdc7123dffaadd6 (patch) | |
tree | 49dbf45b8abb6d88071ffdfddcf5776846d867a0 | |
parent | 799e9053641a6478d3144866a97737b37b87c260 (diff) | |
download | llvm-users/matthias-springer/memref_1_to_n.zip llvm-users/matthias-springer/memref_1_to_n.tar.gz llvm-users/matthias-springer/memref_1_to_n.tar.bz2 |
1:N memref to LLVMusers/matthias-springer/memref_1_to_n
update some more code
update
update
update
update
some progress
update
update
more improements
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h | 57 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 2 | ||||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h | 20 | ||||
-rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 5 | ||||
-rw-r--r-- | mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp | 40 | ||||
-rw-r--r-- | mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 80 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 304 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp | 24 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp | 263 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 45 | ||||
-rw-r--r-- | mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp | 224 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 422 | ||||
-rw-r--r-- | mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 77 |
15 files changed, 669 insertions, 902 deletions
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h index d5055f0..119106e 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -30,13 +30,13 @@ class LLVMPointerType; /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. -class MemRefDescriptor : public StructBuilder { +class MemRefDescriptor { public: /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(Value descriptor); + explicit MemRefDescriptor(ValueRange elements); /// Builds IR creating a `poison` value of the descriptor type. static MemRefDescriptor poison(OpBuilder &builder, Location loc, - Type descriptorType); + TypeRange descriptorTypes); /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. @@ -49,6 +49,11 @@ public: const LLVMTypeConverter &typeConverter, MemRefType type, Value memory, Value alignedMemory); + /// Builds IR extracting individual elements of a MemRef descriptor structure + /// and returning them as `results` list. + static MemRefDescriptor fromPackedStruct(OpBuilder &builder, Location loc, + Value packed); + /// Builds IR extracting the allocated pointer from the descriptor. Value allocatedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the allocated pointer into the descriptor. @@ -98,6 +103,8 @@ public: Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type); + int64_t getRank(); + /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: /// - allocated pointer; @@ -106,20 +113,21 @@ public: /// - <rank> sizes; /// - <rank> strides; /// where <rank> is the MemRef rank as provided in `type`. - static Value pack(OpBuilder &builder, Location loc, - const LLVMTypeConverter &converter, MemRefType type, - ValueRange values); - - /// Builds IR extracting individual elements of a MemRef descriptor structure - /// and returning them as `results` list. - static void unpack(OpBuilder &builder, Location loc, Value packed, - MemRefType type, SmallVectorImpl<Value> &results); + Value packStruct(OpBuilder &builder, Location loc); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues(MemRefType type); + ValueRange getElements() { return elements; } + + /*implicit*/ operator ValueRange() { return elements; } + private: + SmallVector<Value> elements; + // Value allocatedPtrVal, alignedPtrVal, offsetVal; + // SmallVector<Value> sizeVals, strideVals; + // Cached index type. Type indexType; }; @@ -155,13 +163,18 @@ private: ValueRange elements; }; -class UnrankedMemRefDescriptor : public StructBuilder { +class UnrankedMemRefDescriptor { public: /// Construct a helper for the given descriptor value. - explicit UnrankedMemRefDescriptor(Value descriptor); + explicit UnrankedMemRefDescriptor(ValueRange elements); /// Builds IR creating an `undef` value of the descriptor type. static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc, - Type descriptorType); + TypeRange descriptorType); + + /// Builds IR extracting individual elements of a MemRef descriptor structure + /// and returning them as `results` list. + static UnrankedMemRefDescriptor fromPackedStruct(OpBuilder &builder, + Location loc, Value packed); /// Builds IR extracting the rank from the descriptor Value rank(OpBuilder &builder, Location loc) const; @@ -176,14 +189,7 @@ public: /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. - static Value pack(OpBuilder &builder, Location loc, - const LLVMTypeConverter &converter, UnrankedMemRefType type, - ValueRange values); - - /// Builds IR extracting individual elements that compose an unranked memref - /// descriptor and returns them as `results` list. - static void unpack(OpBuilder &builder, Location loc, Value packed, - SmallVectorImpl<Value> &results); + Value packStruct(OpBuilder &builder, Location loc); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. @@ -269,6 +275,13 @@ public: static void setStride(OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, Value strideBasePtr, Value index, Value stride); + + ValueRange getElements() { return elements; } + + /*implicit*/ operator ValueRange() { return elements; } + +private: + SmallVector<Value> elements; }; } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index e78f174..2d743a9 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -76,7 +76,7 @@ protected: // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, + Value getStridedElementPtr(Location loc, MemRefType type, ValueRange memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h index 38b5e49..a65f136 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -91,6 +91,17 @@ public: Type convertCallingConventionType(Type type, bool useBarePointerCallConv = false) const; + /// Convert a memref type into an LLVM type that captures the relevant data. + LogicalResult convertMemRefType(MemRefType type, + SmallVectorImpl<Type> &result, + bool packed = false) const; + + /// Convert an unranked memref type to an LLVM type that captures the + /// runtime rank and a pointer to the static ranked memref desc + LogicalResult convertUnrankedMemRefType(UnrankedMemRefType type, + SmallVectorImpl<Type> &result, + bool packed = false) const; + /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). @@ -111,7 +122,7 @@ public: /// of the platform-specific C/C++ ABI lowering related to struct argument /// passing. SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, + ArrayRef<ValueRange> operands, OpBuilder &builder, bool useBarePtrCallConv = false) const; /// Promote the LLVM struct representation of one MemRef descriptor to stack @@ -245,13 +256,6 @@ private: /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported. Type convertComplexType(ComplexType type) const; - /// Convert a memref type into an LLVM type that captures the relevant data. - Type convertMemRefType(MemRefType type) const; - - /// Convert an unranked memref type to an LLVM type that captures the - /// runtime rank and a pointer to the static ranked memref desc - Type convertUnrankedMemRefType(UnrankedMemRefType type) const; - /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type) const; diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3acd470..e5f70c4 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -178,6 +178,7 @@ struct FatRawBufferCastLowering LogicalResult matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + /* Location loc = op.getLoc(); Value memRef = adaptor.getSource(); Value unconvertedMemref = op.getSource(); @@ -222,7 +223,7 @@ struct FatRawBufferCastLowering Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), - chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7); + chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=7); Value result = MemRefDescriptor::poison( rewriter, loc, @@ -241,6 +242,8 @@ struct FatRawBufferCastLowering } rewriter.replaceOp(op, result); return success(); + */ + return failure(); } }; diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index debfd00..bc6613d 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -125,24 +125,35 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, return rewriter.applySignatureConversion(block, *conversion, converter); } +static SmallVector<Value> flattenValueRanges(ArrayRef<ValueRange> ranges) { + SmallVector<Value> result; + for (ValueRange range : ranges) + llvm::append_range(result, range); + return result; +} + /// Convert the destination block signature (if necessary) and lower the branch /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedOperands = + flattenValueRanges(adaptor.getOperands()); FailureOr<Block *> convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), - TypeRange(adaptor.getOperands())); + TypeRange(ValueRange(flattenedOperands))); if (failed(convertedBlock)) return failure(); Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( - op, adaptor.getOperands(), *convertedBlock); + op, flattenedOperands, *convertedBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. - newOp->setAttrs(op->getAttrDictionary()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); return success(); } }; @@ -151,28 +162,33 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::CondBranchOp op, - typename cf::CondBranchOp::Adaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedTrueDestOperands = + flattenValueRanges(adaptor.getTrueDestOperands()); FailureOr<Block *> convertedTrueBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), - TypeRange(adaptor.getTrueDestOperands())); + TypeRange(ValueRange(flattenedTrueDestOperands))); if (failed(convertedTrueBlock)) return failure(); + SmallVector<Value> flattenedFalseDestOperands = + flattenValueRanges(adaptor.getFalseDestOperands()); FailureOr<Block *> convertedFalseBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), - TypeRange(adaptor.getFalseDestOperands())); + TypeRange(ValueRange(flattenedFalseDestOperands))); if (failed(convertedFalseBlock)) return failure(); Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( - op, adaptor.getCondition(), *convertedTrueBlock, - adaptor.getTrueDestOperands(), *convertedFalseBlock, - adaptor.getFalseDestOperands()); + op, llvm::getSingleElement(adaptor.getCondition()), *convertedTrueBlock, + flattenedTrueDestOperands, *convertedFalseBlock, + flattenedFalseDestOperands); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. - newOp->setAttrs(op->getAttrDictionary()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); return success(); } }; diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 55f0a9a..c5c0817 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -140,15 +140,23 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, for (auto [index, argType] : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(index + argOffset); if (auto memrefType = dyn_cast<MemRefType>(argType)) { + SmallVector<Type> convertedType; + LogicalResult status = typeConverter.convertMemRefType(memrefType, convertedType, /*packed=*/true); + (void)status; + assert(succeeded(status) && "failed to convert memref type"); Value loaded = rewriter.create<LLVM::LoadOp>( - loc, typeConverter.convertType(memrefType), arg); - MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); + loc, llvm::getSingleElement(convertedType), arg); + llvm::append_range(args, MemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements()); continue; } - if (isa<UnrankedMemRefType>(argType)) { + if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(argType)) { + SmallVector<Type> convertedType; + LogicalResult status = typeConverter.convertUnrankedMemRefType(unrankedMemrefType, convertedType, /*packed=*/true); + (void)status; + assert(succeeded(status) && "failed to convert memref type"); Value loaded = rewriter.create<LLVM::LoadOp>( - loc, typeConverter.convertType(argType), arg); - UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); + loc, llvm::getSingleElement(convertedType), arg); + llvm::append_range(args, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements()); continue; } @@ -231,14 +239,12 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc, numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) : UnrankedMemRefDescriptor::getNumUnpackedValues(); - Value packed = - memRefType - ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, - wrapperArgsRange.take_front(numToDrop)) - : UnrankedMemRefDescriptor::pack( - builder, loc, typeConverter, unrankedMemRefType, - wrapperArgsRange.take_front(numToDrop)); - + Value packed; + if (memRefType) { + packed = MemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc); + } else { + packed = UnrankedMemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc); + } auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext()); Value one = builder.create<LLVM::ConstantOp>( loc, typeConverter.convertType(builder.getIndexType()), @@ -515,9 +521,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering<CallOpType>; using Base = ConvertOpToLLVMPattern<CallOpType>; + using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor; LogicalResult matchAndRewriteImpl(CallOpType callOp, - typename CallOpType::Adaptor adaptor, + Adaptor adaptor, ConversionPatternRewriter &rewriter, bool useBarePtrCallConv = false) const { // Pack the result types into a struct. @@ -579,7 +586,18 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { return failure(); } - rewriter.replaceOp(callOp, results); + SmallVector<SmallVector<Value>> unpackedResults; + for (auto it : llvm::zip_equal(resultTypes, results)) { + SmallVector<Value> &result = unpackedResults.emplace_back(); + if (isa<MemRefType>(std::get<0>(it))) { + llvm::append_range(result, MemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements()); + } else if (isa<UnrankedMemRefType>(std::get<0>(it))) { + llvm::append_range(result, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements()); + } else { + result.push_back(std::get<1>(it)); + } + } + rewriter.replaceOpWithMultiple(callOp, unpackedResults); return success(); } }; @@ -593,7 +611,7 @@ public: symbolTable(symbolTable) {} LogicalResult - matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool useBarePtrCallConv = false; if (getTypeConverter()->getOptions().useBarePtrCallConv) { @@ -623,7 +641,7 @@ struct CallIndirectOpLowering using Super::Super; LogicalResult - matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); } @@ -666,7 +684,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); unsigned numArguments = op.getNumOperands(); @@ -680,20 +698,36 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); - Value newOperand = std::get<1>(it); + ValueRange adaptorVal = std::get<1>(it); if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( cast<BaseMemRefType>(oldTy))) { - MemRefDescriptor memrefDesc(newOperand); - newOperand = memrefDesc.allocatedPtr(rewriter, loc); + MemRefDescriptor memrefDesc(adaptorVal); + updatedOperands.push_back( memrefDesc.allocatedPtr(rewriter, loc)); } else if (isa<UnrankedMemRefType>(oldTy)) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); + } else { + assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types"); + updatedOperands.push_back(adaptorVal.front()); } - updatedOperands.push_back(newOperand); } } else { - updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); + // Pack operands. + for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { + Value operand = std::get<0>(it); + ValueRange adaptorVal = std::get<1>(it); + if (isa<MemRefType>(operand.getType())) { + MemRefDescriptor memrefDesc(adaptorVal); + updatedOperands.push_back(memrefDesc.packStruct(rewriter, loc)); + } else if (isa<UnrankedMemRefType>(operand.getType())) { + UnrankedMemRefDescriptor unrankedMemrefDesc(adaptorVal); + updatedOperands.push_back(unrankedMemrefDesc.packStruct(rewriter, loc)); + } else { + assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types"); + updatedOperands.push_back(adaptorVal.front()); + } + } (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index f22ad1f..79bd1582 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -76,310 +76,8 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Location loc = gpuFuncOp.getLoc(); - - SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; - if (encodeWorkgroupAttributionsAsArguments) { - // Append an `llvm.ptr` argument to the function signature to encode - // workgroup attributions. - - ArrayRef<BlockArgument> workgroupAttributions = - gpuFuncOp.getWorkgroupAttributions(); - size_t numAttributions = workgroupAttributions.size(); - - // Insert all arguments at the end. - unsigned index = gpuFuncOp.getNumArguments(); - SmallVector<unsigned> argIndices(numAttributions, index); - - // New arguments will simply be `llvm.ptr` with the correct address space - Type workgroupPtrType = - rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace); - SmallVector<Type> argTypes(numAttributions, workgroupPtrType); - - // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>) - std::array attrs{ - rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(), - rewriter.getUnitAttr()), - rewriter.getNamedAttr( - getDialect().getWorkgroupAttributionAttrHelper().getName(), - rewriter.getUnitAttr()), - }; - SmallVector<DictionaryAttr> argAttrs; - for (BlockArgument attribution : workgroupAttributions) { - auto attributionType = cast<MemRefType>(attribution.getType()); - IntegerAttr numElements = - rewriter.getI64IntegerAttr(attributionType.getNumElements()); - Type llvmElementType = - getTypeConverter()->convertType(attributionType.getElementType()); - if (!llvmElementType) - return failure(); - TypeAttr type = TypeAttr::get(llvmElementType); - attrs.back().setValue( - rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type)); - argAttrs.push_back(rewriter.getDictionaryAttr(attrs)); - } - // Location match function location - SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc()); - - // Perform signature modification - rewriter.modifyOpInPlace( - gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() { - static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments( - argIndices, argTypes, argAttrs, argLocs); - }); - } else { - workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); - for (auto [idx, attribution] : - llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { - auto type = dyn_cast<MemRefType>(attribution.getType()); - assert(type && type.hasStaticShape() && "unexpected type in attribution"); - - uint64_t numElements = type.getNumElements(); - - auto elementType = - cast<Type>(typeConverter->convertType(type.getElementType())); - auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); - std::string name = - std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx)); - uint64_t alignment = 0; - if (auto alignAttr = dyn_cast_or_null<IntegerAttr>( - gpuFuncOp.getWorkgroupAttributionAttr( - idx, LLVM::LLVMDialect::getAlignAttrName()))) - alignment = alignAttr.getInt(); - auto globalOp = rewriter.create<LLVM::GlobalOp>( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, - LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, - workgroupAddrSpace); - workgroupBuffers.push_back(globalOp); - } - } - - // Remap proper input types. - TypeConverter::SignatureConversion signatureConversion( - gpuFuncOp.front().getNumArguments()); - - Type funcType = getTypeConverter()->convertFunctionSignature( - gpuFuncOp.getFunctionType(), /*isVariadic=*/false, - getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion); - if (!funcType) { - return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) { - diag << "failed to convert function signature type for: " - << gpuFuncOp.getFunctionType(); - }); - } - - // Create the new function operation. Only copy those attributes that are - // not specific to function modeling. - SmallVector<NamedAttribute, 4> attributes; - ArrayAttr argAttrs; - for (const auto &attr : gpuFuncOp->getAttrs()) { - if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || - attr.getName() == - gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() || - attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() || - attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() || - attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() || - attr.getName() == gpuFuncOp.getKnownGridSizeAttrName()) - continue; - if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) { - argAttrs = gpuFuncOp.getArgAttrsAttr(); - continue; - } - attributes.push_back(attr); - } - - DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr(); - DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr(); - // Ensure we don't lose information if the function is lowered before its - // surrounding context. - auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect()); - if (knownBlockSize) - attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(), - knownBlockSize); - if (knownGridSize) - attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(), - knownGridSize); - - // Add a dialect specific kernel attribute in addition to GPU kernel - // attribute. The former is necessary for further translation while the - // latter is expected by gpu.launch_func. - if (gpuFuncOp.isKernel()) { - if (kernelAttributeName) - attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); - // Set the dialect-specific block size attribute if there is one. - if (kernelBlockSizeAttributeName && knownBlockSize) { - attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize); - } - } - LLVM::CConv callingConvention = gpuFuncOp.isKernel() - ? kernelCallingConvention - : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, - LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, - /*comdat=*/nullptr, attributes); - - { - // Insert operations that correspond to converted workgroup and private - // memory attributions to the body of the function. This must operate on - // the original function, before the body region is inlined in the new - // function to maintain the relation between block arguments and the - // parent operation that assigns their semantics. - OpBuilder::InsertionGuard guard(rewriter); - - // Rewrite workgroup memory attributions to addresses of global buffers. - rewriter.setInsertionPointToStart(&gpuFuncOp.front()); - unsigned numProperArguments = gpuFuncOp.getNumArguments(); - - if (encodeWorkgroupAttributionsAsArguments) { - // Build a MemRefDescriptor with each of the arguments added above. - - unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions(); - assert(numProperArguments >= numAttributions && - "Expecting attributions to be encoded as arguments already"); - - // Arguments encoding workgroup attributions will be in positions - // [numProperArguments, numProperArguments+numAttributions) - ArrayRef<BlockArgument> attributionArguments = - gpuFuncOp.getArguments().slice(numProperArguments - numAttributions, - numAttributions); - for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal( - gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) { - auto [attribution, arg] = vals; - auto type = cast<MemRefType>(attribution.getType()); - - // Arguments are of llvm.ptr type and attributions are of memref type: - // we need to wrap them in memref descriptors. - Value descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), type, arg); - - // And remap the arguments - signatureConversion.remapInput(numProperArguments + idx, descr); - } - } else { - for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), - global.getAddrSpace()); - Value address = rewriter.create<LLVM::AddressOfOp>( - loc, ptrType, global.getSymNameAttr()); - Value memory = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), - address, ArrayRef<LLVM::GEPArg>{0, 0}); - - // Build a memref descriptor pointing to the buffer to plug with the - // existing memref infrastructure. This may use more registers than - // otherwise necessary given that memref sizes are fixed, but we can try - // and canonicalize that away later. - Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx]; - auto type = cast<MemRefType>(attribution.getType()); - Value descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), type, memory); - signatureConversion.remapInput(numProperArguments + idx, descr); - } - } - - // Rewrite private memory attributions to alloca'ed buffers. - unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); - auto int64Ty = IntegerType::get(rewriter.getContext(), 64); - for (const auto [idx, attribution] : - llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { - auto type = cast<MemRefType>(attribution.getType()); - assert(type && type.hasStaticShape() && "unexpected type in attribution"); - - // Explicitly drop memory space when lowering private memory - // attributions since NVVM models it as `alloca`s in the default - // memory space and does not support `alloca`s with addrspace(5). - Type elementType = typeConverter->convertType(type.getElementType()); - auto ptrType = - LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create<LLVM::ConstantOp>( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); - uint64_t alignment = 0; - if (auto alignAttr = - dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( - idx, LLVM::LLVMDialect::getAlignAttrName()))) - alignment = alignAttr.getInt(); - Value allocated = rewriter.create<LLVM::AllocaOp>( - gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); - Value descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, *getTypeConverter(), type, allocated); - signatureConversion.remapInput( - numProperArguments + numWorkgroupAttributions + idx, descr); - } - } - - // Move the region to the new function, update the entry block signature. - rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), - llvmFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, - &signatureConversion))) - return failure(); - - // Get memref type from function arguments and set the noalias to - // pointer arguments. - for (const auto [idx, argTy] : - llvm::enumerate(gpuFuncOp.getArgumentTypes())) { - auto remapping = signatureConversion.getInputMapping(idx); - NamedAttrList argAttr = - argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList(); - auto copyAttribute = [&](StringRef attrName) { - Attribute attr = argAttr.erase(attrName); - if (!attr) - return; - for (size_t i = 0, e = remapping->size; i < e; ++i) - llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); - }; - auto copyPointerAttribute = [&](StringRef attrName) { - Attribute attr = argAttr.erase(attrName); - - if (!attr) - return; - if (remapping->size > 1 && - attrName == LLVM::LLVMDialect::getNoAliasAttrName()) { - emitWarning(llvmFuncOp.getLoc(), - "Cannot copy noalias with non-bare pointers.\n"); - return; - } - for (size_t i = 0, e = remapping->size; i < e; ++i) { - if (isa<LLVM::LLVMPointerType>( - llvmFuncOp.getArgument(remapping->inputNo + i).getType())) { - llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr); - } - } - }; - - if (argAttr.empty()) - continue; - - copyAttribute(LLVM::LLVMDialect::getReturnedAttrName()); - copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName()); - copyAttribute(LLVM::LLVMDialect::getInRegAttrName()); - bool lowersToPointer = false; - for (size_t i = 0, e = remapping->size; i < e; ++i) { - lowersToPointer |= isa<LLVM::LLVMPointerType>( - llvmFuncOp.getArgument(remapping->inputNo + i).getType()); - } - - if (lowersToPointer) { - copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName()); - copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName()); - copyPointerAttribute( - LLVM::LLVMDialect::getDereferenceableOrNullAttrName()); - copyPointerAttribute( - LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr()); - } - } - rewriter.eraseOp(gpuFuncOp); - return success(); + return failure(); } LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 512820b..f0b1602 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -723,8 +723,10 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); - auto arguments = getTypeConverter()->promoteOperands( - loc, op->getOperands(), adaptor.getOperands(), rewriter); + llvm_unreachable("TODO"); + SmallVector<Value> arguments; + //auto arguments = getTypeConverter()->promoteOperands( + // loc, op->getOperands(), adaptor.getOperands(), rewriter); arguments.push_back(elementSize); hostRegisterCallBuilder.create(loc, rewriter, arguments); @@ -745,8 +747,10 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite( auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); - auto arguments = getTypeConverter()->promoteOperands( - loc, op->getOperands(), adaptor.getOperands(), rewriter); + llvm_unreachable("TODO"); + SmallVector<Value> arguments; + //auto arguments = getTypeConverter()->promoteOperands( + // loc, op->getOperands(), adaptor.getOperands(), rewriter); arguments.push_back(elementSize); hostUnregisterCallBuilder.create(loc, rewriter, arguments); @@ -805,9 +809,9 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( if (allocOp.getAsyncToken()) { // Async alloc: make dependent ops use the same stream. - rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); + //rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); } else { - rewriter.replaceOp(allocOp, {memRefDescriptor}); + //rewriter.replaceOp(allocOp, {memRefDescriptor}); } return success(); @@ -977,9 +981,11 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( // Note: If `useBarePtrCallConv` is set in the type converter's options, // the value of `kernelBarePtrCallConv` will be ignored. OperandRange origArguments = launchOp.getKernelOperands(); - SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( - loc, origArguments, adaptor.getKernelOperands(), rewriter, - /*useBarePtrCallConv=*/kernelBarePtrCallConv); + llvm_unreachable("TODO"); + SmallVector<Value,8> llvmArguments; + //SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands( + // loc, origArguments, adaptor.getKernelOperands(), rewriter, + // /*useBarePtrCallConv=*/kernelBarePtrCallConv); SmallVector<Value, 8> llvmArgumentsWithSizes; // Intersperse size information if requested. diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index 86d6643..9f8030a 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -21,19 +21,23 @@ using namespace mlir; //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(Value descriptor) - : StructBuilder(descriptor) { - assert(value != nullptr && "value cannot be null"); - indexType = cast<LLVM::LLVMStructType>(value.getType()) - .getBody()[kOffsetPosInMemRefDescriptor]; +MemRefDescriptor::MemRefDescriptor(ValueRange elements) : elements(elements) { + indexType = elements[kOffsetPosInMemRefDescriptor].getType(); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc, - Type descriptorType) { - - Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); - return MemRefDescriptor(descriptor); + TypeRange descriptorTypes) { + DenseMap<Type, Value> poisonValues; + SmallVector<Value> elements; + for (Type t : descriptorTypes) { + auto it = poisonValues.find(t); + if (it == poisonValues.end()) { + poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t); + } + elements.push_back(poisonValues[t]); + } + return MemRefDescriptor(elements); } /// Builds IR creating a MemRef descriptor that represents `type` and @@ -57,10 +61,11 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape( assert(!llvm::any_of(strides, ShapedType::isDynamic) && "expected static strides"); - auto convertedType = typeConverter.convertType(type); - assert(convertedType && "unexpected failure in memref type conversion"); + SmallVector<Type> convertedTypes; + LogicalResult status = typeConverter.convertType(type, convertedTypes); + assert(succeeded(status) && "unexpected failure in memref type conversion"); - auto descr = MemRefDescriptor::poison(builder, loc, convertedType); + auto descr = MemRefDescriptor::poison(builder, loc, convertedTypes); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, alignedMemory); descr.setConstantOffset(builder, loc, offset); @@ -73,26 +78,81 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape( return descr; } +static Value extractStructElement(OpBuilder &builder, Location loc, + Value packed, ArrayRef<int64_t> idx) { + return builder.create<LLVM::ExtractValueOp>(loc, packed, idx); +} + +static Value insertStructElement(OpBuilder &builder, Location loc, Value packed, + Value val, ArrayRef<int64_t> idx) { + return builder.create<LLVM::InsertValueOp>(loc, packed, val, idx); +} +MemRefDescriptor MemRefDescriptor::fromPackedStruct(OpBuilder &builder, + Location loc, + Value packed) { + auto llvmStruct = cast<LLVM::LLVMStructType>(packed.getType()); + SmallVector<Value> elements; + elements.push_back(extractStructElement(builder, loc, packed, 0)); + elements.push_back(extractStructElement(builder, loc, packed, 1)); + elements.push_back(extractStructElement(builder, loc, packed, 2)); + if (llvmStruct.getBody().size() > 3) { + auto llvmArray = cast<LLVM::LLVMArrayType>(llvmStruct.getBody()[3]); + int64_t rank = llvmArray.getNumElements(); + for (int i = 0; i < rank; ++i) + elements.push_back(extractStructElement(builder, loc, packed, {3, i})); + for (int i = 0; i < rank; ++i) + elements.push_back(extractStructElement(builder, loc, packed, {4, i})); + } + return MemRefDescriptor(elements); +} + +Value MemRefDescriptor::packStruct(OpBuilder &builder, Location loc) { + Type offsetStrideTy = elements[2].getType(); + SmallVector<Type> fields; + fields.push_back(elements[0].getType()); + fields.push_back(elements[1].getType()); + fields.push_back(offsetStrideTy); + if (getRank() > 0) { + auto llvmArray = LLVM::LLVMArrayType::get(builder.getContext(), + offsetStrideTy, getRank()); + fields.push_back(llvmArray); + fields.push_back(llvmArray); + } + Value desc = builder.create<LLVM::UndefOp>( + loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields)); + desc = insertStructElement(builder, loc, desc, elements[0], 0); + desc = insertStructElement(builder, loc, desc, elements[1], 1); + desc = insertStructElement(builder, loc, desc, elements[2], 2); + if(getRank() > 0) { + for (int i = 0; i < getRank(); ++i) + desc = insertStructElement(builder, loc, desc, elements[3 + i], {3, i}); + for (int i = 0; i < getRank(); ++i) + desc = insertStructElement(builder, loc, desc, elements[3 + getRank() + i], + {4, i}); + } + return desc; +} + /// Builds IR extracting the allocated pointer from the descriptor. Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); + return elements[kAllocatedPtrPosInMemRefDescriptor]; } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr) { - setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); + elements[kAllocatedPtrPosInMemRefDescriptor] = ptr; } /// Builds IR extracting the aligned pointer from the descriptor. Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); + return elements[kAlignedPtrPosInMemRefDescriptor]; } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, Value ptr) { - setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); + elements[kAlignedPtrPosInMemRefDescriptor] = ptr; } // Creates a constant Op producing a value of `resultType` from an index-typed @@ -105,28 +165,25 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc, /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { - return builder.create<LLVM::ExtractValueOp>(loc, value, - kOffsetPosInMemRefDescriptor); + return elements[kOffsetPosInMemRefDescriptor]; } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { - value = builder.create<LLVM::InsertValueOp>(loc, value, offset, - kOffsetPosInMemRefDescriptor); + elements[kOffsetPosInMemRefDescriptor] = offset; } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { - setOffset(builder, loc, - createIndexAttrConstant(builder, loc, indexType, offset)); + elements[kOffsetPosInMemRefDescriptor] = + createIndexAttrConstant(builder, loc, indexType, offset); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); + return elements[kSizePosInMemRefDescriptor + pos]; } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, @@ -137,8 +194,14 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, // Copy size values to stack-allocated memory. auto one = createIndexAttrConstant(builder, loc, indexType, 1); - auto sizes = builder.create<LLVM::ExtractValueOp>( - loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); + SmallVector<Type> structElems(rank, indexType); + Value sizes = builder.create<LLVM::UndefOp>( + loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), structElems)); + ValueRange sizeVals = + ValueRange(elements).slice(kSizePosInMemRefDescriptor, rank); + for (auto it : llvm::enumerate(sizeVals)) + sizes = + builder.create<LLVM::InsertValueOp>(loc, sizes, it.value(), it.index()); auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one, /*alignment=*/0); builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); @@ -152,40 +215,35 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { - value = builder.create<LLVM::InsertValueOp>( - loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); + elements[kSizePosInMemRefDescriptor + pos] = size; } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { - setSize(builder, loc, pos, - createIndexAttrConstant(builder, loc, indexType, size)); + elements[kSizePosInMemRefDescriptor + pos] = + createIndexAttrConstant(builder, loc, indexType, size); } /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); + return elements[kSizePosInMemRefDescriptor + getRank() + pos]; } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { - value = builder.create<LLVM::InsertValueOp>( - loc, value, stride, - ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos})); + elements[kSizePosInMemRefDescriptor + getRank() + pos] = stride; } void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { - setStride(builder, loc, pos, - createIndexAttrConstant(builder, loc, indexType, stride)); + elements[kSizePosInMemRefDescriptor + getRank() + pos] = + createIndexAttrConstant(builder, loc, indexType, stride); } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { return cast<LLVM::LLVMPointerType>( - cast<LLVM::LLVMStructType>(value.getType()) - .getBody()[kAlignedPtrPosInMemRefDescriptor]); + elements[kAlignedPtrPosInMemRefDescriptor].getType()); } Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, @@ -212,51 +270,6 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, return ptr; } -/// Creates a MemRef descriptor structure from a list of individual values -/// composing that descriptor, in the following order: -/// - allocated pointer; -/// - aligned pointer; -/// - offset; -/// - <rank> sizes; -/// - <rank> strides; -/// where <rank> is the MemRef rank as provided in `type`. -Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, - const LLVMTypeConverter &converter, - MemRefType type, ValueRange values) { - Type llvmType = converter.convertType(type); - auto d = MemRefDescriptor::poison(builder, loc, llvmType); - - d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); - d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); - d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); - - int64_t rank = type.getRank(); - for (unsigned i = 0; i < rank; ++i) { - d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); - d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); - } - - return d; -} - -/// Builds IR extracting individual elements of a MemRef descriptor structure -/// and returning them as `results` list. -void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, - MemRefType type, - SmallVectorImpl<Value> &results) { - int64_t rank = type.getRank(); - results.reserve(results.size() + getNumUnpackedValues(type)); - - MemRefDescriptor d(packed); - results.push_back(d.allocatedPtr(builder, loc)); - results.push_back(d.alignedPtr(builder, loc)); - results.push_back(d.offset(builder, loc)); - for (int64_t i = 0; i < rank; ++i) - results.push_back(d.size(builder, loc, i)); - for (int64_t i = 0; i < rank; ++i) - results.push_back(d.stride(builder, loc, i)); -} - /// Returns the number of non-aggregate values that would be produced by /// `unpack`. unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { @@ -264,6 +277,8 @@ unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { return 3 + 2 * type.getRank(); } +int64_t MemRefDescriptor::getRank() { return (elements.size() - 3) / 2; } + //===----------------------------------------------------------------------===// // MemRefDescriptorView implementation. //===----------------------------------------------------------------------===// @@ -296,57 +311,61 @@ Value MemRefDescriptorView::stride(unsigned pos) { //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. -UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) - : StructBuilder(descriptor) {} +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValueRange elements) + : elements(elements) {} /// Builds IR creating an `undef` value of the descriptor type. -UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder, - Location loc, - Type descriptorType) { - Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType); - return UnrankedMemRefDescriptor(descriptor); +UnrankedMemRefDescriptor +UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc, + TypeRange descriptorTypes) { + DenseMap<Type, Value> poisonValues; + SmallVector<Value> elements; + for (Type t : descriptorTypes) { + auto it = poisonValues.find(t); + if (it == poisonValues.end()) { + poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t); + } + elements.push_back(poisonValues[t]); + } + return UnrankedMemRefDescriptor(elements); +} + +/// Builds IR extracting individual elements of a MemRef descriptor structure +/// and returning them as `results` list. +UnrankedMemRefDescriptor +UnrankedMemRefDescriptor::fromPackedStruct(OpBuilder &builder, Location loc, + Value packed) { + SmallVector<Value> elements; + elements.push_back(extractStructElement(builder, loc, packed, 0)); + elements.push_back(extractStructElement(builder, loc, packed, 1)); + return UnrankedMemRefDescriptor(elements); +} + +Value UnrankedMemRefDescriptor::packStruct(OpBuilder &builder, Location loc) { + SmallVector<Type> fields; + fields.push_back(elements[0].getType()); + fields.push_back(elements[1].getType()); + Value desc = builder.create<LLVM::UndefOp>( + loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields)); + desc = insertStructElement(builder, loc, desc, elements[0], 0); + desc = insertStructElement(builder, loc, desc, elements[1], 1); + return desc; } + Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { - return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); + return elements[kRankInUnrankedMemRefDescriptor]; } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, Value v) { - setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); + elements[kRankInUnrankedMemRefDescriptor] = v; } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, Location loc) const { - return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); + return elements[kPtrInUnrankedMemRefDescriptor]; } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, Location loc, Value v) { - setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); -} - -/// Builds IR populating an unranked MemRef descriptor structure from a list -/// of individual constituent values in the following order: -/// - rank of the memref; -/// - pointer to the memref descriptor. -Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, - const LLVMTypeConverter &converter, - UnrankedMemRefType type, - ValueRange values) { - Type llvmType = converter.convertType(type); - auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType); - - d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); - d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); - return d; -} - -/// Builds IR extracting individual elements that compose an unranked memref -/// descriptor and returns them as `results` list. -void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, - Value packed, - SmallVectorImpl<Value> &results) { - UnrankedMemRefDescriptor d(packed); - results.reserve(results.size() + 2); - results.push_back(d.rank(builder, loc)); - results.push_back(d.memRefDescPtr(builder, loc)); + elements[kPtrInUnrankedMemRefDescriptor] = v; } void UnrankedMemRefDescriptor::computeSizes( diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 71b68619..c5af470 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -59,7 +59,7 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, MemRefType type, Value memRefDesc, ValueRange indices, + Location loc, MemRefType type, ValueRange memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { auto [strides, offset] = type.getStridesAndOffset(); @@ -217,34 +217,20 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef<Value> sizes, ArrayRef<Value> strides, ConversionPatternRewriter &rewriter) const { - auto structType = typeConverter->convertType(memRefType); - auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType); - - // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); - - // Field 2: Actual aligned pointer to payload. - memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); - - // Field 3: Offset in aligned pointer. + SmallVector<Value> elements; + elements.push_back(allocatedPtr); + elements.push_back(alignedPtr); Type indexType = getIndexType(); - memRefDescriptor.setOffset( - rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0)); - - // Fields 4: Sizes. - for (const auto &en : llvm::enumerate(sizes)) - memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); - - // Field 5: Strides. - for (const auto &en : llvm::enumerate(strides)) - memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); - - return memRefDescriptor; + elements.push_back(createIndexAttrConstant(rewriter, loc, indexType, 0)); + llvm::append_range(elements, sizes); + llvm::append_range(elements, strides); + return MemRefDescriptor(elements); } LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( OpBuilder &builder, Location loc, TypeRange origTypes, SmallVectorImpl<Value> &operands, bool toDynamic) const { + // TODO: Pass unpacked structs to this function. assert(origTypes.size() == operands.size() && "expected as may original types as operands"); @@ -253,7 +239,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( SmallVector<unsigned> unrankedAddressSpaces; for (unsigned i = 0, e = operands.size(); i < e; ++i) { if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { - unrankedMemrefs.emplace_back(operands[i]); + unrankedMemrefs.push_back(UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i])); FailureOr<unsigned> addressSpace = getTypeConverter()->getMemRefAddressSpace(memRefType); if (failed(addressSpace)) @@ -294,7 +280,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( if (!isa<UnrankedMemRefType>(type)) continue; Value allocationSize = sizes[unrankedMemrefPos++]; - UnrankedMemRefDescriptor desc(operands[i]); + UnrankedMemRefDescriptor desc = UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i]); // Allocate memory, copy, and free the source if necessary. Value memory = @@ -315,16 +301,15 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // times, attempting to modify its pointer can lead to memory leaks // (allocated twice and overwritten) or double frees (the caller does not // know if the descriptor points to the same memory). - Type descriptorType = getTypeConverter()->convertType(type); - if (!descriptorType) + SmallVector<Type> descriptorTypes; + if (failed(getTypeConverter()->convertType(type, descriptorTypes))) return failure(); auto updatedDesc = - UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); + UnrankedMemRefDescriptor::poison(builder, loc, descriptorTypes); Value rank = desc.rank(builder, loc); updatedDesc.setRank(builder, loc, rank); updatedDesc.setMemRefDescPtr(builder, loc, memory); - - operands[i] = updatedDesc; + operands[i] = updatedDesc.packStruct(builder, loc); } return success(); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index ea251e4..2113bd3 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -50,68 +50,6 @@ static bool isBarePointer(ValueRange values) { isa<LLVM::LLVMPointerType>(values.front().getType()); } -/// Pack SSA values into an unranked memref descriptor struct. -static Value packUnrankedMemRefDesc(OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc, - const LLVMTypeConverter &converter) { - // Note: Bare pointers are not supported for unranked memrefs because a - // memref descriptor cannot be built just from a bare pointer. - if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) - return Value(); - return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType, - inputs); -} - -/// Pack SSA values into a ranked memref descriptor struct. -static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc, - const LLVMTypeConverter &converter) { - assert(resultType && "expected non-null result type"); - if (isBarePointer(inputs)) - return MemRefDescriptor::fromStaticShape(builder, loc, converter, - resultType, inputs[0]); - if (TypeRange(inputs) == - converter.getMemRefDescriptorFields(resultType, - /*unpackAggregates=*/true)) - return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs); - // The inputs are neither a bare pointer nor an unpacked memref descriptor. - // This materialization function cannot be used. - return Value(); -} - -/// MemRef descriptor elements -> UnrankedMemRefType -static Value unrankedMemRefMaterialization(OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc, - const LLVMTypeConverter &converter) { - // A source materialization must return a value of type - // `resultType`, so insert a cast from the memref descriptor type - // (!llvm.struct) to the original memref type. - Value packed = - packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); - if (!packed) - return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) - .getResult(0); -} - -/// MemRef descriptor elements -> MemRefType -static Value rankedMemRefMaterialization(OpBuilder &builder, - MemRefType resultType, - ValueRange inputs, Location loc, - const LLVMTypeConverter &converter) { - // A source materialization must return a value of type `resultType`, - // so insert a cast from the memref descriptor type (!llvm.struct) to the - // original memref type. - Value packed = - packRankedMemRefDesc(builder, resultType, inputs, loc, converter); - if (!packed) - return Value(); - return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) - .getResult(0); -} - /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, @@ -126,9 +64,22 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); addConversion([&](IntegerType type) { return convertIntegerType(type); }); - addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( - [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); + [&](MemRefType type, + SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { + LogicalResult status = convertMemRefType(type, result); + if (failed(status)) + return std::nullopt; + return success(); + }); + addConversion( + [&](UnrankedMemRefType type, + SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> { + LogicalResult status = convertUnrankedMemRefType(type, result); + if (failed(status)) + return std::nullopt; + return success(); + }); addConversion([&](VectorType type) -> std::optional<Type> { FailureOr<Type> llvmType = convertVectorType(type); if (failed(llvmType)) @@ -228,42 +179,26 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) .getResult(0); }); - addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) { - return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) - .getResult(0); + addTargetMaterialization([&](OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs, + Location loc) -> SmallVector<Value> { + auto castOp = + builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs); + return llvm::map_to_vector(castOp.getResults(), + [](OpResult r) -> Value { return r; }); }); - // Source materializations convert from the new block argument types - // (multiple SSA values that make up a memref descriptor) back to the - // original block argument type. - addSourceMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, ValueRange inputs, - Location loc) { - return unrankedMemRefMaterialization(builder, resultType, inputs, loc, - *this); - }); addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) { - return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); - }); - - // Bare pointer -> Packed MemRef descriptor - addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc, - Type originalType) -> Value { - // The original MemRef type is required to build a MemRef descriptor - // because the sizes/strides of the MemRef cannot be inferred from just the - // bare pointer. - if (!originalType) - return Value(); - if (resultType != convertType(originalType)) - return Value(); - if (auto memrefType = dyn_cast<MemRefType>(originalType)) - return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this); - if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType)) - return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc, - *this); + if (isBarePointer(inputs)) { + MemRefDescriptor desc = MemRefDescriptor::fromStaticShape( + builder, loc, *this, resultType, inputs[0]); + return builder + .create<UnrealizedConversionCastOp>(loc, resultType, + desc.getElements()) + .getResult(0); + } + // Default materialization creates unrealized_conversion_cast. return Value(); }); @@ -430,8 +365,10 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { Type resultType = type.getNumResults() == 0 ? LLVM::LLVMVoidType::get(&getContext()) : packFunctionResults(type.getResults()); - if (!resultType) + if (!resultType) { + llvm_unreachable("no result type!"); return {}; + } auto ptrType = LLVM::LLVMPointerType::get(type.getContext()); auto structType = dyn_cast<LLVM::LLVMStructType>(resultType); @@ -443,9 +380,11 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { } for (Type t : type.getInputs()) { - auto converted = convertType(t); - if (!converted || !LLVM::isCompatibleType(converted)) + auto converted = convertCallingConventionType(t); + if (!converted || !LLVM::isCompatibleType(converted)) { + llvm_unreachable("could not convert input!"); return {}; + } if (isa<MemRefType, UnrankedMemRefType>(t)) converted = ptrType; inputs.push_back(converted); @@ -533,14 +472,18 @@ LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. -Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { - // When converting a MemRefType to a struct with descriptor fields, do not - // unpack the `sizes` and `strides` arrays. - SmallVector<Type, 5> types = - getMemRefDescriptorFields(type, /*unpackAggregates=*/false); - if (types.empty()) - return {}; - return LLVM::LLVMStructType::getLiteral(&getContext(), types); +LogicalResult LLVMTypeConverter::convertMemRefType( + MemRefType type, SmallVectorImpl<Type> &result, bool packed) const { + SmallVector<Type, 5> fields = + getMemRefDescriptorFields(type, /*unpackAggregates=*/!packed); + if (fields.empty()) + return failure(); + if (packed) { + result.push_back(LLVM::LLVMStructType::getLiteral(&getContext(), fields)); + } else { + llvm::append_range(result, fields); + } + return success(); } /// Convert an unranked memref type into a list of non-aggregate LLVM IR types @@ -563,12 +506,17 @@ unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( llvm::divideCeil(getPointerBitwidth(space), 8); } -Type LLVMTypeConverter::convertUnrankedMemRefType( - UnrankedMemRefType type) const { +LogicalResult LLVMTypeConverter::convertUnrankedMemRefType( + UnrankedMemRefType type, SmallVectorImpl<Type> &result, bool packed) const { if (!convertType(type.getElementType())) - return {}; - return LLVM::LLVMStructType::getLiteral(&getContext(), - getUnrankedMemRefDescriptorFields()); + return failure(); + if (packed) { + result.push_back(LLVM::LLVMStructType::getLiteral( + &getContext(), getUnrankedMemRefDescriptorFields())); + } else { + llvm::append_range(result, getUnrankedMemRefDescriptorFields()); + } + return success(); } FailureOr<unsigned> @@ -665,6 +613,20 @@ Type LLVMTypeConverter::convertCallingConventionType( if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) return convertMemRefToBarePtr(memrefTy); + if (auto memrefTy = dyn_cast<MemRefType>(type)) { + SmallVector<Type> convertedType; + LogicalResult status = convertMemRefType(memrefTy, convertedType, true); + if (failed(status)) return Type(); + return llvm::getSingleElement(convertedType); + } + + if (auto unrankedMemrefTy = dyn_cast<UnrankedMemRefType>(type)) { + SmallVector<Type> convertedType; + LogicalResult status = convertUnrankedMemRefType(unrankedMemrefTy, convertedType, true); + if (failed(status)) return Type(); + return llvm::getSingleElement(convertedType); + } + return convertType(type); } @@ -674,12 +636,15 @@ Type LLVMTypeConverter::convertCallingConventionType( void LLVMTypeConverter::promoteBarePtrsToDescriptors( ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, SmallVectorImpl<Value> &values) const { - assert(stdTypes.size() == values.size() && - "The number of types and values doesn't match"); - for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) - values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, - memrefTy, values[i]); + /* + assert(stdTypes.size() == values.size() && + "The number of types and values doesn't match"); + for (unsigned i = 0, end = values.size(); i < end; ++i) + if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) + values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, + memrefTy, values[i]); + */ + llvm_unreachable("not implemented"); } /// Convert a non-empty list of types of values produced by an operation into an @@ -743,38 +708,27 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, + ArrayRef<ValueRange> operands, OpBuilder &builder, bool useBarePtrCallConv) const { SmallVector<Value, 4> promotedOperands; promotedOperands.reserve(operands.size()); useBarePtrCallConv |= options.useBarePtrCallConv; for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); - auto llvmOperand = std::get<1>(it); + auto llvmOperands = std::get<1>(it); if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (dyn_cast<MemRefType>(operand.getType())) { - MemRefDescriptor desc(llvmOperand); - llvmOperand = desc.alignedPtr(builder, loc); + MemRefDescriptor desc(llvmOperands); + promotedOperands.push_back(desc.alignedPtr(builder, loc)); + continue; } else if (isa<UnrankedMemRefType>(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } - } else { - if (isa<UnrankedMemRefType>(operand.getType())) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, - promotedOperands); - continue; - } - if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, - promotedOperands); - continue; - } } - - promotedOperands.push_back(llvmOperand); + llvm::append_range(promotedOperands, llvmOperands); } return promotedOperands; } diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index c5b2e83..c072723 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -195,6 +195,6 @@ LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); // Return the final value of the descriptor. - rewriter.replaceOp(op, {memRefDescriptor}); + rewriter.replaceOpWithMultiple(op, {memRefDescriptor}); return success(); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index cb4317e..a12507b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/MathExtras.h" #include <optional> @@ -185,15 +186,14 @@ struct AssumeAlignmentOpLowering : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {} LogicalResult - matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, + matchAndRewrite(memref::AssumeAlignmentOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value memref = adaptor.getMemref(); unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); auto srcMemRefType = cast<MemRefType>(op.getMemref().getType()); - Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, - rewriter); + Value ptr = getStridedElementPtr(loc, srcMemRefType, adaptor.getMemref(), + /*indices=*/{}, rewriter); // Emit llvm.assume(true) ["align"(memref, alignment)]. // This is more direct than ptrtoint-based checks, is explicitly supported, @@ -220,7 +220,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} LogicalResult - matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, + matchAndRewrite(memref::DeallocOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. FailureOr<LLVM::LLVMFuncOp> freeFunc = @@ -253,21 +253,20 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, + matchAndRewrite(memref::DimOp dimOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.getSource().getType(); if (isa<UnrankedMemRefType>(operandType)) { - FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef( - operandType, dimOp, adaptor.getOperands(), rewriter); + FailureOr<Value> extractedSize = + extractSizeOfUnrankedMemRef(operandType, dimOp, adaptor, rewriter); if (failed(extractedSize)) return failure(); rewriter.replaceOp(dimOp, {*extractedSize}); return success(); } if (isa<MemRefType>(operandType)) { - rewriter.replaceOp( - dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, - adaptor.getOperands(), rewriter)}); + rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, + adaptor, rewriter)}); return success(); } llvm_unreachable("expected MemRefType or UnrankedMemRefType"); @@ -276,7 +275,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { private: FailureOr<Value> extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp, - OpAdaptor adaptor, + OneToNOpAdaptor &adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); @@ -298,20 +297,24 @@ private: UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); - Type elementType = typeConverter->convertType(scalarMemRefType); + SmallVector<Type> convertedMemRefType; + if (failed(static_cast<const LLVMTypeConverter *>(typeConverter) + ->convertMemRefType(scalarMemRefType, convertedMemRefType, + /*packed=*/true))) + return failure(); // Get pointer to offset field of memref<element_type> descriptor. auto indexPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); Value offsetPtr = rewriter.create<LLVM::GEPOp>( - loc, indexPtrTy, elementType, underlyingRankedDesc, - ArrayRef<LLVM::GEPArg>{0, 2}); + loc, indexPtrTy, llvm::getSingleElement(convertedMemRefType), + underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2}); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create<LLVM::AddOp>( loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), - adaptor.getIndex()); + llvm::getSingleElement(adaptor.getIndex())); Value sizePtr = rewriter.create<LLVM::GEPOp>( loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); @@ -331,7 +334,7 @@ private: } Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp, - OpAdaptor adaptor, + OneToNOpAdaptor &adaptor, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); @@ -351,7 +354,7 @@ private: return createIndexAttrConstant(rewriter, loc, indexType, dimSize); } } - Value index = adaptor.getIndex(); + Value index = llvm::getSingleElement(adaptor.getIndex()); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(adaptor.getSource()); return memrefDescriptor.size(rewriter, loc, index, rank); @@ -400,7 +403,7 @@ struct GenericAtomicRMWOpLowering using Base::Base; LogicalResult - matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, + matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); @@ -416,8 +419,12 @@ struct GenericAtomicRMWOpLowering // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType()); + SmallVector<Value> indices = + llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + indices, rewriter); Value init = rewriter.create<LLVM::LoadOp>( loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); rewriter.create<LLVM::BrOp>(loc, init, loopBlock); @@ -579,13 +586,15 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { using Base::Base; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + matchAndRewrite(memref::LoadOp loadOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); - - Value dataPtr = - getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + SmallVector<Value> indices = + llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); + Value dataPtr = getStridedElementPtr( + loadOp.getLoc(), type, adaptor.getMemref(), indices, rewriter); rewriter.replaceOpWithNewOp<LLVM::LoadOp>( loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false, loadOp.getNontemporal()); @@ -599,14 +608,18 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { using Base::Base; LogicalResult - matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + matchAndRewrite(memref::StoreOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); - + SmallVector<Value> indices = + llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr, - 0, false, op.getNontemporal()); + indices, rewriter); + rewriter.replaceOpWithNewOp<LLVM::StoreOp>( + op, llvm::getSingleElement(adaptor.getValue()), dataPtr, 0, false, + op.getNontemporal()); return success(); } }; @@ -617,13 +630,16 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { using Base::Base; LogicalResult - matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, + matchAndRewrite(memref::PrefetchOp prefetchOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); - - Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + SmallVector<Value> indices = + llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); + Value dataPtr = + getStridedElementPtr(loc, type, adaptor.getMemref(), indices, rewriter); // Replace with llvm.prefetch. IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); @@ -640,7 +656,7 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, + matchAndRewrite(memref::RankOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.getMemref().getType(); @@ -664,8 +680,9 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, + matchAndRewrite(memref::CastOp memRefCastOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto loc = memRefCastOp.getLoc(); Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); @@ -674,21 +691,21 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { // and require source and result type to have the same rank. Therefore, // perform a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. + SmallVector<Type> convertedSrc, convertedDst; + if (failed(typeConverter->convertType(srcType, convertedSrc)) || + failed(typeConverter->convertType(dstType, convertedDst))) + return failure(); if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) - if (typeConverter->convertType(srcType) != - typeConverter->convertType(dstType)) + if (!llvm::equal(convertedSrc, convertedDst)) return failure(); // Unranked to unranked cast is disallowed if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) return failure(); - auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); - auto loc = memRefCastOp.getLoc(); - // For ranked/ranked case, just keep the original descriptor. if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) { - rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); + rewriter.replaceOpWithMultiple(memRefCastOp, {adaptor.getSource()}); return success(); } @@ -701,19 +718,20 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( - loc, adaptor.getSource(), rewriter); + loc, MemRefDescriptor(adaptor.getSource()).packStruct(rewriter, loc), + rewriter); // rank = ConstantOp srcRank auto rankVal = rewriter.create<LLVM::ConstantOp>( loc, getIndexType(), rewriter.getIndexAttr(rank)); // poison = PoisonOp UnrankedMemRefDescriptor memRefDesc = - UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType); + UnrankedMemRefDescriptor::poison(rewriter, loc, convertedDst); // d1 = InsertValueOp poison, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, ptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, ptr); - rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); + rewriter.replaceOpWithMultiple(memRefCastOp, {memRefDesc}); } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) { // Casting from unranked type to ranked. @@ -722,10 +740,16 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); - // struct = LoadOp ptr - auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr); - rewriter.replaceOp(memRefCastOp, loadOp.getResult()); + SmallVector<Type> targetStructType; + if (failed(getTypeConverter()->convertMemRefType( + cast<MemRefType>(dstType), targetStructType, /*packed=*/true))) + return failure(); + auto loadOp = rewriter.create<LLVM::LoadOp>( + loc, llvm::getSingleElement(targetStructType), ptr); + rewriter.replaceOpWithMultiple(memRefCastOp, + {MemRefDescriptor::fromPackedStruct( + rewriter, loc, loadOp.getResult())}); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } @@ -743,7 +767,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; LogicalResult - lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, + lowerToMemCopyIntrinsic(memref::CopyOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto srcType = dyn_cast<MemRefType>(op.getSource().getType()); @@ -782,74 +806,75 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { return success(); } - LogicalResult - lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto srcType = cast<BaseMemRefType>(op.getSource().getType()); - auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); - - // First make sure we have an unranked memref descriptor representation. - auto makeUnranked = [&, this](Value ranked, MemRefType type) { - auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - type.getRank()); - auto *typeConverter = getTypeConverter(); - auto ptr = - typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); - - auto unrankedType = - UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); - return UnrankedMemRefDescriptor::pack( - rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); - }; - - // Save stack position before promoting descriptors - auto stackSaveOp = - rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); - - auto srcMemRefType = dyn_cast<MemRefType>(srcType); - Value unrankedSource = - srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) - : adaptor.getSource(); - auto targetMemRefType = dyn_cast<MemRefType>(targetType); - Value unrankedTarget = - targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) - : adaptor.getTarget(); - - // Now promote the unranked descriptors to the stack. - auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(1)); - auto promote = [&](Value desc) { - auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - auto allocated = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); - rewriter.create<LLVM::StoreOp>(loc, desc, allocated); - return allocated; - }; - - auto sourcePtr = promote(unrankedSource); - auto targetPtr = promote(unrankedTarget); - - // Derive size from llvm.getelementptr which will account for any - // potential alignment - auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); - auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( - op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); - if (failed(copyFn)) - return failure(); - rewriter.create<LLVM::CallOp>(loc, copyFn.value(), - ValueRange{elemSize, sourcePtr, targetPtr}); - - // Restore stack used for descriptors - rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); + /* + LogicalResult + lowerToMemCopyFunctionCall(memref::CopyOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcType = cast<BaseMemRefType>(op.getSource().getType()); + auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); + + // First make sure we have an unranked memref descriptor representation. + auto makeUnranked = [&, this](Value ranked, MemRefType type) { + auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), + type.getRank()); + auto *typeConverter = getTypeConverter(); + auto ptr = + typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); + + auto unrankedType = + UnrankedMemRefType::get(type.getElementType(), + type.getMemorySpace()); return UnrankedMemRefDescriptor::pack( rewriter, + loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); + }; + + // Save stack position before promoting descriptors + auto stackSaveOp = + rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); + + auto srcMemRefType = dyn_cast<MemRefType>(srcType); + Value unrankedSource = + srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) + : adaptor.getSource(); + auto targetMemRefType = dyn_cast<MemRefType>(targetType); + Value unrankedTarget = + targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) + : adaptor.getTarget(); + + // Now promote the unranked descriptors to the stack. + auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), + rewriter.getIndexAttr(1)); + auto promote = [&](Value desc) { + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto allocated = + rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); + rewriter.create<LLVM::StoreOp>(loc, desc, allocated); + return allocated; + }; + + auto sourcePtr = promote(unrankedSource); + auto targetPtr = promote(unrankedTarget); + + // Derive size from llvm.getelementptr which will account for any + // potential alignment + auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); + auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( + op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType()); + if (failed(copyFn)) + return failure(); + rewriter.create<LLVM::CallOp>(loc, copyFn.value(), + ValueRange{elemSize, sourcePtr, targetPtr}); - rewriter.eraseOp(op); + // Restore stack used for descriptors + rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); - return success(); - } + rewriter.eraseOp(op); + return success(); + } + */ LogicalResult - matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + matchAndRewrite(memref::CopyOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = cast<BaseMemRefType>(op.getSource().getType()); auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); @@ -868,7 +893,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) return lowerToMemCopyIntrinsic(op, adaptor, rewriter); - return lowerToMemCopyFunctionCall(op, adaptor, rewriter); + return failure(); + // return lowerToMemCopyFunctionCall(op, adaptor, rewriter); } }; @@ -878,26 +904,23 @@ struct MemorySpaceCastOpLowering memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, + matchAndRewrite(memref::MemorySpaceCastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type resultType = op.getDest().getType(); + SmallVector<Type> convertedResultTypes; + if (failed(typeConverter->convertType(resultType, convertedResultTypes))) + return failure(); + if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) { - auto resultDescType = - cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR)); - Type newPtrType = resultDescType.getBody()[0]; + Type newPtrType = convertedResultTypes[0]; - SmallVector<Value> descVals; - MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR, - descVals); + SmallVector<Value> descVals = llvm::to_vector(adaptor.getSource()); descVals[0] = rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]); descVals[1] = rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]); - Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(), - resultTypeR, descVals); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithMultiple(op, {descVals}); return success(); } if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) { @@ -922,8 +945,8 @@ struct MemorySpaceCastOpLowering Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc); // Create and allocate storage for new memref descriptor. - auto result = UnrankedMemRefDescriptor::poison( - rewriter, loc, typeConverter->convertType(resultTypeU)); + auto result = + UnrankedMemRefDescriptor::poison(rewriter, loc, convertedResultTypes); result.setRank(rewriter, loc, rank); SmallVector<Value, 1> sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), @@ -972,7 +995,7 @@ struct MemorySpaceCastOpLowering rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals, copySize, /*isVolatile=*/false); - rewriter.replaceOp(op, ValueRange{result}); + rewriter.replaceOpWithMultiple(op, ValueRange{result}); return success(); } return rewriter.notifyMatchFailure(loc, "unexpected memref type"); @@ -986,7 +1009,7 @@ static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Value originalOperand, - Value convertedOperand, + ValueRange convertedOperand, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); @@ -1026,33 +1049,32 @@ struct MemRefReinterpretCastOpLowering memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, + matchAndRewrite(memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = castOp.getSource().getType(); - Value descriptor; + SmallVector<Value> descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); - rewriter.replaceOp(castOp, {descriptor}); + rewriter.replaceOpWithMultiple(castOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor( ConversionPatternRewriter &rewriter, Type srcType, - memref::ReinterpretCastOp castOp, - memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { + memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor, + SmallVector<Value> *descriptor) const { MemRefType targetMemRefType = cast<MemRefType>(castOp.getResult().getType()); - auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( - typeConverter->convertType(targetMemRefType)); - if (!llvmTargetDescriptorTy) + SmallVector<Type> convertedTypes; + if (failed(typeConverter->convertType(targetMemRefType, convertedTypes))) return failure(); // Create descriptor. Location loc = castOp.getLoc(); - auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; @@ -1064,7 +1086,8 @@ private: // Set offset. if (castOp.isDynamicOffset(0)) - desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); + desc.setOffset(rewriter, loc, + llvm::getSingleElement(adaptor.getOffsets()[0])); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); @@ -1073,16 +1096,19 @@ private: unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) - desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); + desc.setSize(rewriter, loc, i, + llvm::getSingleElement(adaptor.getSizes()[dynSizeId++])); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) - desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); + desc.setStride( + rewriter, loc, i, + llvm::getSingleElement(adaptor.getStrides()[dynStrideId++])); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } - *descriptor = desc; + llvm::append_range(*descriptor, desc.getElements()); return success(); } }; @@ -1092,15 +1118,15 @@ struct MemRefReshapeOpLowering using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, + matchAndRewrite(memref::ReshapeOp reshapeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type srcType = reshapeOp.getSource().getType(); - Value descriptor; + SmallVector<Value> descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); - rewriter.replaceOp(reshapeOp, {descriptor}); + rewriter.replaceOpWithMultiple(reshapeOp, {descriptor}); return success(); } @@ -1108,21 +1134,19 @@ private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, memref::ReshapeOp reshapeOp, - memref::ReshapeOp::Adaptor adaptor, - Value *descriptor) const { + OneToNOpAdaptor adaptor, + SmallVector<Value> *descriptor) const { auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType()); if (shapeMemRefType.hasStaticShape()) { MemRefType targetMemRefType = cast<MemRefType>(reshapeOp.getResult().getType()); - auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( - typeConverter->convertType(targetMemRefType)); - if (!llvmTargetDescriptorTy) + SmallVector<Type> convertedTypes; + if (failed(typeConverter->convertType(targetMemRefType, convertedTypes))) return failure(); // Create descriptor. Location loc = reshapeOp.getLoc(); - auto desc = - MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; @@ -1188,7 +1212,7 @@ private: stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); } - *descriptor = desc; + llvm::append_range(*descriptor, desc.getElements()); return success(); } @@ -1204,8 +1228,11 @@ private: // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. + SmallVector<Type> convertedTypes; + if (failed(typeConverter->convertType(targetType, convertedTypes))) + return failure(); auto targetDesc = UnrankedMemRefDescriptor::poison( - rewriter, loc, typeConverter->convertType(targetType)); + rewriter, loc, convertedTypes); targetDesc.setRank(rewriter, loc, resultRank); SmallVector<Value, 4> sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), @@ -1303,7 +1330,7 @@ private: // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); - *descriptor = targetDesc; + llvm::append_range(*descriptor, targetDesc.getElements()); return success(); } }; @@ -1315,10 +1342,11 @@ class ReassociatingReshapeOpConversion : public ConvertOpToLLVMPattern<ReshapeOp> { public: using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; - using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; + using ReshapeOpAdaptor = + typename ConvertOpToLLVMPattern<ReshapeOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, + matchAndRewrite(ReshapeOp reshapeOp, ReshapeOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return rewriter.notifyMatchFailure( reshapeOp, @@ -1332,7 +1360,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, + matchAndRewrite(memref::SubViewOp subViewOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return rewriter.notifyMatchFailure( subViewOp, "subview operations should have been expanded beforehand"); @@ -1351,7 +1379,7 @@ public: using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, + matchAndRewrite(memref::TransposeOp transposeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = transposeOp.getLoc(); MemRefDescriptor viewMemRef(adaptor.getIn()); @@ -1360,9 +1388,11 @@ public: if (transposeOp.getPermutation().isIdentity()) return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); - auto targetMemRef = MemRefDescriptor::poison( - rewriter, loc, - typeConverter->convertType(transposeOp.getIn().getType())); + SmallVector<Type> convertedTypes; + if (failed(typeConverter->convertType(transposeOp.getIn().getType(), + convertedTypes))) + return failure(); + auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, convertedTypes); // Copy the base and aligned pointers from the old descriptor to the new // one. @@ -1388,7 +1418,7 @@ public: viewMemRef.stride(rewriter, loc, sourcePos)); } - rewriter.replaceOp(transposeOp, {targetMemRef}); + rewriter.replaceOpWithMultiple(transposeOp, {targetMemRef}); return success(); } }; @@ -1434,17 +1464,19 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { } LogicalResult - matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, + matchAndRewrite(memref::ViewOp viewOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = viewOp.getLoc(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); - auto targetDescTy = typeConverter->convertType(viewMemRefType); - if (!targetDescTy || !targetElementTy || - !LLVM::isCompatibleType(targetElementTy) || - !LLVM::isCompatibleType(targetDescTy)) + SmallVector<Type> targetDescTy; + if (failed(typeConverter->convertType(viewMemRefType, targetDescTy))) + return viewOp.emitWarning("Target descriptor type not converted to LLVM"), + failure(); + // TODO: Check targetDescTy is LLVM compatible. + if (!targetElementTy || !LLVM::isCompatibleType(targetElementTy)) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); @@ -1475,7 +1507,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { alignedPtr = rewriter.create<LLVM::GEPOp>( loc, alignedPtr.getType(), typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, - adaptor.getByteShift()); + llvm::getSingleElement(adaptor.getByteShift())); targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr); @@ -1493,10 +1525,14 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { // Fields 4 and 5: Update sizes and strides. Value stride = nullptr, nextSize = nullptr; + SmallVector<Value> sizes = + llvm::map_to_vector(adaptor.getSizes(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - Value size = getSize(rewriter, loc, viewMemRefType.getShape(), - adaptor.getSizes(), i, indexType); + Value size = getSize(rewriter, loc, viewMemRefType.getShape(), sizes, i, + indexType); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = @@ -1505,7 +1541,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { nextSize = size; } - rewriter.replaceOp(viewOp, {targetMemRef}); + rewriter.replaceOpWithMultiple(viewOp, {targetMemRef.getElements()}); return success(); } }; @@ -1551,7 +1587,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { using Base::Base; LogicalResult - matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, + matchAndRewrite(memref::AtomicRMWOp atomicOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) @@ -1561,11 +1597,15 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { int64_t offset; if (failed(memRefType.getStridesAndOffset(strides, offset))) return failure(); - auto dataPtr = - getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), - adaptor.getIndices(), rewriter); + SmallVector<Value> indices = + llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) { + return llvm::getSingleElement(r); + }); + auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType, + adaptor.getMemref(), indices, rewriter); rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( - atomicOp, *maybeKind, dataPtr, adaptor.getValue(), + atomicOp, *maybeKind, dataPtr, + llvm::getSingleElement(adaptor.getValue()), LLVM::AtomicOrdering::acq_rel); return success(); } @@ -1580,7 +1620,7 @@ public: LogicalResult matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, - OpAdaptor adaptor, + OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { BaseMemRefType sourceTy = extractOp.getSource().getType(); @@ -1616,12 +1656,8 @@ public: LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, - OpAdaptor adaptor, + OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) - return failure(); - // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.getSource()); Location loc = extractStridedMetadataOp.getLoc(); @@ -1629,7 +1665,7 @@ public: auto sourceMemRefType = cast<MemRefType>(source.getType()); int64_t rank = sourceMemRefType.getRank(); - SmallVector<Value> results; + SmallVector<ValueRange> results; results.reserve(2 + rank * 2); // Base buffer. @@ -1639,19 +1675,11 @@ public: rewriter, loc, *getTypeConverter(), cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()), baseBuffer, alignedBuffer); - results.push_back((Value)dstMemRef); - - // Offset. - results.push_back(sourceMemRef.offset(rewriter, loc)); - - // Sizes. - for (unsigned i = 0; i < rank; ++i) - results.push_back(sourceMemRef.size(rewriter, loc, i)); - // Strides. - for (unsigned i = 0; i < rank; ++i) - results.push_back(sourceMemRef.stride(rewriter, loc, i)); - - rewriter.replaceOp(extractStridedMetadataOp, results); + results.push_back(dstMemRef.getElements()); + // Offset, sizes, strides of the source memref. + for (size_t i = 2, e = sourceMemRef.getElements().size(); i < e; ++i) + results.push_back(sourceMemRef.getElements().slice(i, 1)); + rewriter.replaceOpWithMultiple(extractStridedMetadataOp, results); return success(); } }; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 51507c6..4613b90 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1179,8 +1179,10 @@ struct NVGPUTmaCreateDescriptorOpLowering Value tensorElementType = elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType()); - auto promotedOperands = getTypeConverter()->promoteOperands( - b.getLoc(), op->getOperands(), adaptor.getOperands(), b); + llvm_unreachable("TODO"); + SmallVector<Value> promotedOperands; + //auto promotedOperands = getTypeConverter()->promoteOperands( + // b.getLoc(), op->getOperands(), adaptor.getOperands(), b); Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type, makeI64Const(b, 5)); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 213f737..23525cc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -95,7 +95,7 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType, // Add an index vector component to a base pointer. static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, - MemRefType memRefType, Value llvmMemref, Value base, + MemRefType memRefType, ValueRange llvmMemref, Value base, Value index, VectorType vectorType) { assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) && "unsupported memref type"); @@ -185,8 +185,9 @@ public: /// Overloaded utility that replaces a vector.load, vector.store, /// vector.maskedload and vector.maskedstore with their respective LLVM /// couterparts. +template<typename Adaptor> static void replaceLoadOrStoreOp(vector::LoadOp loadOp, - vector::LoadOpAdaptor adaptor, + Adaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align, @@ -194,29 +195,32 @@ static void replaceLoadOrStoreOp(vector::LoadOp loadOp, loadOp.getNontemporal()); } +template<typename Adaptor> static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, - vector::MaskedLoadOpAdaptor adaptor, + Adaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>( - loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); + loadOp, vectorTy, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru()), align); } +template<typename Adaptor> static void replaceLoadOrStoreOp(vector::StoreOp storeOp, - vector::StoreOpAdaptor adaptor, + Adaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(), + rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, llvm::getSingleElement(adaptor.getValueToStore()), ptr, align, /*volatile_=*/false, storeOp.getNontemporal()); } +template<typename Adaptor> static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, - vector::MaskedStoreOpAdaptor adaptor, + Adaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>( - storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); + storeOp, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask()), align); } /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and @@ -225,10 +229,11 @@ template <class LoadOrStoreOp> class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> { public: using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern; + using Adaptor = typename ConvertOpToLLVMPattern<LoadOrStoreOp>::OneToNOpAdaptor; LogicalResult matchAndRewrite(LoadOrStoreOp loadOrStoreOp, - typename LoadOrStoreOp::Adaptor adaptor, + Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only 1-D vectors can be lowered to LLVM. VectorType vectorTy = loadOrStoreOp.getVectorType(); @@ -244,10 +249,11 @@ public: return failure(); // Resolve address. + SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); }); auto vtype = cast<VectorType>( this->typeConverter->convertType(loadOrStoreOp.getVectorType())); Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), - adaptor.getIndices(), rewriter); + indices, rewriter); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align, rewriter); return success(); @@ -261,7 +267,7 @@ public: using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor, + matchAndRewrite(vector::GatherOp gather, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = gather->getLoc(); MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType()); @@ -284,17 +290,18 @@ public: } // Resolve address. + SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); }); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); - Value base = adaptor.getBase(); + indices, rewriter); + ValueRange base = adaptor.getBase(); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - base, ptr, adaptor.getIndexVec(), vType); + base, ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_gather>( - gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), - adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); + gather, typeConverter->convertType(vType), ptrs, llvm::getSingleElement(adaptor.getMask()), + llvm::getSingleElement(adaptor.getPassThru()), rewriter.getI32IntegerAttr(align)); return success(); } }; @@ -306,7 +313,7 @@ public: using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor, + matchAndRewrite(vector::ScatterOp scatter, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); MemRefType memRefType = scatter.getMemRefType(); @@ -328,15 +335,16 @@ public: } // Resolve address. + SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); }); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + indices, rewriter); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); + adaptor.getBase(), ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( - scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), + scatter, llvm::getSingleElement(adaptor.getValueToStore()), ptrs, llvm::getSingleElement(adaptor.getMask()), rewriter.getI32IntegerAttr(align)); return success(); } @@ -349,18 +357,19 @@ public: using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor, + matchAndRewrite(vector::ExpandLoadOp expand, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = expand->getLoc(); MemRefType memRefType = expand.getMemRefType(); // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); + SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); }); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + indices, rewriter); rewriter.replaceOpWithNewOp<LLVM::masked_expandload>( - expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); + expand, vtype, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru())); return success(); } }; @@ -372,17 +381,18 @@ public: using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor, + matchAndRewrite(vector::CompressStoreOp compress, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = compress->getLoc(); MemRefType memRefType = compress.getMemRefType(); // Resolve address. + SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); }); Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), - adaptor.getIndices(), rewriter); + indices, rewriter); rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>( - compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); + compress, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask())); return success(); } }; @@ -1416,7 +1426,7 @@ public: using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor, + matchAndRewrite(vector::TypeCastOp castOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = castOp->getLoc(); MemRefType sourceMemRefType = @@ -1428,15 +1438,10 @@ public: !targetMemRefType.hasStaticShape()) return failure(); - auto llvmSourceDescriptorTy = - dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType()); - if (!llvmSourceDescriptorTy) - return failure(); MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]); - auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( - typeConverter->convertType(targetMemRefType)); - if (!llvmTargetDescriptorTy) + SmallVector<Type> llvmTargetDescriptorTypes; + if (failed(typeConverter->convertType(targetMemRefType, llvmTargetDescriptorTypes))) return failure(); // Only contiguous source buffers supported atm. @@ -1453,7 +1458,7 @@ public: auto int64Ty = IntegerType::get(rewriter.getContext(), 64); // Create descriptor. - auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy); + auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTypes); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); desc.setAllocatedPtr(rewriter, loc, allocated); @@ -1480,7 +1485,7 @@ public: desc.setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(castOp, {desc}); + rewriter.replaceOpWithMultiple(castOp, {desc.getElements()}); return success(); } }; |