diff options
-rw-r--r-- | flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 22 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/ReductionProcessor.h | 21 | ||||
-rw-r--r-- | flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 | 50 | ||||
-rw-r--r-- | flang/test/Lower/OpenMP/parallel-reduction-complex.f90 | 50 |
4 files changed, 137 insertions, 6 deletions
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index c1c9411..0453c01 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -13,7 +13,9 @@ #include "ReductionProcessor.h" #include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/ConvertType.h" #include "flang/Lower/SymbolMap.h" +#include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRType.h" @@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, fir::FirOpBuilder &builder) { type = fir::unwrapRefType(type); if (!fir::isa_integer(type) && !fir::isa_real(type) && - !mlir::isa<fir::LogicalType>(type)) + !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) TODO(loc, "Reduction of some types is not supported"); switch (redId) { case ReductionIdentifier::MAX: { @@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, case ReductionIdentifier::OR: case ReductionIdentifier::EQV: case ReductionIdentifier::NEQV: + if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { + mlir::Type realTy = + Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); + mlir::Value initRe = builder.createRealConstant( + loc, realTy, getOperationIdentity(redId, loc)); + mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); + + return fir::factory::Complex{builder, loc}.createComplex(type, initRe, + initIm); + } if (type.isa<mlir::FloatType>()) return builder.create<mlir::arith::ConstantOp>( loc, type, @@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner( break; case ReductionIdentifier::ADD: reductionOp = - getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( - builder, type, loc, op1, op2); + getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, + fir::AddcOp>(builder, type, loc, op1, op2); break; case ReductionIdentifier::MULTIPLY: reductionOp = - getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( - builder, type, loc, op1, op2); + getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, + fir::MulcOp>(builder, type, loc, op1, op2); break; case ReductionIdentifier::AND: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index ee27325..7ea252f 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -100,6 +100,10 @@ public: static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type, mlir::Location loc, mlir::Value op1, mlir::Value op2); + template <typename FloatOp, typename IntegerOp, typename ComplexOp> + static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2); static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, mlir::Location loc, @@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, mlir::Value op1, mlir::Value op2) { type = fir::unwrapRefType(type); assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); + "only integer, float and complex types are currently supported"); if (type.isIntOrIndex()) return builder.create<IntegerOp>(loc, op1, op2); return builder.create<FloatOp>(loc, op1, op2); } +template <typename FloatOp, typename IntegerOp, typename ComplexOp> +mlir::Value +ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() || + fir::isa_complex(type) && + "only integer, float and complex types are currently supported"); + if (type.isIntOrIndex()) + return builder.create<IntegerOp>(loc, op1, op2); + if (fir::isa_real(type)) + return builder.create<FloatOp>(loc, op1, op2); + return builder.create<ComplexOp>(loc, op1, op2); +} + } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 new file mode 100644 index 0000000..376defb --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init { +!CHECK: ^bb0(%{{.*}}: !fir.complex<8>): +!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index] +!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index] +!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>): +!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8> +!CHECK: omp.yield(%[[RES]] : !fir.complex<8>) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_complex_mul +!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}} +!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index] +!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index] +!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>> +!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64 +!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index] +!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index] +!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8> +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +subroutine simple_complex_mul + complex(8) :: c + c = 0 + + !$omp parallel reduction(*:c) + c = c * cmplx(1, -2) + !$omp end parallel + + print *, c +end subroutine diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 new file mode 100644 index 0000000..bc5a6b4 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init { +!CHECK: ^bb0(%{{.*}}: !fir.complex<8>): +!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index] +!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index] +!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>): +!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8> +!CHECK: omp.yield(%[[RES]] : !fir.complex<8>) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_complex_add +!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}} +!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index] +!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index] +!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>> +!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index] +!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index] +!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8> +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +subroutine simple_complex_add + complex(8) :: c + c = 0 + + !$omp parallel reduction(+:c) + c = c + 1 + !$omp end parallel + + print *, c +end subroutine |