//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file This file contains class to help build DXIL op functions. //===----------------------------------------------------------------------===// #include "DXILOpBuilder.h" #include "DXILConstants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" using namespace llvm; using namespace llvm::dxil; constexpr StringLiteral DXILOpNamePrefix = "dx.op."; namespace { enum OverloadKind : uint16_t { VOID = 1, HALF = 1 << 1, FLOAT = 1 << 2, DOUBLE = 1 << 3, I1 = 1 << 4, I8 = 1 << 5, I16 = 1 << 6, I32 = 1 << 7, I64 = 1 << 8, UserDefineType = 1 << 9, ObjectType = 1 << 10, }; } // namespace static const char *getOverloadTypeName(OverloadKind Kind) { switch (Kind) { case OverloadKind::HALF: return "f16"; case OverloadKind::FLOAT: return "f32"; case OverloadKind::DOUBLE: return "f64"; case OverloadKind::I1: return "i1"; case OverloadKind::I8: return "i8"; case OverloadKind::I16: return "i16"; case OverloadKind::I32: return "i32"; case OverloadKind::I64: return "i64"; case OverloadKind::VOID: case OverloadKind::ObjectType: case OverloadKind::UserDefineType: break; } llvm_unreachable("invalid overload type for name"); return "void"; } static OverloadKind getOverloadKind(Type *Ty) { Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: return OverloadKind::VOID; case Type::HalfTyID: return OverloadKind::HALF; case Type::FloatTyID: return OverloadKind::FLOAT; case Type::DoubleTyID: return OverloadKind::DOUBLE; case Type::IntegerTyID: { IntegerType *ITy = cast(Ty); unsigned Bits = ITy->getBitWidth(); switch (Bits) { case 1: return OverloadKind::I1; case 8: return OverloadKind::I8; case 16: return OverloadKind::I16; case 32: return OverloadKind::I32; case 64: return OverloadKind::I64; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } case Type::PointerTyID: return OverloadKind::UserDefineType; case Type::StructTyID: return OverloadKind::ObjectType; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } static std::string getTypeName(OverloadKind Kind, Type *Ty) { if (Kind < OverloadKind::UserDefineType) { return getOverloadTypeName(Kind); } else if (Kind == OverloadKind::UserDefineType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else if (Kind == OverloadKind::ObjectType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else { std::string Str; raw_string_ostream OS(Str); Ty->print(OS); return OS.str(); } } // Static properties. struct OpCodeProperty { dxil::OpCode OpCode; // Offset in DXILOpCodeNameTable. unsigned OpCodeNameOffset; dxil::OpCodeClass OpCodeClass; // Offset in DXILOpCodeClassNameTable. unsigned OpCodeClassNameOffset; uint16_t OverloadTys; llvm::Attribute::AttrKind FuncAttr; int OverloadParamIndex; // parameter index which control the overload. // When < 0, should be only 1 overload type. unsigned NumOfParameters; // Number of parameters include return value. unsigned ParameterTableOffset; // Offset in ParameterTable. }; // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and // getOpCodeParameterKind which generated by tableGen. #define DXIL_OP_OPERATION_TABLE #include "DXILOperation.inc" #undef DXIL_OP_OPERATION_TABLE static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop) { if (Kind == OverloadKind::VOID) { return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); } return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + getTypeName(Kind, Ty)) .str(); } static std::string constructOverloadTypeName(OverloadKind Kind, StringRef TypeName) { if (Kind == OverloadKind::VOID) return TypeName.str(); assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); } static StructType *getOrCreateStructType(StringRef Name, ArrayRef EltTys, LLVMContext &Ctx) { StructType *ST = StructType::getTypeByName(Ctx, Name); if (ST) return ST; return StructType::create(Ctx, EltTys, Name); } static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { OverloadKind Kind = getOverloadKind(OverloadTy); std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, Type::getInt32Ty(Ctx)}; return getOrCreateStructType(TypeName, FieldTypes, Ctx); } static StructType *getHandleType(LLVMContext &Ctx) { return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx), Ctx); } static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { auto &Ctx = OverloadTy->getContext(); switch (Kind) { case ParameterKind::Void: return Type::getVoidTy(Ctx); case ParameterKind::Half: return Type::getHalfTy(Ctx); case ParameterKind::Float: return Type::getFloatTy(Ctx); case ParameterKind::Double: return Type::getDoubleTy(Ctx); case ParameterKind::I1: return Type::getInt1Ty(Ctx); case ParameterKind::I8: return Type::getInt8Ty(Ctx); case ParameterKind::I16: return Type::getInt16Ty(Ctx); case ParameterKind::I32: return Type::getInt32Ty(Ctx); case ParameterKind::I64: return Type::getInt64Ty(Ctx); case ParameterKind::Overload: return OverloadTy; case ParameterKind::ResourceRet: return getResRetType(OverloadTy, Ctx); case ParameterKind::DXILHandle: return getHandleType(Ctx); default: break; } llvm_unreachable("Invalid parameter kind"); return nullptr; } /// Construct DXIL function type. This is the type of a function with /// the following prototype /// OverloadType dx.op..(int opcode, ) /// are constructed from types in Prop. /// \param Prop Structure containing DXIL Operation properties based on /// its specification in DXIL.td. /// \param OverloadTy Return type to be used to construct DXIL function type. static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, Type *ReturnTy, Type *OverloadTy) { SmallVector ArgTys; auto ParamKinds = getOpCodeParameterKind(*Prop); // Add ReturnTy as return type of the function ArgTys.emplace_back(ReturnTy); // Add DXIL Opcode value type viz., Int32 as first argument ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext())); // Add DXIL Operation parameter types as specified in DXIL properties for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { ParameterKind Kind = ParamKinds[I]; ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); } return FunctionType::get( ArgTys[0], ArrayRef(&ArgTys[1], ArgTys.size() - 1), false); } namespace llvm { namespace dxil { CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, Type *OverloadTy, SmallVector Args) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); OverloadKind Kind = getOverloadKind(OverloadTy); if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false); } std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); FunctionCallee DXILFn; // Get the function with name DXILFnName, if one exists if (auto *Func = M.getFunction(DXILFnName)) { DXILFn = FunctionCallee(Func); } else { // Construct and add a function with name DXILFnName FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); } return B.CreateCall(DXILFn, Args); } Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); // If DXIL Op has no overload parameter, just return the // precise return type specified. if (Prop->OverloadParamIndex < 0) { auto &Ctx = FT->getContext(); switch (Prop->OverloadTys) { case OverloadKind::VOID: return Type::getVoidTy(Ctx); case OverloadKind::HALF: return Type::getHalfTy(Ctx); case OverloadKind::FLOAT: return Type::getFloatTy(Ctx); case OverloadKind::DOUBLE: return Type::getDoubleTy(Ctx); case OverloadKind::I1: return Type::getInt1Ty(Ctx); case OverloadKind::I8: return Type::getInt8Ty(Ctx); case OverloadKind::I16: return Type::getInt16Ty(Ctx); case OverloadKind::I32: return Type::getInt32Ty(Ctx); case OverloadKind::I64: return Type::getInt64Ty(Ctx); default: llvm_unreachable("invalid overload type"); return nullptr; } } // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). Type *OverloadType = FT->getReturnType(); if (Prop->OverloadParamIndex != 0) { // Skip Return Type. OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1); } auto ParamKinds = getOpCodeParameterKind(*Prop); auto Kind = ParamKinds[Prop->OverloadParamIndex]; // For ResRet and CBufferRet, OverloadTy is in field of StructType. if (Kind == ParameterKind::CBufferRet || Kind == ParameterKind::ResourceRet) { auto *ST = cast(OverloadType); OverloadType = ST->getElementType(0); } return OverloadType; } const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { return ::getOpCodeName(DXILOp); } } // namespace dxil } // namespace llvm