diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 052a48c..3e6fcc0 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -684,6 +684,62 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( return success(); } +LogicalResult GPUReturnOpLowering::matchAndRewrite( + gpu::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + unsigned numArguments = op.getNumOperands(); + SmallVector<Value, 4> updatedOperands; + + bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv; + if (useBarePtrCallConv) { + // For the bare-ptr calling convention, extract the aligned pointer to + // be returned from the memref descriptor. + for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { + Type oldTy = std::get<0>(it).getType(); + Value newOperand = std::get<1>(it); + if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( + cast<BaseMemRefType>(oldTy))) { + MemRefDescriptor memrefDesc(newOperand); + newOperand = memrefDesc.allocatedPtr(rewriter, loc); + } else if (isa<UnrankedMemRefType>(oldTy)) { + // Unranked memref is not supported in the bare pointer calling + // convention. + return failure(); + } + updatedOperands.push_back(newOperand); + } + } else { + updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); + (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), + updatedOperands, + /*toDynamic=*/true); + } + + // If ReturnOp has 0 or 1 operand, create it and return immediately. + if (numArguments <= 1) { + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( + op, TypeRange(), updatedOperands, op->getAttrs()); + return success(); + } + + // Otherwise, we need to pack the arguments into an LLVM struct type before + // returning. + auto packedType = getTypeConverter()->packFunctionResults( + op.getOperandTypes(), useBarePtrCallConv); + if (!packedType) { + return rewriter.notifyMatchFailure(op, "could not convert result types"); + } + + Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType); + for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { + packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); + } + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, + op->getAttrs()); + return success(); +} + void mlir::populateGpuMemorySpaceAttributeConversions( TypeConverter &typeConverter, const MemorySpaceMapping &mapping) { typeConverter.addTypeAttributeConversion( |