diff options
Diffstat (limited to 'flang/lib/Lower/OpenMP/OpenMP.cpp')
-rw-r--r-- | flang/lib/Lower/OpenMP/OpenMP.cpp | 2129 |
1 files changed, 1076 insertions, 1053 deletions
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 3dcfe0fd775d..9b9975223666 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -103,21 +103,6 @@ static fir::GlobalOp globalInitialization( return global; } -static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (mlir::Value reductionOperand : reductionOp->getOperands()) { - if (mlir::Operation *compareOp = reductionOperand.getDefiningOp()) { - if (compareOp->getOperand(0) == loadVal || - compareOp->getOperand(1) == loadVal) - assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) || - mlir::isa<mlir::arith::CmpFOp>(compareOp)) && - "Expected comparison not found in reduction intrinsic"); - return compareOp; - } - } - return nullptr; -} - // Get the extended value for \p val by extracting additional variable // information from \p base. static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, @@ -237,213 +222,276 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter, return storeOp; } -static mlir::Operation * -findReductionChain(mlir::Value loadVal, mlir::Value *reductionVal = nullptr) { - for (mlir::OpOperand &loadOperand : loadVal.getUses()) { - if (mlir::Operation *reductionOp = loadOperand.getOwner()) { - if (auto convertOp = mlir::dyn_cast<fir::ConvertOp>(reductionOp)) { - for (mlir::OpOperand &convertOperand : convertOp.getRes().getUses()) { - if (mlir::Operation *reductionOp = convertOperand.getOwner()) - return reductionOp; - } - } - for (mlir::OpOperand &reductionOperand : reductionOp->getUses()) { - if (auto store = - mlir::dyn_cast<fir::StoreOp>(reductionOperand.getOwner())) { - if (store.getMemref() == *reductionVal) { - store.erase(); - return reductionOp; - } - } - if (auto assign = - mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) { - if (assign.getLhs() == *reductionVal) { - assign.erase(); - return reductionOp; - } - } - } +// This helper function implements the functionality of "promoting" +// non-CPTR arguments of use_device_ptr to use_device_addr +// arguments (automagic conversion of use_device_ptr -> +// use_device_addr in these cases). The way we do so currently is +// through the shuffling of operands from the devicePtrOperands to +// deviceAddrOperands where neccesary and re-organizing the types, +// locations and symbols to maintain the correct ordering of ptr/addr +// input -> BlockArg. +// +// This effectively implements some deprecated OpenMP functionality +// that some legacy applications unfortunately depend on +// (deprecated in specification version 5.2): +// +// "If a list item in a use_device_ptr clause is not of type C_PTR, +// the behavior is as if the list item appeared in a use_device_addr +// clause. Support for such list items in a use_device_ptr clause +// is deprecated." +static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( + mlir::omp::UseDeviceClauseOps &clauseOps, + llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, + llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> + &useDeviceSymbols) { + auto moveElementToBack = [](size_t idx, auto &vector) { + auto *iter = std::next(vector.begin(), idx); + vector.push_back(*iter); + vector.erase(iter); + }; + + // Iterate over our use_device_ptr list and shift all non-cptr arguments into + // use_device_addr. + for (auto *it = clauseOps.useDevicePtrVars.begin(); + it != clauseOps.useDevicePtrVars.end();) { + if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { + clauseOps.useDeviceAddrVars.push_back(*it); + // We have to shuffle the symbols around as well, to maintain + // the correct Input -> BlockArg for use_device_ptr/use_device_addr. + // NOTE: However, as map's do not seem to be included currently + // this isn't as pertinent, but we must try to maintain for + // future alterations. I believe the reason they are not currently + // is that the BlockArg assign/lowering needs to be extended + // to a greater set of types. + auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); + moveElementToBack(idx, useDeviceTypes); + moveElementToBack(idx, useDeviceLocs); + moveElementToBack(idx, useDeviceSymbols); + it = clauseOps.useDevicePtrVars.erase(it); + continue; } + ++it; } - return nullptr; } -// for a logical operator 'op' reduction X = X op Y -// This function returns the operation responsible for converting Y from -// fir.logical<4> to i1 -static fir::ConvertOp getConvertFromReductionOp(mlir::Operation *reductionOp, - mlir::Value loadVal) { - for (mlir::Value reductionOperand : reductionOp->getOperands()) { - if (auto convertOp = - mlir::dyn_cast<fir::ConvertOp>(reductionOperand.getDefiningOp())) { - if (convertOp.getOperand() == loadVal) - continue; - return convertOp; +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static void getDeclareTargetInfo( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + mlir::omp::DeclareTargetClauseOps &clauseOps, + llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { + const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>( + declareTargetConstruct.t); + if (const auto *objectList{ + Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) { + ObjectList objects{makeObjects(*objectList, semaCtx)}; + // Case: declare target(func, var1, var2) + gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, + symbolAndClause); + } else if (const auto *clauseList{ + Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>( + spec.u)}) { + if (clauseList->v.empty()) { + // Case: declare target, implicit capture of function + symbolAndClause.emplace_back( + mlir::omp::DeclareTargetCaptureClause::to, + eval.getOwningProcedure()->getSubprogramSymbol()); } + + ClauseProcessor cp(converter, semaCtx, *clauseList); + cp.processDeviceType(clauseOps); + cp.processEnter(symbolAndClause); + cp.processLink(symbolAndClause); + cp.processTo(symbolAndClause); + + cp.processTODO<clause::Indirect>(converter.getCurrentLocation(), + llvm::omp::Directive::OMPD_declare_target); } - return nullptr; } -static void updateReduction(mlir::Operation *op, - fir::FirOpBuilder &firOpBuilder, - mlir::Value loadVal, mlir::Value reductionVal, - fir::ConvertOp *convertOp = nullptr) { - mlir::OpBuilder::InsertPoint insertPtDel = firOpBuilder.saveInsertionPoint(); - firOpBuilder.setInsertionPoint(op); +static void collectDeferredDeclareTargets( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo> + &deferredDeclareTarget) { + mlir::omp::DeclareTargetClauseOps clauseOps; + llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - mlir::Value reductionOp; - if (convertOp) - reductionOp = convertOp->getOperand(); - else if (op->getOperand(0) == loadVal) - reductionOp = op->getOperand(1); - else - reductionOp = op->getOperand(0); - - firOpBuilder.create<mlir::omp::ReductionOp>(op->getLoc(), reductionOp, - reductionVal); - firOpBuilder.restoreInsertionPoint(insertPtDel); -} - -static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { - for (mlir::Operation *reductionOpUse : reductionOp->getUsers()) { - if (auto convertReduction = - mlir::dyn_cast<fir::ConvertOp>(reductionOpUse)) { - for (mlir::Operation *convertReductionUse : - convertReduction.getRes().getUsers()) { - if (auto storeOp = mlir::dyn_cast<fir::StoreOp>(convertReductionUse)) { - if (storeOp.getMemref() == symVal) - storeOp.erase(); - } - if (auto assignOp = - mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) { - if (assignOp.getLhs() == symVal) - assignOp.erase(); - } - } + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol(converter.mangleName( + std::get<const Fortran::semantics::Symbol &>(symClause))); + + if (!op) { + deferredDeclareTarget.push_back({std::get<0>(symClause), + clauseOps.deviceType, + std::get<1>(symClause)}); } } } -// Generate an OpenMP reduction operation. -// TODO: Currently assumes it is either an integer addition/multiplication -// reduction, or a logical and reduction. Generalize this for various reduction -// operation types. -// TODO: Generate the reduction operation during lowering instead of creating -// and removing operations since this is not a robust approach. Also, removing -// ops in the builder (instead of a rewriter) is probably not the best approach. -static void -genOpenMPReduction(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { +static std::optional<mlir::omp::DeclareTargetDeviceType> +getDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + mlir::omp::DeclareTargetClauseOps clauseOps; + llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; + getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, + clauseOps, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (const DeclareTargetCapturePair &symClause : symbolAndClause) { + mlir::Operation *op = mod.lookupSymbol(converter.mangleName( + std::get<const Fortran::semantics::Symbol &>(symClause))); + + if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op)) + return clauseOps.deviceType; + } + + return std::nullopt; +} + +static llvm::SmallVector<const Fortran::semantics::Symbol *> +genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef<const Fortran::semantics::Symbol *> args) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + auto ®ion = op->getRegion(0); - List<Clause> clauses{makeClauses(clauseList, semaCtx)}; - - for (const Clause &clause : clauses) { - if (const auto &reductionClause = - std::get_if<clause::Reduction>(&clause.u)) { - const auto &redOperatorList{ - std::get<clause::Reduction::ReductionIdentifiers>( - reductionClause->t)}; - assert(redOperatorList.size() == 1 && "Expecting single operator"); - const auto &redOperator = redOperatorList.front(); - const auto &objects{std::get<ObjectList>(reductionClause->t)}; - if (const auto *reductionOp = - std::get_if<clause::DefinedOperator>(&redOperator.u)) { - const auto &intrinsicOp{ - std::get<clause::DefinedOperator::IntrinsicOperator>( - reductionOp->u)}; - - switch (intrinsicOp) { - case clause::DefinedOperator::IntrinsicOperator::Add: - case clause::DefinedOperator::IntrinsicOperator::Multiply: - case clause::DefinedOperator::IntrinsicOperator::AND: - case clause::DefinedOperator::IntrinsicOperator::EQV: - case clause::DefinedOperator::IntrinsicOperator::OR: - case clause::DefinedOperator::IntrinsicOperator::NEQV: - break; - default: - continue; - } - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - mlir::Type reductionType = - reductionVal.getType().cast<fir::ReferenceType>().getEleTy(); - if (!reductionType.isa<fir::LogicalType>()) { - if (!reductionType.isIntOrIndexOrFloat()) - continue; - } - for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - if (reductionType.isa<fir::LogicalType>()) { - mlir::Operation *reductionOp = findReductionChain(loadVal); - fir::ConvertOp convertOp = - getConvertFromReductionOp(reductionOp, loadVal); - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal, &convertOp); - removeStoreOp(reductionOp, reductionVal); - } else if (mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal)) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if<clause::ProcedureDesignator>(&redOperator.u)) { - if (!ReductionProcessor::supportedIntrinsicProcReduction( - *reductionIntrinsic)) - continue; - ReductionProcessor::ReductionIdentifier redId = - ReductionProcessor::getReductionType(*reductionIntrinsic); - for (const Object &object : objects) { - if (const Fortran::semantics::Symbol *symbol = object.id()) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = - mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; - - if (redId == ReductionProcessor::ReductionIdentifier::MAX || - redId == ReductionProcessor::ReductionIdentifier::MIN) { - assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redId == ReductionProcessor::ReductionIdentifier::IOR || - redId == ReductionProcessor::ReductionIdentifier::IEOR || - redId == ReductionProcessor::ReductionIdentifier::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } - } - } - } - } - } + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : args) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType); + llvm::SmallVector<mlir::Location> locs(args.size(), loc); + firOpBuilder.createBlock(®ion, {}, tiv, locs); + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + mlir::Operation *storeOp = nullptr; + for (auto [argIndex, argSymbol] : llvm::enumerate(args)) { + mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex)); + storeOp = + createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); + } + firOpBuilder.setInsertionPointAfter(storeOp); + return llvm::SmallVector<const Fortran::semantics::Symbol *>(args); +} + +static void genReductionVars( + mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs, + llvm::ArrayRef<mlir::Type> reductionTypes) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + llvm::SmallVector<mlir::Location> blockArgLocs(reductionArgs.size(), loc); + + mlir::Block *entryBlock = firOpBuilder.createBlock( + &op->getRegion(0), {}, reductionTypes, blockArgLocs); + + // Bind the reduction arguments to their block arguments. + for (auto [arg, prv] : + llvm::zip_equal(reductionArgs, entryBlock->getArguments())) { + converter.bindSymbol(*arg, prv); + } +} + +static llvm::SmallVector<const Fortran::semantics::Symbol *> +genLoopAndReductionVars( + mlir::Operation *op, Fortran::lower::AbstractConverter &converter, + mlir::Location &loc, + llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs, + llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs, + llvm::ArrayRef<mlir::Type> reductionTypes) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + + llvm::SmallVector<mlir::Type> blockArgTypes; + llvm::SmallVector<mlir::Location> blockArgLocs; + blockArgTypes.reserve(loopArgs.size() + reductionArgs.size()); + blockArgLocs.reserve(blockArgTypes.size()); + mlir::Block *entryBlock; + + if (loopArgs.size()) { + std::size_t loopVarTypeSize = 0; + for (const Fortran::semantics::Symbol *arg : loopArgs) + loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); + mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); + std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(), + loopVarType); + std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc); + } + if (reductionArgs.size()) { + llvm::copy(reductionTypes, std::back_inserter(blockArgTypes)); + std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc); + } + entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes, + blockArgLocs); + // The argument is not currently in memory, so make a temporary for the + // argument, and store it there, then bind that location to the argument. + if (loopArgs.size()) { + mlir::Operation *storeOp = nullptr; + for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) { + mlir::Value indexVal = + fir::getBase(op->getRegion(0).front().getArgument(argIndex)); + storeOp = + createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); } + firOpBuilder.setInsertionPointAfter(storeOp); + } + // Bind the reduction arguments to their block arguments + for (auto [arg, prv] : llvm::zip_equal( + reductionArgs, + llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) { + converter.bindSymbol(*arg, prv); + } + + return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs); +} + +static void +markDeclareTarget(mlir::Operation *op, + Fortran::lower::AbstractConverter &converter, + mlir::omp::DeclareTargetCaptureClause captureClause, + mlir::omp::DeclareTargetDeviceType deviceType) { + // TODO: Add support for program local variables with declare target applied + auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op); + if (!declareTargetOp) + fir::emitFatalError( + converter.getCurrentLocation(), + "Attempt to apply declare target on unsupported operation"); + + // The function or global already has a declare target applied to it, very + // likely through implicit capture (usage in another declare target + // function/subroutine). It should be marked as any if it has been assigned + // both host and nohost, else we skip, as there is no change + if (declareTargetOp.isDeclareTarget()) { + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) + declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, + captureClause); + return; } + + declareTargetOp.setDeclareTarget(deviceType, captureClause); } +//===----------------------------------------------------------------------===// +// Op body generation helper structures and functions +//===----------------------------------------------------------------------===// + struct OpWithBodyGenInfo { /// A type for a code-gen callback function. This takes as argument the op for /// which the code is being generated and returns the arguments of the op's @@ -715,362 +763,6 @@ static void genBodyOfTargetDataOp( genNestedEvaluations(converter, eval); } -template <typename OpTy, typename... Args> -static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { - auto op = info.converter.getFirOpBuilder().create<OpTy>( - info.loc, std::forward<Args>(args)...); - createBodyOfOp<OpTy>(op, info); - return op; -} - -static mlir::omp::MasterOp -genMasterOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation) { - return genOpWithBody<mlir::omp::MasterOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested)); -} - -static mlir::omp::OrderedRegionOp -genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - mlir::omp::OrderedRegionClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processTODO<clause::Simd>(currentLocation, - llvm::omp::Directive::OMPD_ordered); - - return genOpWithBody<mlir::omp::OrderedRegionOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested), - clauseOps); -} - -static mlir::omp::ParallelOp -genParallelOp(Fortran::lower::AbstractConverter &converter, - Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList, - bool outerCombined = false) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - Fortran::lower::StatementContext stmtCtx; - mlir::omp::ParallelClauseOps clauseOps; - llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms; - llvm::SmallVector<mlir::Type> reductionTypes; - llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - cp.processNumThreads(stmtCtx, clauseOps); - cp.processProcBind(clauseOps); - cp.processDefault(); - cp.processAllocate(clauseOps); - - if (!outerCombined) - cp.processReduction(currentLocation, clauseOps, &reductionTypes, - &reductionSyms); - - if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) - clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - - auto reductionCallback = [&](mlir::Operation *op) { - llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(), - currentLocation); - auto *block = - firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs); - for (auto [arg, prv] : - llvm::zip_equal(reductionSyms, block->getArguments())) { - converter.bindSymbol(*arg, prv); - } - return reductionSyms; - }; - - OpWithBodyGenInfo genInfo = - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setOuterCombined(outerCombined) - .setClauses(&clauseList) - .setReductions(&reductionSyms, &reductionTypes) - .setGenRegionEntryCb(reductionCallback); - - if (!enableDelayedPrivatization) - return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps); - - bool privatize = !outerCombined; - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, - /*useDelayedPrivatization=*/true, &symTable); - - if (privatize) - dsp.processStep1(&clauseOps, &privateSyms); - - auto genRegionEntryCB = [&](mlir::Operation *op) { - auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op); - - llvm::SmallVector<mlir::Location> reductionLocs( - clauseOps.reductionVars.size(), currentLocation); - - mlir::OperandRange privateVars = parallelOp.getPrivateVars(); - mlir::Region ®ion = parallelOp.getRegion(); - - llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes; - privateVarTypes.reserve(privateVarTypes.size() + privateVars.size()); - llvm::transform(privateVars, std::back_inserter(privateVarTypes), - [](mlir::Value v) { return v.getType(); }); - - llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs; - privateVarLocs.reserve(privateVarLocs.size() + privateVars.size()); - llvm::transform(privateVars, std::back_inserter(privateVarLocs), - [](mlir::Value v) { return v.getLoc(); }); - - firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, - privateVarLocs); - - llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols = - reductionSyms; - allSymbols.append(privateSyms); - for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { - converter.bindSymbol(*arg, prv); - } - - return allSymbols; - }; - - // TODO Merge with the reduction CB. - genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); - return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps); -} - -static mlir::omp::SectionOp -genSectionOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList §ionsClauseList) { - // Currently only private/firstprivate clause is handled, and - // all privatization is done within `omp.section` operations. - return genOpWithBody<mlir::omp::SectionOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(§ionsClauseList)); -} - -static mlir::omp::SingleOp -genSingleOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList &endClauseList) { - mlir::omp::SingleClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processAllocate(clauseOps); - // TODO Support delayed privatization. - - ClauseProcessor ecp(converter, semaCtx, endClauseList); - ecp.processNowait(clauseOps); - ecp.processCopyprivate(currentLocation, clauseOps); - - return genOpWithBody<mlir::omp::SingleOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&beginClauseList), - clauseOps); -} - -static mlir::omp::TaskOp -genTaskOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TaskClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); - cp.processAllocate(clauseOps); - cp.processDefault(); - cp.processFinal(stmtCtx, clauseOps); - cp.processUntied(clauseOps); - cp.processMergeable(clauseOps); - cp.processPriority(stmtCtx, clauseOps); - cp.processDepend(clauseOps); - // TODO Support delayed privatization. - - cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>( - currentLocation, llvm::omp::Directive::OMPD_task); - - return genOpWithBody<mlir::omp::TaskOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&clauseList), - clauseOps); -} - -static mlir::omp::TaskgroupOp -genTaskgroupOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - mlir::omp::TaskgroupClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processAllocate(clauseOps); - cp.processTODO<clause::TaskReduction>(currentLocation, - llvm::omp::Directive::OMPD_taskgroup); - - return genOpWithBody<mlir::omp::TaskgroupOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setClauses(&clauseList), - clauseOps); -} - -// This helper function implements the functionality of "promoting" -// non-CPTR arguments of use_device_ptr to use_device_addr -// arguments (automagic conversion of use_device_ptr -> -// use_device_addr in these cases). The way we do so currently is -// through the shuffling of operands from the devicePtrOperands to -// deviceAddrOperands where neccesary and re-organizing the types, -// locations and symbols to maintain the correct ordering of ptr/addr -// input -> BlockArg. -// -// This effectively implements some deprecated OpenMP functionality -// that some legacy applications unfortunately depend on -// (deprecated in specification version 5.2): -// -// "If a list item in a use_device_ptr clause is not of type C_PTR, -// the behavior is as if the list item appeared in a use_device_addr -// clause. Support for such list items in a use_device_ptr clause -// is deprecated." -static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr( - mlir::omp::UseDeviceClauseOps &clauseOps, - llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, - llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, - llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> - &useDeviceSymbols) { - auto moveElementToBack = [](size_t idx, auto &vector) { - auto *iter = std::next(vector.begin(), idx); - vector.push_back(*iter); - vector.erase(iter); - }; - - // Iterate over our use_device_ptr list and shift all non-cptr arguments into - // use_device_addr. - for (auto *it = clauseOps.useDevicePtrVars.begin(); - it != clauseOps.useDevicePtrVars.end();) { - if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) { - clauseOps.useDeviceAddrVars.push_back(*it); - // We have to shuffle the symbols around as well, to maintain - // the correct Input -> BlockArg for use_device_ptr/use_device_addr. - // NOTE: However, as map's do not seem to be included currently - // this isn't as pertinent, but we must try to maintain for - // future alterations. I believe the reason they are not currently - // is that the BlockArg assign/lowering needs to be extended - // to a greater set of types. - auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it); - moveElementToBack(idx, useDeviceTypes); - moveElementToBack(idx, useDeviceLocs); - moveElementToBack(idx, useDeviceSymbols); - it = clauseOps.useDevicePtrVars.erase(it); - continue; - } - ++it; - } -} - -static mlir::omp::TargetDataOp -genTargetDataOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TargetDataClauseOps clauseOps; - llvm::SmallVector<mlir::Type> useDeviceTypes; - llvm::SmallVector<mlir::Location> useDeviceLocs; - llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, - useDeviceSyms); - - // This function implements the deprecated functionality of use_device_ptr - // that allows users to provide non-CPTR arguments to it with the caveat - // that the compiler will treat them as use_device_addr. A lot of legacy - // code may still depend on this functionality, so we should support it - // in some manner. We do so currently by simply shifting non-cptr operands - // from the use_device_ptr list into the front of the use_device_addr list - // whilst maintaining the ordering of useDeviceLocs, useDeviceSymbols and - // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg - // ordering. - // TODO: Perhaps create a user provideable compiler option that will - // re-introduce a hard-error rather than a warning in these cases. - promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, - useDeviceLocs, useDeviceSyms); - cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data, - stmtCtx, clauseOps); - - auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>( - currentLocation, clauseOps); - - genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp, - useDeviceTypes, useDeviceLocs, useDeviceSyms, - currentLocation); - return dataOp; -} - -template <typename OpTy> -static OpTy genTargetEnterExitDataUpdateOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - Fortran::lower::StatementContext stmtCtx; - mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; - - // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. - [[maybe_unused]] llvm::omp::Directive directive; - if constexpr (std::is_same_v<OpTy, mlir::omp::TargetEnterDataOp>) { - directive = llvm::omp::Directive::OMPD_target_enter_data; - } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetExitDataOp>) { - directive = llvm::omp::Directive::OMPD_target_exit_data; - } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) { - directive = llvm::omp::Directive::OMPD_target_update; - } else { - return nullptr; - } - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(directive, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processDepend(clauseOps); - cp.processNowait(clauseOps); - - if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) { - cp.processMotionClauses<clause::To>(stmtCtx, clauseOps); - cp.processMotionClauses<clause::From>(stmtCtx, clauseOps); - } else { - cp.processMap(currentLocation, directive, stmtCtx, clauseOps); - } - - return firOpBuilder.create<OpTy>(currentLocation, clauseOps); -} - // This functions creates a block for the body of the targetOp's region. It adds // all the symbols present in mapSymbols as block arguments to this block. static void @@ -1225,38 +917,583 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter, genNestedEvaluations(converter, eval); } +template <typename OpTy, typename... Args> +static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { + auto op = info.converter.getFirOpBuilder().create<OpTy>( + info.loc, std::forward<Args>(args)...); + createBodyOfOp<OpTy>(op, info); + return op; +} + +//===----------------------------------------------------------------------===// +// Code generation functions for clauses +//===----------------------------------------------------------------------===// + +static void genCriticalDeclareClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processHint(clauseOps); + clauseOps.nameAttr = + mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); +} + +static void genFlushClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const std::optional<Fortran::parser::OmpObjectList> &objects, + const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>> + &clauses, + mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) { + if (objects) + genObjectList2(*objects, converter, operandRange); + + if (clauses && clauses->size() > 0) + TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); +} + +static void +genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::OrderedRegionClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered); +} + +static void genParallelClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + llvm::SmallVectorImpl<mlir::Type> &reductionTypes, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processDefault(); + cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); + cp.processNumThreads(stmtCtx, clauseOps); + cp.processProcBind(clauseOps); + + if (processReduction) { + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = converter.getFirOpBuilder().getUnitAttr(); + } +} + +static void genSectionsClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + bool clausesFromBeginSections, + mlir::omp::SectionsClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + if (clausesFromBeginSections) { + cp.processAllocate(clauseOps); + cp.processSectionsReduction(loc, clauseOps); + // TODO Support delayed privatization. + } else { + cp.processNowait(clauseOps); + } +} + +static void genSimdLoopClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::SimdLoopClauseOps &clauseOps, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processCollapse(loc, eval, clauseOps, iv); + cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); + cp.processReduction(loc, clauseOps); + cp.processSafelen(clauseOps); + cp.processSimdlen(clauseOps); + clauseOps.loopInclusiveAttr = converter.getFirOpBuilder().getUnitAttr(); + // TODO Support delayed privatization. + + cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear, + clause::Nontemporal, clause::Order>( + loc, llvm::omp::Directive::OMPD_simd); +} + +static void genSingleClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &beginClauses, + const Fortran::parser::OmpClauseList &endClauses, + mlir::Location loc, + mlir::omp::SingleClauseOps &clauseOps) { + ClauseProcessor bcp(converter, semaCtx, beginClauses); + bcp.processAllocate(clauseOps); + // TODO Support delayed privatization. + + ClauseProcessor ecp(converter, semaCtx, endClauses); + ecp.processCopyprivate(loc, clauseOps); + ecp.processNowait(clauseOps); +} + +static void genTargetClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + bool processHostOnlyClauses, bool processReduction, + mlir::omp::TargetClauseOps &clauseOps, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms, + llvm::SmallVectorImpl<mlir::Location> &mapLocs, + llvm::SmallVectorImpl<mlir::Type> &mapTypes, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &deviceAddrSyms, + llvm::SmallVectorImpl<mlir::Location> &deviceAddrLocs, + llvm::SmallVectorImpl<mlir::Type> &deviceAddrTypes, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &devicePtrSyms, + llvm::SmallVectorImpl<mlir::Location> &devicePtrLocs, + llvm::SmallVectorImpl<mlir::Type> &devicePtrTypes) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDepend(clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, + deviceAddrSyms); + cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); + cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, + devicePtrSyms); + cp.processMap(loc, stmtCtx, clauseOps, &mapSyms, &mapLocs, &mapTypes); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. + + if (processHostOnlyClauses) + cp.processNowait(clauseOps); + + cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate, + clause::InReduction, clause::Private, clause::Reduction, + clause::UsesAllocators>(loc, + llvm::omp::Directive::OMPD_target); +} + +static void genTargetDataClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + mlir::omp::TargetDataClauseOps &clauseOps, + llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, + llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDevice(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps); + cp.processMap(loc, stmtCtx, clauseOps); + cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs, + useDeviceSyms); + + // This function implements the deprecated functionality of use_device_ptr + // that allows users to provide non-CPTR arguments to it with the caveat + // that the compiler will treat them as use_device_addr. A lot of legacy + // code may still depend on this functionality, so we should support it + // in some manner. We do so currently by simply shifting non-cptr operands + // from the use_device_ptr list into the front of the use_device_addr list + // whilst maintaining the ordering of useDeviceLocs, useDeviceSyms and + // useDeviceTypes to use_device_ptr/use_device_addr input for BlockArg + // ordering. + // TODO: Perhaps create a user provideable compiler option that will + // re-introduce a hard-error rather than a warning in these cases. + promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes, + useDeviceLocs, useDeviceSyms); +} + +static void genTargetEnterExitUpdateDataClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, + llvm::omp::Directive directive, + mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processDepend(clauseOps); + cp.processDevice(stmtCtx, clauseOps); + cp.processIf(directive, clauseOps); + cp.processNowait(clauseOps); + + if (directive == llvm::omp::Directive::OMPD_target_update) { + cp.processMotionClauses<clause::To>(stmtCtx, clauseOps); + cp.processMotionClauses<clause::From>(stmtCtx, clauseOps); + } else { + cp.processMap(loc, stmtCtx, clauseOps); + } +} + +static void genTaskClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processDefault(); + cp.processDepend(clauseOps); + cp.processFinal(stmtCtx, clauseOps); + cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); + cp.processMergeable(clauseOps); + cp.processPriority(stmtCtx, clauseOps); + cp.processUntied(clauseOps); + // TODO Support delayed privatization. + + cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>( + loc, llvm::omp::Directive::OMPD_task); +} + +static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskgroupClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + + cp.processTODO<clause::TaskReduction>(loc, + llvm::omp::Directive::OMPD_taskgroup); +} + +static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TaskwaitClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processTODO<clause::Depend, clause::Nowait>( + loc, llvm::omp::Directive::OMPD_taskwait); +} + +static void genTeamsClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + const Fortran::parser::OmpClauseList &clauses, + mlir::Location loc, + mlir::omp::TeamsClauseOps &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processAllocate(clauseOps); + cp.processDefault(); + cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + // TODO Support delayed privatization. + + cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams); +} + +static void genWsloopClauses( + Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::StatementContext &stmtCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OmpClauseList &beginClauses, + const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc, + mlir::omp::WsloopClauseOps &clauseOps, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv, + llvm::SmallVectorImpl<mlir::Type> &reductionTypes, + llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + ClauseProcessor bcp(converter, semaCtx, beginClauses); + bcp.processCollapse(loc, eval, clauseOps, iv); + bcp.processOrdered(clauseOps); + bcp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); + bcp.processSchedule(stmtCtx, clauseOps); + clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); + // TODO Support delayed privatization. + + if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) + clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); + + if (endClauses) { + ClauseProcessor ecp(converter, semaCtx, *endClauses); + ecp.processNowait(clauseOps); + } + + bcp.processTODO<clause::Allocate, clause::Linear, clause::Order>( + loc, llvm::omp::Directive::OMPD_do); +} + +//===----------------------------------------------------------------------===// +// Code generation functions for leaf constructs +//===----------------------------------------------------------------------===// + +static mlir::omp::BarrierOp +genBarrierOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc) { + return converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(loc); +} + +static mlir::omp::CriticalOp +genCriticalOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList, + const std::optional<Fortran::parser::Name> &name) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::FlatSymbolRefAttr nameAttr; + + if (name) { + std::string nameStr = name->ToString(); + mlir::ModuleOp mod = firOpBuilder.getModule(); + auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr); + if (!global) { + mlir::omp::CriticalClauseOps clauseOps; + genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps, + nameStr); + + mlir::OpBuilder modBuilder(mod.getBodyRegion()); + global = modBuilder.create<mlir::omp::CriticalDeclareOp>(loc, clauseOps); + } + nameAttr = mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), + global.getSymName()); + } + + return genOpWithBody<mlir::omp::CriticalOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), + nameAttr); +} + +static mlir::omp::DistributeOp +genDistributeOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + TODO(loc, "Distribute construct"); + return nullptr; +} + +static mlir::omp::FlushOp +genFlushOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const std::optional<Fortran::parser::OmpObjectList> &objectList, + const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>> + &clauseList) { + llvm::SmallVector<mlir::Value> operandRange; + genFlushClauses(converter, semaCtx, objectList, clauseList, loc, + operandRange); + + return converter.getFirOpBuilder().create<mlir::omp::FlushOp>( + converter.getCurrentLocation(), operandRange); +} + +static mlir::omp::MasterOp +genMasterOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc) { + return genOpWithBody<mlir::omp::MasterOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested)); +} + +static mlir::omp::OrderedOp +genOrderedOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + TODO(loc, "OMPD_ordered"); + return nullptr; +} + +static mlir::omp::OrderedRegionOp +genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::OrderedRegionClauseOps clauseOps; + genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps); + + return genOpWithBody<mlir::omp::OrderedRegionOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), + clauseOps); +} + +static mlir::omp::ParallelOp +genParallelOp(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; + mlir::omp::ParallelClauseOps clauseOps; + llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms; + llvm::SmallVector<mlir::Type> reductionTypes; + llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms; + genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc, + /*processReduction=*/!outerCombined, clauseOps, + reductionTypes, reductionSyms); + + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + return reductionSyms; + }; + + OpWithBodyGenInfo genInfo = + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList) + .setReductions(&reductionSyms, &reductionTypes) + .setGenRegionEntryCb(reductionCallback); + + if (!enableDelayedPrivatization) + return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps); + + bool privatize = !outerCombined; + DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, + /*useDelayedPrivatization=*/true, &symTable); + + if (privatize) + dsp.processStep1(&clauseOps, &privateSyms); + + auto genRegionEntryCB = [&](mlir::Operation *op) { + auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op); + + llvm::SmallVector<mlir::Location> reductionLocs( + clauseOps.reductionVars.size(), loc); + + mlir::OperandRange privateVars = parallelOp.getPrivateVars(); + mlir::Region ®ion = parallelOp.getRegion(); + + llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes; + privateVarTypes.reserve(privateVarTypes.size() + privateVars.size()); + llvm::transform(privateVars, std::back_inserter(privateVarTypes), + [](mlir::Value v) { return v.getType(); }); + + llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs; + privateVarLocs.reserve(privateVarLocs.size() + privateVars.size()); + llvm::transform(privateVars, std::back_inserter(privateVarLocs), + [](mlir::Value v) { return v.getLoc(); }); + + firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes, + privateVarLocs); + + llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols = + reductionSyms; + allSymbols.append(privateSyms); + for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) { + converter.bindSymbol(*arg, prv); + } + + return allSymbols; + }; + + // TODO Merge with the reduction CB. + genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp); + return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps); +} + +static mlir::omp::SectionOp +genSectionOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + // Currently only private/firstprivate clause is handled, and + // all privatization is done within `omp.section` operations. + return genOpWithBody<mlir::omp::SectionOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList)); +} + +static mlir::omp::SectionsOp +genSectionsOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const mlir::omp::SectionsClauseOps &clauseOps) { + return genOpWithBody<mlir::omp::SectionsOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(false), + clauseOps); +} + +static mlir::omp::SimdLoopOp +genSimdLoopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + DataSharingProcessor dsp(converter, semaCtx, clauseList, eval); + dsp.processStep1(); + + Fortran::lower::StatementContext stmtCtx; + mlir::omp::SimdLoopClauseOps clauseOps; + llvm::SmallVector<const Fortran::semantics::Symbol *> iv; + genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc, + clauseOps, iv); + + auto *nestedEval = + getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList)); + + auto ivCallback = [&](mlir::Operation *op) { + return genLoopVars(op, converter, loc, iv); + }; + + return genOpWithBody<mlir::omp::SimdLoopOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) + .setClauses(&clauseList) + .setDataSharingProcessor(&dsp) + .setGenRegionEntryCb(ivCallback), + clauseOps); +} + +static mlir::omp::SingleOp +genSingleOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList &endClauseList) { + mlir::omp::SingleClauseOps clauseOps; + genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc, + clauseOps); + + return genOpWithBody<mlir::omp::SingleOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&beginClauseList), + clauseOps); +} + static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, + mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, - llvm::omp::Directive directive, bool outerCombined = false) { + bool outerCombined = false) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; + + bool processHostOnlyClauses = + !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp()) + .getIsTargetDevice(); + mlir::omp::TargetClauseOps clauseOps; - llvm::SmallVector<mlir::Type> mapTypes, devicePtrTypes, deviceAddrTypes; - llvm::SmallVector<mlir::Location> mapLocs, devicePtrLocs, deviceAddrLocs; llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms, devicePtrSyms, deviceAddrSyms; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processDevice(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); - cp.processDepend(clauseOps); - cp.processNowait(clauseOps); - cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms, - &mapLocs, &mapTypes); - cp.processIsDevicePtr(clauseOps, devicePtrTypes, devicePtrLocs, - devicePtrSyms); - cp.processHasDeviceAddr(clauseOps, deviceAddrTypes, deviceAddrLocs, - deviceAddrSyms); - // TODO Support delayed privatization. - - cp.processTODO<clause::Private, clause::Firstprivate, clause::Reduction, - clause::InReduction, clause::Allocate, clause::UsesAllocators, - clause::Defaultmap>(currentLocation, - llvm::omp::Directive::OMPD_target); + llvm::SmallVector<mlir::Location> mapLocs, devicePtrLocs, deviceAddrLocs; + llvm::SmallVector<mlir::Type> mapTypes, devicePtrTypes, deviceAddrTypes; + genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc, + processHostOnlyClauses, /*processReduction=*/outerCombined, + clauseOps, mapSyms, mapLocs, mapTypes, deviceAddrSyms, + deviceAddrLocs, deviceAddrTypes, devicePtrSyms, + devicePtrLocs, devicePtrTypes); // 5.8.1 Implicit Data-Mapping Attribute Rules // The following code follows the implicit data-mapping rules to map all the @@ -1278,22 +1515,21 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym); name << sym.name().ToString(); - Fortran::lower::AddrAndBoundsInfo info = - getDataOperandBaseAddr(converter, converter.getFirOpBuilder(), sym, - converter.getCurrentLocation()); + Fortran::lower::AddrAndBoundsInfo info = getDataOperandBaseAddr( + converter, firOpBuilder, sym, converter.getCurrentLocation()); if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) bounds = Fortran::lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( - converter.getFirOpBuilder(), converter.getCurrentLocation(), - converter, dataExv, info); + firOpBuilder, converter.getCurrentLocation(), converter, + dataExv, info); if (fir::unwrapRefType(info.addr.getType()).isa<fir::SequenceType>()) { bool dataExvIsAssumedSize = Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate()); bounds = Fortran::lower::genBaseBoundsOps<mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>( - converter.getFirOpBuilder(), converter.getCurrentLocation(), - converter, dataExv, dataExvIsAssumedSize); + firOpBuilder, converter.getCurrentLocation(), converter, dataExv, + dataExvIsAssumedSize); } llvm::omp::OpenMPOffloadMappingFlags mapFlag = @@ -1307,7 +1543,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, // If a variable is specified in declare target link and if device // type is not specified as `nohost`, it needs to be mapped tofrom - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + mlir::ModuleOp mod = firOpBuilder.getModule(); mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym)); auto declareTargetOp = llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op); @@ -1327,8 +1563,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, } mlir::Value mapOp = createMapInfoOp( - converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, mlir::Value{}, - name.str(), bounds, {}, + firOpBuilder, baseOp.getLoc(), baseOp, mlir::Value{}, name.str(), + bounds, {}, static_cast< std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( mapFlag), @@ -1343,338 +1579,144 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, }; Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap); - auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>( - currentLocation, clauseOps); - + auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps); genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms, - mapLocs, mapTypes, currentLocation); - + mapLocs, mapTypes, loc); return targetOp; } -static mlir::omp::TeamsOp -genTeamsOp(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location currentLocation, - const Fortran::parser::OmpClauseList &clauseList, - bool outerCombined = false) { +static mlir::omp::TargetDataOp +genTargetDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { Fortran::lower::StatementContext stmtCtx; - mlir::omp::TeamsClauseOps clauseOps; - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processAllocate(clauseOps); - cp.processDefault(); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); - // TODO Support delayed privatization. - - cp.processTODO<clause::Reduction>(currentLocation, - llvm::omp::Directive::OMPD_teams); + mlir::omp::TargetDataClauseOps clauseOps; + llvm::SmallVector<mlir::Type> useDeviceTypes; + llvm::SmallVector<mlir::Location> useDeviceLocs; + llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms; + genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps, + useDeviceTypes, useDeviceLocs, useDeviceSyms); - return genOpWithBody<mlir::omp::TeamsOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(genNested) - .setOuterCombined(outerCombined) - .setClauses(&clauseList), - clauseOps); + auto targetDataOp = + converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(loc, + clauseOps); + genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, targetDataOp, + useDeviceTypes, useDeviceLocs, useDeviceSyms, loc); + return targetDataOp; } -/// Extract the list of function and variable symbols affected by the given -/// 'declare target' directive and return the intended device type for them. -static void getDeclareTargetInfo( +template <typename OpTy> +static OpTy genTargetEnterExitUpdateDataOp( Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, - mlir::omp::DeclareTargetClauseOps &clauseOps, - llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) { - const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>( - declareTargetConstruct.t); - if (const auto *objectList{ - Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) { - ObjectList objects{makeObjects(*objectList, semaCtx)}; - // Case: declare target(func, var1, var2) - gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to, - symbolAndClause); - } else if (const auto *clauseList{ - Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>( - spec.u)}) { - if (clauseList->v.empty()) { - // Case: declare target, implicit capture of function - symbolAndClause.emplace_back( - mlir::omp::DeclareTargetCaptureClause::to, - eval.getOwningProcedure()->getSubprogramSymbol()); - } + Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + Fortran::lower::StatementContext stmtCtx; - ClauseProcessor cp(converter, semaCtx, *clauseList); - cp.processTo(symbolAndClause); - cp.processEnter(symbolAndClause); - cp.processLink(symbolAndClause); - cp.processDeviceType(clauseOps); - cp.processTODO<clause::Indirect>(converter.getCurrentLocation(), - llvm::omp::Directive::OMPD_declare_target); + // GCC 9.3.0 emits a (probably) bogus warning about an unused variable. + [[maybe_unused]] llvm::omp::Directive directive; + if constexpr (std::is_same_v<OpTy, mlir::omp::TargetEnterDataOp>) { + directive = llvm::omp::Directive::OMPD_target_enter_data; + } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetExitDataOp>) { + directive = llvm::omp::Directive::OMPD_target_exit_data; + } else if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) { + directive = llvm::omp::Directive::OMPD_target_update; + } else { + llvm_unreachable("Unexpected TARGET DATA construct"); } -} - -static void collectDeferredDeclareTargets( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, - llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo> - &deferredDeclareTarget) { - mlir::omp::DeclareTargetClauseOps clauseOps; - llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; - getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, - clauseOps, symbolAndClause); - // Return the device type only if at least one of the targets for the - // directive is a function or subroutine - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol(converter.mangleName( - std::get<const Fortran::semantics::Symbol &>(symClause))); + mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; + genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList, + loc, directive, clauseOps); - if (!op) { - deferredDeclareTarget.push_back({std::get<0>(symClause), - clauseOps.deviceType, - std::get<1>(symClause)}); - } - } + return firOpBuilder.create<OpTy>(loc, clauseOps); } -static std::optional<mlir::omp::DeclareTargetDeviceType> -getDeclareTargetFunctionDevice( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - mlir::omp::DeclareTargetClauseOps clauseOps; - llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause; - getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct, - clauseOps, symbolAndClause); - - // Return the device type only if at least one of the targets for the - // directive is a function or subroutine - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - for (const DeclareTargetCapturePair &symClause : symbolAndClause) { - mlir::Operation *op = mod.lookupSymbol(converter.mangleName( - std::get<const Fortran::semantics::Symbol &>(symClause))); - - if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op)) - return clauseOps.deviceType; - } +static mlir::omp::TaskOp +genTaskOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + Fortran::lower::StatementContext stmtCtx; + mlir::omp::TaskClauseOps clauseOps; + genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); - return std::nullopt; + return genOpWithBody<mlir::omp::TaskOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList), + clauseOps); } -//===----------------------------------------------------------------------===// -// genOMP() Code generation helper functions -//===----------------------------------------------------------------------===// - -static void -genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, bool genNested, - const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - const auto &directive = - std::get<Fortran::parser::OmpSimpleStandaloneDirective>( - simpleStandaloneConstruct.t); - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - const auto &opClauseList = - std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t); - mlir::Location currentLocation = converter.genLocation(directive.source); - - switch (directive.v) { - default: - break; - case llvm::omp::Directive::OMPD_barrier: - firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation); - break; - case llvm::omp::Directive::OMPD_taskwait: { - mlir::omp::TaskwaitClauseOps clauseOps; - ClauseProcessor cp(converter, semaCtx, opClauseList); - cp.processTODO<clause::Depend, clause::Nowait>( - currentLocation, llvm::omp::Directive::OMPD_taskwait); - firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps); - break; - } - case llvm::omp::Directive::OMPD_taskyield: - firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation); - break; - case llvm::omp::Directive::OMPD_target_data: - genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, - opClauseList); - break; - case llvm::omp::Directive::OMPD_target_enter_data: - genTargetEnterExitDataUpdateOp<mlir::omp::TargetEnterDataOp>( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_target_exit_data: - genTargetEnterExitDataUpdateOp<mlir::omp::TargetExitDataOp>( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_target_update: - genTargetEnterExitDataUpdateOp<mlir::omp::TargetUpdateOp>( - converter, semaCtx, currentLocation, opClauseList); - break; - case llvm::omp::Directive::OMPD_ordered: - TODO(currentLocation, "OMPD_ordered"); - } -} +static mlir::omp::TaskgroupOp +genTaskgroupOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::TaskgroupClauseOps clauseOps; + genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps); -static void -genOmpFlush(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - llvm::SmallVector<mlir::Value, 4> operandRange; - if (const auto &ompObjectList = - std::get<std::optional<Fortran::parser::OmpObjectList>>( - flushConstruct.t)) - genObjectList2(*ompObjectList, converter, operandRange); - const auto &memOrderClause = - std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>( - flushConstruct.t); - if (memOrderClause && memOrderClause->size() > 0) - TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); - converter.getFirOpBuilder().create<mlir::omp::FlushOp>( - converter.getCurrentLocation(), operandRange); + return genOpWithBody<mlir::omp::TaskgroupOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setClauses(&clauseList), + clauseOps); } -static llvm::SmallVector<const Fortran::semantics::Symbol *> -genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef<const Fortran::semantics::Symbol *> args) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - auto ®ion = op->getRegion(0); - - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : args) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType); - llvm::SmallVector<mlir::Location> locs(args.size(), loc); - firOpBuilder.createBlock(®ion, {}, tiv, locs); - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - mlir::Operation *storeOp = nullptr; - for (auto [argIndex, argSymbol] : llvm::enumerate(args)) { - mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex)); - storeOp = - createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); - } - firOpBuilder.setInsertionPointAfter(storeOp); - - return llvm::SmallVector<const Fortran::semantics::Symbol *>(args); +static mlir::omp::TaskloopOp +genTaskloopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + TODO(loc, "Taskloop construct"); } -static llvm::SmallVector<const Fortran::semantics::Symbol *> -genLoopAndReductionVars( - mlir::Operation *op, Fortran::lower::AbstractConverter &converter, - mlir::Location &loc, - llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs, - llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs, - llvm::ArrayRef<mlir::Type> reductionTypes) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - llvm::SmallVector<mlir::Type> blockArgTypes; - llvm::SmallVector<mlir::Location> blockArgLocs; - blockArgTypes.reserve(loopArgs.size() + reductionArgs.size()); - blockArgLocs.reserve(blockArgTypes.size()); - mlir::Block *entryBlock; - - if (loopArgs.size()) { - std::size_t loopVarTypeSize = 0; - for (const Fortran::semantics::Symbol *arg : loopArgs) - loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); - mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); - std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(), - loopVarType); - std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc); - } - if (reductionArgs.size()) { - llvm::copy(reductionTypes, std::back_inserter(blockArgTypes)); - std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc); - } - entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes, - blockArgLocs); - // The argument is not currently in memory, so make a temporary for the - // argument, and store it there, then bind that location to the argument. - if (loopArgs.size()) { - mlir::Operation *storeOp = nullptr; - for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) { - mlir::Value indexVal = - fir::getBase(op->getRegion(0).front().getArgument(argIndex)); - storeOp = - createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol); - } - firOpBuilder.setInsertionPointAfter(storeOp); - } - // Bind the reduction arguments to their block arguments - for (auto [arg, prv] : llvm::zip_equal( - reductionArgs, - llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) { - converter.bindSymbol(*arg, prv); - } - - return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs); +static mlir::omp::TaskwaitOp +genTaskwaitOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &clauseList) { + mlir::omp::TaskwaitClauseOps clauseOps; + genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps); + return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc, + clauseOps); } -static void -createSimdLoop(Fortran::lower::AbstractConverter &converter, +static mlir::omp::TaskyieldOp +genTaskyieldOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - llvm::omp::Directive ompDirective, - const Fortran::parser::OmpClauseList &loopOpClauseList, - mlir::Location loc) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval); - dsp.processStep1(); + Fortran::lower::pft::Evaluation &eval, mlir::Location loc) { + return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc); +} +static mlir::omp::TeamsOp +genTeamsOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, bool genNested, + mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, + bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; - mlir::omp::SimdLoopClauseOps clauseOps; - llvm::SmallVector<const Fortran::semantics::Symbol *> iv; - - ClauseProcessor cp(converter, semaCtx, loopOpClauseList); - cp.processCollapse(loc, eval, clauseOps, iv); - cp.processReduction(loc, clauseOps); - cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps); - cp.processSimdlen(clauseOps); - cp.processSafelen(clauseOps); - clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); - // TODO Support delayed privatization. - - cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear, - clause::Nontemporal, clause::Order>(loc, ompDirective); - - auto *nestedEval = getCollapsedLoopEval( - eval, Fortran::lower::getCollapseValue(loopOpClauseList)); - - auto ivCallback = [&](mlir::Operation *op) { - return genLoopVars(op, converter, loc, iv); - }; + mlir::omp::TeamsClauseOps clauseOps; + genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); - genOpWithBody<mlir::omp::SimdLoopOp>( - OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&loopOpClauseList) - .setDataSharingProcessor(&dsp) - .setGenRegionEntryCb(ivCallback), + return genOpWithBody<mlir::omp::TeamsOp>( + OpWithBodyGenInfo(converter, semaCtx, loc, eval) + .setGenNested(genNested) + .setOuterCombined(outerCombined) + .setClauses(&clauseList), clauseOps); } -static void createWsloop(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - llvm::omp::Directive ompDirective, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, - mlir::Location loc) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); +static mlir::omp::WsloopOp +genWsloopOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, mlir::Location loc, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList) { DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); dsp.processStep1(); @@ -1683,30 +1725,9 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> iv; llvm::SmallVector<mlir::Type> reductionTypes; llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms; - - ClauseProcessor cp(converter, semaCtx, beginClauseList); - cp.processCollapse(loc, eval, clauseOps, iv); - cp.processSchedule(stmtCtx, clauseOps); - cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); - cp.processOrdered(clauseOps); - clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr(); - // TODO Support delayed privatization. - - if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) - clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - - cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc, - ompDirective); - - // In FORTRAN `nowait` clause occur at the end of `omp do` directive. - // i.e - // !$omp do - // <...> - // !$omp end do nowait - if (endClauseList) { - ClauseProcessor ecp(converter, semaCtx, *endClauseList); - ecp.processNowait(clauseOps); - } + genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList, + endClauseList, loc, clauseOps, iv, reductionTypes, + reductionSyms); auto *nestedEval = getCollapsedLoopEval( eval, Fortran::lower::getCollapseValue(beginClauseList)); @@ -1716,7 +1737,7 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, reductionTypes); }; - genOpWithBody<mlir::omp::WsloopOp>( + return genOpWithBody<mlir::omp::WsloopOp>( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) .setClauses(&beginClauseList) .setDataSharingProcessor(&dsp) @@ -1725,7 +1746,11 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter, clauseOps); } -static void createSimdWsloop( +//===----------------------------------------------------------------------===// +// Code generation functions for composite constructs +//===----------------------------------------------------------------------===// + +static void genCompositeDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective, @@ -1733,7 +1758,7 @@ static void createSimdWsloop( const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { ClauseProcessor cp(converter, semaCtx, beginClauseList); cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear, - clause::Safelen, clause::Simdlen, clause::Order>(loc, + clause::Order, clause::Safelen, clause::Simdlen>(loc, ompDirective); // TODO: Add support for vectorization - add vectorization hints inside loop // body. @@ -1743,34 +1768,7 @@ static void createSimdWsloop( // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - createWsloop(converter, semaCtx, eval, ompDirective, beginClauseList, - endClauseList, loc); -} - -static void -markDeclareTarget(mlir::Operation *op, - Fortran::lower::AbstractConverter &converter, - mlir::omp::DeclareTargetCaptureClause captureClause, - mlir::omp::DeclareTargetDeviceType deviceType) { - // TODO: Add support for program local variables with declare target applied - auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op); - if (!declareTargetOp) - fir::emitFatalError( - converter.getCurrentLocation(), - "Attempt to apply declare target on unsupported operation"); - - // The function or global already has a declare target applied to it, very - // likely through implicit capture (usage in another declare target - // function/subroutine). It should be marked as any if it has been assigned - // both host and nohost, else we skip, as there is no change - if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) - declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any, - captureClause); - return; - } - - declareTargetOp.setDeclareTarget(deviceType, captureClause); + genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); } //===----------------------------------------------------------------------===// @@ -1866,6 +1864,102 @@ genOMP(Fortran::lower::AbstractConverter &converter, } //===----------------------------------------------------------------------===// +// OpenMPStandaloneConstruct visitors +//===----------------------------------------------------------------------===// + +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPSimpleStandaloneConstruct + &simpleStandaloneConstruct) { + const auto &directive = + std::get<Fortran::parser::OmpSimpleStandaloneDirective>( + simpleStandaloneConstruct.t); + const auto &clauseList = + std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t); + mlir::Location currentLocation = converter.genLocation(directive.source); + + switch (directive.v) { + default: + break; + case llvm::omp::Directive::OMPD_barrier: + genBarrierOp(converter, semaCtx, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_taskwait: + genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_taskyield: + genTaskyieldOp(converter, semaCtx, eval, currentLocation); + break; + case llvm::omp::Directive::OMPD_target_data: + genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, + currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_enter_data: + genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_exit_data: + genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_target_update: + genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>( + converter, semaCtx, currentLocation, clauseList); + break; + case llvm::omp::Directive::OMPD_ordered: + genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList); + break; + } +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { + const auto &verbatim = std::get<Fortran::parser::Verbatim>(flushConstruct.t); + const auto &objectList = + std::get<std::optional<Fortran::parser::OmpObjectList>>(flushConstruct.t); + const auto &clauseList = + std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>( + flushConstruct.t); + mlir::Location currentLocation = converter.genLocation(verbatim.source); + genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList); +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); +} + +static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPCancellationPointConstruct + &cancellationPointConstruct) { + TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); +} + +static void +genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, + Fortran::semantics::SemanticsContext &semaCtx, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { + std::visit( + [&](auto &&s) { return genOMP(converter, symTable, semaCtx, eval, s); }, + standaloneConstruct.u); +} + +//===----------------------------------------------------------------------===// // OpenMPConstruct visitors //===----------------------------------------------------------------------===// @@ -1996,7 +2090,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_target: genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, directive.v); + beginClauseList); break; case llvm::omp::Directive::OMPD_target_data: genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, @@ -2012,8 +2106,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; case llvm::omp::Directive::OMPD_teams: genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - beginClauseList, - /*outerCombined=*/false); + beginClauseList); break; case llvm::omp::Directive::OMPD_workshare: // FIXME: Workshare is not a commonly used OpenMP construct, an @@ -2035,8 +2128,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet) .test(directive.v)) { genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - beginClauseList, directive.v, - /*outerCombined=*/true); + beginClauseList, /*outerCombined=*/true); combinedDirective = true; } if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet) @@ -2073,44 +2165,13 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::Location currentLocation = converter.getCurrentLocation(); - std::string name; - const Fortran::parser::OmpCriticalDirective &cd = + const auto &cd = std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); - if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) { - name = - std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString(); - } - - mlir::omp::CriticalOp criticalOp = [&]() { - if (name.empty()) { - return firOpBuilder.create<mlir::omp::CriticalOp>( - currentLocation, mlir::FlatSymbolRefAttr()); - } - - mlir::ModuleOp module = firOpBuilder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); - auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name); - if (!global) { - mlir::omp::CriticalClauseOps clauseOps; - const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); - - ClauseProcessor cp(converter, semaCtx, clauseList); - cp.processHint(clauseOps); - clauseOps.nameAttr = - mlir::StringAttr::get(firOpBuilder.getContext(), name); - - global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation, - clauseOps); - } - - return firOpBuilder.create<mlir::omp::CriticalOp>( - currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), - global.getSymName())); - }(); - auto genInfo = OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval); - createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, genInfo); + const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); + const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t); + mlir::Location currentLocation = converter.getCurrentLocation(); + genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, + clauseList, name); } static void @@ -2129,7 +2190,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t); - const auto &loopOpClauseList = + const auto &beginClauseList = std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); @@ -2150,33 +2211,31 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, bool validDirective = false; if (llvm::omp::topTaskloopSet.test(ompDirective)) { validDirective = true; - TODO(currentLocation, "Taskloop construct"); + genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList); } else { // Create omp.{target, teams, distribute, parallel} nested operations if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genTargetOp(converter, semaCtx, eval, /*genNested=*/false, - currentLocation, loopOpClauseList, ompDirective, - /*outerCombined=*/true); + currentLocation, beginClauseList, /*outerCombined=*/true); } if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation, - loopOpClauseList, - /*outerCombined=*/true); + beginClauseList, /*outerCombined=*/true); } if (llvm::omp::allDistributeSet.test(ompDirective)) { validDirective = true; - TODO(currentLocation, "Distribute construct"); + genDistributeOp(converter, semaCtx, eval, /*genNested=*/false, + currentLocation, beginClauseList); } if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet) .test(ompDirective)) { validDirective = true; genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false, - currentLocation, loopOpClauseList, - /*outerCombined=*/true); + currentLocation, beginClauseList, /*outerCombined=*/true); } } if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective)) @@ -2190,17 +2249,14 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, if (llvm::omp::allDoSimdSet.test(ompDirective)) { // 2.9.3.2 Workshare SIMD construct - createSimdWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - endClauseList, currentLocation); - + genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList, + endClauseList, currentLocation); } else if (llvm::omp::allSimdSet.test(ompDirective)) { // 2.9.3.1 SIMD construct - createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - currentLocation); - genOpenMPReduction(converter, semaCtx, loopOpClauseList); + genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList); } else { - createWsloop(converter, semaCtx, eval, ompDirective, loopOpClauseList, - endClauseList, currentLocation); + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, + endClauseList); } } @@ -2220,44 +2276,39 @@ genOMP(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - mlir::Location currentLocation = converter.getCurrentLocation(); - mlir::omp::SectionsClauseOps clauseOps; const auto &beginSectionsDirective = std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t); - const auto §ionsClauseList = + const auto &beginClauseList = std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t); // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region - ClauseProcessor cp(converter, semaCtx, sectionsClauseList); - cp.processSectionsReduction(currentLocation, clauseOps); - cp.processAllocate(clauseOps); - // TODO Support delayed privatization. + mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::omp::SectionsClauseOps clauseOps; + genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation, + /*clausesFromBeginSections=*/true, clauseOps); + // Parallel wrapper of PARALLEL SECTIONS construct llvm::omp::Directive dir = std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t) .v; - - // Parallel wrapper of PARALLEL SECTIONS construct if (dir == llvm::omp::Directive::OMPD_parallel_sections) { genParallelOp(converter, symTable, semaCtx, eval, - /*genNested=*/false, currentLocation, sectionsClauseList, + /*genNested=*/false, currentLocation, beginClauseList, /*outerCombined=*/true); } else { const auto &endSectionsDirective = std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); - const auto &endSectionsClauseList = + const auto &endClauseList = std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t); - ClauseProcessor(converter, semaCtx, endSectionsClauseList) - .processNowait(clauseOps); + genSectionsClauses(converter, semaCtx, endClauseList, currentLocation, + /*clausesFromBeginSections=*/false, clauseOps); } - // SECTIONS construct - genOpWithBody<mlir::omp::SectionsOp>( - OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval) - .setGenNested(false), - clauseOps); + // SECTIONS construct. + genSectionsOp(converter, semaCtx, eval, currentLocation, clauseOps); + // Generate nested SECTION operations recursively. const auto §ionBlocks = std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t); auto &firOpBuilder = converter.getFirOpBuilder(); @@ -2266,40 +2317,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, - sectionsClauseList); + beginClauseList); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); } } -static void -genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::SymMap &symTable, - Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { - std::visit( - Fortran::common::visitors{ - [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct - &simpleStandaloneConstruct) { - genOmpSimpleStandalone(converter, semaCtx, eval, - /*genNested=*/true, - simpleStandaloneConstruct); - }, - [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { - genOmpFlush(converter, semaCtx, eval, flushConstruct); - }, - [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - [&](const Fortran::parser::OpenMPCancellationPointConstruct - &cancellationPointConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); - }, - }, - standaloneConstruct.u); -} - static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semaCtx, |