aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/AtomicExpandPass.cpp
diff options
context:
space:
mode:
authorMatt Arsenault <Matthew.Arsenault@amd.com>2024-04-06 15:27:45 -0400
committerGitHub <noreply@github.com>2024-04-06 15:27:45 -0400
commit4cb110a84f587d3c65b85d79ab6fc8aa5489fb86 (patch)
treedf02e1df31b98f5b2a241401d465c5ecae304a73 /llvm/lib/CodeGen/AtomicExpandPass.cpp
parentbd589f5c7a079d8829fcf994b746634eaaea24ff (diff)
downloadllvm-4cb110a84f587d3c65b85d79ab6fc8aa5489fb86.zip
llvm-4cb110a84f587d3c65b85d79ab6fc8aa5489fb86.tar.gz
llvm-4cb110a84f587d3c65b85d79ab6fc8aa5489fb86.tar.bz2
[RFC] IR: Support atomicrmw FP ops with vector types (#86796)
Allow using atomicrmw fadd, fsub, fmin, and fmax with vectors of floating-point type. AMDGPU supports atomic fadd for <2 x half> and <2 x bfloat> on some targets and address spaces. Note this only supports the proper floating-point operations; float vector typed xchg is still not supported. cmpxchg still only supports integers, so this inserts bitcasts for the loop expansion. I have support for fp vector typed xchg, and vector of int/ptr separately implemented but I don't have an immediate need for those beyond feature consistency.
Diffstat (limited to 'llvm/lib/CodeGen/AtomicExpandPass.cpp')
-rw-r--r--llvm/lib/CodeGen/AtomicExpandPass.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/AtomicExpandPass.cpp b/llvm/lib/CodeGen/AtomicExpandPass.cpp
index d5db79d..0aa89ea 100644
--- a/llvm/lib/CodeGen/AtomicExpandPass.cpp
+++ b/llvm/lib/CodeGen/AtomicExpandPass.cpp
@@ -562,9 +562,9 @@ static void createCmpXchgInstFun(IRBuilderBase &Builder, Value *Addr,
Value *&Success, Value *&NewLoaded) {
Type *OrigTy = NewVal->getType();
- // This code can go away when cmpxchg supports FP types.
+ // This code can go away when cmpxchg supports FP and vector types.
assert(!OrigTy->isPointerTy());
- bool NeedBitcast = OrigTy->isFloatingPointTy();
+ bool NeedBitcast = OrigTy->isFloatingPointTy() || OrigTy->isVectorTy();
if (NeedBitcast) {
IntegerType *IntTy = Builder.getIntNTy(OrigTy->getPrimitiveSizeInBits());
NewVal = Builder.CreateBitCast(NewVal, IntTy);
@@ -731,7 +731,7 @@ static PartwordMaskValues createMaskInstrs(IRBuilderBase &Builder,
unsigned ValueSize = DL.getTypeStoreSize(ValueType);
PMV.ValueType = PMV.IntValueType = ValueType;
- if (PMV.ValueType->isFloatingPointTy())
+ if (PMV.ValueType->isFloatingPointTy() || PMV.ValueType->isVectorTy())
PMV.IntValueType =
Type::getIntNTy(Ctx, ValueType->getPrimitiveSizeInBits());