diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 5cda6a0..7505507 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -74,17 +74,20 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - // We expect the codegen to avoid doing implicit bitcast from a load. - assert(TargetType->getElementType() == SourceType->getElementType()); - assert(TargetType->getNumElements() < SourceType->getNumElements()); - + assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); + Value *AssignValue = NewLoad; + if (TargetType->getElementType() != SourceType->getElementType()) { + AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, + {TargetType, SourceType}, {NewLoad}); + buildAssignType(B, TargetType, AssignValue); + } SmallVector<int> Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; - Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask); + Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask); buildAssignType(B, TargetType, Output); return Output; } @@ -135,8 +138,9 @@ class SPIRVLegalizePointerCast : public FunctionPass { Output = loadFirstValueFromAggregate(B, SVT->getElementType(), OriginalOperand, LI); } - // Destination is a smaller vector than source. + // Destination is a smaller vector than source or different vector type. // - float3 v3 = vector4; + // - float4 v2 = int4; else if (SVT && DVT) Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand); // Destination is the scalar type stored at the start of an aggregate. |