//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file This file contains passes and utilities to lower llvm intrinsic call /// to DXILOp function call. //===----------------------------------------------------------------------===// #include "DXILConstants.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "dxil-op-lower" using namespace llvm; using namespace llvm::DXIL; constexpr StringLiteral DXILOpNamePrefix = "dx.op."; enum OverloadKind : uint16_t { VOID = 1, HALF = 1 << 1, FLOAT = 1 << 2, DOUBLE = 1 << 3, I1 = 1 << 4, I8 = 1 << 5, I16 = 1 << 6, I32 = 1 << 7, I64 = 1 << 8, UserDefineType = 1 << 9, ObjectType = 1 << 10, }; static const char *getOverloadTypeName(OverloadKind Kind) { switch (Kind) { case OverloadKind::HALF: return "f16"; case OverloadKind::FLOAT: return "f32"; case OverloadKind::DOUBLE: return "f64"; case OverloadKind::I1: return "i1"; case OverloadKind::I8: return "i8"; case OverloadKind::I16: return "i16"; case OverloadKind::I32: return "i32"; case OverloadKind::I64: return "i64"; case OverloadKind::VOID: case OverloadKind::ObjectType: case OverloadKind::UserDefineType: break; } llvm_unreachable("invalid overload type for name"); return "void"; } static OverloadKind getOverloadKind(Type *Ty) { Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: return OverloadKind::VOID; case Type::HalfTyID: return OverloadKind::HALF; case Type::FloatTyID: return OverloadKind::FLOAT; case Type::DoubleTyID: return OverloadKind::DOUBLE; case Type::IntegerTyID: { IntegerType *ITy = cast(Ty); unsigned Bits = ITy->getBitWidth(); switch (Bits) { case 1: return OverloadKind::I1; case 8: return OverloadKind::I8; case 16: return OverloadKind::I16; case 32: return OverloadKind::I32; case 64: return OverloadKind::I64; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } case Type::PointerTyID: return OverloadKind::UserDefineType; case Type::StructTyID: return OverloadKind::ObjectType; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } static std::string getTypeName(OverloadKind Kind, Type *Ty) { if (Kind < OverloadKind::UserDefineType) { return getOverloadTypeName(Kind); } else if (Kind == OverloadKind::UserDefineType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else if (Kind == OverloadKind::ObjectType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else { std::string Str; raw_string_ostream OS(Str); Ty->print(OS); return OS.str(); } } // Static properties. struct OpCodeProperty { DXIL::OpCode OpCode; // Offset in DXILOpCodeNameTable. unsigned OpCodeNameOffset; DXIL::OpCodeClass OpCodeClass; // Offset in DXILOpCodeClassNameTable. unsigned OpCodeClassNameOffset; uint16_t OverloadTys; llvm::Attribute::AttrKind FuncAttr; }; // Include getOpCodeClassName getOpCodeProperty and getOpCodeName which // generated by tableGen. #define DXIL_OP_OPERATION_TABLE #include "DXILOperation.inc" #undef DXIL_OP_OPERATION_TABLE static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop) { if (Kind == OverloadKind::VOID) { return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); } return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + getTypeName(Kind, Ty)) .str(); } static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, Module &M) { const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); // Get return type as overload type for DXILOp. // Only simple mapping case here, so return type is good enough. Type *OverloadTy = F.getReturnType(); OverloadKind Kind = getOverloadKind(OverloadTy); // FIXME: find the issue and report error in clang instead of check it in // backend. if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { llvm_unreachable("invalid overload"); } std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); assert(!M.getFunction(FnName) && "Function already exists"); auto &Ctx = M.getContext(); Type *OpCodeTy = Type::getInt32Ty(Ctx); SmallVector ArgTypes; // DXIL has i32 opcode as first arg. ArgTypes.emplace_back(OpCodeTy); FunctionType *FT = F.getFunctionType(); ArgTypes.append(FT->param_begin(), FT->param_end()); FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); return M.getOrInsertFunction(FnName, DXILOpFT); } static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); IRBuilder<> B(M.getContext()); Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); for (User *U : make_early_inc_range(F.users())) { CallInst *CI = dyn_cast(U); if (!CI) continue; SmallVector Args; Args.emplace_back(DXILOpArg); Args.append(CI->arg_begin(), CI->arg_end()); B.SetInsertPoint(CI); CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); LLVM_DEBUG(DXILCI->setName(getOpCodeName(DXILOp))); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); } if (F.user_empty()) F.eraseFromParent(); } static bool lowerIntrinsics(Module &M) { bool Updated = false; #define DXIL_OP_INTRINSIC_MAP #include "DXILOperation.inc" #undef DXIL_OP_INTRINSIC_MAP for (Function &F : make_early_inc_range(M.functions())) { if (!F.isDeclaration()) continue; Intrinsic::ID ID = F.getIntrinsicID(); if (ID == Intrinsic::not_intrinsic) continue; auto LowerIt = LowerMap.find(ID); if (LowerIt == LowerMap.end()) continue; lowerIntrinsic(LowerIt->second, F, M); Updated = true; } return Updated; } namespace { /// A pass that transforms external global definitions into declarations. class DXILOpLowering : public PassInfoMixin { public: PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { if (lowerIntrinsics(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } }; } // namespace namespace { class DXILOpLoweringLegacy : public ModulePass { public: bool runOnModule(Module &M) override { return lowerIntrinsics(M); } StringRef getPassName() const override { return "DXIL Op Lowering"; } DXILOpLoweringLegacy() : ModulePass(ID) {} static char ID; // Pass identification. }; char DXILOpLoweringLegacy::ID = 0; } // end anonymous namespace INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) ModulePass *llvm::createDXILOpLoweringLegacyPass() { return new DXILOpLoweringLegacy(); }