diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 23 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 23 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 180 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 3 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp | 16 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 245 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h | 4 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 7 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h | 2 |
11 files changed, 492 insertions, 17 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 6ec7544..25cdf72 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -148,6 +148,7 @@ struct ConvertBuiltin { bool IsSaturated; bool IsRounded; bool IsBfloat16; + bool IsTF32; FPRoundingMode::FPRoundingMode RoundingMode; }; @@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, // - "__spirv_SubgroupImageMediaBlockReadINTEL" // - "__spirv_SubgroupImageMediaBlockWriteINTEL" // - "__spirv_Convert" + // - "__spirv_Round" // - "__spirv_UConvert" // - "__spirv_SConvert" // - "__spirv_FConvert" @@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|" "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" - "Convert|" + "Convert|Round|" "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?"); std::smatch Match; if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { @@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, + Register(0)); Register ScopeRegister = buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR); @@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall, } } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeFloat)) { - // Float -> Float - Opcode = SPIRV::OpFConvert; + if (Builtin->IsTF32) { + const auto *ST = static_cast<const SPIRVSubtarget *>( + &MIRBuilder.getMF().getSubtarget()); + if (!ST->canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) + NeedExtMsg = "SPV_INTEL_tensor_float32_conversion"; + IsRightComponentsNumber = + GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == + GR->getScalarOrVectorComponentCount(Call->ReturnRegister); + Opcode = SPIRV::OpRoundFToTF32INTEL; + } else { + // Float -> Float + Opcode = SPIRV::OpFConvert; + } } } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index ea78dcd..d08560b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1461,6 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> { bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)), !not(!eq(!find(name, "bfloat16"), -1))); + bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)), + !not(!eq(!find(name, "tensor_float32"), -1))); FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, !not(!eq(!find(name, "_rtz"), -1)) : RTZ, !not(!eq(!find(name, "_rtp"), -1)) : RTP, @@ -1472,7 +1474,7 @@ class ConvertBuiltin<string name, InstructionSet set> { def ConvertBuiltins : GenericTable { let FilterClass = "ConvertBuiltin"; let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", - "IsRounded", "IsBfloat16", "RoundingMode"]; + "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"]; string TypeOf_Set = "InstructionSet"; string TypeOf_RoundingMode = "FPRoundingMode"; } @@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>; } +// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding convert builtin record. +multiclass DemangledTF32RoundBuiltin<string name1, string name2> { + // Create records for scalar and vector conversions. + foreach i = ["", "2", "3", "4", "8", "16"] in { + def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>; + def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>; + } +} + +defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">; +defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">; + +foreach conv = ["FToTF32INTEL"] in { + def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>; + def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>; +} + //===----------------------------------------------------------------------===// // Class defining a vector data load/store builtin record used for lowering // into OpExtInst instruction. diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 2726203..d9265f4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}, {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}, {"SPV_KHR_float_controls2", - SPIRV::Extension::Extension::SPV_KHR_float_controls2}}; + SPIRV::Extension::Extension::SPV_KHR_float_controls2}, + {"SPV_INTEL_tensor_float32_conversion", + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 3c631ce..947b574 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -194,6 +194,42 @@ class SPIRVEmitIntrinsics void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B); + // Tries to walk the type accessed by the given GEP instruction. + // For each nested type access, one of the 2 callbacks is called: + // - OnLiteralIndexing when the index is a known constant value. + // Parameters: + // PointedType: the pointed type resulting of this indexing. + // If the parent type is an array, this is the index in the array. + // If the parent type is a struct, this is the field index. + // Index: index of the element in the parent type. + // - OnDynamnicIndexing when the index is a non-constant value. + // This callback is only called when indexing into an array. + // Parameters: + // ElementType: the type of the elements stored in the parent array. + // Offset: the Value* containing the byte offset into the array. + // Return true if an error occured during the walk, false otherwise. + bool walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *PointedType, uint64_t Index)> + &OnLiteralIndexing, + const std::function<void(Type *ElementType, Value *Offset)> + &OnDynamicIndexing); + + // Returns the type accessed using the given GEP instruction by relying + // on the GEP type. + // FIXME: GEP types are not supposed to be used to retrieve the pointed + // type. This must be fixed. + Type *getGEPType(GetElementPtrInst *GEP); + + // Returns the type accessed using the given GEP instruction by walking + // the source type using the GEP indices. + // FIXME: without help from the frontend, this method cannot reliably retrieve + // the stored type, nor can robustly determine the depth of the type + // we are accessing. + Type *getGEPTypeLogical(GetElementPtrInst *GEP); + + Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP); + public: static char ID; SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr) @@ -246,6 +282,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) { } } +// Returns the source pointer from `I` ignoring intermediate ptrcast. +Value *getPointerRoot(Value *I) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) { + Value *V = II->getArgOperand(0); + return getPointerRoot(V); + } + } + return I; +} + } // namespace char SPIRVEmitIntrinsics::ID = 0; @@ -555,7 +602,112 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy, Ty = RefTy; } -Type *getGEPType(GetElementPtrInst *Ref) { +bool SPIRVEmitIntrinsics::walkLogicalAccessChain( + GetElementPtrInst &GEP, + const std::function<void(Type *, uint64_t)> &OnLiteralIndexing, + const std::function<void(Type *, Value *)> &OnDynamicIndexing) { + // We only rewrite i8* GEP. Other should be left as-is. + // Valid i8* GEP must always have a single index. + assert(GEP.getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext())); + assert(GEP.getNumIndices() == 1); + + auto &DL = CurrF->getDataLayout(); + Value *Src = getPointerRoot(GEP.getPointerOperand()); + Type *CurType = deduceElementType(Src, true); + + Value *Operand = *GEP.idx_begin(); + ConstantInt *CI = dyn_cast<ConstantInt>(Operand); + if (!CI) { + ArrayType *AT = dyn_cast<ArrayType>(CurType); + // Operand is not constant. Either we have an array and accept it, or we + // give up. + if (AT) + OnDynamicIndexing(AT->getElementType(), Operand); + return AT == nullptr; + } + + assert(CI); + uint64_t Offset = CI->getZExtValue(); + + do { + if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8; + assert(Offset < AT->getNumElements() * EltTypeSize); + uint64_t Index = Offset / EltTypeSize; + Offset = Offset - (Index * EltTypeSize); + CurType = AT->getElementType(); + OnLiteralIndexing(CurType, Index); + } else if (StructType *ST = dyn_cast<StructType>(CurType)) { + uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8; + assert(Offset < StructSize); + (void)StructSize; + const auto &STL = DL.getStructLayout(ST); + unsigned Element = STL->getElementContainingOffset(Offset); + Offset -= STL->getElementOffset(Element); + CurType = ST->getElementType(Element); + OnLiteralIndexing(CurType, Element); + } else { + // Vector type indexing should not use GEP. + // So if we have an index left, something is wrong. Giving up. + return true; + } + } while (Offset > 0); + + return false; +} + +Instruction * +SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) { + auto &DL = CurrF->getDataLayout(); + IRBuilder<> B(GEP.getParent()); + B.SetInsertPoint(&GEP); + + std::vector<Value *> Indices; + Indices.push_back(ConstantInt::get( + IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false)); + walkLogicalAccessChain( + GEP, + [&Indices, &B](Type *EltType, uint64_t Index) { + Indices.push_back( + ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false)); + }, + [&Indices, &B, &DL](Type *EltType, Value *Offset) { + uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8; + Value *Index = B.CreateUDiv( + Offset, ConstantInt::get(Offset->getType(), EltTypeSize, + /* Signed= */ false)); + Indices.push_back(Index); + }); + + SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()}; + SmallVector<Value *, 4> Args; + Args.push_back(B.getInt1(GEP.isInBounds())); + Args.push_back(GEP.getOperand(0)); + llvm::append_range(Args, Indices); + auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + replaceAllUsesWithAndErase(B, &GEP, NewI); + return NewI; +} + +Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) { + + Type *CurType = GEP->getResultElementType(); + + bool Interrupted = walkLogicalAccessChain( + *GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; }, + [&CurType](Type *EltType, Value *Index) { CurType = EltType; }); + + return Interrupted ? GEP->getResultElementType() : CurType; +} + +Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) { + if (Ref->getSourceElementType() == + IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + return getGEPTypeLogical(Ref); + } + Type *Ty = nullptr; // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything // useful here @@ -1395,6 +1547,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) { } Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) { + if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) && + TM->getSubtargetImpl()->isLogicalSPIRV()) { + Instruction *Result = buildLogicalAccessChainFromGEP(I); + if (Result) + return Result; + } + IRBuilder<> B(I.getParent()); B.SetInsertPoint(&I); SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()}; @@ -1588,7 +1747,24 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, } if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) { Value *Pointer = GEPI->getPointerOperand(); - Type *OpTy = GEPI->getSourceElementType(); + Type *OpTy = nullptr; + + // Knowing the accessed type is mandatory for logical SPIR-V. Sadly, + // the GEP source element type should not be used for this purpose, and + // the alternative type-scavenging method is not working. + // Physical SPIR-V can work around this, but not logical, hence still + // try to rely on the broken type scavenging for logical. + bool IsRewrittenGEP = + GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext()); + if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) { + Value *Src = getPointerRoot(Pointer); + OpTy = GR->findDeducedElementType(Src); + } + + // In all cases, fall back to the GEP type if type scavenging failed. + if (!OpTy) + OpTy = GEPI->getSourceElementType(); + replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B); if (isNestedPointer(OpTy)) insertTodoType(Pointer); diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 049ba02..f0b938d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938 def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>; def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>; +// SPV_INTEL_tensor_float32_conversion +def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>; + // 3.42.12 Composite Instructions def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 5cda6a0..7505507 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -74,17 +74,20 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - // We expect the codegen to avoid doing implicit bitcast from a load. - assert(TargetType->getElementType() == SourceType->getElementType()); - assert(TargetType->getNumElements() < SourceType->getNumElements()); - + assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); + Value *AssignValue = NewLoad; + if (TargetType->getElementType() != SourceType->getElementType()) { + AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, + {TargetType, SourceType}, {NewLoad}); + buildAssignType(B, TargetType, AssignValue); + } SmallVector<int> Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; - Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask); + Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask); buildAssignType(B, TargetType, Output); return Output; } @@ -135,8 +138,9 @@ class SPIRVLegalizePointerCast : public FunctionPass { Output = loadFirstValueFromAggregate(B, SVT->getElementType(), OriginalOperand, LI); } - // Destination is a smaller vector than source. + // Destination is a smaller vector than source or different vector type. // - float3 v3 = vector4; + // - float4 v2 = int4; else if (SVT && DVT) Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand); // Destination is the scalar type stored at the start of an aggregate. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 721f64a..1995e0f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); } + getActionDefinitionsBuilder(G_IS_FPCLASS).custom(); + getLegacyLegalizerInfo().computeTables(); verify(*ST.getInstrInfo()); } @@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, bool SPIRVLegalizerInfo::legalizeCustom( LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const { - auto Opc = MI.getOpcode(); MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); - if (Opc == TargetOpcode::G_ICMP) { + switch (MI.getOpcode()) { + default: + // TODO: implement legalization for other opcodes. + return true; + case TargetOpcode::G_IS_FPCLASS: + return legalizeIsFPClass(Helper, MI, LocObserver); + case TargetOpcode::G_ICMP: { assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); auto &Op0 = MI.getOperand(2); auto &Op1 = MI.getOperand(3); @@ -378,6 +385,238 @@ bool SPIRVLegalizerInfo::legalizeCustom( } return true; } - // TODO: implement legalization for other opcodes. + } +} + +// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted +// to ensure that all instructions created during the lowering have SPIR-V types +// assigned to them. +bool SPIRVLegalizerInfo::legalizeIsFPClass( + LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const { + auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs(); + FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm()); + + auto &MIRBuilder = Helper.MIRBuilder; + auto &MF = MIRBuilder.getMF(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + + Type *LLVMDstTy = + IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits()); + if (DstTy.isVector()) + LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount()); + SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType( + LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + + unsigned BitSize = SrcTy.getScalarSizeInBits(); + const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType()); + + LLT IntTy = LLT::scalar(BitSize); + Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize); + if (SrcTy.isVector()) { + IntTy = LLT::vector(SrcTy.getElementCount(), IntTy); + LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount()); + } + SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType( + LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + + // Clang doesn't support capture of structured bindings: + LLT DstTyCopy = DstTy; + const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) { + // Assign this MI's (assumed only) destination to one of the two types we + // expect: either the G_IS_FPCLASS's destination type, or the integer type + // bitcast from the source type. + LLT MITy = MRI.getType(MI.getReg(0)); + assert((MITy == IntTy || MITy == DstTyCopy) && + "Unexpected LLT type while lowering G_IS_FPCLASS"); + auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy; + GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF); + return MI; + }; + + // Helper to build and assign a constant in one go + const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder { + if (!Ty.isFixedVector()) + return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C)); + auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C); + assert((Ty == IntTy || Ty == DstTyCopy) && + "Unexpected LLT type while lowering constant for G_IS_FPCLASS"); + SPIRVType *VecEltTy = GR->getOrCreateSPIRVType( + (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder, + SPIRV::AccessQualifier::ReadWrite, + /*EmitIR*/ true); + GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF); + return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC)); + }; + + if (Mask == fcNone) { + MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0)); + MI.eraseFromParent(); + return true; + } + if (Mask == fcAllFlags) { + MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1)); + MI.eraseFromParent(); + return true; + } + + // Note that rather than creating a COPY here (between a floating-point and + // integer type of the same size) we create a SPIR-V bitcast immediately. We + // can't create a G_BITCAST because the LLTs are the same, and we can't seem + // to correctly lower COPYs to SPIR-V bitcasts at this moment. + Register ResVReg = MRI.createGenericVirtualRegister(IntTy); + MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy)); + GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF()); + auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast) + .addDef(ResVReg) + .addUse(GR->getSPIRVTypeID(SPIRVIntTy)) + .addUse(SrcReg); + AsInt = assignSPIRVTy(std::move(AsInt)); + + // Various masks. + APInt SignBit = APInt::getSignMask(BitSize); + APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign. + APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit. + APInt ExpMask = Inf; + APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf; + APInt QNaNBitMask = + APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1); + APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits()); + + auto SignBitC = buildSPIRVConstant(IntTy, SignBit); + auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask); + auto InfC = buildSPIRVConstant(IntTy, Inf); + auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask); + auto ZeroC = buildSPIRVConstant(IntTy, 0); + + auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC)); + auto Sign = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs)); + + auto Res = buildSPIRVConstant(DstTy, 0); + + const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) { + Res = assignSPIRVTy( + MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend)))); + }; + + // Tests that involve more than one class should be processed first. + if ((Mask & fcFinite) == fcFinite) { + // finite(V) ==> abs(V) u< exp_mask + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs, + ExpMaskC)); + Mask &= ~fcFinite; + } else if ((Mask & fcFinite) == fcPosFinite) { + // finite(V) && V > 0 ==> V u< exp_mask + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt, + ExpMaskC)); + Mask &= ~fcPosFinite; + } else if ((Mask & fcFinite) == fcNegFinite) { + // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1 + auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, + DstTy, Abs, ExpMaskC)); + appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign)); + Mask &= ~fcNegFinite; + } + + if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) { + // fcZero | fcSubnormal => test all exponent bits are 0 + // TODO: Handle sign bit specific cases + // TODO: Handle inverted case + if (PartialCheck == (fcZero | fcSubnormal)) { + auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC)); + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + ExpBits, ZeroC)); + Mask &= ~PartialCheck; + } + } + + // Check for individual classes. + if (FPClassTest PartialCheck = Mask & fcZero) { + if (PartialCheck == fcPosZero) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, ZeroC)); + else if (PartialCheck == fcZero) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC)); + else // fcNegZero + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, SignBitC)); + } + + if (FPClassTest PartialCheck = Mask & fcSubnormal) { + // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set) + // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set) + auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs; + auto OneC = buildSPIRVConstant(IntTy, 1); + auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC); + auto SubnormalRes = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne, + buildSPIRVConstant(IntTy, AllOneMantissa))); + if (PartialCheck == fcNegSubnormal) + SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign); + appendToRes(std::move(SubnormalRes)); + } + + if (FPClassTest PartialCheck = Mask & fcInf) { + if (PartialCheck == fcPosInf) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, InfC)); + else if (PartialCheck == fcInf) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC)); + else { // fcNegInf + APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt(); + auto NegInfC = buildSPIRVConstant(IntTy, NegInf); + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, + AsInt, NegInfC)); + } + } + + if (FPClassTest PartialCheck = Mask & fcNan) { + auto InfWithQnanBitC = buildSPIRVConstant(IntTy, Inf | QNaNBitMask); + if (PartialCheck == fcNan) { + // isnan(V) ==> abs(V) u> int(inf) + appendToRes( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC)); + } else if (PartialCheck == fcQNan) { + // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit) + appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs, + InfWithQnanBitC)); + } else { // fcSNan + // issignaling(V) ==> abs(V) u> unsigned(Inf) && + // abs(V) u< (unsigned(Inf) | quiet_bit) + auto IsNan = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC)); + auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp( + CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC)); + appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan)); + } + } + + if (FPClassTest PartialCheck = Mask & fcNormal) { + // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u< + // (max_exp-1)) + APInt ExpLSB = ExpMask & ~(ExpMask.shl(1)); + auto ExpMinusOne = assignSPIRVTy( + MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB))); + APInt MaxExpMinusOne = ExpMask - ExpLSB; + auto NormalRes = assignSPIRVTy( + MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne, + buildSPIRVConstant(IntTy, MaxExpMinusOne))); + if (PartialCheck == fcNegNormal) + NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign); + else if (PartialCheck == fcPosNormal) { + auto PosSign = assignSPIRVTy(MIRBuilder.buildXor( + DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask))); + NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign); + } + appendToRes(std::move(NormalRes)); + } + + MIRBuilder.buildCopy(DstReg, Res); + MI.eraseFromParent(); return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index 6335f21..eeefa42 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -30,6 +30,10 @@ public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; SPIRVLegalizerInfo(const SPIRVSubtarget &ST); + +private: + bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index ad976e5..0cd9d78 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); } break; + case SPIRV::OpRoundFToTF32INTEL: + if (ST.canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); + Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL); + } + break; case SPIRV::OpVariableLengthArrayINTEL: case SPIRV::OpSaveMemoryINTEL: case SPIRV::OpRestoreMemoryINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 548e9b7..614e83a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>; defm SPV_INTEL_2d_block_io : ExtensionOperand<122>; defm SPV_INTEL_int4 : ExtensionOperand<123>; defm SPV_KHR_float_controls2 : ExtensionOperand<124>; +defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>; defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; +defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h index 43bf6e9..60c4e2d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h @@ -59,6 +59,8 @@ public: Intrinsic::ID IID) const override; Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, Value *NewV) const override; + + bool allowVectorElementIndexingUsingGEP() const override { return false; } }; } // namespace llvm |