diff options
author | Akash Banerjee <akash.banerjee@amd.com> | 2025-08-20 18:18:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-08-20 17:18:30 +0000 |
commit | d69ccded4ff14644245990e4ecc4f96e9610dd3d (patch) | |
tree | f9d2cb36f42e70f9f8c3fcba2139a54fbab4af3b | |
parent | 65de318d186c815f43b892aa20b98c50f22ab6fe (diff) | |
download | llvm-d69ccded4ff14644245990e4ecc4f96e9610dd3d.zip llvm-d69ccded4ff14644245990e4ecc4f96e9610dd3d.tar.gz llvm-d69ccded4ff14644245990e4ecc4f96e9610dd3d.tar.bz2 |
[MLIR] Add cpow support in ComplexToROCDLLibraryCalls (#153183)
This PR adds support for complex power operations (`cpow`) in the
`ComplexToROCDLLibraryCalls` conversion pass, specifically targeting
AMDGPU architectures. The implementation optimises complex
exponentiation by using mathematical identities and special-case
handling for small integer powers.
- Force lowering to `complex.pow` operations for the `amdgcn-amd-amdhsa`
target instead of using library calls
- Convert `complex.pow(z, w)` to `complex.exp(w * complex.log(z))` using
mathematical identity
4 files changed, 87 insertions, 30 deletions
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 8aacdb1..cc3ae31 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1287,6 +1287,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } +mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, + const MathOperation &mathOp, + mlir::FunctionType mathLibFuncType, + llvm::ArrayRef<mlir::Value> args) { + bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); + if (!isAMDGPU) + return genLibCall(builder, loc, mathOp, mathLibFuncType, args); + + auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0)); + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, args[1]); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + mlir::Value complexExp = + builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); + mlir::Value result = + builder.create<mlir::complex::PowOp>(loc, args[0], complexExp); + result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); + return result; +} + /// Mapping between mathematical intrinsic operations and MLIR operations /// of some appropriate dialect (math, complex, etc.) or libm calls. /// TODO: support remaining Fortran math intrinsics. @@ -1636,15 +1656,19 @@ static constexpr MathOperation mathOperations[] = { genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, genMathOp<mlir::math::FPowIOp>}, {"pow", RTNAME_STRING(cpowi), - genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall}, + genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, + genComplexPow}, {"pow", RTNAME_STRING(zpowi), - genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall}, + genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, + genComplexPow}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, genLibF128Call}, {"pow", RTNAME_STRING(cpowk), - genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall}, + genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, + genComplexPow}, {"pow", RTNAME_STRING(zpowk), - genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall}, + genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, + genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, genLibF128Call}, {"remainder", "remainderf", @@ -4044,21 +4068,20 @@ void IntrinsicLibrary::genExecuteCommandLine( mlir::Value waitAddr = fir::getBase(wait); mlir::Value waitIsPresentAtRuntime = builder.genIsNotNullAddr(loc, waitAddr); - waitBool = builder - .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, - /*withElseRegion=*/true) - .genThen([&]() { - auto waitLoad = - fir::LoadOp::create(builder, loc, waitAddr); - mlir::Value cast = - builder.createConvert(loc, i1Ty, waitLoad); - fir::ResultOp::create(builder, loc, cast); - }) - .genElse([&]() { - mlir::Value trueVal = builder.createBool(loc, true); - fir::ResultOp::create(builder, loc, trueVal); - }) - .getResults()[0]; + waitBool = + builder + .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, + /*withElseRegion=*/true) + .genThen([&]() { + auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr); + mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad); + fir::ResultOp::create(builder, loc, cast); + }) + .genElse([&]() { + mlir::Value trueVal = builder.createBool(loc, true); + fir::ResultOp::create(builder, loc, trueVal); + }) + .getResults()[0]; } mlir::Value exitstatBox = diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index f15c7db..4ee5de4 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -1,21 +1,27 @@ ! REQUIRES: amdgpu-registered-target -! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s +! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s +! CHECK-LABEL: func @_QPcabsf_test( +! CHECK: complex.abs +! CHECK-NOT: fir.call @cabsf subroutine cabsf_test(a, b) complex :: a real :: b b = abs(a) end subroutine -! CHECK-LABEL: func @_QPcabsf_test( -! CHECK: complex.abs -! CHECK-NOT: fir.call @cabsf - +! CHECK-LABEL: func @_QPcexpf_test( +! CHECK: complex.exp +! CHECK-NOT: fir.call @cexpf subroutine cexpf_test(a, b) complex :: a, b b = exp(a) end subroutine -! CHECK-LABEL: func @_QPcexpf_test( -! CHECK: complex.exp -! CHECK-NOT: fir.call @cexpf +! CHECK-LABEL: func @_QPpow_test( +! CHECK: complex.pow +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine pow_test(a, b, c) + complex :: a, b, c + a = b**c +end subroutine pow_test diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index b3d6d59..7a3a7fd 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -56,10 +56,26 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> { private: std::string funcName; }; + +// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z)) +struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> { + using OpRewritePattern<complex::PowOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs()); + Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase); + Value exp = rewriter.create<complex::ExpOp>(loc, mul); + rewriter.replaceOp(op, exp); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( @@ -110,9 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); + target.addLegalOp<complex::MulOp>(); target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp, - complex::LogOp, complex::SinOp, complex::SqrtOp, - complex::TanOp, complex::TanhOp>(); + complex::LogOp, complex::PowOp, complex::SinOp, + complex::SqrtOp, complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index 82936d8..080ba4f 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s +// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64 @@ -57,6 +57,17 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp return %lf, %ld : complex<f32>, complex<f64> } +//CHECK-LABEL: @pow_caller +//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>) +func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> { + // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) + // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]] + // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) + // CHECK: return %[[EXP]] + %r = complex.pow %z, %w : complex<f32> + return %r : complex<f32> +} + //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) |