aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
diff options
context:
space:
mode:
authorChristopher Bate <cbate@nvidia.com>2023-06-30 16:04:08 -0600
committerChristopher Bate <cbate@nvidia.com>2023-07-03 13:26:51 -0600
commit14858cf05dc7cbc0f34629d693b0039c3d15c34f (patch)
tree3bbac9292454b7a4a842f45267b1f4a38ccfff3b /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
parent6a66673765b2bf45f412ab4261a72704805dd526 (diff)
downloadllvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.zip
llvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.tar.gz
llvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.tar.bz2
[mlir][Conversion/GPUCommon] Fix bug in conversion of `math` ops
The common GPU operation transformation that lowers `math` operations to function calls in the `gpu-to-nvvm` and `gpu-to-rocdl` passes handles `vector` types by applying the function to each scalar and returning a new vector. However, there was a typo that results in incorrectly accumulating the result vector, and the rewrite returns an `llvm.mlir.undef` result instead of the correct vector. A patch is added and tests are strengthened. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D154269
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp4
1 files changed, 2 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 38b7248..2fe1c7c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -485,8 +485,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
auto scalarOperands = llvm::map_to_vector(operands, extractElement);
Operation *scalarOp =
rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
- rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0),
- index);
+ result = rewriter.create<LLVM::InsertElementOp>(
+ loc, result, scalarOp->getResult(0), index);
}
rewriter.replaceOp(op, result);