diff options
author | Vyacheslav Levytskyy <vyacheslav.levytskyy@intel.com> | 2024-03-13 08:32:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-13 08:32:01 +0100 |
commit | 0a443f13b49b3f392461a0bb60b0146cfc4607c7 (patch) | |
tree | 960853562f89e2452d82cd9879e03b153def0edc /llvm/lib/Target/SPIRV | |
parent | cd2f6163137dce45d909aa445cfd57b7188f8ed1 (diff) | |
download | llvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.zip llvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.tar.gz llvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.tar.bz2 |
[SPIR-V] Add implementation of G_SPLAT_VECTOR opcode and fix invalid types processing (#84766)
This PR:
* adds support for G_SPLAT_VECTOR generic opcode that may be legally
generated instead of G_BUILD_VECTOR by previous passes of the translator
(see https://github.com/llvm/llvm-project/pull/80378 for the source of
breaking changes);
* improves deduction of types for opaque pointers.
This PR also fixes the following issues:
* if a function has ptr argument(s), two functions that have different
SPIR-V type definitions may get identical LLVM function types and break
agreements of global register and duplicate checker;
* checks for pointer types do not account for TypedPointerType.
Update of tests:
* A test case is added to cover the issue with function ptr parameters.
* The first case, that is support for G_SPLAT_VECTOR generic opcode, is
covered by existing test cases.
* Multiple additional checks by `spirv-val` is added to cover more
possibilities of generation of invalid code.
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 49 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 136 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 32 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 3 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 41 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVUtils.h | 26 |
7 files changed, 236 insertions, 55 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 2d7a00b..f1fbe2b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -85,6 +85,42 @@ static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { return nullptr; } +// If the function has pointer arguments, we are forced to re-create this +// function type from the very beginning, changing PointerType by +// TypedPointerType for each pointer argument. Otherwise, the same `Type*` +// potentially corresponds to different SPIR-V function type, effectively +// invalidating logic behind global registry and duplicates tracker. +static FunctionType * +fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F, + FunctionType *FTy, const SPIRVType *SRetTy, + const SmallVector<SPIRVType *, 4> &SArgTys) { + if (F.getParent()->getNamedMetadata("spv.cloned_funcs")) + return FTy; + + bool hasArgPtrs = false; + for (auto &Arg : F.args()) { + // check if it's an instance of a non-typed PointerType + if (Arg.getType()->isPointerTy()) { + hasArgPtrs = true; + break; + } + } + if (!hasArgPtrs) { + Type *RetTy = FTy->getReturnType(); + // check if it's an instance of a non-typed PointerType + if (!RetTy->isPointerTy()) + return FTy; + } + + // re-create function type, using TypedPointerType instead of PointerType to + // properly trace argument types + const Type *RetTy = GR->getTypeForSPIRVType(SRetTy); + SmallVector<Type *, 4> ArgTys; + for (auto SArgTy : SArgTys) + ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy))); + return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false); +} + // This code restores function args/retvalue types for composite cases // because the final types should still be aggregate whereas they're i32 // during the translation to cope with aggregate flattening etc. @@ -162,7 +198,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot // be legally reassigned later). - if (!OriginalArgType->isPointerTy()) + if (!isPointerTy(OriginalArgType)) return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); // In case OriginalArgType is of pointer type, there are three possibilities: @@ -179,8 +215,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder); return GR->getOrCreateSPIRVPointerType( ElementType, MIRBuilder, - addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(), - ST)); + addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST)); } for (auto User : Arg->users()) { @@ -240,7 +275,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget()); // Assign types and names to all args, and store their types for later. - FunctionType *FTy = getOriginalFunctionType(F); SmallVector<SPIRVType *, 4> ArgTypeVRegs; if (VRegs.size() > 0) { unsigned i = 0; @@ -255,7 +289,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, if (Arg.hasName()) buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); - if (Arg.getType()->isPointerTy()) { + if (isPointerTy(Arg.getType())) { auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); if (DerefBytes != 0) buildOpDecorate(VRegs[i][0], MIRBuilder, @@ -322,7 +356,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); + FunctionType *FTy = getOriginalFunctionType(F); SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); + FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs); SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( FTy, RetTy, ArgTypeVRegs, MIRBuilder); uint32_t FuncControl = getFunctionControl(F); @@ -429,7 +465,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, return false; MachineFunction &MF = MIRBuilder.getMF(); GR->setCurrentFunc(MF); - FunctionType *FTy = nullptr; const Function *CF = nullptr; std::string DemangledName; const Type *OrigRetTy = Info.OrigRet.Ty; @@ -444,7 +479,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // TODO: support constexpr casts and indirect calls. if (CF == nullptr) return false; - if ((FTy = getOriginalFunctionType(*CF)) != nullptr) + if (FunctionType *FTy = getOriginalFunctionType(*CF)) OrigRetTy = FTy->getReturnType(); } diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 575e903..c5b9012 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -57,8 +57,14 @@ class SPIRVEmitIntrinsics bool TrackConstants = true; DenseMap<Instruction *, Constant *> AggrConsts; DenseSet<Instruction *> AggrStores; + + // deduce values type + DenseMap<Value *, Type *> DeducedElTys; + Type *deduceElementType(Value *I); + void preprocessCompositeConstants(IRBuilder<> &B); void preprocessUndefs(IRBuilder<> &B); + CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types, Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms, IRBuilder<> &B) { @@ -72,6 +78,7 @@ class SPIRVEmitIntrinsics Args.push_back(Imm); return B.CreateIntrinsic(IntrID, {Types}, Args); } + void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B); void processInstrAfterVisit(Instruction *I, IRBuilder<> &B); void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B); @@ -156,6 +163,48 @@ static inline void reportFatalOnTokenType(const Instruction *I) { false); } +// Deduce and return a successfully deduced Type of the Instruction, +// or nullptr otherwise. +static Type *deduceElementTypeHelper(Value *I, + std::unordered_set<Value *> &Visited, + DenseMap<Value *, Type *> &DeducedElTys) { + // maybe already known + auto It = DeducedElTys.find(I); + if (It != DeducedElTys.end()) + return It->second; + + // maybe a cycle + if (Visited.find(I) != Visited.end()) + return nullptr; + Visited.insert(I); + + // fallback value in case when we fail to deduce a type + Type *Ty = nullptr; + // look for known basic patterns of type inference + if (auto *Ref = dyn_cast<AllocaInst>(I)) + Ty = Ref->getAllocatedType(); + else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) + Ty = Ref->getResultElementType(); + else if (auto *Ref = dyn_cast<GlobalValue>(I)) + Ty = Ref->getValueType(); + else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) + Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited, + DeducedElTys); + + // remember the found relationship + if (Ty) + DeducedElTys[I] = Ty; + + return Ty; +} + +Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) { + std::unordered_set<Value *> Visited; + if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys)) + return Ty; + return IntegerType::getInt8Ty(I->getContext()); +} + void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B) { @@ -280,7 +329,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) { // varying element types. In case of IR coming from older versions of LLVM // such bitcasts do not provide sufficient information, should be just skipped // here, and handled in insertPtrCastOrAssignTypeInstr. - if (I.getType()->isPointerTy()) { + if (isPointerTy(I.getType())) { I.replaceAllUsesWith(Source); I.eraseFromParent(); return nullptr; @@ -333,20 +382,10 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer)) Pointer = BC->getOperand(0); - // Do not emit spv_ptrcast if Pointer is a GlobalValue of expected type. - GlobalValue *GV = dyn_cast<GlobalValue>(Pointer); - if (GV && GV->getValueType() == ExpectedElementType) - return; - - // Do not emit spv_ptrcast if Pointer is a result of alloca with expected - // type. - AllocaInst *A = dyn_cast<AllocaInst>(Pointer); - if (A && A->getAllocatedType() == ExpectedElementType) - return; - - // Do not emit spv_ptrcast if Pointer is a result of GEP of expected type. - GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Pointer); - if (GEPI && GEPI->getResultElementType() == ExpectedElementType) + // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType + std::unordered_set<Value *> Visited; + Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys); + if (PointerElemTy == ExpectedElementType) return; setInsertPointSkippingPhis(B, I); @@ -356,7 +395,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( ValueAsMetadata::getConstant(ExpectedElementTypeConst); MDTuple *TyMD = MDNode::get(F->getContext(), CM); MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD); - unsigned AddressSpace = Pointer->getType()->getPointerAddressSpace(); + unsigned AddressSpace = getPointerAddressSpace(Pointer->getType()); bool FirstPtrCastOrAssignPtrType = true; // Do not emit new spv_ptrcast if equivalent one already exists or when @@ -401,9 +440,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( // spv_assign_ptr_type instead. if (FirstPtrCastOrAssignPtrType && (isa<Instruction>(Pointer) || isa<Argument>(Pointer))) { - buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()}, - ExpectedElementTypeConst, Pointer, - {B.getInt32(AddressSpace)}, B); + CallInst *CI = buildIntrWithMD( + Intrinsic::spv_assign_ptr_type, {Pointer->getType()}, + ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B); + DeducedElTys[CI] = ExpectedElementType; + DeducedElTys[Pointer] = ExpectedElementType; return; } @@ -419,7 +460,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, // Handle basic instructions: StoreInst *SI = dyn_cast<StoreInst>(I); if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL && - SI->getValueOperand()->getType()->isPointerTy() && + isPointerTy(SI->getValueOperand()->getType()) && isa<Argument>(SI->getValueOperand())) { return replacePointerOperandWithPtrCast( I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0, @@ -440,9 +481,34 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic()) return; + // collect information about formal parameter types + Function *CalledF = CI->getCalledFunction(); + SmallVector<Type *, 4> CalledArgTys; + bool HaveTypes = false; + for (auto &CalledArg : CalledF->args()) { + if (!isPointerTy(CalledArg.getType())) { + CalledArgTys.push_back(nullptr); + continue; + } + auto It = DeducedElTys.find(&CalledArg); + Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr; + if (!ParamTy) { + for (User *U : CalledArg.users()) { + if (Instruction *Inst = dyn_cast<Instruction>(U)) { + std::unordered_set<Value *> Visited; + ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys); + if (ParamTy) + break; + } + } + } + HaveTypes |= ParamTy != nullptr; + CalledArgTys.push_back(ParamTy); + } + std::string DemangledName = getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName()); - if (DemangledName.empty()) + if (DemangledName.empty() && !HaveTypes) return; for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) { @@ -455,8 +521,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand)) continue; - Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType( - DemangledName, OpIdx, I->getContext()); + Type *ExpectedType = + OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr; + if (!ExpectedType && !DemangledName.empty()) + ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType( + DemangledName, OpIdx, I->getContext()); if (!ExpectedType) continue; @@ -639,30 +708,25 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B) { reportFatalOnTokenType(I); - if (!I->getType()->isPointerTy() || !requireAssignType(I) || + if (!isPointerTy(I->getType()) || !requireAssignType(I) || isa<BitCastInst>(I)) return; setInsertPointSkippingPhis(B, I->getNextNode()); - Constant *EltTyConst; - unsigned AddressSpace = I->getType()->getPointerAddressSpace(); - if (auto *AI = dyn_cast<AllocaInst>(I)) - EltTyConst = UndefValue::get(AI->getAllocatedType()); - else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) - EltTyConst = UndefValue::get(GEP->getResultElementType()); - else - EltTyConst = UndefValue::get(IntegerType::getInt8Ty(I->getContext())); - - buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I, - {B.getInt32(AddressSpace)}, B); + Type *ElemTy = deduceElementType(I); + Constant *EltTyConst = UndefValue::get(ElemTy); + unsigned AddressSpace = getPointerAddressSpace(I->getType()); + CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, + EltTyConst, I, {B.getInt32(AddressSpace)}, B); + DeducedElTys[CI] = ElemTy; } void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B) { reportFatalOnTokenType(I); Type *Ty = I->getType(); - if (!Ty->isVoidTy() && !Ty->isPointerTy() && requireAssignType(I)) { + if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) { setInsertPointSkippingPhis(B, I->getNextNode()); Type *TypeToAssign = Ty; if (auto *II = dyn_cast<IntrinsicInst>(I)) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 8556581..bda9c57 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -750,7 +750,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType( SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { - if (TypesInProcessing.count(Ty) && !Ty->isPointerTy()) + if (TypesInProcessing.count(Ty) && !isPointerTy(Ty)) return nullptr; TypesInProcessing.insert(Ty); SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); @@ -762,11 +762,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( // will be added later. For special types it is already added to DT. if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && !isSpecialOpaqueType(Ty)) { - if (!Ty->isPointerTy()) + if (!isPointerTy(Ty)) DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); + else if (isTypedPointerTy(Ty)) + DT.add(cast<TypedPointerType>(Ty)->getElementType(), + getPointerAddressSpace(Ty), &MIRBuilder.getMF(), + getSPIRVTypeID(SpirvType)); else DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), - Ty->getPointerAddressSpace(), &MIRBuilder.getMF(), + getPointerAddressSpace(Ty), &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); } @@ -787,12 +791,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( const Type *Ty, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { Register Reg; - if (!Ty->isPointerTy()) + if (!isPointerTy(Ty)) Reg = DT.find(Ty, &MIRBuilder.getMF()); + else if (isTypedPointerTy(Ty)) + Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(), + getPointerAddressSpace(Ty), &MIRBuilder.getMF()); else Reg = DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), - Ty->getPointerAddressSpace(), &MIRBuilder.getMF()); + getPointerAddressSpace(Ty), &MIRBuilder.getMF()); if (Reg.isValid() && !isSpecialOpaqueType(Ty)) return getSPIRVTypeForVReg(Reg); @@ -836,11 +843,16 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, unsigned SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const { - if (SPIRVType *Type = getSPIRVTypeForVReg(VReg)) - return Type->getOpcode() == SPIRV::OpTypeVector - ? static_cast<unsigned>(Type->getOperand(2).getImm()) - : 1; - return 0; + return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg)); +} + +unsigned +SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const { + if (!Type) + return 0; + return Type->getOpcode() == SPIRV::OpTypeVector + ? static_cast<unsigned>(Type->getOperand(2).getImm()) + : 1; } unsigned diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 9c0061d..25d82eb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -198,9 +198,10 @@ public: // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool). bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const; - // Return number of elements in a vector if the given VReg is associated with + // Return number of elements in a vector if the argument is associated with // a vector type. Return 1 for a scalar type, and 0 for a missing type. unsigned getScalarOrVectorComponentCount(Register VReg) const; + unsigned getScalarOrVectorComponentCount(SPIRVType *Type) const; // For vectors or scalars of booleans, integers and floats, return the scalar // type's bitwidth. Otherwise calls llvm_unreachable(). diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 74df8de..fd19b74 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -125,6 +125,8 @@ private: bool selectConstVector(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectSplatVector(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; bool selectCmp(Register ResVReg, const SPIRVType *ResType, unsigned comparisonOpcode, MachineInstr &I) const; @@ -313,6 +315,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg, case TargetOpcode::G_BUILD_VECTOR: return selectConstVector(ResVReg, ResType, I); + case TargetOpcode::G_SPLAT_VECTOR: + return selectSplatVector(ResVReg, ResType, I); case TargetOpcode::G_SHUFFLE_VECTOR: { MachineBasicBlock &BB = *I.getParent(); @@ -1185,6 +1189,43 @@ bool SPIRVInstructionSelector::selectConstVector(Register ResVReg, return MIB.constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + if (ResType->getOpcode() != SPIRV::OpTypeVector) + report_fatal_error("Cannot select G_SPLAT_VECTOR with a non-vector result"); + unsigned N = GR.getScalarOrVectorComponentCount(ResType); + unsigned OpIdx = I.getNumExplicitDefs(); + if (!I.getOperand(OpIdx).isReg()) + report_fatal_error("Unexpected argument in G_SPLAT_VECTOR"); + + // check if we may construct a constant vector + Register OpReg = I.getOperand(OpIdx).getReg(); + bool IsConst = false; + if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) { + if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE && + OpDef->getOperand(1).isReg()) { + if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg())) + OpDef = RefDef; + } + IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT || + OpDef->getOpcode() == TargetOpcode::G_FCONSTANT; + } + + if (!IsConst && N < 2) + report_fatal_error( + "There must be at least two constituent operands in a vector"); + + auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), + TII.get(IsConst ? SPIRV::OpConstantComposite + : SPIRV::OpCompositeConstruct)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + for (unsigned i = 0; i < N; ++i) + MIB.addUse(OpReg); + return MIB.constrainAllUses(TII, TRI, RBI); +} + bool SPIRVInstructionSelector::selectCmp(Register ResVReg, const SPIRVType *ResType, unsigned CmpOpc, diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index f815487..4b871bd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -149,7 +149,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); // TODO: add proper rules for vectors legalization. - getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); + getActionDefinitionsBuilder( + {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) + .alwaysLegal(); // Vector Reduction Operations getActionDefinitionsBuilder( diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index e5f35aa..d5ed501 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -15,6 +15,7 @@ #include "MCTargetDesc/SPIRVBaseInfo.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/TypedPointerType.h" #include <string> namespace llvm { @@ -100,5 +101,30 @@ bool isEntryPoint(const Function &F); // Parse basic scalar type name, substring TypeName, and return LLVM type. Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx); + +// True if this is an instance of TypedPointerType. +inline bool isTypedPointerTy(const Type *T) { + return T->getTypeID() == Type::TypedPointerTyID; +} + +// True if this is an instance of PointerType. +inline bool isUntypedPointerTy(const Type *T) { + return T->getTypeID() == Type::PointerTyID; +} + +// True if this is an instance of PointerType or TypedPointerType. +inline bool isPointerTy(const Type *T) { + return isUntypedPointerTy(T) || isTypedPointerTy(T); +} + +// Get the address space of this pointer or pointer vector type for instances of +// PointerType or TypedPointerType. +inline unsigned getPointerAddressSpace(const Type *T) { + Type *SubT = T->getScalarType(); + return SubT->getTypeID() == Type::PointerTyID + ? cast<PointerType>(SubT)->getAddressSpace() + : cast<TypedPointerType>(SubT)->getAddressSpace(); +} + } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H |