aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp56
1 files changed, 56 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 052a48c..3e6fcc0 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -684,6 +684,62 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
return success();
}
+LogicalResult GPUReturnOpLowering::matchAndRewrite(
+ gpu::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ unsigned numArguments = op.getNumOperands();
+ SmallVector<Value, 4> updatedOperands;
+
+ bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
+ if (useBarePtrCallConv) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
+ for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+ Type oldTy = std::get<0>(it).getType();
+ Value newOperand = std::get<1>(it);
+ if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
+ cast<BaseMemRefType>(oldTy))) {
+ MemRefDescriptor memrefDesc(newOperand);
+ newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+ } else if (isa<UnrankedMemRefType>(oldTy)) {
+ // Unranked memref is not supported in the bare pointer calling
+ // convention.
+ return failure();
+ }
+ updatedOperands.push_back(newOperand);
+ }
+ } else {
+ updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+ (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
+ updatedOperands,
+ /*toDynamic=*/true);
+ }
+
+ // If ReturnOp has 0 or 1 operand, create it and return immediately.
+ if (numArguments <= 1) {
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+ op, TypeRange(), updatedOperands, op->getAttrs());
+ return success();
+ }
+
+ // Otherwise, we need to pack the arguments into an LLVM struct type before
+ // returning.
+ auto packedType = getTypeConverter()->packFunctionResults(
+ op.getOperandTypes(), useBarePtrCallConv);
+ if (!packedType) {
+ return rewriter.notifyMatchFailure(op, "could not convert result types");
+ }
+
+ Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
+ for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
+ packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
+ }
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
+ op->getAttrs());
+ return success();
+}
+
void mlir::populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(