summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMats Petersson <mats.petersson@arm.com>2024-04-08 10:18:14 +0100
committerGitHub <noreply@github.com>2024-04-08 10:18:14 +0100
commit221f438af1c1292d787b58da99a5a7b371888456 (patch)
tree369993eda4431924d7bc72ac2a41fed0ba9cfb27
parent364028a1a51689d2b33d3ec50c426fbeac269679 (diff)
[flang][OpenMP] Add support for complex reductions (#87488)
This adds support for complex type to the OpenMP reductions. Note that some more work would be needed to give decent error messages when complex is used in ways that need client supplied functions (e.g. MAX or MIN). It does fail these with a not so user friendly message at present.
-rw-r--r--flang/lib/Lower/OpenMP/ReductionProcessor.cpp22
-rw-r--r--flang/lib/Lower/OpenMP/ReductionProcessor.h21
-rw-r--r--flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f9050
-rw-r--r--flang/test/Lower/OpenMP/parallel-reduction-complex.f9050
4 files changed, 137 insertions, 6 deletions
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd90..0453c0152277 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
#include "ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
#include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
- !mlir::isa<fir::LogicalType>(type))
+ !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
@@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
+ if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+ mlir::Type realTy =
+ Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+ mlir::Value initRe = builder.createRealConstant(
+ loc, realTy, getOperationIdentity(redId, loc));
+ mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+ return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+ initIm);
+ }
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
@@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
break;
case ReductionIdentifier::ADD:
reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+ fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+ fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc2..7ea252fde360 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -100,6 +100,10 @@ public:
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2);
+ template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
mlir::Location loc,
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
+ "only integer, float and complex types are currently supported");
if (type.isIntOrIndex())
return builder.create<IntegerOp>(loc, op1, op2);
return builder.create<FloatOp>(loc, op1, op2);
}
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() ||
+ fir::isa_complex(type) &&
+ "only integer, float and complex types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ if (fir::isa_real(type))
+ return builder.create<FloatOp>(loc, op1, op2);
+ return builder.create<ComplexOp>(loc, op1, op2);
+}
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 000000000000..376defb82358
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_mul
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(*:c)
+ c = c * cmplx(1, -2)
+ !$omp end parallel
+
+ print *, c
+end subroutine
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 000000000000..bc5a6b475e25
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_add
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(+:c)
+ c = c + 1
+ !$omp end parallel
+
+ print *, c
+end subroutine