//===------- CGHLSLBuiltins.cpp - Emit LLVM Code for HLSL builtins --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This contains code to emit HLSL Builtin calls as LLVM code. // //===----------------------------------------------------------------------===// #include "CGBuiltin.h" #include "CGHLSLRuntime.h" #include "CodeGenFunction.h" using namespace clang; using namespace CodeGen; using namespace llvm; static Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) { assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() && E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) && "asdouble operands types mismatch"); Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0)); Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1)); llvm::Type *ResultType = CGF.DoubleTy; int N = 1; if (auto *VTy = E->getArg(0)->getType()->getAs()) { N = VTy->getNumElements(); ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N); } if (CGF.CGM.getTarget().getTriple().isDXIL()) return CGF.Builder.CreateIntrinsic( /*ReturnType=*/ResultType, Intrinsic::dx_asdouble, {OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble"); if (!E->getArg(0)->getType()->isVectorType()) { OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits); OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits); } llvm::SmallVector Mask; for (int i = 0; i < N; i++) { Mask.push_back(i); Mask.push_back(i + N); } Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask); return CGF.Builder.CreateBitCast(BitVec, ResultType); } static Value *handleHlslClip(const CallExpr *E, CodeGenFunction *CGF) { Value *Op0 = CGF->EmitScalarExpr(E->getArg(0)); Constant *FZeroConst = ConstantFP::getZero(CGF->FloatTy); Value *CMP; Value *LastInstr; if (const auto *VecTy = E->getArg(0)->getType()->getAs()) { FZeroConst = ConstantVector::getSplat( ElementCount::getFixed(VecTy->getNumElements()), FZeroConst); auto *FCompInst = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst); CMP = CGF->Builder.CreateIntrinsic( CGF->Builder.getInt1Ty(), CGF->CGM.getHLSLRuntime().getAnyIntrinsic(), {FCompInst}); } else { CMP = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst); } if (CGF->CGM.getTarget().getTriple().isDXIL()) { LastInstr = CGF->Builder.CreateIntrinsic(Intrinsic::dx_discard, {CMP}); } else if (CGF->CGM.getTarget().getTriple().isSPIRV()) { BasicBlock *LT0 = CGF->createBasicBlock("lt0", CGF->CurFn); BasicBlock *End = CGF->createBasicBlock("end", CGF->CurFn); CGF->Builder.CreateCondBr(CMP, LT0, End); CGF->Builder.SetInsertPoint(LT0); CGF->Builder.CreateIntrinsic(Intrinsic::spv_discard, {}); LastInstr = CGF->Builder.CreateBr(End); CGF->Builder.SetInsertPoint(End); } else { llvm_unreachable("Backend Codegen not supported."); } return LastInstr; } static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) { Value *Op0 = CGF->EmitScalarExpr(E->getArg(0)); const auto *OutArg1 = dyn_cast(E->getArg(1)); const auto *OutArg2 = dyn_cast(E->getArg(2)); CallArgList Args; LValue Op1TmpLValue = CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType()); LValue Op2TmpLValue = CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee()) Args.reverseWritebacks(); Value *LowBits = nullptr; Value *HighBits = nullptr; if (CGF->CGM.getTarget().getTriple().isDXIL()) { llvm::Type *RetElementTy = CGF->Int32Ty; if (auto *Op0VecTy = E->getArg(0)->getType()->getAs()) RetElementTy = llvm::VectorType::get( CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements())); auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy); CallInst *CI = CGF->Builder.CreateIntrinsic( RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble"); LowBits = CGF->Builder.CreateExtractValue(CI, 0); HighBits = CGF->Builder.CreateExtractValue(CI, 1); } else { // For Non DXIL targets we generate the instructions. if (!Op0->getType()->isVectorTy()) { FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2); Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy); LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0); HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1); } else { int NumElements = 1; if (const auto *VecTy = E->getArg(0)->getType()->getAs()) NumElements = VecTy->getNumElements(); FixedVectorType *Uint32VecTy = FixedVectorType::get(CGF->Int32Ty, NumElements * 2); Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy); if (NumElements == 1) { LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0); HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1); } else { SmallVector EvenMask, OddMask; for (int I = 0, E = NumElements; I != E; ++I) { EvenMask.push_back(I * 2); OddMask.push_back(I * 2 + 1); } LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask); HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask); } } } CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress()); auto *LastInst = CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress()); CGF->EmitWritebacks(Args); return LastInst; } // Return dot product intrinsic that corresponds to the QT scalar type static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) { if (QT->isFloatingType()) return RT.getFDotIntrinsic(); if (QT->isSignedIntegerType()) return RT.getSDotIntrinsic(); assert(QT->isUnsignedIntegerType()); return RT.getUDotIntrinsic(); } static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { if (QT->hasSignedIntegerRepresentation()) { return RT.getFirstBitSHighIntrinsic(); } assert(QT->hasUnsignedIntegerRepresentation()); return RT.getFirstBitUHighIntrinsic(); } // Return wave active sum that corresponds to the QT scalar type static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, CGHLSLRuntime &RT, QualType QT) { switch (Arch) { case llvm::Triple::spirv: return Intrinsic::spv_wave_reduce_sum; case llvm::Triple::dxil: { if (QT->isUnsignedIntegerType()) return Intrinsic::dx_wave_reduce_usum; return Intrinsic::dx_wave_reduce_sum; } default: llvm_unreachable("Intrinsic WaveActiveSum" " not supported by target architecture"); } } // Return wave active sum that corresponds to the QT scalar type static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch, CGHLSLRuntime &RT, QualType QT) { switch (Arch) { case llvm::Triple::spirv: if (QT->isUnsignedIntegerType()) return Intrinsic::spv_wave_reduce_umax; return Intrinsic::spv_wave_reduce_max; case llvm::Triple::dxil: { if (QT->isUnsignedIntegerType()) return Intrinsic::dx_wave_reduce_umax; return Intrinsic::dx_wave_reduce_max; } default: llvm_unreachable("Intrinsic WaveActiveMax" " not supported by target architecture"); } } // Returns the mangled name for a builtin function that the SPIR-V backend // will expand into a spec Constant. static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType, ASTContext &Context) { // The parameter types for our conceptual intrinsic function. QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType}; // Create a temporary FunctionDecl for the builtin fuction. It won't be // added to the AST. FunctionProtoType::ExtProtoInfo EPI; QualType FnType = Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI); DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant"); FunctionDecl *FnDeclForMangling = FunctionDecl::Create( Context, Context.getTranslationUnitDecl(), SourceLocation(), SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern); // Attach the created parameter declarations to the function declaration. SmallVector ParamDecls; for (QualType ParamType : ClangParamTypes) { ParmVarDecl *PD = ParmVarDecl::Create( Context, FnDeclForMangling, SourceLocation(), SourceLocation(), /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None, /*DefaultArg*/ nullptr); ParamDecls.push_back(PD); } FnDeclForMangling->setParams(ParamDecls); // Get the mangled name. std::string Name; llvm::raw_string_ostream MangledNameStream(Name); std::unique_ptr Mangler(Context.createMangleContext()); Mangler->mangleName(FnDeclForMangling, MangledNameStream); MangledNameStream.flush(); return Name; } Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E, ReturnValueSlot ReturnValue) { if (!getLangOpts().HLSL) return nullptr; switch (BuiltinID) { case Builtin::BI__builtin_hlsl_adduint64: { Value *OpA = EmitScalarExpr(E->getArg(0)); Value *OpB = EmitScalarExpr(E->getArg(1)); QualType Arg0Ty = E->getArg(0)->getType(); uint64_t NumElements = Arg0Ty->castAs()->getNumElements(); assert(Arg0Ty == E->getArg(1)->getType() && "AddUint64 operand types must match"); assert(Arg0Ty->hasIntegerRepresentation() && "AddUint64 operands must have an integer representation"); assert((NumElements == 2 || NumElements == 4) && "AddUint64 operands must have 2 or 4 elements"); llvm::Value *LowA; llvm::Value *HighA; llvm::Value *LowB; llvm::Value *HighB; // Obtain low and high words of inputs A and B if (NumElements == 2) { LowA = Builder.CreateExtractElement(OpA, (uint64_t)0, "LowA"); HighA = Builder.CreateExtractElement(OpA, (uint64_t)1, "HighA"); LowB = Builder.CreateExtractElement(OpB, (uint64_t)0, "LowB"); HighB = Builder.CreateExtractElement(OpB, (uint64_t)1, "HighB"); } else { LowA = Builder.CreateShuffleVector(OpA, {0, 2}, "LowA"); HighA = Builder.CreateShuffleVector(OpA, {1, 3}, "HighA"); LowB = Builder.CreateShuffleVector(OpB, {0, 2}, "LowB"); HighB = Builder.CreateShuffleVector(OpB, {1, 3}, "HighB"); } // Use an uadd_with_overflow to compute the sum of low words and obtain a // carry value llvm::Value *Carry; llvm::Value *LowSum = EmitOverflowIntrinsic( *this, Intrinsic::uadd_with_overflow, LowA, LowB, Carry); llvm::Value *ZExtCarry = Builder.CreateZExt(Carry, HighA->getType(), "CarryZExt"); // Sum the high words and the carry llvm::Value *HighSum = Builder.CreateAdd(HighA, HighB, "HighSum"); llvm::Value *HighSumPlusCarry = Builder.CreateAdd(HighSum, ZExtCarry, "HighSumPlusCarry"); if (NumElements == 4) { return Builder.CreateShuffleVector(LowSum, HighSumPlusCarry, {0, 2, 1, 3}, "hlsl.AddUint64"); } llvm::Value *Result = PoisonValue::get(OpA->getType()); Result = Builder.CreateInsertElement(Result, LowSum, (uint64_t)0, "hlsl.AddUint64.upto0"); Result = Builder.CreateInsertElement(Result, HighSumPlusCarry, (uint64_t)1, "hlsl.AddUint64"); return Result; } case Builtin::BI__builtin_hlsl_resource_getpointer: { Value *HandleOp = EmitScalarExpr(E->getArg(0)); Value *IndexOp = EmitScalarExpr(E->getArg(1)); llvm::Type *RetTy = ConvertType(E->getType()); return Builder.CreateIntrinsic( RetTy, CGM.getHLSLRuntime().getCreateResourceGetPointerIntrinsic(), ArrayRef{HandleOp, IndexOp}); } case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: { llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); return llvm::PoisonValue::get(HandleTy); } case Builtin::BI__builtin_hlsl_resource_handlefrombinding: { llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); Value *RegisterOp = EmitScalarExpr(E->getArg(1)); Value *SpaceOp = EmitScalarExpr(E->getArg(2)); Value *RangeOp = EmitScalarExpr(E->getArg(3)); Value *IndexOp = EmitScalarExpr(E->getArg(4)); Value *Name = EmitScalarExpr(E->getArg(5)); // FIXME: NonUniformResourceIndex bit is not yet implemented // (llvm/llvm-project#135452) Value *NonUniform = llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false); llvm::Intrinsic::ID IntrinsicID = CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(); SmallVector Args{SpaceOp, RegisterOp, RangeOp, IndexOp, NonUniform, Name}; return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); } case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: { llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); Value *SpaceOp = EmitScalarExpr(E->getArg(1)); Value *RangeOp = EmitScalarExpr(E->getArg(2)); Value *IndexOp = EmitScalarExpr(E->getArg(3)); Value *OrderID = EmitScalarExpr(E->getArg(4)); Value *Name = EmitScalarExpr(E->getArg(5)); // FIXME: NonUniformResourceIndex bit is not yet implemented // (llvm/llvm-project#135452) Value *NonUniform = llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false); llvm::Intrinsic::ID IntrinsicID = CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic(); SmallVector Args{OrderID, SpaceOp, RangeOp, IndexOp, NonUniform, Name}; return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); } case Builtin::BI__builtin_hlsl_all: { Value *Op0 = EmitScalarExpr(E->getArg(0)); return Builder.CreateIntrinsic( /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()), CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.all"); } case Builtin::BI__builtin_hlsl_and: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Value *Op1 = EmitScalarExpr(E->getArg(1)); return Builder.CreateAnd(Op0, Op1, "hlsl.and"); } case Builtin::BI__builtin_hlsl_or: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Value *Op1 = EmitScalarExpr(E->getArg(1)); return Builder.CreateOr(Op0, Op1, "hlsl.or"); } case Builtin::BI__builtin_hlsl_any: { Value *Op0 = EmitScalarExpr(E->getArg(0)); return Builder.CreateIntrinsic( /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()), CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.any"); } case Builtin::BI__builtin_hlsl_asdouble: return handleAsDoubleBuiltin(*this, E); case Builtin::BI__builtin_hlsl_elementwise_clamp: { Value *OpX = EmitScalarExpr(E->getArg(0)); Value *OpMin = EmitScalarExpr(E->getArg(1)); Value *OpMax = EmitScalarExpr(E->getArg(2)); QualType Ty = E->getArg(0)->getType(); if (auto *VecTy = Ty->getAs()) Ty = VecTy->getElementType(); Intrinsic::ID Intr; if (Ty->isFloatingType()) { Intr = CGM.getHLSLRuntime().getNClampIntrinsic(); } else if (Ty->isUnsignedIntegerType()) { Intr = CGM.getHLSLRuntime().getUClampIntrinsic(); } else { assert(Ty->isSignedIntegerType()); Intr = CGM.getHLSLRuntime().getSClampIntrinsic(); } return Builder.CreateIntrinsic( /*ReturnType=*/OpX->getType(), Intr, ArrayRef{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp"); } case Builtin::BI__builtin_hlsl_crossf16: case Builtin::BI__builtin_hlsl_crossf32: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Value *Op1 = EmitScalarExpr(E->getArg(1)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && E->getArg(1)->getType()->hasFloatingRepresentation() && "cross operands must have a float representation"); // make sure each vector has exactly 3 elements assert( E->getArg(0)->getType()->castAs()->getNumElements() == 3 && E->getArg(1)->getType()->castAs()->getNumElements() == 3 && "input vectors must have 3 elements each"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(), ArrayRef{Op0, Op1}, nullptr, "hlsl.cross"); } case Builtin::BI__builtin_hlsl_dot: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Value *Op1 = EmitScalarExpr(E->getArg(1)); llvm::Type *T0 = Op0->getType(); llvm::Type *T1 = Op1->getType(); // If the arguments are scalars, just emit a multiply if (!T0->isVectorTy() && !T1->isVectorTy()) { if (T0->isFloatingPointTy()) return Builder.CreateFMul(Op0, Op1, "hlsl.dot"); if (T0->isIntegerTy()) return Builder.CreateMul(Op0, Op1, "hlsl.dot"); llvm_unreachable( "Scalar dot product is only supported on ints and floats."); } // For vectors, validate types and emit the appropriate intrinsic assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(), E->getArg(1)->getType()) && "Dot product operands must have the same type."); auto *VecTy0 = E->getArg(0)->getType()->castAs(); assert(VecTy0 && "Dot product argument must be a vector."); return Builder.CreateIntrinsic( /*ReturnType=*/T0->getScalarType(), getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()), ArrayRef{Op0, Op1}, nullptr, "hlsl.dot"); } case Builtin::BI__builtin_hlsl_dot4add_i8packed: { Value *X = EmitScalarExpr(E->getArg(0)); Value *Y = EmitScalarExpr(E->getArg(1)); Value *Acc = EmitScalarExpr(E->getArg(2)); Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic(); // Note that the argument order disagrees between the builtin and the // intrinsic here. return Builder.CreateIntrinsic( /*ReturnType=*/Acc->getType(), ID, ArrayRef{Acc, X, Y}, nullptr, "hlsl.dot4add.i8packed"); } case Builtin::BI__builtin_hlsl_dot4add_u8packed: { Value *X = EmitScalarExpr(E->getArg(0)); Value *Y = EmitScalarExpr(E->getArg(1)); Value *Acc = EmitScalarExpr(E->getArg(2)); Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic(); // Note that the argument order disagrees between the builtin and the // intrinsic here. return Builder.CreateIntrinsic( /*ReturnType=*/Acc->getType(), ID, ArrayRef{Acc, X, Y}, nullptr, "hlsl.dot4add.u8packed"); } case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: { Value *X = EmitScalarExpr(E->getArg(0)); return Builder.CreateIntrinsic( /*ReturnType=*/ConvertType(E->getType()), getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()), ArrayRef{X}, nullptr, "hlsl.firstbithigh"); } case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: { Value *X = EmitScalarExpr(E->getArg(0)); return Builder.CreateIntrinsic( /*ReturnType=*/ConvertType(E->getType()), CGM.getHLSLRuntime().getFirstBitLowIntrinsic(), ArrayRef{X}, nullptr, "hlsl.firstbitlow"); } case Builtin::BI__builtin_hlsl_lerp: { Value *X = EmitScalarExpr(E->getArg(0)); Value *Y = EmitScalarExpr(E->getArg(1)); Value *S = EmitScalarExpr(E->getArg(2)); if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("lerp operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(), ArrayRef{X, Y, S}, nullptr, "hlsl.lerp"); } case Builtin::BI__builtin_hlsl_normalize: { Value *X = EmitScalarExpr(E->getArg(0)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && "normalize operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef{X}, nullptr, "hlsl.normalize"); } case Builtin::BI__builtin_hlsl_elementwise_degrees: { Value *X = EmitScalarExpr(E->getArg(0)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && "degree operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getDegreesIntrinsic(), ArrayRef{X}, nullptr, "hlsl.degrees"); } case Builtin::BI__builtin_hlsl_elementwise_frac: { Value *Op0 = EmitScalarExpr(E->getArg(0)); if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("frac operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.frac"); } case Builtin::BI__builtin_hlsl_elementwise_isinf: { Value *Op0 = EmitScalarExpr(E->getArg(0)); llvm::Type *Xty = Op0->getType(); llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext()); if (Xty->isVectorTy()) { auto *XVecTy = E->getArg(0)->getType()->castAs(); retType = llvm::VectorType::get( retType, ElementCount::getFixed(XVecTy->getNumElements())); } if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("isinf operand must have a float representation"); return Builder.CreateIntrinsic(retType, Intrinsic::dx_isinf, ArrayRef{Op0}, nullptr, "dx.isinf"); } case Builtin::BI__builtin_hlsl_mad: { Value *M = EmitScalarExpr(E->getArg(0)); Value *A = EmitScalarExpr(E->getArg(1)); Value *B = EmitScalarExpr(E->getArg(2)); if (E->getArg(0)->getType()->hasFloatingRepresentation()) return Builder.CreateIntrinsic( /*ReturnType*/ M->getType(), Intrinsic::fmuladd, ArrayRef{M, A, B}, nullptr, "hlsl.fmad"); if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) { if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) return Builder.CreateIntrinsic( /*ReturnType*/ M->getType(), Intrinsic::dx_imad, ArrayRef{M, A, B}, nullptr, "dx.imad"); Value *Mul = Builder.CreateNSWMul(M, A); return Builder.CreateNSWAdd(Mul, B); } assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation()); if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) return Builder.CreateIntrinsic( /*ReturnType=*/M->getType(), Intrinsic::dx_umad, ArrayRef{M, A, B}, nullptr, "dx.umad"); Value *Mul = Builder.CreateNUWMul(M, A); return Builder.CreateNUWAdd(Mul, B); } case Builtin::BI__builtin_hlsl_elementwise_rcp: { Value *Op0 = EmitScalarExpr(E->getArg(0)); if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("rcp operand must have a float representation"); llvm::Type *Ty = Op0->getType(); llvm::Type *EltTy = Ty->getScalarType(); Constant *One = Ty->isVectorTy() ? ConstantVector::getSplat( ElementCount::getFixed( cast(Ty)->getNumElements()), ConstantFP::get(EltTy, 1.0)) : ConstantFP::get(EltTy, 1.0); return Builder.CreateFDiv(One, Op0, "hlsl.rcp"); } case Builtin::BI__builtin_hlsl_elementwise_rsqrt: { Value *Op0 = EmitScalarExpr(E->getArg(0)); if (!E->getArg(0)->getType()->hasFloatingRepresentation()) llvm_unreachable("rsqrt operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.rsqrt"); } case Builtin::BI__builtin_hlsl_elementwise_saturate: { Value *Op0 = EmitScalarExpr(E->getArg(0)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && "saturate operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.saturate"); } case Builtin::BI__builtin_hlsl_select: { Value *OpCond = EmitScalarExpr(E->getArg(0)); RValue RValTrue = EmitAnyExpr(E->getArg(1)); Value *OpTrue = RValTrue.isScalar() ? RValTrue.getScalarVal() : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this); RValue RValFalse = EmitAnyExpr(E->getArg(2)); Value *OpFalse = RValFalse.isScalar() ? RValFalse.getScalarVal() : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this); if (auto *VTy = E->getType()->getAs()) { if (!OpTrue->getType()->isVectorTy()) OpTrue = Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat"); if (!OpFalse->getType()->isVectorTy()) OpFalse = Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat"); } Value *SelectVal = Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select"); if (!RValTrue.isScalar()) Builder.CreateStore(SelectVal, ReturnValue.getAddress(), ReturnValue.isVolatile()); return SelectVal; } case Builtin::BI__builtin_hlsl_step: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Value *Op1 = EmitScalarExpr(E->getArg(1)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && E->getArg(1)->getType()->hasFloatingRepresentation() && "step operands must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(), ArrayRef{Op0, Op1}, nullptr, "hlsl.step"); } case Builtin::BI__builtin_hlsl_wave_active_all_true: { Value *Op = EmitScalarExpr(E->getArg(0)); assert(Op->getType()->isIntegerTy(1) && "Intrinsic WaveActiveAllTrue operand must be a bool"); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); } case Builtin::BI__builtin_hlsl_wave_active_any_true: { Value *Op = EmitScalarExpr(E->getArg(0)); assert(Op->getType()->isIntegerTy(1) && "Intrinsic WaveActiveAnyTrue operand must be a bool"); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); } case Builtin::BI__builtin_hlsl_wave_active_count_bits: { Value *OpExpr = EmitScalarExpr(E->getArg(0)); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), ArrayRef{OpExpr}); } case Builtin::BI__builtin_hlsl_wave_active_sum: { // Due to the use of variadic arguments, explicitly retreive argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); Intrinsic::ID IID = getWaveActiveSumIntrinsic( getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), E->getArg(0)->getType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.sum"); } case Builtin::BI__builtin_hlsl_wave_active_max: { // Due to the use of variadic arguments, explicitly retreive argument Value *OpExpr = EmitScalarExpr(E->getArg(0)); Intrinsic::ID IID = getWaveActiveMaxIntrinsic( getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), E->getArg(0)->getType()); return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), IID, {OpExpr->getType()}), ArrayRef{OpExpr}, "hlsl.wave.active.max"); } case Builtin::BI__builtin_hlsl_wave_get_lane_index: { // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in // defined in SPIRVBuiltins.td. So instead we manually get the matching name // for the DirectX intrinsic and the demangled builtin name switch (CGM.getTarget().getTriple().getArch()) { case llvm::Triple::dxil: return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( &CGM.getModule(), Intrinsic::dx_wave_getlaneindex)); case llvm::Triple::spirv: return EmitRuntimeCall(CGM.CreateRuntimeFunction( llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index", {}, false, true)); default: llvm_unreachable( "Intrinsic WaveGetLaneIndex not supported by target architecture"); } } case Builtin::BI__builtin_hlsl_wave_is_first_lane: { Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); } case Builtin::BI__builtin_hlsl_wave_get_lane_count: { Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveGetLaneCountIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); } case Builtin::BI__builtin_hlsl_wave_read_lane_at: { // Due to the use of variadic arguments we must explicitly retreive them and // create our function type. Value *OpExpr = EmitScalarExpr(E->getArg(0)); Value *OpIndex = EmitScalarExpr(E->getArg(1)); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration( &CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(), {OpExpr->getType()}), ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane"); } case Builtin::BI__builtin_hlsl_elementwise_sign: { auto *Arg0 = E->getArg(0); Value *Op0 = EmitScalarExpr(Arg0); llvm::Type *Xty = Op0->getType(); llvm::Type *retType = llvm::Type::getInt32Ty(this->getLLVMContext()); if (Xty->isVectorTy()) { auto *XVecTy = Arg0->getType()->castAs(); retType = llvm::VectorType::get( retType, ElementCount::getFixed(XVecTy->getNumElements())); } assert((Arg0->getType()->hasFloatingRepresentation() || Arg0->getType()->hasIntegerRepresentation()) && "sign operand must have a float or int representation"); if (Arg0->getType()->hasUnsignedIntegerRepresentation()) { Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::get(Xty, 0)); return Builder.CreateSelect(Cmp, ConstantInt::get(retType, 0), ConstantInt::get(retType, 1), "hlsl.sign"); } return Builder.CreateIntrinsic( retType, CGM.getHLSLRuntime().getSignIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.sign"); } case Builtin::BI__builtin_hlsl_elementwise_radians: { Value *Op0 = EmitScalarExpr(E->getArg(0)); assert(E->getArg(0)->getType()->hasFloatingRepresentation() && "radians operand must have a float representation"); return Builder.CreateIntrinsic( /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef{Op0}, nullptr, "hlsl.radians"); } case Builtin::BI__builtin_hlsl_buffer_update_counter: { Value *ResHandle = EmitScalarExpr(E->getArg(0)); Value *Offset = EmitScalarExpr(E->getArg(1)); Value *OffsetI8 = Builder.CreateIntCast(Offset, Int8Ty, true); return Builder.CreateIntrinsic( /*ReturnType=*/Offset->getType(), CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(), ArrayRef{ResHandle, OffsetI8}, nullptr); } case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { assert((E->getArg(0)->getType()->hasFloatingRepresentation() && E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() && E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) && "asuint operands types mismatch"); return handleHlslSplitdouble(E, this); } case Builtin::BI__builtin_hlsl_elementwise_clip: assert(E->getArg(0)->getType()->hasFloatingRepresentation() && "clip operands types mismatch"); return handleHlslClip(E, this); case Builtin::BI__builtin_hlsl_group_memory_barrier_with_group_sync: { Intrinsic::ID ID = CGM.getHLSLRuntime().getGroupMemoryBarrierWithGroupSyncIntrinsic(); return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); } case Builtin::BI__builtin_get_spirv_spec_constant_bool: case Builtin::BI__builtin_get_spirv_spec_constant_short: case Builtin::BI__builtin_get_spirv_spec_constant_ushort: case Builtin::BI__builtin_get_spirv_spec_constant_int: case Builtin::BI__builtin_get_spirv_spec_constant_uint: case Builtin::BI__builtin_get_spirv_spec_constant_longlong: case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong: case Builtin::BI__builtin_get_spirv_spec_constant_half: case Builtin::BI__builtin_get_spirv_spec_constant_float: case Builtin::BI__builtin_get_spirv_spec_constant_double: { llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType()); llvm::Value *SpecId = EmitScalarExpr(E->getArg(0)); llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1)); llvm::Value *Args[] = {SpecId, DefaultVal}; return Builder.CreateCall(SpecConstantFn, Args); } } return nullptr; } llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction( const clang::QualType &SpecConstantType) { // Find or create the declaration for the function. llvm::Module *M = &CGM.getModule(); std::string MangledName = getSpecConstantFunctionName(SpecConstantType, getContext()); llvm::Function *SpecConstantFn = M->getFunction(MangledName); if (!SpecConstantFn) { llvm::Type *IntType = ConvertType(getContext().IntTy); llvm::Type *RetTy = ConvertType(SpecConstantType); llvm::Type *ArgTypes[] = {IntType, RetTy}; llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false); SpecConstantFn = llvm::Function::Create( FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M); } return SpecConstantFn; }