diff options
author | jeanPerier <jperier@nvidia.com> | 2024-07-24 10:24:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-24 10:24:04 +0200 |
commit | 1ead51a86c6c746a1b9948ca1ee142df223ffebd (patch) | |
tree | 609c6c0863229448b8d6bb3835a154555e5e7664 | |
parent | a3de21cac1fb8f1dd98cfe1d1443e2d3f0a97351 (diff) | |
download | llvm-1ead51a86c6c746a1b9948ca1ee142df223ffebd.zip llvm-1ead51a86c6c746a1b9948ca1ee142df223ffebd.tar.gz llvm-1ead51a86c6c746a1b9948ca1ee142df223ffebd.tar.bz2 |
[flang] fix C_PTR function result lowering (#100082)
Functions returning C_PTR were lowered to function returning intptr (i64
on 64bit arch). This caused conflicts when these functions were defined
as returning !fir.ref<none>/llvm.ptr in other compiler generated
contexts (e.g., malloc).
Lower them to return !fir.ref<none>.
This should deal with https://github.com/llvm/llvm-project/issues/97325
and https://github.com/llvm/llvm-project/issues/98644.
-rw-r--r-- | flang/lib/Optimizer/Builder/FIRBuilder.cpp | 54 | ||||
-rw-r--r-- | flang/lib/Optimizer/Transforms/AbstractResult.cpp | 108 | ||||
-rw-r--r-- | flang/test/Fir/abstract-results.fir | 36 |
3 files changed, 110 insertions, 88 deletions
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 2961df9..fbe79d0 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -1541,21 +1541,44 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder, zero); } +static std::pair<mlir::Value, mlir::Type> +genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type cptrTy) { + auto recTy = mlir::cast<fir::RecordType>(cptrTy); + assert(recTy.getTypeList().size() == 1); + auto addrFieldName = recTy.getTypeList()[0].first; + mlir::Type addrFieldTy = recTy.getTypeList()[0].second; + auto fieldIndexType = fir::FieldType::get(cptrTy.getContext()); + mlir::Value addrFieldIndex = builder.create<fir::FieldIndexOp>( + loc, fieldIndexType, addrFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + return {addrFieldIndex, addrFieldTy}; +} + mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value cPtr, mlir::Type ty) { - assert(mlir::isa<fir::RecordType>(ty)); - auto recTy = mlir::dyn_cast<fir::RecordType>(ty); - assert(recTy.getTypeList().size() == 1); - auto fieldName = recTy.getTypeList()[0].first; - mlir::Type fieldTy = recTy.getTypeList()[0].second; - auto fieldIndexType = fir::FieldType::get(ty.getContext()); - mlir::Value field = - builder.create<fir::FieldIndexOp>(loc, fieldIndexType, fieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - return builder.create<fir::CoordinateOp>(loc, builder.getRefType(fieldTy), - cPtr, field); + auto [addrFieldIndex, addrFieldTy] = + genCPtrOrCFunptrFieldIndex(builder, loc, ty); + return builder.create<fir::CoordinateOp>(loc, builder.getRefType(addrFieldTy), + cPtr, addrFieldIndex); +} + +mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value cPtr) { + mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType()); + if (fir::isa_ref_type(cPtr.getType())) { + mlir::Value cPtrAddr = + fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy); + return builder.create<fir::LoadOp>(loc, cPtrAddr); + } + auto [addrFieldIndex, addrFieldTy] = + genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy); + auto arrayAttr = + builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)}); + return builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr, arrayAttr); } fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder, @@ -1596,15 +1619,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder, return fir::BoxValue(box, lbounds, explicitTypeParams); } -mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value cPtr) { - mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType()); - mlir::Value cPtrAddr = - fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy); - return builder.create<fir::LoadOp>(loc, cPtrAddr); -} - mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type boxType) { diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 3906aa5..ff37310 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, /*resultTypes=*/{}); } +static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return fir::ReferenceType::get(mlir::NoneType::get(context)); +} + /// This is for function result types that are of type C_PTR from ISO_C_BINDING. /// Follow the ABI for interoperability with C. static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { - auto resultType = funcTy.getResult(0); - assert(fir::isa_builtin_cptr_type(resultType)); - llvm::SmallVector<mlir::Type> outputTypes; - auto recTy = mlir::dyn_cast<fir::RecordType>(resultType); - outputTypes.emplace_back(recTy.getTypeList()[0].second); + assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); + llvm::SmallVector<mlir::Type> outputTypes{ + getVoidPtrType(funcTy.getContext())}; return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), outputTypes); } @@ -109,15 +111,11 @@ public: saveResult.getTypeparams()); llvm::SmallVector<mlir::Type> newResultTypes; - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); - Op newOp; - if (isResultBuiltinCPtr) { - auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType()); - newResultTypes.emplace_back(recTy.getTypeList()[0].second); - } + if (isResultBuiltinCPtr) + newResultTypes.emplace_back(getVoidPtrType(result.getContext())); + Op newOp; // fir::CallOp specific handling. if constexpr (std::is_same_v<Op, fir::CallOp>) { if (op.getCallee()) { @@ -175,7 +173,7 @@ public: FirOpBuilder builder(rewriter, module); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); - rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr); + builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); } op->dropAllReferences(); rewriter.eraseOp(op); @@ -210,42 +208,52 @@ public: mlir::PatternRewriter &rewriter) const override { auto loc = ret.getLoc(); rewriter.setInsertionPoint(ret); - auto returnedValue = ret.getOperand(0); - bool replacedStorage = false; - if (auto *op = returnedValue.getDefiningOp()) - if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { - auto resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>( - resultStorage.getDefiningOp())) - resultStorage = declare.getMemref(); - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. - if (fir::isa_builtin_cptr_type(returnedValue.getType())) { - rewriter.eraseOp(load); - auto module = ret->getParentOfType<mlir::ModuleOp>(); - FirOpBuilder builder(rewriter, module); - mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( - builder, loc, resultStorage, returnedValue.getType()); - mlir::Value retValue = rewriter.create<fir::LoadOp>( - loc, fir::unwrapRefType(retAddr.getType()), retAddr); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( - ret, mlir::ValueRange{retValue}); - return mlir::success(); - } - resultStorage.replaceAllUsesWith(newArg); - replacedStorage = true; - if (auto *alloc = resultStorage.getDefiningOp()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->getParentOfType<mlir::ModuleOp>(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); } - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result storage. - // at the return point. - if (!replacedStorage) - rewriter.create<fir::StoreOp>(loc, returnedValue, newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( + ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create<fir::StoreOp>(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); return mlir::success(); } @@ -263,8 +271,6 @@ public: mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType()); mlir::FunctionType newFuncTy; - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. if (oldFuncTy.getNumResults() != 0 && fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) newFuncTy = getCPtrFunctionType(oldFuncTy); @@ -298,8 +304,6 @@ public: // Convert function type itself if it has an abstract result. auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType()); if (hasAbstractResult(funcTy)) { - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { func.setType(getCPtrFunctionType(funcTy)); patterns.insert<ReturnOpConversion>(context, mlir::Value{}); diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir index 82f1cd3..93e63dc 100644 --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> { // FUNC-BOX: return } -// FUNC-REF-LABEL: func @retcptr() -> i64 -// FUNC-BOX-LABEL: func @retcptr() -> i64 +// FUNC-REF-LABEL: func @retcptr() -> !fir.ref<none> +// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref<none> func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> { %0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"} %1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>> @@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres // FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64> // FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64> - // FUNC-REF: return %[[VAL]] : i64 + // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none> + // FUNC-REF: return %[[CAST]] : !fir.ref<none> // FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"} // FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64> // FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64> - // FUNC-BOX: return %[[VAL]] : i64 + // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none> + // FUNC-BOX: return %[[CAST]] : !fir.ref<none> } // FUNC-REF-LABEL: func private @arrayfunc_callee_declare( @@ -311,8 +313,8 @@ func.func @test_address_of() { } -// FUNC-REF-LABEL: func.func private @returns_null() -> i64 -// FUNC-BOX-LABEL: func.func private @returns_null() -> i64 +// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref<none> +// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref<none> func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF-LABEL: func @test_address_of_cptr @@ -323,12 +325,12 @@ func.func @test_address_of_cptr() { fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> () return - // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64 - // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) + // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none> + // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ()) // FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> () - // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64 - // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) + // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none> + // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ()) // FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> () } @@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) { // FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"} // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) - // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64) - // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64 + // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>) + // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none> // FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64> - // FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64> + // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64 + // FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64> // FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"} // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) - // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64) - // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64 + // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>) + // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none> // FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64> - // FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64> + // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64 + // FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64> } // ----------------------- Test GlobalOp rewrite ------------------------ |