diff options
author | Vyacheslav Levytskyy <vyacheslav.levytskyy@intel.com> | 2024-03-25 10:14:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-25 10:14:46 +0100 |
commit | b0d03ccc0855f2bff39160f25fcde06aae07cace (patch) | |
tree | e34aed05ae556e3aa666ed13ead4ed157a986f08 /llvm/lib/Target/SPIRV | |
parent | 1d250d9099a9ba8b53add7eb7db6827e8fc0c8fd (diff) | |
download | llvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.zip llvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.tar.gz llvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.tar.bz2 |
[SPIR-V] Fix illegal OpConstantComposite instruction with non-const constituents in SPIR-V Backend (#86352)
This PR fixes illegal use of OpConstantComposite with non-constant
constituents. The test attached to the PR is able now to satisfy
`spirv-val` check. Before the fix SPIR-V Backend produced for the
attached test case a pattern like
```
%a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123
%11 = OpConstantComposite %_struct_6 %a %a
```
so that `spirv-val` complained with
```
error: line 25: OpConstantComposite Constituent <id> '10[%a]' is not a constant or undef.
%11 = OpConstantComposite %_struct_6 %a %a
```
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 9 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 8 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 90 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 1 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 1 |
7 files changed, 95 insertions, 16 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp index d82fb2df..7c32bb1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp @@ -39,6 +39,7 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph( prebuildReg2Entry(GT, Reg2Entry); prebuildReg2Entry(FT, Reg2Entry); prebuildReg2Entry(AT, Reg2Entry); + prebuildReg2Entry(MT, Reg2Entry); prebuildReg2Entry(ST, Reg2Entry); for (auto &Op2E : Reg2Entry) { diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h index 96cc621..2ec3fb3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -262,6 +262,7 @@ class SPIRVGeneralDuplicatesTracker { SPIRVDuplicatesTracker<GlobalVariable> GT; SPIRVDuplicatesTracker<Function> FT; SPIRVDuplicatesTracker<Argument> AT; + SPIRVDuplicatesTracker<MachineInstr> MT; SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST; // NOTE: using MOs instead of regs to get rid of MF dependency to be able @@ -306,6 +307,10 @@ public: AT.add(Arg, MF, R); } + void add(const MachineInstr *MI, const MachineFunction *MF, Register R) { + MT.add(MI, MF, R); + } + void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF, Register R) { ST.add(TD, MF, R); @@ -337,6 +342,10 @@ public: return AT.find(const_cast<Argument *>(Arg), MF); } + Register find(const MachineInstr *MI, const MachineFunction *MF) { + return MT.find(const_cast<MachineInstr *>(MI), MF); + } + Register find(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF) { return ST.find(TD, MF); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index ee52163..db66ed4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -123,6 +123,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType, MachineIRBuilder &MIRBuilder) { auto EleOpc = ElemType->getOpcode(); + (void)EleOpc; assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || EleOpc == SPIRV::OpTypeBool) && "Invalid vector element type"); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index da480b2..ed0f90f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -94,6 +94,14 @@ public: DT.add(Arg, MF, R); } + void add(const MachineInstr *MI, MachineFunction *MF, Register R) { + DT.add(MI, MF, R); + } + + Register find(const MachineInstr *MI, MachineFunction *MF) { + return DT.find(MI, MF); + } + Register find(const Constant *C, MachineFunction *MF) { return DT.find(C, MF); } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 39228e2..505b19a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -231,6 +231,9 @@ private: Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const; Register buildOnesVal(bool AllOnes, const SPIRVType *ResType, MachineInstr &I) const; + + bool wrapIntoSpecConstantOp(MachineInstr &I, + SmallVector<Register> &CompositeArgs) const; }; } // end anonymous namespace @@ -1249,6 +1252,24 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI, return N; } +// Return true if the type represents a constant register +static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) { + if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE && + OpDef->getOperand(1).isReg()) { + if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg())) + OpDef = RefDef; + } + return OpDef->getOpcode() == TargetOpcode::G_CONSTANT || + OpDef->getOpcode() == TargetOpcode::G_FCONSTANT; +} + +// Return true if the virtual register represents a constant +static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) { + if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) + return isConstReg(MRI, OpDef); + return false; +} + bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -1266,16 +1287,7 @@ bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg, // check if we may construct a constant vector Register OpReg = I.getOperand(OpIdx).getReg(); - bool IsConst = false; - if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) { - if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE && - OpDef->getOperand(1).isReg()) { - if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg())) - OpDef = RefDef; - } - IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT || - OpDef->getOpcode() == TargetOpcode::G_FCONSTANT; - } + bool IsConst = isConstReg(MRI, OpReg); if (!IsConst && N < 2) report_fatal_error( @@ -1628,6 +1640,48 @@ bool SPIRVInstructionSelector::selectGEP(Register ResVReg, return Res.constrainAllUses(TII, TRI, RBI); } +// Maybe wrap a value into OpSpecConstantOp +bool SPIRVInstructionSelector::wrapIntoSpecConstantOp( + MachineInstr &I, SmallVector<Register> &CompositeArgs) const { + bool Result = true; + unsigned Lim = I.getNumExplicitOperands(); + for (unsigned i = I.getNumExplicitDefs() + 1; i < Lim; ++i) { + Register OpReg = I.getOperand(i).getReg(); + SPIRVType *OpDefine = MRI->getVRegDef(OpReg); + SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg); + if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) || + OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) { + // The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed + // by selectAddrSpaceCast() + CompositeArgs.push_back(OpReg); + continue; + } + MachineFunction *MF = I.getMF(); + Register WrapReg = GR.find(OpDefine, MF); + if (WrapReg.isValid()) { + CompositeArgs.push_back(WrapReg); + continue; + } + // Create a new register for the wrapper + WrapReg = MRI->createVirtualRegister(&SPIRV::IDRegClass); + GR.add(OpDefine, MF, WrapReg); + CompositeArgs.push_back(WrapReg); + // Decorate the wrapper register and generate a new instruction + MRI->setType(WrapReg, LLT::pointer(0, 32)); + GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF); + MachineBasicBlock &BB = *I.getParent(); + Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp)) + .addDef(WrapReg) + .addUse(GR.getSPIRVTypeID(OpType)) + .addImm(static_cast<uint32_t>(SPIRV::Opcode::Bitcast)) + .addUse(OpReg) + .constrainAllUses(TII, TRI, RBI); + if (!Result) + break; + } + return Result; +} + bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const { @@ -1666,17 +1720,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_const_composite: { // If no values are attached, the composite is null constant. bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands(); - unsigned Opcode = - IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite; + // Select a proper instruction. + unsigned Opcode = SPIRV::OpConstantNull; + SmallVector<Register> CompositeArgs; + if (!IsNull) { + Opcode = SPIRV::OpConstantComposite; + if (!wrapIntoSpecConstantOp(I, CompositeArgs)) + return false; + } auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)); // skip type MD node we already used when generated assign.type for this if (!IsNull) { - for (unsigned i = I.getNumExplicitDefs() + 1; - i < I.getNumExplicitOperands(); ++i) { - MIB.addUse(I.getOperand(i).getReg()); - } + for (Register OpReg : CompositeArgs) + MIB.addUse(OpReg); } return MIB.constrainAllUses(TII, TRI, RBI); } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index d547f91..1f0d8d8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -543,6 +543,7 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, Register Dst = ICMP->getOperand(0).getReg(); MachineOperand &PredOp = ICMP->getOperand(1); const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate()); + (void)CC; assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) && MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg)); uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI); diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 8dbbd90..ff102e3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -1611,3 +1611,4 @@ multiclass OpcodeOperand<bits<32> value> { // TODO: implement other mnemonics. defm InBoundsPtrAccessChain : OpcodeOperand<70>; defm PtrCastToGeneric : OpcodeOperand<121>; +defm Bitcast : OpcodeOperand<124>; |