diff options
author | Mats Petersson <mats.petersson@arm.com> | 2024-04-08 10:18:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 10:18:14 +0100 |
commit | 221f438af1c1292d787b58da99a5a7b371888456 (patch) | |
tree | 369993eda4431924d7bc72ac2a41fed0ba9cfb27 | |
parent | 364028a1a51689d2b33d3ec50c426fbeac269679 (diff) |
[flang][OpenMP] Add support for complex reductions (#87488)
This adds support for complex type to the OpenMP reductions.
Note that some more work would be needed to give decent error messages when complex
is used in ways that need client supplied functions (e.g. MAX or MIN). It does fail these with
a not so user friendly message at present.
-rw-r--r-- | flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 22 | ||||
-rw-r--r-- | flang/lib/Lower/OpenMP/ReductionProcessor.h | 21 | ||||
-rw-r--r-- | flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 | 50 | ||||
-rw-r--r-- | flang/test/Lower/OpenMP/parallel-reduction-complex.f90 | 50 |
4 files changed, 137 insertions, 6 deletions
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index c1c94119fd90..0453c0152277 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -13,7 +13,9 @@ #include "ReductionProcessor.h" #include "flang/Lower/AbstractConverter.h" +#include "flang/Lower/ConvertType.h" #include "flang/Lower/SymbolMap.h" +#include "flang/Optimizer/Builder/Complex.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRType.h" @@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, fir::FirOpBuilder &builder) { type = fir::unwrapRefType(type); if (!fir::isa_integer(type) && !fir::isa_real(type) && - !mlir::isa<fir::LogicalType>(type)) + !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) TODO(loc, "Reduction of some types is not supported"); switch (redId) { case ReductionIdentifier::MAX: { @@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, case ReductionIdentifier::OR: case ReductionIdentifier::EQV: case ReductionIdentifier::NEQV: + if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { + mlir::Type realTy = + Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); + mlir::Value initRe = builder.createRealConstant( + loc, realTy, getOperationIdentity(redId, loc)); + mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); + + return fir::factory::Complex{builder, loc}.createComplex(type, initRe, + initIm); + } if (type.isa<mlir::FloatType>()) return builder.create<mlir::arith::ConstantOp>( loc, type, @@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner( break; case ReductionIdentifier::ADD: reductionOp = - getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( - builder, type, loc, op1, op2); + getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, + fir::AddcOp>(builder, type, loc, op1, op2); break; case ReductionIdentifier::MULTIPLY: reductionOp = - getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( - builder, type, loc, op1, op2); + getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, + fir::MulcOp>(builder, type, loc, op1, op2); break; case ReductionIdentifier::AND: { mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index ee2732547fc2..7ea252fde360 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -100,6 +100,10 @@ public: static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type, mlir::Location loc, mlir::Value op1, mlir::Value op2); + template <typename FloatOp, typename IntegerOp, typename ComplexOp> + static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2); static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, mlir::Location loc, @@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, mlir::Value op1, mlir::Value op2) { type = fir::unwrapRefType(type); assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); + "only integer, float and complex types are currently supported"); if (type.isIntOrIndex()) return builder.create<IntegerOp>(loc, op1, op2); return builder.create<FloatOp>(loc, op1, op2); } +template <typename FloatOp, typename IntegerOp, typename ComplexOp> +mlir::Value +ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() || + fir::isa_complex(type) && + "only integer, float and complex types are currently supported"); + if (type.isIntOrIndex()) + return builder.create<IntegerOp>(loc, op1, op2); + if (fir::isa_real(type)) + return builder.create<FloatOp>(loc, op1, op2); + return builder.create<ComplexOp>(loc, op1, op2); +} + } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 new file mode 100644 index 000000000000..376defb82358 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init { +!CHECK: ^bb0(%{{.*}}: !fir.complex<8>): +!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index] +!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index] +!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>): +!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8> +!CHECK: omp.yield(%[[RES]] : !fir.complex<8>) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_complex_mul +!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}} +!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index] +!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index] +!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>> +!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64 +!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index] +!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index] +!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8> +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +subroutine simple_complex_mul + complex(8) :: c + c = 0 + + !$omp parallel reduction(*:c) + c = c * cmplx(1, -2) + !$omp end parallel + + print *, c +end subroutine diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 new file mode 100644 index 000000000000..bc5a6b475e25 --- /dev/null +++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init { +!CHECK: ^bb0(%{{.*}}: !fir.complex<8>): +!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index] +!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index] +!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>): +!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8> +!CHECK: omp.yield(%[[RES]] : !fir.complex<8>) +!CHECK: } + +!CHECK-LABEL: func.func @_QPsimple_complex_add +!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}} +!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index] +!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index] +!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) { +!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>) +!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>> +!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64 +!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64 +!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8> +!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index] +!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index] +!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8> +!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>> +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +subroutine simple_complex_add + complex(8) :: c + c = 0 + + !$omp parallel reduction(+:c) + c = c + 1 + !$omp end parallel + + print *, c +end subroutine |