diff options
| author | Shilei Tian <i@tianshilei.me> | 2024-06-03 11:17:36 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-03 11:17:36 -0400 |
| commit | b448efb8eafef7df2c8d467bbb9cd0fc1e2ea7d5 (patch) | |
| tree | 7f5cefe0604ef7bab28f6e5f46e808340975fdb2 /openmp/runtime/src/include | |
| parent | 539dbfcfcf5705cf100999ad2483318192418e21 (diff) | |
| download | llvm-b448efb8eafef7df2c8d467bbb9cd0fc1e2ea7d5.zip llvm-b448efb8eafef7df2c8d467bbb9cd0fc1e2ea7d5.tar.gz llvm-b448efb8eafef7df2c8d467bbb9cd0fc1e2ea7d5.tar.bz2 | |
Reapply "[OpenMP][OMPX] Add shfl_down_sync (#93311)" (#94139)
Diffstat (limited to 'openmp/runtime/src/include')
| -rw-r--r-- | openmp/runtime/src/include/ompx.h.var | 68 |
1 files changed, 60 insertions, 8 deletions
diff --git a/openmp/runtime/src/include/ompx.h.var b/openmp/runtime/src/include/ompx.h.var index 1985188..623f0b9 100644 --- a/openmp/runtime/src/include/ompx.h.var +++ b/openmp/runtime/src/include/ompx.h.var @@ -9,6 +9,12 @@ #ifndef __OMPX_H #define __OMPX_H +#ifdef __AMDGCN_WAVEFRONT_SIZE +#define __WARP_SIZE __AMDGCN_WAVEFRONT_SIZE +#else +#define __WARP_SIZE 32 +#endif + typedef unsigned long uint64_t; #ifdef __cplusplus @@ -75,11 +81,11 @@ _TGT_KERNEL_LANGUAGE_HOST_IMPL_GRID_C(grid_dim, 1) static inline RETTY ompx_##NAME(ARGS) { BODY; } _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C(void, sync_block, int Ordering, - _Pragma("omp barrier")); + _Pragma("omp barrier")) _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C(void, sync_block_acq_rel, void, - ompx_sync_block(ompx_acq_rel)); + ompx_sync_block(ompx_acq_rel)) _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C(void, sync_block_divergent, int Ordering, - ompx_sync_block(Ordering)); + ompx_sync_block(Ordering)) #undef _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_C ///} @@ -87,6 +93,22 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) { __builtin_trap(); } +/// ompx_shfl_down_sync_{i,f,l,d} +///{ +#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(TYPE, TY) \ + static inline TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, \ + unsigned delta, int width) { \ + __builtin_trap(); \ + } + +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(int, i) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(float, f) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(long, l) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL(double, d) + +#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC_HOST_IMPL +///} + #pragma omp end declare variant /// ompx_{sync_block}_{,divergent} @@ -94,9 +116,9 @@ static inline uint64_t ompx_ballot_sync(uint64_t mask, int pred) { #define _TGT_KERNEL_LANGUAGE_DECL_SYNC_C(RETTY, NAME, ARGS) \ RETTY ompx_##NAME(ARGS); -_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block, int Ordering); -_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block_acq_rel, void); -_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block_divergent, int Ordering); +_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block, int Ordering) +_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block_acq_rel, void) +_TGT_KERNEL_LANGUAGE_DECL_SYNC_C(void, sync_block_divergent, int Ordering) #undef _TGT_KERNEL_LANGUAGE_DECL_SYNC_C ///} @@ -117,6 +139,20 @@ _TGT_KERNEL_LANGUAGE_DECL_GRID_C(grid_dim) uint64_t ompx_ballot_sync(uint64_t mask, int pred); +/// ompx_shfl_down_sync_{i,f,l,d} +///{ +#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \ + TYPE ompx_shfl_down_sync_##TY(uint64_t mask, TYPE var, unsigned delta, \ + int width); + +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d) + +#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC +///} + #ifdef __cplusplus } #endif @@ -162,9 +198,9 @@ _TGT_KERNEL_LANGUAGE_HOST_IMPL_GRID_CXX(grid_dim) } _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_CXX(void, sync_block, int Ordering = acc_rel, - Ordering); + Ordering) _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_CXX(void, sync_block_divergent, - int Ordering = acc_rel, Ordering); + int Ordering = acc_rel, Ordering) #undef _TGT_KERNEL_LANGUAGE_HOST_IMPL_SYNC_CXX ///} @@ -172,6 +208,22 @@ static inline uint64_t ballot_sync(uint64_t mask, int pred) { return ompx_ballot_sync(mask, pred); } +/// shfl_down_sync +///{ +#define _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(TYPE, TY) \ + static inline TYPE shfl_down_sync(uint64_t mask, TYPE var, unsigned delta, \ + int width = __WARP_SIZE) { \ + return ompx_shfl_down_sync_##TY(mask, var, delta, width); \ + } + +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(int, i) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(float, f) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(long, l) +_TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC(double, d) + +#undef _TGT_KERNEL_LANGUAGE_SHFL_DOWN_SYNC +///} + } // namespace ompx #endif |
