aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
authorVyacheslav Levytskyy <vyacheslav.levytskyy@intel.com>2024-06-03 10:34:05 +0200
committerGitHub <noreply@github.com>2024-06-03 10:34:05 +0200
commitce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3 (patch)
tree7ee133b0b602741a7317072305cbedf6655da696 /llvm/lib/Target/SPIRV
parent264b1b24869eb45463a98d70e9b9e991092acc28 (diff)
downloadllvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.zip
llvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.tar.gz
llvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.tar.bz2
[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)
This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent.
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp99
1 files changed, 73 insertions, 26 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2bd22bb..5ccbaf1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0u, RC);
}
+inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
+ SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+ return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+ ? TypeInst->getOperand(1).getReg()
+ : OpReg;
+}
+
+static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &I,
+ Register OpReg, unsigned OpIdx,
+ SPIRVType *NewPtrType) {
+ Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ MachineIRBuilder MIB(I);
+ bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+ .addDef(NewReg)
+ .addUse(GR.getSPIRVTypeID(NewPtrType))
+ .addUse(OpReg)
+ .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
+ if (!Res)
+ report_fatal_error("insert validation bitcast: cannot constrain all uses");
+ MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+ GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+ I.getOperand(OpIdx).setReg(NewReg);
+}
+
+static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
+ SPIRVType *OpType, bool ReuseType,
+ bool EmitIR, SPIRVType *ResType,
+ const Type *ResTy) {
+ SPIRV::StorageClass::StorageClass SC =
+ static_cast<SPIRV::StorageClass::StorageClass>(
+ OpType->getOperand(1).getImm());
+ MachineIRBuilder MIB(I);
+ SPIRVType *NewBaseType =
+ ReuseType ? ResType
+ : GR.getOrCreateSPIRVType(
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+}
+
// Insert a bitcast before the instruction to keep SPIR-V code valid
// when there is a type mismatch between results and operand types.
static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
// Get operand type
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
- SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- Register OpTypeReg =
- TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
- ? TypeInst->getOperand(1).getReg()
- : OpReg;
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
- SPIRV::StorageClass::StorageClass SC =
- static_cast<SPIRV::StorageClass::StorageClass>(
- OpType->getOperand(1).getImm());
- MachineIRBuilder MIB(I);
- SPIRVType *NewBaseType =
- IsSameMF ? ResType
- : GR.getOrCreateSPIRVType(
- ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
- bool Res = MIB.buildInstr(SPIRV::OpBitcast)
- .addDef(NewReg)
- .addUse(GR.getSPIRVTypeID(NewPtrType))
- .addUse(OpReg)
- .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
- *STI.getRegBankInfo());
- if (!Res)
- report_fatal_error("insert validation bitcast: cannot constrain all uses");
- MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
- GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
- I.getOperand(OpIdx).setReg(NewReg);
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
+// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
+// that doesn't point to OpTypeEvent.
+static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &I) {
+ constexpr unsigned OpIdx = 2;
+ MachineFunction *MF = I.getParent()->getParent();
+ Register OpReg = I.getOperand(OpIdx).getReg();
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+ if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+ return;
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
+ return;
+ // Insert a bitcast before the instruction to keep SPIR-V code valid.
+ LLVMContext &Context = MF->getMMI().getModule()->getContext();
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, false, true, nullptr,
+ TargetExtType::get(Context, "spirv.Event"));
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
// Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpGroupWaitEvents:
+ // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
+ validateGroupWaitEventsPtr(STI, MRI, GR, MI);
+ break;
case SPIRV::OpConstantI: {
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&