aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--flang/lib/Lower/OpenMP/ReductionProcessor.cpp22
-rw-r--r--flang/lib/Lower/OpenMP/ReductionProcessor.h21
-rw-r--r--flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f9050
-rw-r--r--flang/test/Lower/OpenMP/parallel-reduction-complex.f9050
4 files changed, 137 insertions, 6 deletions
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c9411..0453c01 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
#include "ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
#include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
- !mlir::isa<fir::LogicalType>(type))
+ !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
@@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
+ if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+ mlir::Type realTy =
+ Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+ mlir::Value initRe = builder.createRealConstant(
+ loc, realTy, getOperationIdentity(redId, loc));
+ mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+ return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+ initIm);
+ }
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
@@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
break;
case ReductionIdentifier::ADD:
reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+ fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+ fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee27325..7ea252f 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -100,6 +100,10 @@ public:
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2);
+ template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
mlir::Location loc,
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
+ "only integer, float and complex types are currently supported");
if (type.isIntOrIndex())
return builder.create<IntegerOp>(loc, op1, op2);
return builder.create<FloatOp>(loc, op1, op2);
}
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() ||
+ fir::isa_complex(type) &&
+ "only integer, float and complex types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ if (fir::isa_real(type))
+ return builder.create<FloatOp>(loc, op1, op2);
+ return builder.create<ComplexOp>(loc, op1, op2);
+}
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 0000000..376defb
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_mul
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(*:c)
+ c = c * cmplx(1, -2)
+ !$omp end parallel
+
+ print *, c
+end subroutine
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 0000000..bc5a6b4
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_add
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(+:c)
+ c = c + 1
+ !$omp end parallel
+
+ print *, c
+end subroutine