aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Sigg <csigg@google.com>2021-02-02 10:43:32 +0100
committerChristian Sigg <csigg@google.com>2021-02-03 20:00:36 +0100
commit8a43ec7faa274257325823f312ca2e7657c79785 (patch)
tree4be98da868cd28b365e54eeb3cb311fc40972e84
parent39fbb5c3e307ac06c7ca83aca8e3c76ed99b25f3 (diff)
downloadllvm-8a43ec7faa274257325823f312ca2e7657c79785.zip
llvm-8a43ec7faa274257325823f312ca2e7657c79785.tar.gz
llvm-8a43ec7faa274257325823f312ca2e7657c79785.tar.bz2
Set GPU context before {cu,hip}MemHostRegister.
Differential Revision: https://reviews.llvm.org/D95856
-rw-r--r--mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp1
-rw-r--r--mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp1
2 files changed, 2 insertions, 0 deletions
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index d4360de..b8554bb 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -148,6 +148,7 @@ extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
// Allows to register byte array with the CUDA runtime. Helpful until we have
// transfer functions implemented.
extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+ ScopedContext scopedContext;
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
}
diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index cf3c757..361ba8f 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -148,6 +148,7 @@ extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
// Allows to register byte array with the ROCM runtime. Helpful until we have
// transfer functions implemented.
extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+ ScopedContext scopedContext;
HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
}