diff options
Diffstat (limited to 'mlir')
3 files changed, 32 insertions, 17 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index b7b2f8c..cbfc649 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -770,7 +770,7 @@ static void allocByValReductionVars( DenseMap<Value, llvm::Value *> &reductionVariableMap, llvm::ArrayRef<bool> isByRefs) { llvm::IRBuilderBase::InsertPointGuard guard(builder); - builder.restoreIP(allocaIP); + builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); auto args = loop.getRegion().getArguments().take_back(loop.getNumReductionVars()); @@ -780,7 +780,7 @@ static void allocByValReductionVars( llvm::Value *var = builder.CreateAlloca( moduleTranslation.convertType(reductionDecls[i].getType())); moduleTranslation.mapValue(args[i], var); - privateReductionVariables.push_back(var); + privateReductionVariables[i] = var; reductionVariableMap.try_emplace(loop.getReductionVars()[i], var); } } @@ -911,7 +911,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - SmallVector<llvm::Value *> privateReductionVariables; + SmallVector<llvm::Value *> privateReductionVariables( + wsloopOp.getNumReductionVars()); DenseMap<Value, llvm::Value *> reductionVariableMap; allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, privateReductionVariables, @@ -942,7 +943,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // ptr builder.CreateStore(phis[0], var); - privateReductionVariables.push_back(var); + privateReductionVariables[i] = var; moduleTranslation.mapValue(reductionArgs[i], phis[0]); reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]); } else { @@ -1140,7 +1141,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Collect reduction declarations SmallVector<omp::DeclareReductionOp> reductionDecls; collectReductionDecls(opInst, reductionDecls); - SmallVector<llvm::Value *> privateReductionVariables; + SmallVector<llvm::Value *> privateReductionVariables( + opInst.getNumReductionVars()); auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { // Allocate reduction vars @@ -1154,6 +1156,21 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, MutableArrayRef<BlockArgument> reductionArgs = opInst.getRegion().getArguments().take_back( opInst.getNumReductionVars()); + + llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init"); + allocaIP = + InsertPointTy(allocaIP.getBlock(), + allocaIP.getBlock()->getTerminator()->getIterator()); + SmallVector<llvm::Value *> byRefVars(opInst.getNumReductionVars()); + for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) { + if (isByRef[i]) { + // Allocate reduction variable (which is a pointer to the real reduciton + // variable allocated in the inlined region) + byRefVars[i] = builder.CreateAlloca( + moduleTranslation.convertType(reductionDecls[i].getType())); + } + } + for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) { SmallVector<llvm::Value *> phis; @@ -1166,18 +1183,14 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, assert(phis.size() == 1 && "expected one value to be yielded from the " "reduction neutral element declaration region"); - builder.restoreIP(allocaIP); + builder.SetInsertPoint(initBlock->getTerminator()); if (isByRef[i]) { - // Allocate reduction variable (which is a pointer to the real reduciton - // variable allocated in the inlined region) - llvm::Value *var = builder.CreateAlloca( - moduleTranslation.convertType(reductionDecls[i].getType())); // Store the result of the inlined region to the allocated reduction var // ptr - builder.CreateStore(phis[0], var); + builder.CreateStore(phis[0], byRefVars[i]); - privateReductionVariables.push_back(var); + privateReductionVariables[i] = byRefVars[i]; moduleTranslation.mapValue(reductionArgs[i], phis[0]); reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]); } else { diff --git a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir index 84a487c..8afa89f 100644 --- a/mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir +++ b/mlir/test/Target/LLVMIR/openmp-parallel-reduction-cleanup.mlir @@ -55,11 +55,11 @@ // Private reduction variable and its initialization. // CHECK: %tid.addr.local = alloca i32 -// CHECK: %[[MALLOC_I:.+]] = call ptr @malloc(i64 4) // CHECK: %[[PRIV_PTR_I:.+]] = alloca ptr +// CHECK: %[[PRIV_PTR_J:.+]] = alloca ptr +// CHECK: %[[MALLOC_I:.+]] = call ptr @malloc(i64 4) // CHECK: store ptr %[[MALLOC_I]], ptr %[[PRIV_PTR_I]] // CHECK: %[[MALLOC_J:.+]] = call ptr @malloc(i64 4) -// CHECK: %[[PRIV_PTR_J:.+]] = alloca ptr // CHECK: store ptr %[[MALLOC_J]], ptr %[[PRIV_PTR_J]] // Call to the reduction function. diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir index f4b77cb..361905f 100644 --- a/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir +++ b/mlir/test/Target/LLVMIR/openmp-reduction-init-arg.mlir @@ -59,13 +59,15 @@ module { // CHECK: %[[VAL_17:.*]] = load i32, ptr %[[VAL_18:.*]], align 4 // CHECK: store i32 %[[VAL_17]], ptr %[[VAL_16]], align 4 // CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_16]], align 4 -// CHECK: %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_13]], align 8 // CHECK: %[[VAL_21:.*]] = alloca ptr, align 8 +// CHECK: %[[VAL_23:.*]] = alloca ptr, align 8 +// CHECK: %[[VAL_20:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_13]], align 8 +// CHECK: %[[VAL_24:.*]] = alloca [2 x ptr], align 8 +// CHECK: br label %[[INIT_LABEL:.*]] +// CHECK: [[INIT_LABEL]]: // CHECK: store ptr %[[VAL_13]], ptr %[[VAL_21]], align 8 // CHECK: %[[VAL_22:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[VAL_15]], align 8 -// CHECK: %[[VAL_23:.*]] = alloca ptr, align 8 // CHECK: store ptr %[[VAL_15]], ptr %[[VAL_23]], align 8 -// CHECK: %[[VAL_24:.*]] = alloca [2 x ptr], align 8 // CHECK: br label %[[VAL_25:.*]] // CHECK: omp.par.region: ; preds = %[[VAL_26:.*]] // CHECK: br label %[[VAL_27:.*]] |