diff options
author | Vyacheslav Levytskyy <vyacheslav.levytskyy@intel.com> | 2024-06-03 10:34:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-03 10:34:05 +0200 |
commit | ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3 (patch) | |
tree | 7ee133b0b602741a7317072305cbedf6655da696 /llvm/lib/Target/SPIRV | |
parent | 264b1b24869eb45463a98d70e9b9e991092acc28 (diff) | |
download | llvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.zip llvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.tar.gz llvm-ce73e17e3ab5ccfa33a977843e82a9bbfb6b4ce3.tar.bz2 |
[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents (#93661)
This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V
Backend inserts a bitcast before OpGroupWaitEvents if the last argument
is a pointer that doesn't point to OpTypeEvent.
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 99 |
1 files changed, 73 insertions, 26 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index 2bd22bb..5ccbaf1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0u, RC); } +inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) { + SPIRVType *TypeInst = MRI->getVRegDef(OpReg); + return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter + ? TypeInst->getOperand(1).getReg() + : OpReg; +} + +static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, MachineInstr &I, + Register OpReg, unsigned OpIdx, + SPIRVType *NewPtrType) { + Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + MachineIRBuilder MIB(I); + bool Res = MIB.buildInstr(SPIRV::OpBitcast) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(NewPtrType)) + .addUse(OpReg) + .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), + *STI.getRegBankInfo()); + if (!Res) + report_fatal_error("insert validation bitcast: cannot constrain all uses"); + MRI->setRegClass(NewReg, &SPIRV::IDRegClass); + GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); + I.getOperand(OpIdx).setReg(NewReg); +} + +static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, + SPIRVType *OpType, bool ReuseType, + bool EmitIR, SPIRVType *ResType, + const Type *ResTy) { + SPIRV::StorageClass::StorageClass SC = + static_cast<SPIRV::StorageClass::StorageClass>( + OpType->getOperand(1).getImm()); + MachineIRBuilder MIB(I); + SPIRVType *NewBaseType = + ReuseType ? ResType + : GR.getOrCreateSPIRVType( + ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR); + return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); +} + // Insert a bitcast before the instruction to keep SPIR-V code valid // when there is a type mismatch between results and operand types. static void validatePtrTypes(const SPIRVSubtarget &STI, @@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, // Get operand type MachineFunction *MF = I.getParent()->getParent(); Register OpReg = I.getOperand(OpIdx).getReg(); - SPIRVType *TypeInst = MRI->getVRegDef(OpReg); - Register OpTypeReg = - TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter - ? TypeInst->getOperand(1).getReg() - : OpReg; + Register OpTypeReg = getTypeReg(MRI, OpReg); SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) return; @@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI, return; // There is a type mismatch between results and operand types // and we insert a bitcast before the instruction to keep SPIR-V code valid - SPIRV::StorageClass::StorageClass SC = - static_cast<SPIRV::StorageClass::StorageClass>( - OpType->getOperand(1).getImm()); - MachineIRBuilder MIB(I); - SPIRVType *NewBaseType = - IsSameMF ? ResType - : GR.getOrCreateSPIRVType( - ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false); - SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); + SPIRVType *NewPtrType = + createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy); if (!GR.isBitcastCompatible(NewPtrType, OpType)) report_fatal_error( "insert validation bitcast: incompatible result and operand types"); - Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - bool Res = MIB.buildInstr(SPIRV::OpBitcast) - .addDef(NewReg) - .addUse(GR.getSPIRVTypeID(NewPtrType)) - .addUse(OpReg) - .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), - *STI.getRegBankInfo()); - if (!Res) - report_fatal_error("insert validation bitcast: cannot constrain all uses"); - MRI->setRegClass(NewReg, &SPIRV::IDRegClass); - GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); - I.getOperand(OpIdx).setReg(NewReg); + doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); +} + +// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer +// that doesn't point to OpTypeEvent. +static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, + MachineInstr &I) { + constexpr unsigned OpIdx = 2; + MachineFunction *MF = I.getParent()->getParent(); + Register OpReg = I.getOperand(OpIdx).getReg(); + Register OpTypeReg = getTypeReg(MRI, OpReg); + SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); + if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) + return; + SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); + if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent) + return; + // Insert a bitcast before the instruction to keep SPIR-V code valid. + LLVMContext &Context = MF->getMMI().getModule()->getContext(); + SPIRVType *NewPtrType = + createNewPtrType(GR, I, OpType, false, true, nullptr, + TargetExtType::get(Context, "spirv.Event")); + doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); } // Insert a bitcast before the function call instruction to keep SPIR-V code @@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { SPIRV::OpTypeBool)) MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); break; + case SPIRV::OpGroupWaitEvents: + // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> + validateGroupWaitEventsPtr(STI, MRI, GR, MI); + break; case SPIRV::OpConstantI: { SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() && |