diff options
Diffstat (limited to 'flang')
-rw-r--r-- | flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 14 | ||||
-rw-r--r-- | flang/test/Lower/CUDA/cuda-device-proc.cuf | 9 |
2 files changed, 14 insertions, 9 deletions
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 8aed288..4988b6b 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -6508,12 +6508,13 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType, } static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc, - llvm::StringRef funcName, + llvm::StringRef funcName, mlir::Type resTy, llvm::ArrayRef<mlir::Value> args) { mlir::MLIRContext *context = builder.getContext(); mlir::Type i32Ty = builder.getI32Type(); + mlir::Type i1Ty = builder.getI1Type(); mlir::FunctionType ftype = - mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty}); + mlir::FunctionType::get(context, {i32Ty, i1Ty}, {resTy}); auto funcOp = builder.createFunction(loc, funcName, ftype); llvm::SmallVector<mlir::Value> filteredArgs; return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0); @@ -6523,14 +6524,16 @@ static mlir::Value genVoteSync(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { assert(args.size() == 2); - return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", args); + return genVoteSync(builder, loc, "llvm.nvvm.vote.all.sync", + builder.getI1Type(), args); } // ANY_SYNC mlir::Value IntrinsicLibrary::genVoteAnySync(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { assert(args.size() == 2); - return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", args); + return genVoteSync(builder, loc, "llvm.nvvm.vote.any.sync", + builder.getI1Type(), args); } // BALLOT_SYNC @@ -6538,7 +6541,8 @@ mlir::Value IntrinsicLibrary::genVoteBallotSync(mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) { assert(args.size() == 2); - return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", args); + return genVoteSync(builder, loc, "llvm.nvvm.vote.ballot.sync", + builder.getI32Type(), args); } // MATCH_ANY_SYNC diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 6a7fee7..a4a4750 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -297,10 +297,11 @@ end ! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> () attributes(device) subroutine testVote() - integer :: a, ipred, mask, v32 - a = all_sync(mask, v32) - a = any_sync(mask, v32) - a = ballot_sync(mask, v32) + integer :: a, ipred, mask + logical(4) :: pred + a = all_sync(mask, pred) + a = any_sync(mask, pred) + a = ballot_sync(mask, pred) end subroutine ! CHECK-LABEL: func.func @_QPtestvote() |