aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul C Fuqua <paul.fuqua@amd.com>2023-12-20 09:35:42 -0600
committerGitHub <noreply@github.com>2023-12-20 09:35:42 -0600
commit11141bc68adc311afad1ff130e4fbbd1e3062e05 (patch)
tree29a4dbca5246aad8565b3b7edf6ce92987bb1134
parent300adbee88c53caef833cc240195b722cf76961d (diff)
downloadllvm-11141bc68adc311afad1ff130e4fbbd1e3062e05.zip
llvm-11141bc68adc311afad1ff130e4fbbd1e3062e05.tar.gz
llvm-11141bc68adc311afad1ff130e4fbbd1e3062e05.tar.bz2
Fix what seems to be a silly bug in gpu.set_default_device rewriting. Smoke test included. (#75756)
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp5
-rw-r--r--mlir/test/Conversion/GPUCommon/set-default-device.mlir10
2 files changed, 13 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index b68baff..94df376 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1334,8 +1334,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.getDevIndex()});
- rewriter.replaceOp(op, {});
+ auto call = setDefaultDeviceCallBuilder.create(loc, rewriter,
+ {adaptor.getDevIndex()});
+ rewriter.replaceOp(op, call);
return success();
}
diff --git a/mlir/test/Conversion/GPUCommon/set-default-device.mlir b/mlir/test/Conversion/GPUCommon/set-default-device.mlir
new file mode 100644
index 0000000..c23d8a3
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/set-default-device.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+ // CHECK-LABEL: func @set_default_device
+ func.func @set_default_device(%arg0: i32) {
+ // CHECK: mgpuSetDefaultDevice
+ gpu.set_default_device %arg0
+ return
+ }
+}