diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 321 |
1 files changed, 299 insertions, 22 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index bc159d5..dc717a6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -248,6 +248,22 @@ static InstrSignature instrToSignature(const MachineInstr &MI, Register DefReg; InstrSignature Signature{MI.getOpcode()}; for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + // The only decorations that can be applied more than once to a given <id> + // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442), + // and CacheControlStoreINTEL (6443). For all the rest of decorations, we + // will only add to the signature the Opcode, the id to which it applies, + // and the decoration id, disregarding any decoration flags. This will + // ensure that any subsequent decoration with the same id will be deemed as + // a duplicate. Then, at the call site, we will be able to handle duplicates + // in the best way. + unsigned Opcode = MI.getOpcode(); + if ((Opcode == SPIRV::OpDecorate) && i >= 2) { + unsigned DecorationID = MI.getOperand(1).getImm(); + if (DecorationID != SPIRV::Decoration::UserSemantic && + DecorationID != SPIRV::Decoration::CacheControlLoadINTEL && + DecorationID != SPIRV::Decoration::CacheControlStoreINTEL) + continue; + } const MachineOperand &MO = MI.getOperand(i); size_t h; if (MO.isReg()) { @@ -559,8 +575,54 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, MAI.setSkipEmission(&MI); InstrSignature MISign = instrToSignature(MI, MAI, true); auto FoundMI = IS.insert(std::move(MISign)); - if (!FoundMI.second) + if (!FoundMI.second) { + if (MI.getOpcode() == SPIRV::OpDecorate) { + assert(MI.getNumOperands() >= 2 && + "Decoration instructions must have at least 2 operands"); + assert(MSType == SPIRV::MB_Annotations && + "Only OpDecorate instructions can be duplicates"); + // For FPFastMathMode decoration, we need to merge the flags of the + // duplicate decoration with the original one, so we need to find the + // original instruction that has the same signature. For the rest of + // instructions, we will simply skip the duplicate. + if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode) + return; // Skip duplicates of other decorations. + + const SPIRV::InstrList &Decorations = MAI.MS[MSType]; + for (const MachineInstr *OrigMI : Decorations) { + if (instrToSignature(*OrigMI, MAI, true) == MISign) { + assert(OrigMI->getNumOperands() == MI.getNumOperands() && + "Original instruction must have the same number of operands"); + assert( + OrigMI->getNumOperands() == 3 && + "FPFastMathMode decoration must have 3 operands for OpDecorate"); + unsigned OrigFlags = OrigMI->getOperand(2).getImm(); + unsigned NewFlags = MI.getOperand(2).getImm(); + if (OrigFlags == NewFlags) + return; // No need to merge, the flags are the same. + + // Emit warning about possible conflict between flags. + unsigned FinalFlags = OrigFlags | NewFlags; + llvm::errs() + << "Warning: Conflicting FPFastMathMode decoration flags " + "in instruction: " + << *OrigMI << "Original flags: " << OrigFlags + << ", new flags: " << NewFlags + << ". They will be merged on a best effort basis, but not " + "validated. Final flags: " + << FinalFlags << "\n"; + MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI); + MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2); + OrigFlagsOp = + MachineOperand::CreateImm(static_cast<unsigned>(FinalFlags)); + return; // Merge done, so we found a duplicate; don't add it to MAI.MS + } + } + assert(false && "No original instruction found for the duplicate " + "OpDecorate, but we found one in IS."); + } return; // insert failed, so we found a duplicate; don't add it to MAI.MS + } // No duplicates, so add it. if (Append) MAI.MS[MSType].push_back(&MI); @@ -934,6 +996,11 @@ static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex, } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) { Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL); Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error); + } else if (Dec == SPIRV::Decoration::FPFastMathMode) { + if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + Reqs.addRequirements(SPIRV::Capability::FloatControls2); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); + } } } @@ -1994,10 +2061,13 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, // Collect requirements for OpExecutionMode instructions. auto Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { - bool RequireFloatControls = false, RequireFloatControls2 = false, + bool RequireFloatControls = false, RequireIntelFloatControls2 = false, + RequireKHRFloatControls2 = false, VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4)); - bool HasFloatControls2 = + bool HasIntelFloatControls2 = ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + bool HasKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); for (unsigned i = 0; i < Node->getNumOperands(); i++) { MDNode *MDN = cast<MDNode>(Node->getOperand(i)); const MDOperand &MDOp = MDN->getOperand(1); @@ -2010,7 +2080,6 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, switch (EM) { case SPIRV::ExecutionMode::DenormPreserve: case SPIRV::ExecutionMode::DenormFlushToZero: - case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: case SPIRV::ExecutionMode::RoundingModeRTE: case SPIRV::ExecutionMode::RoundingModeRTZ: RequireFloatControls = VerLower14; @@ -2021,8 +2090,28 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, case SPIRV::ExecutionMode::RoundingModeRTNINTEL: case SPIRV::ExecutionMode::FloatingPointModeALTINTEL: case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL: - if (HasFloatControls2) { - RequireFloatControls2 = true; + if (HasIntelFloatControls2) { + RequireIntelFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + case SPIRV::ExecutionMode::FPFastMathDefault: { + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); + } + break; + } + case SPIRV::ExecutionMode::ContractionOff: + case SPIRV::ExecutionMode::SignedZeroInfNanPreserve: + if (HasKHRFloatControls2) { + RequireKHRFloatControls2 = true; + MAI.Reqs.getAndAddRequirements( + SPIRV::OperandCategory::ExecutionModeOperand, + SPIRV::ExecutionMode::FPFastMathDefault, ST); + } else { MAI.Reqs.getAndAddRequirements( SPIRV::OperandCategory::ExecutionModeOperand, EM, ST); } @@ -2037,8 +2126,10 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, if (RequireFloatControls && ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls)) MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls); - if (RequireFloatControls2) + if (RequireIntelFloatControls2) MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2); + if (RequireKHRFloatControls2) + MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -2078,8 +2169,11 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI, } } -static unsigned getFastMathFlags(const MachineInstr &I) { +static unsigned getFastMathFlags(const MachineInstr &I, + const SPIRVSubtarget &ST) { unsigned Flags = SPIRV::FPFastMathMode::None; + bool CanUseKHRFloatControls2 = + ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); if (I.getFlag(MachineInstr::MIFlag::FmNoNans)) Flags |= SPIRV::FPFastMathMode::NotNaN; if (I.getFlag(MachineInstr::MIFlag::FmNoInfs)) @@ -2088,12 +2182,45 @@ static unsigned getFastMathFlags(const MachineInstr &I) { Flags |= SPIRV::FPFastMathMode::NSZ; if (I.getFlag(MachineInstr::MIFlag::FmArcp)) Flags |= SPIRV::FPFastMathMode::AllowRecip; - if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) - Flags |= SPIRV::FPFastMathMode::Fast; + if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2) + Flags |= SPIRV::FPFastMathMode::AllowContract; + if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) { + if (CanUseKHRFloatControls2) + // LLVM reassoc maps to SPIRV transform, see + // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details. + // Because we are enabling AllowTransform, we must enable AllowReassoc and + // AllowContract too, as required by SPIRV spec. Also, we used to map + // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by + // replaced by turning all the other bits instead. Therefore, we're + // enabling every bit here except None and Fast. + Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip | + SPIRV::FPFastMathMode::AllowTransform | + SPIRV::FPFastMathMode::AllowReassoc | + SPIRV::FPFastMathMode::AllowContract; + else + Flags |= SPIRV::FPFastMathMode::Fast; + } + + if (CanUseKHRFloatControls2) { + // Error out if SPIRV::FPFastMathMode::Fast is enabled. + assert(!(Flags & SPIRV::FPFastMathMode::Fast) && + "SPIRV::FPFastMathMode::Fast is deprecated and should not be used " + "anymore."); + + // Error out if AllowTransform is enabled without AllowReassoc and + // AllowContract. + assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) || + ((Flags & SPIRV::FPFastMathMode::AllowReassoc && + Flags & SPIRV::FPFastMathMode::AllowContract))) && + "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and " + "AllowContract flags to be enabled as well."); + } + return Flags; } -static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { +static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) { if (ST.isKernel()) return true; if (ST.getSPIRVVersion() < VersionTuple(1, 2)) @@ -2101,9 +2228,10 @@ static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) { return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2); } -static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, - const SPIRVInstrInfo &TII, - SPIRV::RequirementHandler &Reqs) { +static void handleMIFlagDecoration( + MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII, + SPIRV::RequirementHandler &Reqs, const SPIRVGlobalRegistry *GR, + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) { if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) && getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand, SPIRV::Decoration::NoSignedWrap, ST, Reqs) @@ -2119,13 +2247,53 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, buildOpDecorate(I.getOperand(0).getReg(), I, TII, SPIRV::Decoration::NoUnsignedWrap, {}); } - if (!TII.canUseFastMathFlags(I)) - return; - unsigned FMFlags = getFastMathFlags(I); - if (FMFlags == SPIRV::FPFastMathMode::None) + if (!TII.canUseFastMathFlags( + I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))) return; - if (isFastMathMathModeAvailable(ST)) { + unsigned FMFlags = getFastMathFlags(I, ST); + if (FMFlags == SPIRV::FPFastMathMode::None) { + // We also need to check if any FPFastMathDefault info was set for the + // types used in this instruction. + if (FPFastMathDefaultInfoVec.empty()) + return; + + // There are three types of instructions that can use fast math flags: + // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.) + // 2. Relational instructions (FCmp, FOrd, FUnord, etc.) + // 3. Extended instructions (ExtInst) + // For arithmetic instructions, the floating point type can be in the + // result type or in the operands, but they all must be the same. + // For the relational and logical instructions, the floating point type + // can only be in the operands 1 and 2, not the result type. Also, the + // operands must have the same type. For the extended instructions, the + // floating point type can be in the result type or in the operands. It's + // unclear if the operands and the result type must be the same. Let's + // assume they must be. Therefore, for 1. and 2., we can check the first + // operand type, and for 3. we can check the result type. + assert(I.getNumOperands() >= 3 && "Expected at least 3 operands"); + Register ResReg = I.getOpcode() == SPIRV::OpExtInst + ? I.getOperand(1).getReg() + : I.getOperand(2).getReg(); + SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF()); + const Type *Ty = GR->getTypeForSPIRVType(ResType); + Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty; + + // Match instruction type with the FPFastMathDefaultInfoVec. + bool Emit = false; + for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) { + if (Ty == Elem.Ty) { + FMFlags = Elem.FastMathFlags; + Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve || + Elem.FPFastMathDefault; + break; + } + } + + if (FMFlags == SPIRV::FPFastMathMode::None && !Emit) + return; + } + if (isFastMathModeAvailable(ST)) { Register DstReg = I.getOperand(0).getReg(); buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags}); @@ -2135,14 +2303,17 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST, // Walk all functions and add decorations related to MI flags. static void addDecorations(const Module &M, const SPIRVInstrInfo &TII, MachineModuleInfo *MMI, const SPIRVSubtarget &ST, - SPIRV::ModuleAnalysisInfo &MAI) { + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVGlobalRegistry *GR) { for (auto F = M.begin(), E = M.end(); F != E; ++F) { MachineFunction *MF = MMI->getMachineFunction(*F); if (!MF) continue; + for (auto &MBB : *MF) for (auto &MI : MBB) - handleMIFlagDecoration(MI, ST, TII, MAI.Reqs); + handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR, + MAI.FPFastMathDefaultInfoMap[&(*F)]); } } @@ -2188,6 +2359,111 @@ static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR, } } +static SPIRV::FPFastMathDefaultInfoVector &getOrCreateFPFastMathDefaultInfoVec( + const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) { + auto it = MAI.FPFastMathDefaultInfoMap.find(F); + if (it != MAI.FPFastMathDefaultInfoMap.end()) + return it->second; + + // If the map does not contain the entry, create a new one. Initialize it to + // contain all 3 elements sorted by bit width of target type: {half, float, + // double}. + SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec; + FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()), + SPIRV::FPFastMathMode::None); + FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()), + SPIRV::FPFastMathMode::None); + return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec); +} + +static SPIRV::FPFastMathDefaultInfo &getFPFastMathDefaultInfo( + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, + const Type *Ty) { + size_t BitWidth = Ty->getScalarSizeInBits(); + int Index = + SPIRV::FPFastMathDefaultInfoVector::computeFPFastMathDefaultInfoVecIndex( + BitWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + return FPFastMathDefaultInfoVec[Index]; +} + +static void collectFPFastMathDefaults(const Module &M, + SPIRV::ModuleAnalysisInfo &MAI, + const SPIRVSubtarget &ST) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) + return; + + // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap. + // We need the entry point (function) as the key, and the target + // type and flags as the value. + // We also need to check ContractionOff and SignedZeroInfNanPreserve + // execution modes, as they are now deprecated and must be replaced + // with FPFastMathDefaultInfo. + auto Node = M.getNamedMetadata("spirv.ExecutionMode"); + if (!Node) + return; + + for (unsigned i = 0; i < Node->getNumOperands(); i++) { + MDNode *MDN = cast<MDNode>(Node->getOperand(i)); + assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands"); + const Function *F = cast<Function>( + cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue()); + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault) { + assert(MDN->getNumOperands() == 4 && + "Expected 4 operands for FPFastMathDefault"); + + const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType(); + unsigned Flags = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue()) + ->getZExtValue(); + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + SPIRV::FPFastMathDefaultInfo &Info = + getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T); + Info.FastMathFlags = Flags; + Info.FPFastMathDefault = true; + } else if (EM == SPIRV::ExecutionMode::ContractionOff) { + assert(MDN->getNumOperands() == 2 && + "Expected no operands for ContractionOff"); + + // We need to save this info for every possible FP type, i.e. {half, + // float, double, fp128}. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) { + Info.ContractionOff = true; + } + } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) { + assert(MDN->getNumOperands() == 3 && + "Expected 1 operand for SignedZeroInfNanPreserve"); + unsigned TargetWidth = + cast<ConstantInt>( + cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue()) + ->getZExtValue(); + // We need to save this info only for the FP type with TargetWidth. + SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec = + getOrCreateFPFastMathDefaultInfoVec(M, MAI, F); + int Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(TargetWidth); + assert(Index >= 0 && Index < 3 && + "Expected FPFastMathDefaultInfo for half, float, or double"); + assert(FPFastMathDefaultInfoVec.size() == 3 && + "Expected FPFastMathDefaultInfoVec to have exactly 3 elements"); + FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true; + } + } +} + struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI; void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { @@ -2209,7 +2485,8 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) { patchPhis(M, GR, *TII, MMI); addMBBNames(M, *TII, MMI, *ST, MAI); - addDecorations(M, *TII, MMI, *ST, MAI); + collectFPFastMathDefaults(M, MAI, *ST); + addDecorations(M, *TII, MMI, *ST, MAI, GR); collectReqs(M, MAI, MMI, *ST); |