aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp25
1 files changed, 24 insertions, 1 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 7505507..ebd957c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -188,8 +188,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
FixedVectorType *DstType =
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
- assert(DstType->getNumElements() >= SrcType->getNumElements());
+ auto dstNumElements = DstType->getNumElements();
+ auto srcNumElements = SrcType->getNumElements();
+
+ // if the element type differs, it is a bitcast.
+ if (DstType->getElementType() != SrcType->getElementType()) {
+ // Support bitcast between vectors of different sizes only if
+ // the total bitwidth is the same.
+ auto dstBitWidth =
+ DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
+ auto srcBitWidth =
+ SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
+ assert(dstBitWidth == srcBitWidth &&
+ "Unsupported bitcast between vectors of different sizes.");
+
+ Src =
+ B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
+ buildAssignType(B, DstType, Src);
+ SrcType = DstType;
+
+ StoreInst *SI = B.CreateStore(Src, Dst);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+ assert(DstType->getNumElements() >= SrcType->getNumElements());
LoadInst *LI = B.CreateLoad(DstType, Dst);
LI->setAlignment(Alignment);
Value *OldValues = LI;