aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
authorVyacheslav Levytskyy <vyacheslav.levytskyy@intel.com>2024-03-25 10:14:46 +0100
committerGitHub <noreply@github.com>2024-03-25 10:14:46 +0100
commitb0d03ccc0855f2bff39160f25fcde06aae07cace (patch)
treee34aed05ae556e3aa666ed13ead4ed157a986f08 /llvm/lib/Target/SPIRV
parent1d250d9099a9ba8b53add7eb7db6827e8fc0c8fd (diff)
downloadllvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.zip
llvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.tar.gz
llvm-b0d03ccc0855f2bff39160f25fcde06aae07cace.tar.bz2
[SPIR-V] Fix illegal OpConstantComposite instruction with non-const constituents in SPIR-V Backend (#86352)
This PR fixes illegal use of OpConstantComposite with non-constant constituents. The test attached to the PR is able now to satisfy `spirv-val` check. Before the fix SPIR-V Backend produced for the attached test case a pattern like ``` %a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123 %11 = OpConstantComposite %_struct_6 %a %a ``` so that `spirv-val` complained with ``` error: line 25: OpConstantComposite Constituent <id> '10[%a]' is not a constant or undef. %11 = OpConstantComposite %_struct_6 %a %a ```
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h9
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h8
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp90
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp1
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td1
7 files changed, 95 insertions, 16 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index d82fb2df..7c32bb1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -39,6 +39,7 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
prebuildReg2Entry(GT, Reg2Entry);
prebuildReg2Entry(FT, Reg2Entry);
prebuildReg2Entry(AT, Reg2Entry);
+ prebuildReg2Entry(MT, Reg2Entry);
prebuildReg2Entry(ST, Reg2Entry);
for (auto &Op2E : Reg2Entry) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 96cc621..2ec3fb3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -262,6 +262,7 @@ class SPIRVGeneralDuplicatesTracker {
SPIRVDuplicatesTracker<GlobalVariable> GT;
SPIRVDuplicatesTracker<Function> FT;
SPIRVDuplicatesTracker<Argument> AT;
+ SPIRVDuplicatesTracker<MachineInstr> MT;
SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -306,6 +307,10 @@ public:
AT.add(Arg, MF, R);
}
+ void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
+ MT.add(MI, MF, R);
+ }
+
void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
Register R) {
ST.add(TD, MF, R);
@@ -337,6 +342,10 @@ public:
return AT.find(const_cast<Argument *>(Arg), MF);
}
+ Register find(const MachineInstr *MI, const MachineFunction *MF) {
+ return MT.find(const_cast<MachineInstr *>(MI), MF);
+ }
+
Register find(const SPIRV::SpecialTypeDescriptor &TD,
const MachineFunction *MF) {
return ST.find(TD, MF);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ee52163..db66ed4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -123,6 +123,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
auto EleOpc = ElemType->getOpcode();
+ (void)EleOpc;
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index da480b2..ed0f90f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,6 +94,14 @@ public:
DT.add(Arg, MF, R);
}
+ void add(const MachineInstr *MI, MachineFunction *MF, Register R) {
+ DT.add(MI, MF, R);
+ }
+
+ Register find(const MachineInstr *MI, MachineFunction *MF) {
+ return DT.find(MI, MF);
+ }
+
Register find(const Constant *C, MachineFunction *MF) {
return DT.find(C, MF);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 39228e2..505b19a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -231,6 +231,9 @@ private:
Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
MachineInstr &I) const;
+
+ bool wrapIntoSpecConstantOp(MachineInstr &I,
+ SmallVector<Register> &CompositeArgs) const;
};
} // end anonymous namespace
@@ -1249,6 +1252,24 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
return N;
}
+// Return true if the type represents a constant register
+static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) {
+ if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
+ OpDef->getOperand(1).isReg()) {
+ if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
+ OpDef = RefDef;
+ }
+ return OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
+ OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+}
+
+// Return true if the virtual register represents a constant
+static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
+ if (SPIRVType *OpDef = MRI->getVRegDef(OpReg))
+ return isConstReg(MRI, OpDef);
+ return false;
+}
+
bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1266,16 +1287,7 @@ bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
// check if we may construct a constant vector
Register OpReg = I.getOperand(OpIdx).getReg();
- bool IsConst = false;
- if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
- if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
- OpDef->getOperand(1).isReg()) {
- if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
- OpDef = RefDef;
- }
- IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
- OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
- }
+ bool IsConst = isConstReg(MRI, OpReg);
if (!IsConst && N < 2)
report_fatal_error(
@@ -1628,6 +1640,48 @@ bool SPIRVInstructionSelector::selectGEP(Register ResVReg,
return Res.constrainAllUses(TII, TRI, RBI);
}
+// Maybe wrap a value into OpSpecConstantOp
+bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
+ MachineInstr &I, SmallVector<Register> &CompositeArgs) const {
+ bool Result = true;
+ unsigned Lim = I.getNumExplicitOperands();
+ for (unsigned i = I.getNumExplicitDefs() + 1; i < Lim; ++i) {
+ Register OpReg = I.getOperand(i).getReg();
+ SPIRVType *OpDefine = MRI->getVRegDef(OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+ if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) ||
+ OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
+ // The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed
+ // by selectAddrSpaceCast()
+ CompositeArgs.push_back(OpReg);
+ continue;
+ }
+ MachineFunction *MF = I.getMF();
+ Register WrapReg = GR.find(OpDefine, MF);
+ if (WrapReg.isValid()) {
+ CompositeArgs.push_back(WrapReg);
+ continue;
+ }
+ // Create a new register for the wrapper
+ WrapReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ GR.add(OpDefine, MF, WrapReg);
+ CompositeArgs.push_back(WrapReg);
+ // Decorate the wrapper register and generate a new instruction
+ MRI->setType(WrapReg, LLT::pointer(0, 32));
+ GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF);
+ MachineBasicBlock &BB = *I.getParent();
+ Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
+ .addDef(WrapReg)
+ .addUse(GR.getSPIRVTypeID(OpType))
+ .addImm(static_cast<uint32_t>(SPIRV::Opcode::Bitcast))
+ .addUse(OpReg)
+ .constrainAllUses(TII, TRI, RBI);
+ if (!Result)
+ break;
+ }
+ return Result;
+}
+
bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1666,17 +1720,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_const_composite: {
// If no values are attached, the composite is null constant.
bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands();
- unsigned Opcode =
- IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite;
+ // Select a proper instruction.
+ unsigned Opcode = SPIRV::OpConstantNull;
+ SmallVector<Register> CompositeArgs;
+ if (!IsNull) {
+ Opcode = SPIRV::OpConstantComposite;
+ if (!wrapIntoSpecConstantOp(I, CompositeArgs))
+ return false;
+ }
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
// skip type MD node we already used when generated assign.type for this
if (!IsNull) {
- for (unsigned i = I.getNumExplicitDefs() + 1;
- i < I.getNumExplicitOperands(); ++i) {
- MIB.addUse(I.getOperand(i).getReg());
- }
+ for (Register OpReg : CompositeArgs)
+ MIB.addUse(OpReg);
}
return MIB.constrainAllUses(TII, TRI, RBI);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91..1f0d8d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -543,6 +543,7 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Dst = ICMP->getOperand(0).getReg();
MachineOperand &PredOp = ICMP->getOperand(1);
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
+ (void)CC;
assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 8dbbd90..ff102e3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -1611,3 +1611,4 @@ multiclass OpcodeOperand<bits<32> value> {
// TODO: implement other mnemonics.
defm InBoundsPtrAccessChain : OpcodeOperand<70>;
defm PtrCastToGeneric : OpcodeOperand<121>;
+defm Bitcast : OpcodeOperand<124>;