//===-- SPIRVPostLegalizer.cpp - amend info after legalization -*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // The pass partially applies pre-legalization logic to new instructions // inserted as a result of legalization: // - assigns SPIR-V types to registers for new instructions. // - inserts ASSIGN_TYPE pseudo-instructions required for type folding. // //===----------------------------------------------------------------------===// #include "SPIRV.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "spirv-postlegalizer" using namespace llvm; namespace { class SPIRVPostLegalizer : public MachineFunctionPass { public: static char ID; SPIRVPostLegalizer() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; }; } // namespace namespace llvm { // Defined in SPIRVPreLegalizer.cpp. extern void updateRegType(Register Reg, Type *Ty, SPIRVType *SpirvTy, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, MachineRegisterInfo &MRI); extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR, SPIRVType *KnownResType); } // namespace llvm static SPIRVType *deduceIntTypeFromResult(Register ResVReg, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR) { const LLT &Ty = MIB.getMRI()->getType(ResVReg); return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB); } static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR, unsigned OpIdx) { Register OpReg = I->getOperand(OpIdx).getReg(); if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) { if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) { Register ResVReg = I->getOperand(0).getReg(); const LLT &ResLLT = MIB.getMRI()->getType(ResVReg); if (ResLLT.isVector()) return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(), MIB, false); return CompType; } } return nullptr; } static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR, unsigned StartOp, unsigned EndOp) { SPIRVType *ResType = nullptr; for (unsigned i = StartOp; i < EndOp; ++i) { if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i)) { #ifdef EXPENSIVE_CHECKS assert(!ResType || Type == ResType && "Conflicting type from operands."); ResType = Type; #else return Type; #endif } } return ResType; } static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use, Register UseRegister, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { for (const MachineOperand &MO : Use->defs()) { if (!MO.isReg()) continue; if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) { if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) { const LLT &ResLLT = MIB.getMRI()->getType(UseRegister); if (ResLLT.isVector()) return GR->getOrCreateSPIRVVectorType( CompType, ResLLT.getNumElements(), MIB, false); return CompType; } } } return nullptr; } static SPIRVType *deducePointerTypeFromResultRegister(MachineInstr *Use, Register UseRegister, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { assert(Use->getOpcode() == TargetOpcode::G_LOAD || Use->getOpcode() == TargetOpcode::G_STORE); Register ValueReg = Use->getOperand(0).getReg(); SPIRVType *ValueType = GR->getSPIRVTypeForVReg(ValueReg); if (!ValueType) return nullptr; return GR->getOrCreateSPIRVPointerType(ValueType, MIB, SPIRV::StorageClass::Function); } static SPIRVType *deduceTypeFromPointerOperand(MachineInstr *Use, Register UseRegister, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { assert(Use->getOpcode() == TargetOpcode::G_LOAD || Use->getOpcode() == TargetOpcode::G_STORE); Register PtrReg = Use->getOperand(1).getReg(); SPIRVType *PtrType = GR->getSPIRVTypeForVReg(PtrReg); if (!PtrType) return nullptr; return GR->getPointeeType(PtrType); } static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { MachineRegisterInfo &MRI = MF.getRegInfo(); for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) { SPIRVType *ResType = nullptr; LLVM_DEBUG(dbgs() << "Looking at use " << Use); switch (Use.getOpcode()) { case TargetOpcode::G_BUILD_VECTOR: case TargetOpcode::G_EXTRACT_VECTOR_ELT: case TargetOpcode::G_UNMERGE_VALUES: case TargetOpcode::G_ADD: case TargetOpcode::G_SUB: case TargetOpcode::G_MUL: case TargetOpcode::G_SDIV: case TargetOpcode::G_UDIV: case TargetOpcode::G_SREM: case TargetOpcode::G_UREM: case TargetOpcode::G_FADD: case TargetOpcode::G_FSUB: case TargetOpcode::G_FMUL: case TargetOpcode::G_FDIV: case TargetOpcode::G_FREM: case TargetOpcode::G_FMA: case TargetOpcode::COPY: case TargetOpcode::G_STRICT_FMA: ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB); break; case TargetOpcode::G_LOAD: case TargetOpcode::G_STORE: if (Reg == Use.getOperand(1).getReg()) ResType = deducePointerTypeFromResultRegister(&Use, Reg, GR, MIB); else ResType = deduceTypeFromPointerOperand(&Use, Reg, GR, MIB); break; case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: case TargetOpcode::G_INTRINSIC: { auto IntrinsicID = cast(Use).getIntrinsicID(); if (IntrinsicID == Intrinsic::spv_insertelt) { if (Reg == Use.getOperand(2).getReg()) ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB); } else if (IntrinsicID == Intrinsic::spv_extractelt) { if (Reg == Use.getOperand(2).getReg()) ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB); } break; } } if (ResType) { LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType); return ResType; } } return nullptr; } static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { Register ResVReg = I->getOperand(0).getReg(); switch (I->getOpcode()) { case TargetOpcode::G_CONSTANT: case TargetOpcode::G_ANYEXT: case TargetOpcode::G_SEXT: case TargetOpcode::G_ZEXT: return deduceIntTypeFromResult(ResVReg, MIB, GR); case TargetOpcode::G_BUILD_VECTOR: return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands()); case TargetOpcode::G_SHUFFLE_VECTOR: return deduceTypeFromOperandRange(I, MIB, GR, 1, 3); default: if (I->getNumDefs() == 1 && I->getNumOperands() > 1 && I->getOperand(1).isReg()) return deduceTypeFromSingleOperand(I, MIB, GR, 1); return nullptr; } } static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { MachineRegisterInfo &MRI = MF.getRegInfo(); Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); SPIRVType *ScalarType = nullptr; if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { assert(DefType->getOpcode() == SPIRV::OpTypeVector); ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); } if (!ScalarType) { // If we could not deduce the type from the source, try to deduce it from // the uses of the results. for (unsigned i = 0; i < I->getNumDefs(); ++i) { Register DefReg = I->getOperand(i).getReg(); ScalarType = deduceTypeFromUses(DefReg, MF, GR, MIB); if (ScalarType) { ScalarType = GR->getScalarOrVectorComponentType(ScalarType); break; } } } if (!ScalarType) return false; for (unsigned i = 0; i < I->getNumOperands(); ++i) { Register DefReg = I->getOperand(i).getReg(); if (GR->getSPIRVTypeForVReg(DefReg)) continue; LLT DefLLT = MRI.getType(DefReg); SPIRVType *ResType = DefLLT.isVector() ? GR->getOrCreateSPIRVVectorType( ScalarType, DefLLT.getNumElements(), *I, *MF.getSubtarget().getInstrInfo()) : ScalarType; setRegClassType(DefReg, ResType, GR, &MRI, MF); } return true; } static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB) { LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I); MachineRegisterInfo &MRI = MF.getRegInfo(); Register ResVReg = I->getOperand(0).getReg(); // G_UNMERGE_VALUES is handled separately because it has multiple definitions, // unlike the other instructions which have a single result register. The main // deduction logic is designed for the single-definition case. if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES) return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB); LLVM_DEBUG(dbgs() << "Inferring type from operands\n"); SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB); if (!ResType) { LLVM_DEBUG(dbgs() << "Inferring type from uses\n"); ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB); } if (!ResType) return false; LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType); GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF); if (!MRI.getRegClassOrNull(ResVReg)) { LLVM_DEBUG(dbgs() << "Updating the register class.\n"); setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); } return true; } static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, MachineRegisterInfo &MRI) { LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: " << I;); if (I.getNumDefs() == 0) { LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n"); return false; } if (!I.isPreISelOpcode()) { LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n"); return false; } Register ResultRegister = I.defs().begin()->getReg(); if (GR->getSPIRVTypeForVReg(ResultRegister)) { LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n"); if (!MRI.getRegClassOrNull(ResultRegister)) { LLVM_DEBUG(dbgs() << "Updating the register class.\n"); setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister), GR, &MRI, *GR->CurMF, true); } return false; } return true; } static void registerSpirvTypeForNewInstructions(MachineFunction &MF, SPIRVGlobalRegistry *GR) { MachineRegisterInfo &MRI = MF.getRegInfo(); SmallVector Worklist; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &I : MBB) { if (requiresSpirvType(I, GR, MRI)) { Worklist.push_back(&I); } } } if (Worklist.empty()) { LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n"); return; } LLVM_DEBUG(dbgs() << "Initial worklist:\n"; for (auto *I : Worklist) { I->dump(); }); bool Changed; do { Changed = false; SmallVector NextWorklist; for (MachineInstr *I : Worklist) { MachineIRBuilder MIB(*I); if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { Changed = true; } else { NextWorklist.push_back(I); } } Worklist = std::move(NextWorklist); LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n"); } while (Changed); if (Worklist.empty()) return; for (auto *I : Worklist) { MachineIRBuilder MIB(*I); LLVM_DEBUG(dbgs() << "Assigning default type to results in " << *I); for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) { Register ResVReg = I->getOperand(Idx).getReg(); if (GR->getSPIRVTypeForVReg(ResVReg)) continue; const LLT &ResLLT = MRI.getType(ResVReg); SPIRVType *ResType = nullptr; if (ResLLT.isVector()) { SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType( ResLLT.getElementType().getSizeInBits(), MIB); ResType = GR->getOrCreateSPIRVVectorType( CompType, ResLLT.getNumElements(), MIB, false); } else { ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); } setRegClassType(ResVReg, ResType, GR, &MRI, MF, true); } } } static bool hasAssignType(Register Reg, MachineRegisterInfo &MRI) { for (MachineInstr &UseInstr : MRI.use_nodbg_instructions(Reg)) { if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) { return true; } } return false; } static void generateAssignType(MachineInstr &MI, Register ResultRegister, SPIRVType *ResultType, SPIRVGlobalRegistry *GR, MachineRegisterInfo &MRI) { LLVM_DEBUG(dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " << printReg(ResultRegister, MRI.getTargetRegisterInfo()) << " with type: " << *ResultType); MachineIRBuilder MIB(MI); updateRegType(ResultRegister, nullptr, ResultType, GR, MIB, MRI); // Tablegen definition assumes SPIRV::ASSIGN_TYPE pseudo-instruction is // present after each auto-folded instruction to take a type reference // from. Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(ResultRegister)); const auto *RegClass = GR->getRegClass(ResultType); MRI.setRegClass(NewReg, RegClass); MRI.setRegClass(ResultRegister, RegClass); GR->assignSPIRVTypeToVReg(ResultType, ResultRegister, MIB.getMF()); // This is to make it convenient for Legalizer to get the SPIRVType // when processing the actual MI (i.e. not pseudo one). GR->assignSPIRVTypeToVReg(ResultType, NewReg, MIB.getMF()); // Copy MIFlags from Def to ASSIGN_TYPE instruction. It's required to // keep the flags after instruction selection. const uint32_t Flags = MI.getFlags(); MIB.buildInstr(SPIRV::ASSIGN_TYPE) .addDef(ResultRegister) .addUse(NewReg) .addUse(GR->getSPIRVTypeID(ResultType)) .setMIFlags(Flags); for (unsigned I = 0, E = MI.getNumDefs(); I != E; ++I) { MachineOperand &MO = MI.getOperand(I); if (MO.getReg() == ResultRegister) { MO.setReg(NewReg); break; } } } static void ensureAssignTypeForTypeFolding(MachineFunction &MF, SPIRVGlobalRegistry *GR) { LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " << MF.getName() << "\n"); MachineRegisterInfo &MRI = MF.getRegInfo(); for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { if (!isTypeFoldingSupported(MI.getOpcode())) continue; LLVM_DEBUG(dbgs() << "Processing instruction: " << MI); Register ResultRegister = MI.defs().begin()->getReg(); if (hasAssignType(ResultRegister, MRI)) { LLVM_DEBUG(dbgs() << " Instruction already has ASSIGN_TYPE\n"); continue; } SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister); assert(ResultType); generateAssignType(MI, ResultRegister, ResultType, GR, MRI); } } } // Do a preorder traversal of the CFG starting from the BB |Start|. // point. Calls |op| on each basic block encountered during the traversal. void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function op) { std::stack ToVisit; SmallPtrSet Seen; ToVisit.push(&Start); Seen.insert(ToVisit.top()); while (ToVisit.size() != 0) { MachineBasicBlock *MBB = ToVisit.top(); ToVisit.pop(); op(MBB); for (auto Succ : MBB->successors()) { if (Seen.contains(Succ)) continue; ToVisit.push(Succ); Seen.insert(Succ); } } } // Do a preorder traversal of the CFG starting from the given function's entry // point. Calls |op| on each basic block encountered during the traversal. void visit(MachineFunction &MF, std::function op) { visit(MF, *MF.begin(), std::move(op)); } bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { // Initialize the type registry. const SPIRVSubtarget &ST = MF.getSubtarget(); SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); GR->setCurrentFunc(MF); registerSpirvTypeForNewInstructions(MF, GR); ensureAssignTypeForTypeFolding(MF, GR); return true; } INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false, false) char SPIRVPostLegalizer::ID = 0; FunctionPass *llvm::createSPIRVPostLegalizerPass() { return new SPIRVPostLegalizer(); }