diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 102 |
1 files changed, 88 insertions, 14 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 646d0ed73084..08ec57803aff 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -804,13 +804,13 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, /// Allocate space for privatized reduction variables. template <typename T> -static void -allocByValReductionVars(T loop, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation, - llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, - SmallVector<omp::DeclareReductionOp> &reductionDecls, - SmallVector<llvm::Value *> &privateReductionVariables, - DenseMap<Value, llvm::Value *> &reductionVariableMap) { +static void allocByValReductionVars( + T loop, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, + SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, + SmallVectorImpl<llvm::Value *> &privateReductionVariables, + DenseMap<Value, llvm::Value *> &reductionVariableMap) { llvm::IRBuilderBase::InsertPointGuard guard(builder); builder.restoreIP(allocaIP); auto args = @@ -825,6 +825,25 @@ allocByValReductionVars(T loop, llvm::IRBuilderBase &builder, } } +/// Map input argument to all reduction initialization regions +template <typename T> +static void +mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation, + SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, + unsigned i) { + // map input argument to the initialization region + mlir::omp::DeclareReductionOp &reduction = reductionDecls[i]; + Region &initializerRegion = reduction.getInitializerRegion(); + Block &entry = initializerRegion.front(); + assert(entry.getNumArguments() == 1 && + "the initialization region has one argument"); + + mlir::Value mlirSource = loop.getReductionVars()[i]; + llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource); + assert(llvmSource && "lookup reduction var"); + moduleTranslation.mapValue(entry.getArgument(0), llvmSource); +} + /// Collect reduction info template <typename T> static void collectReductionInfo( @@ -858,6 +877,32 @@ static void collectReductionInfo( } } +/// handling of DeclareReductionOp's cleanup region +static LogicalResult inlineReductionCleanup( + llvm::SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, + llvm::ArrayRef<llvm::Value *> privateReductionVariables, + LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) { + for (auto [i, reductionDecl] : llvm::enumerate(reductionDecls)) { + Region &cleanupRegion = reductionDecl.getCleanupRegion(); + if (cleanupRegion.empty()) + continue; + + // map the argument to the cleanup region + Block &entry = cleanupRegion.front(); + moduleTranslation.mapValue(entry.getArgument(0), + privateReductionVariables[i]); + + if (failed(inlineConvertOmpRegions(cleanupRegion, "omp.reduction.cleanup", + builder, moduleTranslation))) + return failure(); + + // clear block argument mapping in case it needs to be re-created with a + // different source for another use of the same reduction decl + moduleTranslation.forgetMapping(cleanupRegion); + } + return success(); +} + /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, @@ -902,6 +947,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, loop.getRegion().getArguments().take_back(loop.getNumReductionVars()); for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) { SmallVector<llvm::Value *> phis; + + // map block argument to initializer region + mapInitializationArg(loop, moduleTranslation, reductionDecls, i); + if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral", builder, moduleTranslation, &phis))) @@ -925,6 +974,11 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, builder.CreateStore(phis[0], privateReductionVariables[i]); // the rest was handled in allocByValReductionVars } + + // forget the mapping for the initializer region because we might need a + // different mapping if this reduction declaration is re-used for a + // different variable + moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion()); } // Store the mapping between reduction variables and their private copies on @@ -1044,7 +1098,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, tempTerminator->eraseFromParent(); builder.restoreIP(nextInsertionPoint); - return success(); + // after the workshare loop, deallocate private reduction variables + return inlineReductionCleanup(reductionDecls, privateReductionVariables, + moduleTranslation, builder); } /// A RAII class that on construction replaces the region arguments of the @@ -1097,13 +1153,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, LogicalResult bodyGenStatus = success(); llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); - auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { - // Collect reduction declarations - SmallVector<omp::DeclareReductionOp> reductionDecls; - collectReductionDecls(opInst, reductionDecls); + // Collect reduction declarations + SmallVector<omp::DeclareReductionOp> reductionDecls; + collectReductionDecls(opInst, reductionDecls); + SmallVector<llvm::Value *> privateReductionVariables; + auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { // Allocate reduction vars - SmallVector<llvm::Value *> privateReductionVariables; DenseMap<Value, llvm::Value *> reductionVariableMap; if (!isByRef) { allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP, @@ -1118,6 +1174,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, opInst.getNumReductionVars()); for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) { SmallVector<llvm::Value *> phis; + + // map the block argument + mapInitializationArg(opInst, moduleTranslation, reductionDecls, i); if (failed(inlineConvertOmpRegions( reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral", builder, moduleTranslation, &phis))) @@ -1144,6 +1203,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, builder.CreateStore(phis[0], privateReductionVariables[i]); // the rest is done in allocByValReductionVars } + + // clear block argument mapping in case it needs to be re-created with a + // different source for another use of the same reduction decl + moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion()); } // Store the mapping between reduction variables and their private copies on @@ -1296,7 +1359,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // TODO: Perform finalization actions for variables. This has to be // called for variables which have destructors/finalizers. - auto finiCB = [&](InsertPointTy codeGenIP) {}; + auto finiCB = [&](InsertPointTy codeGenIP) { + InsertPointTy oldIP = builder.saveIP(); + builder.restoreIP(codeGenIP); + + // if the reduction has a cleanup region, inline it here to finalize the + // reduction variables + if (failed(inlineReductionCleanup(reductionDecls, privateReductionVariables, + moduleTranslation, builder))) + bodyGenStatus = failure(); + + builder.restoreIP(oldIP); + }; llvm::Value *ifCond = nullptr; if (auto ifExprVar = opInst.getIfExprVar()) |