aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
authorVyacheslav Levytskyy <vyacheslav.levytskyy@intel.com>2024-03-13 08:32:01 +0100
committerGitHub <noreply@github.com>2024-03-13 08:32:01 +0100
commit0a443f13b49b3f392461a0bb60b0146cfc4607c7 (patch)
tree960853562f89e2452d82cd9879e03b153def0edc /llvm/lib/Target/SPIRV
parentcd2f6163137dce45d909aa445cfd57b7188f8ed1 (diff)
downloadllvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.zip
llvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.tar.gz
llvm-0a443f13b49b3f392461a0bb60b0146cfc4607c7.tar.bz2
[SPIR-V] Add implementation of G_SPLAT_VECTOR opcode and fix invalid types processing (#84766)
This PR: * adds support for G_SPLAT_VECTOR generic opcode that may be legally generated instead of G_BUILD_VECTOR by previous passes of the translator (see https://github.com/llvm/llvm-project/pull/80378 for the source of breaking changes); * improves deduction of types for opaque pointers. This PR also fixes the following issues: * if a function has ptr argument(s), two functions that have different SPIR-V type definitions may get identical LLVM function types and break agreements of global register and duplicate checker; * checks for pointer types do not account for TypedPointerType. Update of tests: * A test case is added to cover the issue with function ptr parameters. * The first case, that is support for G_SPLAT_VECTOR generic opcode, is covered by existing test cases. * Multiple additional checks by `spirv-val` is added to cover more possibilities of generation of invalid code.
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp49
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp136
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp32
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h3
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp41
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp4
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.h26
7 files changed, 236 insertions, 55 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 2d7a00b..f1fbe2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -85,6 +85,42 @@ static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
return nullptr;
}
+// If the function has pointer arguments, we are forced to re-create this
+// function type from the very beginning, changing PointerType by
+// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
+// potentially corresponds to different SPIR-V function type, effectively
+// invalidating logic behind global registry and duplicates tracker.
+static FunctionType *
+fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
+ FunctionType *FTy, const SPIRVType *SRetTy,
+ const SmallVector<SPIRVType *, 4> &SArgTys) {
+ if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
+ return FTy;
+
+ bool hasArgPtrs = false;
+ for (auto &Arg : F.args()) {
+ // check if it's an instance of a non-typed PointerType
+ if (Arg.getType()->isPointerTy()) {
+ hasArgPtrs = true;
+ break;
+ }
+ }
+ if (!hasArgPtrs) {
+ Type *RetTy = FTy->getReturnType();
+ // check if it's an instance of a non-typed PointerType
+ if (!RetTy->isPointerTy())
+ return FTy;
+ }
+
+ // re-create function type, using TypedPointerType instead of PointerType to
+ // properly trace argument types
+ const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
+ SmallVector<Type *, 4> ArgTys;
+ for (auto SArgTy : SArgTys)
+ ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
+ return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
+}
+
// This code restores function args/retvalue types for composite cases
// because the final types should still be aggregate whereas they're i32
// during the translation to cope with aggregate flattening etc.
@@ -162,7 +198,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
- if (!OriginalArgType->isPointerTy())
+ if (!isPointerTy(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
// In case OriginalArgType is of pointer type, there are three possibilities:
@@ -179,8 +215,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
- addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
- ST));
+ addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
}
for (auto User : Arg->users()) {
@@ -240,7 +275,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
// Assign types and names to all args, and store their types for later.
- FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
if (VRegs.size() > 0) {
unsigned i = 0;
@@ -255,7 +289,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (Arg.hasName())
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
- if (Arg.getType()->isPointerTy()) {
+ if (isPointerTy(Arg.getType())) {
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
if (DerefBytes != 0)
buildOpDecorate(VRegs[i][0], MIRBuilder,
@@ -322,7 +356,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
+ FunctionType *FTy = getOriginalFunctionType(F);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+ FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
uint32_t FuncControl = getFunctionControl(F);
@@ -429,7 +465,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return false;
MachineFunction &MF = MIRBuilder.getMF();
GR->setCurrentFunc(MF);
- FunctionType *FTy = nullptr;
const Function *CF = nullptr;
std::string DemangledName;
const Type *OrigRetTy = Info.OrigRet.Ty;
@@ -444,7 +479,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
- if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
+ if (FunctionType *FTy = getOriginalFunctionType(*CF))
OrigRetTy = FTy->getReturnType();
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 575e903..c5b9012 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -57,8 +57,14 @@ class SPIRVEmitIntrinsics
bool TrackConstants = true;
DenseMap<Instruction *, Constant *> AggrConsts;
DenseSet<Instruction *> AggrStores;
+
+ // deduce values type
+ DenseMap<Value *, Type *> DeducedElTys;
+ Type *deduceElementType(Value *I);
+
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
+
CallInst *buildIntrWithMD(Intrinsic::ID IntrID, ArrayRef<Type *> Types,
Value *Arg, Value *Arg2, ArrayRef<Constant *> Imms,
IRBuilder<> &B) {
@@ -72,6 +78,7 @@ class SPIRVEmitIntrinsics
Args.push_back(Imm);
return B.CreateIntrinsic(IntrID, {Types}, Args);
}
+
void replaceMemInstrUses(Instruction *Old, Instruction *New, IRBuilder<> &B);
void processInstrAfterVisit(Instruction *I, IRBuilder<> &B);
void insertAssignPtrTypeIntrs(Instruction *I, IRBuilder<> &B);
@@ -156,6 +163,48 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
false);
}
+// Deduce and return a successfully deduced Type of the Instruction,
+// or nullptr otherwise.
+static Type *deduceElementTypeHelper(Value *I,
+ std::unordered_set<Value *> &Visited,
+ DenseMap<Value *, Type *> &DeducedElTys) {
+ // maybe already known
+ auto It = DeducedElTys.find(I);
+ if (It != DeducedElTys.end())
+ return It->second;
+
+ // maybe a cycle
+ if (Visited.find(I) != Visited.end())
+ return nullptr;
+ Visited.insert(I);
+
+ // fallback value in case when we fail to deduce a type
+ Type *Ty = nullptr;
+ // look for known basic patterns of type inference
+ if (auto *Ref = dyn_cast<AllocaInst>(I))
+ Ty = Ref->getAllocatedType();
+ else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
+ Ty = Ref->getResultElementType();
+ else if (auto *Ref = dyn_cast<GlobalValue>(I))
+ Ty = Ref->getValueType();
+ else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
+ Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
+ DeducedElTys);
+
+ // remember the found relationship
+ if (Ty)
+ DeducedElTys[I] = Ty;
+
+ return Ty;
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+ std::unordered_set<Value *> Visited;
+ if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
+ return Ty;
+ return IntegerType::getInt8Ty(I->getContext());
+}
+
void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
Instruction *New,
IRBuilder<> &B) {
@@ -280,7 +329,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
// varying element types. In case of IR coming from older versions of LLVM
// such bitcasts do not provide sufficient information, should be just skipped
// here, and handled in insertPtrCastOrAssignTypeInstr.
- if (I.getType()->isPointerTy()) {
+ if (isPointerTy(I.getType())) {
I.replaceAllUsesWith(Source);
I.eraseFromParent();
return nullptr;
@@ -333,20 +382,10 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
Pointer = BC->getOperand(0);
- // Do not emit spv_ptrcast if Pointer is a GlobalValue of expected type.
- GlobalValue *GV = dyn_cast<GlobalValue>(Pointer);
- if (GV && GV->getValueType() == ExpectedElementType)
- return;
-
- // Do not emit spv_ptrcast if Pointer is a result of alloca with expected
- // type.
- AllocaInst *A = dyn_cast<AllocaInst>(Pointer);
- if (A && A->getAllocatedType() == ExpectedElementType)
- return;
-
- // Do not emit spv_ptrcast if Pointer is a result of GEP of expected type.
- GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Pointer);
- if (GEPI && GEPI->getResultElementType() == ExpectedElementType)
+ // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
+ std::unordered_set<Value *> Visited;
+ Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
+ if (PointerElemTy == ExpectedElementType)
return;
setInsertPointSkippingPhis(B, I);
@@ -356,7 +395,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
ValueAsMetadata::getConstant(ExpectedElementTypeConst);
MDTuple *TyMD = MDNode::get(F->getContext(), CM);
MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
- unsigned AddressSpace = Pointer->getType()->getPointerAddressSpace();
+ unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
bool FirstPtrCastOrAssignPtrType = true;
// Do not emit new spv_ptrcast if equivalent one already exists or when
@@ -401,9 +440,11 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
// spv_assign_ptr_type instead.
if (FirstPtrCastOrAssignPtrType &&
(isa<Instruction>(Pointer) || isa<Argument>(Pointer))) {
- buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
- ExpectedElementTypeConst, Pointer,
- {B.getInt32(AddressSpace)}, B);
+ CallInst *CI = buildIntrWithMD(
+ Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
+ ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
+ DeducedElTys[CI] = ExpectedElementType;
+ DeducedElTys[Pointer] = ExpectedElementType;
return;
}
@@ -419,7 +460,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
// Handle basic instructions:
StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
- SI->getValueOperand()->getType()->isPointerTy() &&
+ isPointerTy(SI->getValueOperand()->getType()) &&
isa<Argument>(SI->getValueOperand())) {
return replacePointerOperandWithPtrCast(
I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
@@ -440,9 +481,34 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
return;
+ // collect information about formal parameter types
+ Function *CalledF = CI->getCalledFunction();
+ SmallVector<Type *, 4> CalledArgTys;
+ bool HaveTypes = false;
+ for (auto &CalledArg : CalledF->args()) {
+ if (!isPointerTy(CalledArg.getType())) {
+ CalledArgTys.push_back(nullptr);
+ continue;
+ }
+ auto It = DeducedElTys.find(&CalledArg);
+ Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
+ if (!ParamTy) {
+ for (User *U : CalledArg.users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+ std::unordered_set<Value *> Visited;
+ ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
+ if (ParamTy)
+ break;
+ }
+ }
+ }
+ HaveTypes |= ParamTy != nullptr;
+ CalledArgTys.push_back(ParamTy);
+ }
+
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
- if (DemangledName.empty())
+ if (DemangledName.empty() && !HaveTypes)
return;
for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
@@ -455,8 +521,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
continue;
- Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
- DemangledName, OpIdx, I->getContext());
+ Type *ExpectedType =
+ OpIdx < CalledArgTys.size() ? CalledArgTys[OpIdx] : nullptr;
+ if (!ExpectedType && !DemangledName.empty())
+ ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
+ DemangledName, OpIdx, I->getContext());
if (!ExpectedType)
continue;
@@ -639,30 +708,25 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
IRBuilder<> &B) {
reportFatalOnTokenType(I);
- if (!I->getType()->isPointerTy() || !requireAssignType(I) ||
+ if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
isa<BitCastInst>(I))
return;
setInsertPointSkippingPhis(B, I->getNextNode());
- Constant *EltTyConst;
- unsigned AddressSpace = I->getType()->getPointerAddressSpace();
- if (auto *AI = dyn_cast<AllocaInst>(I))
- EltTyConst = UndefValue::get(AI->getAllocatedType());
- else if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
- EltTyConst = UndefValue::get(GEP->getResultElementType());
- else
- EltTyConst = UndefValue::get(IntegerType::getInt8Ty(I->getContext()));
-
- buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()}, EltTyConst, I,
- {B.getInt32(AddressSpace)}, B);
+ Type *ElemTy = deduceElementType(I);
+ Constant *EltTyConst = UndefValue::get(ElemTy);
+ unsigned AddressSpace = getPointerAddressSpace(I->getType());
+ CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
+ EltTyConst, I, {B.getInt32(AddressSpace)}, B);
+ DeducedElTys[CI] = ElemTy;
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
IRBuilder<> &B) {
reportFatalOnTokenType(I);
Type *Ty = I->getType();
- if (!Ty->isVoidTy() && !Ty->isPointerTy() && requireAssignType(I)) {
+ if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
setInsertPointSkippingPhis(B, I->getNextNode());
Type *TypeToAssign = Ty;
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 8556581..bda9c57 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -750,7 +750,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
- if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
+ if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
return nullptr;
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
@@ -762,11 +762,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
// will be added later. For special types it is already added to DT.
if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
!isSpecialOpaqueType(Ty)) {
- if (!Ty->isPointerTy())
+ if (!isPointerTy(Ty))
DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
+ else if (isTypedPointerTy(Ty))
+ DT.add(cast<TypedPointerType>(Ty)->getElementType(),
+ getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
+ getSPIRVTypeID(SpirvType));
else
DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
- Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
+ getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
getSPIRVTypeID(SpirvType));
}
@@ -787,12 +791,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
- if (!Ty->isPointerTy())
+ if (!isPointerTy(Ty))
Reg = DT.find(Ty, &MIRBuilder.getMF());
+ else if (isTypedPointerTy(Ty))
+ Reg = DT.find(cast<TypedPointerType>(Ty)->getElementType(),
+ getPointerAddressSpace(Ty), &MIRBuilder.getMF());
else
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
- Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
+ getPointerAddressSpace(Ty), &MIRBuilder.getMF());
if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
@@ -836,11 +843,16 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
- if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
- return Type->getOpcode() == SPIRV::OpTypeVector
- ? static_cast<unsigned>(Type->getOperand(2).getImm())
- : 1;
- return 0;
+ return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
+}
+
+unsigned
+SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
+ if (!Type)
+ return 0;
+ return Type->getOpcode() == SPIRV::OpTypeVector
+ ? static_cast<unsigned>(Type->getOperand(2).getImm())
+ : 1;
}
unsigned
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 9c0061d..25d82eb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -198,9 +198,10 @@ public:
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
- // Return number of elements in a vector if the given VReg is associated with
+ // Return number of elements in a vector if the argument is associated with
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
unsigned getScalarOrVectorComponentCount(Register VReg) const;
+ unsigned getScalarOrVectorComponentCount(SPIRVType *Type) const;
// For vectors or scalars of booleans, integers and floats, return the scalar
// type's bitwidth. Otherwise calls llvm_unreachable().
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 74df8de..fd19b74 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -125,6 +125,8 @@ private:
bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
bool selectCmp(Register ResVReg, const SPIRVType *ResType,
unsigned comparisonOpcode, MachineInstr &I) const;
@@ -313,6 +315,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_BUILD_VECTOR:
return selectConstVector(ResVReg, ResType, I);
+ case TargetOpcode::G_SPLAT_VECTOR:
+ return selectSplatVector(ResVReg, ResType, I);
case TargetOpcode::G_SHUFFLE_VECTOR: {
MachineBasicBlock &BB = *I.getParent();
@@ -1185,6 +1189,43 @@ bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ if (ResType->getOpcode() != SPIRV::OpTypeVector)
+ report_fatal_error("Cannot select G_SPLAT_VECTOR with a non-vector result");
+ unsigned N = GR.getScalarOrVectorComponentCount(ResType);
+ unsigned OpIdx = I.getNumExplicitDefs();
+ if (!I.getOperand(OpIdx).isReg())
+ report_fatal_error("Unexpected argument in G_SPLAT_VECTOR");
+
+ // check if we may construct a constant vector
+ Register OpReg = I.getOperand(OpIdx).getReg();
+ bool IsConst = false;
+ if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
+ if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
+ OpDef->getOperand(1).isReg()) {
+ if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
+ OpDef = RefDef;
+ }
+ IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
+ OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+ }
+
+ if (!IsConst && N < 2)
+ report_fatal_error(
+ "There must be at least two constituent operands in a vector");
+
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(IsConst ? SPIRV::OpConstantComposite
+ : SPIRV::OpCompositeConstruct))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType));
+ for (unsigned i = 0; i < N; ++i)
+ MIB.addUse(OpReg);
+ return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectCmp(Register ResVReg,
const SPIRVType *ResType,
unsigned CmpOpc,
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index f815487..4b871bd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -149,7 +149,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
// TODO: add proper rules for vectors legalization.
- getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
+ getActionDefinitionsBuilder(
+ {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
+ .alwaysLegal();
// Vector Reduction Operations
getActionDefinitionsBuilder(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index e5f35aa..d5ed501 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -15,6 +15,7 @@
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/TypedPointerType.h"
#include <string>
namespace llvm {
@@ -100,5 +101,30 @@ bool isEntryPoint(const Function &F);
// Parse basic scalar type name, substring TypeName, and return LLVM type.
Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx);
+
+// True if this is an instance of TypedPointerType.
+inline bool isTypedPointerTy(const Type *T) {
+ return T->getTypeID() == Type::TypedPointerTyID;
+}
+
+// True if this is an instance of PointerType.
+inline bool isUntypedPointerTy(const Type *T) {
+ return T->getTypeID() == Type::PointerTyID;
+}
+
+// True if this is an instance of PointerType or TypedPointerType.
+inline bool isPointerTy(const Type *T) {
+ return isUntypedPointerTy(T) || isTypedPointerTy(T);
+}
+
+// Get the address space of this pointer or pointer vector type for instances of
+// PointerType or TypedPointerType.
+inline unsigned getPointerAddressSpace(const Type *T) {
+ Type *SubT = T->getScalarType();
+ return SubT->getTypeID() == Type::PointerTyID
+ ? cast<PointerType>(SubT)->getAddressSpace()
+ : cast<TypedPointerType>(SubT)->getAddressSpace();
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H