diff options
author | Christopher Bate <cbate@nvidia.com> | 2023-06-30 16:04:08 -0600 |
---|---|---|
committer | Christopher Bate <cbate@nvidia.com> | 2023-07-03 13:26:51 -0600 |
commit | 14858cf05dc7cbc0f34629d693b0039c3d15c34f (patch) | |
tree | 3bbac9292454b7a4a842f45267b1f4a38ccfff3b /mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | |
parent | 6a66673765b2bf45f412ab4261a72704805dd526 (diff) | |
download | llvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.zip llvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.tar.gz llvm-14858cf05dc7cbc0f34629d693b0039c3d15c34f.tar.bz2 |
[mlir][Conversion/GPUCommon] Fix bug in conversion of `math` ops
The common GPU operation transformation that lowers `math` operations
to function calls in the `gpu-to-nvvm` and `gpu-to-rocdl` passes handles
`vector` types by applying the function to each scalar and returning a
new vector. However, there was a typo that results in incorrectly
accumulating the result vector, and the rewrite returns an `llvm.mlir.undef`
result instead of the correct vector. A patch is added and tests are
strengthened.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D154269
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 38b7248..2fe1c7c 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -485,8 +485,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - rewriter.create<LLVM::InsertElementOp>(loc, result, scalarOp->getResult(0), - index); + result = rewriter.create<LLVM::InsertElementOp>( + loc, result, scalarOp->getResult(0), index); } rewriter.replaceOp(op, result); |