diff options
author | Vyacheslav Levytskyy <vyacheslav.levytskyy@intel.com> | 2024-07-11 07:16:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-11 07:16:29 +0200 |
commit | dbd00a5968d6c823d686714c91f2b4fcfd03797a (patch) | |
tree | 4580cc18a94806f0e6e32e5104029a09c8497548 /llvm/lib | |
parent | 4710e0f498cb661ca17c99cb174616102fcad923 (diff) | |
download | llvm-dbd00a5968d6c823d686714c91f2b4fcfd03797a.zip llvm-dbd00a5968d6c823d686714c91f2b4fcfd03797a.tar.gz llvm-dbd00a5968d6c823d686714c91f2b4fcfd03797a.tar.bz2 |
[SPIRV] Improve type inference of operand presented by opaque pointers and aggregate types (#98035)
This PR improves type inference of operand presented by opaque pointers
and aggregate types:
* tries to restore original function return type for aggregate types so
that it's possible to deduce a correct type during emit-intrinsics step
(see llvm/test/CodeGen/SPIRV/SpecConstants/restore-spec-type.ll for the
reproducer of the previously existed issue when spirv-val found a
mismatch between object and ptr types in OpStore due to the incorrect
aggregate types tracing),
* explores untyped pointer operands of store to deduce correct pointee
types,
* creates an extension type to track pointee types from emit-intrinsics
step and further instead of direct and naive usage of TypePointerType
that led previously to crashes due to ban of creation of Value of
TypePointerType type,
* tracks instructions with uncomplete type information and tries to
improve their type info after pass calculated types for all machine
functions (it doesn't traverse a code but rather checks only those
instructions which were tracked as uncompleted),
* address more cases of removing unnecessary bitcasts (see, for example,
changes in test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll where
`CHECK-SPIRV-NEXT` in LIT checks show absence of unneeded bitcasts and
unmangled/mangled versions have proper typing now with equivalent type
info),
* address more cases of well known types or relations between types
within instructions (see, for example, atomic*.ll test cases and
Event-related test cases for improved SPIR-V code generated by the
Backend),
* fix the issue of removing unneeded ptrcast instructions in
pre-legalizer pass that led to creation of new assign-type instructions
with the same argument as source in ptrcast and caused errors in type
inference (the reproducer `complex.ll` test case is added to the PR).
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 45 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.h | 2 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 379 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 23 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp | 5 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVUtils.h | 56 |
9 files changed, 444 insertions, 87 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 286bdb9..1609576 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -169,21 +169,9 @@ using namespace InstructionSet; // TableGen records //===----------------------------------------------------------------------===// -/// Looks up the demangled builtin call in the SPIRVBuiltins.td records using -/// the provided \p DemangledCall and specified \p Set. -/// -/// The lookup follows the following algorithm, returning the first successful -/// match: -/// 1. Search with the plain demangled name (expecting a 1:1 match). -/// 2. Search with the prefix before or suffix after the demangled name -/// signyfying the type of the first argument. -/// -/// \returns Wrapper around the demangled call and found builtin definition. -static std::unique_ptr<const SPIRV::IncomingCall> -lookupBuiltin(StringRef DemangledCall, - SPIRV::InstructionSet::InstructionSet Set, - Register ReturnRegister, const SPIRVType *ReturnType, - const SmallVectorImpl<Register> &Arguments) { +namespace SPIRV { +/// Parses the name part of the demangled builtin call. +std::string lookupBuiltinNameHelper(StringRef DemangledCall) { const static std::string PassPrefix = "(anonymous namespace)::"; std::string BuiltinName; // Itanium Demangler result may have "(anonymous namespace)::" prefix @@ -215,6 +203,27 @@ lookupBuiltin(StringRef DemangledCall, BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R")); } + return BuiltinName; +} +} // namespace SPIRV + +/// Looks up the demangled builtin call in the SPIRVBuiltins.td records using +/// the provided \p DemangledCall and specified \p Set. +/// +/// The lookup follows the following algorithm, returning the first successful +/// match: +/// 1. Search with the plain demangled name (expecting a 1:1 match). +/// 2. Search with the prefix before or suffix after the demangled name +/// signyfying the type of the first argument. +/// +/// \returns Wrapper around the demangled call and found builtin definition. +static std::unique_ptr<const SPIRV::IncomingCall> +lookupBuiltin(StringRef DemangledCall, + SPIRV::InstructionSet::InstructionSet Set, + Register ReturnRegister, const SPIRVType *ReturnType, + const SmallVectorImpl<Register> &Arguments) { + std::string BuiltinName = SPIRV::lookupBuiltinNameHelper(DemangledCall); + SmallVector<StringRef, 10> BuiltinArgumentTypes; StringRef BuiltinArgs = DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')')); @@ -2610,9 +2619,6 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall, // Unable to recognize SPIRV type name. return nullptr; - if (BaseType->isVoidTy()) - BaseType = Type::getInt8Ty(Ctx); - // Handle "typeN*" or "type vector[N]*". TypeStr.consume_back("*"); @@ -2621,7 +2627,8 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall, TypeStr.getAsInteger(10, VecElts); if (VecElts > 0) - BaseType = VectorType::get(BaseType, VecElts, false); + BaseType = VectorType::get( + BaseType->isVoidTy() ? Type::getInt8Ty(Ctx) : BaseType, VecElts, false); return BaseType; } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h index 68bff60..d07fc7c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h @@ -19,6 +19,8 @@ namespace llvm { namespace SPIRV { +/// Parses the name part of the demangled builtin call. +std::string lookupBuiltinNameHelper(StringRef DemangledCall); /// Lowers a builtin function call using the provided \p DemangledCall skeleton /// and external instruction \p Set. /// diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 566eafd..d9864ab 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -46,6 +46,10 @@ using namespace llvm; namespace llvm { +namespace SPIRV { +#define GET_BuiltinGroup_DECL +#include "SPIRVGenTables.inc" +} // namespace SPIRV void initializeSPIRVEmitIntrinsicsPass(PassRegistry &); } // namespace llvm @@ -69,22 +73,38 @@ class SPIRVEmitIntrinsics DenseSet<Instruction *> AggrStores; SPIRV::InstructionSet::InstructionSet InstrSet; + // a register of Instructions that don't have a complete type definition + SmallPtrSet<Value *, 8> UncompleteTypeInfo; + SmallVector<Instruction *> PostprocessWorklist; + + // well known result types of builtins + enum WellKnownTypes { Event }; + // deduce element type of untyped pointers Type *deduceElementType(Value *I, bool UnknownElemTypeI8); - Type *deduceElementTypeHelper(Value *I); - Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited); + Type *deduceElementTypeHelper(Value *I, bool UnknownElemTypeI8); + Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8); Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand, - std::unordered_set<Value *> &Visited); + bool UnknownElemTypeI8); + Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand, + std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8); Type *deduceElementTypeByUsersDeep(Value *Op, - std::unordered_set<Value *> &Visited); + std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8); + void maybeAssignPtrType(Type *&Ty, Value *I, Type *RefTy, + bool UnknownElemTypeI8); // deduce nested types of composites - Type *deduceNestedTypeHelper(User *U); + Type *deduceNestedTypeHelper(User *U, bool UnknownElemTypeI8); Type *deduceNestedTypeHelper(User *U, Type *Ty, - std::unordered_set<Value *> &Visited); + std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8); // deduce Types of operands of the Instruction if possible - void deduceOperandElementType(Instruction *I); + void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0, + Type *AskTy = 0, CallInst *AssignCI = 0); void preprocessCompositeConstants(IRBuilder<> &B); void preprocessUndefs(IRBuilder<> &B); @@ -151,6 +171,7 @@ public: bool runOnModule(Module &M) override; bool runOnFunction(Function &F); + bool postprocessTypes(); void getAnalysisUsage(AnalysisUsage &AU) const override { ModulePass::getAnalysisUsage(AU); @@ -223,6 +244,41 @@ static inline void reportFatalOnTokenType(const Instruction *I) { false); } +static bool IsKernelArgInt8(Function *F, StoreInst *SI) { + return SI && F->getCallingConv() == CallingConv::SPIR_KERNEL && + isPointerTy(SI->getValueOperand()->getType()) && + isa<Argument>(SI->getValueOperand()); +} + +// Maybe restore original function return type. +static inline Type *restoreMutatedType(SPIRVGlobalRegistry *GR, Instruction *I, + Type *Ty) { + CallInst *CI = dyn_cast<CallInst>(I); + if (!CI || CI->isIndirectCall() || CI->isInlineAsm() || + !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic()) + return Ty; + if (Type *OriginalTy = GR->findMutated(CI->getCalledFunction())) + return OriginalTy; + return Ty; +} + +// Reconstruct type with nested element types according to deduced type info. +// Return nullptr if no detailed type info is available. +static inline Type *reconstructType(SPIRVGlobalRegistry *GR, Value *Op) { + Type *Ty = Op->getType(); + if (!isUntypedPointerTy(Ty)) + return Ty; + // try to find the pointee type + if (Type *NestedTy = GR->findDeducedElementType(Op)) + return getTypedPointerWrapper(NestedTy, getPointerAddressSpace(Ty)); + // not a pointer according to the type info (e.g., Event object) + CallInst *CI = GR->findAssignPtrTypeInstr(Op); + if (!CI) + return nullptr; + MetadataAsValue *MD = cast<MetadataAsValue>(CI->getArgOperand(1)); + return cast<ConstantAsMetadata>(MD->getMetadata())->getType(); +} + void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) { Value *OfType = PoisonValue::get(Ty); @@ -263,15 +319,26 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg, // Set element pointer type to the given value of ValueTy and tries to // specify this type further (recursively) by Operand value, if needed. +Type * +SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand, + bool UnknownElemTypeI8) { + std::unordered_set<Value *> Visited; + return deduceElementTypeByValueDeep(ValueTy, Operand, Visited, + UnknownElemTypeI8); +} + Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep( - Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited) { + Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8) { Type *Ty = ValueTy; if (Operand) { if (auto *PtrTy = dyn_cast<PointerType>(Ty)) { - if (Type *NestedTy = deduceElementTypeHelper(Operand, Visited)) - Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); + if (Type *NestedTy = + deduceElementTypeHelper(Operand, Visited, UnknownElemTypeI8)) + Ty = getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace()); } else { - Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited); + Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited, + UnknownElemTypeI8); } } return Ty; @@ -279,12 +346,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep( // Traverse User instructions to deduce an element pointer type of the operand. Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep( - Value *Op, std::unordered_set<Value *> &Visited) { + Value *Op, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) { if (!Op || !isPointerTy(Op->getType())) return nullptr; - if (auto PType = dyn_cast<TypedPointerType>(Op->getType())) - return PType->getElementType(); + if (auto ElemTy = getPointeeType(Op->getType())) + return ElemTy; // maybe we already know operand's element type if (Type *KnownTy = GR->findDeducedElementType(Op)) @@ -292,7 +359,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep( for (User *OpU : Op->users()) { if (Instruction *Inst = dyn_cast<Instruction>(OpU)) { - if (Type *Ty = deduceElementTypeHelper(Inst, Visited)) + if (Type *Ty = deduceElementTypeHelper(Inst, Visited, UnknownElemTypeI8)) return Ty; } } @@ -314,13 +381,27 @@ static Type *getPointeeTypeByCallInst(StringRef DemangledName, // Deduce and return a successfully deduced Type of the Instruction, // or nullptr otherwise. -Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) { +Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I, + bool UnknownElemTypeI8) { std::unordered_set<Value *> Visited; - return deduceElementTypeHelper(I, Visited); + return deduceElementTypeHelper(I, Visited, UnknownElemTypeI8); +} + +void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy, + bool UnknownElemTypeI8) { + if (isUntypedPointerTy(RefTy)) { + if (!UnknownElemTypeI8) + return; + if (auto *I = dyn_cast<Instruction>(Op)) { + UncompleteTypeInfo.insert(I); + PostprocessWorklist.push_back(I); + } + } + Ty = RefTy; } Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( - Value *I, std::unordered_set<Value *> &Visited) { + Value *I, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) { // allow to pass nullptr as an argument if (!I) return nullptr; @@ -338,34 +419,41 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( Type *Ty = nullptr; // look for known basic patterns of type inference if (auto *Ref = dyn_cast<AllocaInst>(I)) { - Ty = Ref->getAllocatedType(); + maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) { Ty = Ref->getResultElementType(); } else if (auto *Ref = dyn_cast<GlobalValue>(I)) { Ty = deduceElementTypeByValueDeep( Ref->getValueType(), - Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited); + Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited, + UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) { - Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited); + Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited, + UnknownElemTypeI8); + maybeAssignPtrType(Ty, I, RefTy, UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<BitCastInst>(I)) { if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy(); isPointerTy(Src) && isPointerTy(Dest)) - Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited); + Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, + UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) { Value *Op = Ref->getNewValOperand(); - Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited); + if (isPointerTy(Op->getType())) + Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) { Value *Op = Ref->getValOperand(); - Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited); + if (isPointerTy(Op->getType())) + Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8); } else if (auto *Ref = dyn_cast<PHINode>(I)) { for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) { - Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited); + Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited, + UnknownElemTypeI8); if (Ty) break; } } else if (auto *Ref = dyn_cast<SelectInst>(I)) { for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) { - Ty = deduceElementTypeByUsersDeep(Op, Visited); + Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8); if (Ty) break; } @@ -384,10 +472,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (Function *CalledF = CI->getCalledFunction()) { std::string DemangledName = getOclOrSpirvBuiltinDemangledName(CalledF->getName()); + if (DemangledName.length() > 0) + DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName); auto AsArgIt = ResTypeByArg.find(DemangledName); if (AsArgIt != ResTypeByArg.end()) { Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second), - Visited); + Visited, UnknownElemTypeI8); } } } @@ -404,13 +494,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( // Re-create a type of the value if it has untyped pointer fields, also nested. // Return the original value type if no corrections of untyped pointer // information is found or needed. -Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) { +Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U, + bool UnknownElemTypeI8) { std::unordered_set<Value *> Visited; - return deduceNestedTypeHelper(U, U->getType(), Visited); + return deduceNestedTypeHelper(U, U->getType(), Visited, UnknownElemTypeI8); } Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( - User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) { + User *U, Type *OrigTy, std::unordered_set<Value *> &Visited, + bool UnknownElemTypeI8) { if (!U) return OrigTy; @@ -432,10 +524,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( Type *Ty = OpTy; if (Op) { if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) { - if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) + if (Type *NestedTy = + deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8)) Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); } else { - Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited); + Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited, + UnknownElemTypeI8); } } Tys.push_back(Ty); @@ -451,10 +545,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( Type *OpTy = ArrTy->getElementType(); Type *Ty = OpTy; if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) { - if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) + if (Type *NestedTy = + deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8)) Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); } else { - Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited); + Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited, + UnknownElemTypeI8); } if (Ty != OpTy) { Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements()); @@ -467,10 +563,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( Type *OpTy = VecTy->getElementType(); Type *Ty = OpTy; if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) { - if (Type *NestedTy = deduceElementTypeHelper(Op, Visited)) - Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace()); + if (Type *NestedTy = + deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8)) + Ty = getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace()); } else { - Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited); + Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited, + UnknownElemTypeI8); } if (Ty != OpTy) { Type *NewTy = VectorType::get(Ty, VecTy->getElementCount()); @@ -484,16 +582,38 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( } Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) { - if (Type *Ty = deduceElementTypeHelper(I)) + if (Type *Ty = deduceElementTypeHelper(I, UnknownElemTypeI8)) return Ty; - return UnknownElemTypeI8 ? IntegerType::getInt8Ty(I->getContext()) : nullptr; + if (!UnknownElemTypeI8) + return nullptr; + if (auto *Instr = dyn_cast<Instruction>(I)) { + UncompleteTypeInfo.insert(Instr); + PostprocessWorklist.push_back(Instr); + } + return IntegerType::getInt8Ty(I->getContext()); +} + +static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I, + Value *PointerOperand) { + Type *PointeeTy = GR->findDeducedElementType(PointerOperand); + if (PointeeTy && !isUntypedPointerTy(PointeeTy)) + return nullptr; + auto *PtrTy = dyn_cast<PointerType>(I->getType()); + if (!PtrTy) + return I->getType(); + if (Type *NestedTy = GR->findDeducedElementType(I)) + return getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace()); + return nullptr; } // If the Instruction has Pointer operands with unresolved types, this function // tries to deduce them. If the Instruction has Pointer operands with known // types which differ from expected, this function tries to insert a bitcast to // resolve the issue. -void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) { +void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I, + Instruction *AskOp, + Type *AskTy, + CallInst *AskCI) { SmallVector<std::pair<Value *, unsigned>> Ops; Type *KnownElemTy = nullptr; // look for known basic patterns of type inference @@ -506,6 +626,51 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) { if (isPointerTy(Op->getType())) Ops.push_back(std::make_pair(Op, i)); } + } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) { + KnownElemTy = GR->findDeducedElementType(I); + if (!KnownElemTy) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0)); + } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) { + KnownElemTy = Ref->getSourceElementType(); + if (isUntypedPointerTy(KnownElemTy)) + return; + Type *PointeeTy = GR->findDeducedElementType(Ref->getPointerOperand()); + if (PointeeTy && !isUntypedPointerTy(PointeeTy)) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), + GetElementPtrInst::getPointerOperandIndex())); + } else if (auto *Ref = dyn_cast<LoadInst>(I)) { + KnownElemTy = I->getType(); + if (isUntypedPointerTy(KnownElemTy)) + return; + Type *PointeeTy = GR->findDeducedElementType(Ref->getPointerOperand()); + if (PointeeTy && !isUntypedPointerTy(PointeeTy)) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), + LoadInst::getPointerOperandIndex())); + } else if (auto *Ref = dyn_cast<StoreInst>(I)) { + if (IsKernelArgInt8(Ref->getParent()->getParent(), Ref)) + return; + if (!(KnownElemTy = reconstructType(GR, Ref->getValueOperand()))) + return; + Type *PointeeTy = GR->findDeducedElementType(Ref->getPointerOperand()); + if (PointeeTy && !isUntypedPointerTy(PointeeTy)) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), + StoreInst::getPointerOperandIndex())); + } else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) { + KnownElemTy = getAtomicElemTy(GR, I, Ref->getPointerOperand()); + if (!KnownElemTy) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), + AtomicCmpXchgInst::getPointerOperandIndex())); + } else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) { + KnownElemTy = getAtomicElemTy(GR, I, Ref->getPointerOperand()); + if (!KnownElemTy) + return; + Ops.push_back(std::make_pair(Ref->getPointerOperand(), + AtomicRMWInst::getPointerOperandIndex())); } else if (auto *Ref = dyn_cast<SelectInst>(I)) { if (!isPointerTy(I->getType()) || !(KnownElemTy = GR->findDeducedElementType(I))) @@ -565,6 +730,32 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) { KnownElemTy = ElemTy; // src will rewrite dest if both are defined Ops.push_back(std::make_pair(Op, i)); } + } else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) { + if (CI->arg_size() < 2) + return; + Value *Op = CI->getArgOperand(0); + if (!isPointerTy(Op->getType())) + return; + switch (Opcode) { + case SPIRV::OpAtomicLoad: + case SPIRV::OpAtomicCompareExchangeWeak: + case SPIRV::OpAtomicCompareExchange: + case SPIRV::OpAtomicExchange: + case SPIRV::OpAtomicIAdd: + case SPIRV::OpAtomicISub: + case SPIRV::OpAtomicOr: + case SPIRV::OpAtomicXor: + case SPIRV::OpAtomicAnd: + case SPIRV::OpAtomicUMin: + case SPIRV::OpAtomicUMax: + case SPIRV::OpAtomicSMin: + case SPIRV::OpAtomicSMax: { + KnownElemTy = getAtomicElemTy(GR, I, Op); + if (!KnownElemTy) + return; + Ops.push_back(std::make_pair(Op, 0)); + } break; + } } } } @@ -578,17 +769,18 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) { IRBuilder<> B(Ctx); for (auto &OpIt : Ops) { Value *Op = OpIt.first; - if (Op->use_empty()) + if (Op->use_empty() || (AskOp && Op != AskOp)) continue; - Type *Ty = GR->findDeducedElementType(Op); + Type *Ty = AskOp ? AskTy : GR->findDeducedElementType(Op); if (Ty == KnownElemTy) continue; - Value *OpTyVal = Constant::getNullValue(KnownElemTy); + Value *OpTyVal = PoisonValue::get(KnownElemTy); Type *OpTy = Op->getType(); - if (!Ty) { + if (!Ty || AskTy || isUntypedPointerTy(Ty) || + UncompleteTypeInfo.contains(Op)) { GR->addDeducedElementType(Op, KnownElemTy); // check if there is existing Intrinsic::spv_assign_ptr_type instruction - CallInst *AssignCI = GR->findAssignPtrTypeInstr(Op); + CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op); if (AssignCI == nullptr) { Instruction *User = dyn_cast<Instruction>(Op->use_begin()->get()); setInsertPointSkippingPhis(B, User ? User->getNextNode() : I); @@ -719,7 +911,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) { I->replaceUsesOfWith(Op, CI); KeepInst = true; AggrConsts[CI] = AggrConst; - AggrConstTypes[CI] = deduceNestedTypeHelper(AggrConst); + AggrConstTypes[CI] = deduceNestedTypeHelper(AggrConst, false); } } if (!KeepInst) @@ -864,8 +1056,9 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( Pointer = BC->getOperand(0); // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType - Type *PointerElemTy = deduceElementTypeHelper(Pointer); - if (PointerElemTy == ExpectedElementType) + Type *PointerElemTy = deduceElementTypeHelper(Pointer, false); + if (PointerElemTy == ExpectedElementType || + isEquivalentTypes(PointerElemTy, ExpectedElementType)) return; setInsertPointSkippingPhis(B, I); @@ -930,15 +1123,19 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B) { // Handle basic instructions: StoreInst *SI = dyn_cast<StoreInst>(I); - if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL && - isPointerTy(SI->getValueOperand()->getType()) && - isa<Argument>(SI->getValueOperand())) { + if (IsKernelArgInt8(F, SI)) { return replacePointerOperandWithPtrCast( I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0, B); } else if (SI) { - return replacePointerOperandWithPtrCast( - I, SI->getPointerOperand(), SI->getValueOperand()->getType(), 1, B); + Value *Op = SI->getValueOperand(); + Type *OpTy = Op->getType(); + if (auto *OpI = dyn_cast<Instruction>(Op)) + OpTy = restoreMutatedType(GR, OpI, OpTy); + if (OpTy == Op->getType()) + OpTy = deduceElementTypeByValueDeep(OpTy, Op, false); + return replacePointerOperandWithPtrCast(I, SI->getPointerOperand(), OpTy, 1, + B); } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) { return replacePointerOperandWithPtrCast(I, LI->getPointerOperand(), LI->getType(), 0, B); @@ -978,7 +1175,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, } else { for (User *U : CalledArg->users()) { if (Instruction *Inst = dyn_cast<Instruction>(U)) { - if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr) + if ((ElemTy = deduceElementTypeHelper(Inst, false)) != nullptr) break; } } @@ -1012,7 +1209,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, if (!ExpectedType && !DemangledName.empty()) ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType( DemangledName, OpIdx, I->getContext()); - if (!ExpectedType) + if (!ExpectedType || ExpectedType->isVoidTy()) continue; if (ExpectedType->isTargetExtTy()) @@ -1182,7 +1379,7 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV, // Deduce element type and store results in Global Registry. // Result is ignored, because TypedPointerType is not supported // by llvm IR general logic. - deduceElementTypeHelper(&GV); + deduceElementTypeHelper(&GV, false); Constant *Init = GV.getInitializer(); Type *Ty = isAggrConstForceInt32(Init) ? B.getInt32Ty() : Init->getType(); Constant *Const = isAggrConstForceInt32(Init) ? B.getInt32(1) : Init; @@ -1216,9 +1413,39 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I, void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, IRBuilder<> &B) { + // TODO: extend the list of functions with known result types + static StringMap<unsigned> ResTypeWellKnown = { + {"async_work_group_copy", WellKnownTypes::Event}, + {"async_work_group_strided_copy", WellKnownTypes::Event}, + {"__spirv_GroupAsyncCopy", WellKnownTypes::Event}}; + reportFatalOnTokenType(I); + + bool IsKnown = false; + if (auto *CI = dyn_cast<CallInst>(I)) { + if (!CI->isIndirectCall() && !CI->isInlineAsm() && + CI->getCalledFunction() && !CI->getCalledFunction()->isIntrinsic()) { + Function *CalledF = CI->getCalledFunction(); + std::string DemangledName = + getOclOrSpirvBuiltinDemangledName(CalledF->getName()); + if (DemangledName.length() > 0) + DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName); + auto ResIt = ResTypeWellKnown.find(DemangledName); + if (ResIt != ResTypeWellKnown.end()) { + IsKnown = true; + setInsertPointAfterDef(B, I); + switch (ResIt->second) { + case WellKnownTypes::Event: + buildAssignType(B, TargetExtType::get(I->getContext(), "spirv.Event"), + I); + break; + } + } + } + } + Type *Ty = I->getType(); - if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) { + if (!IsKnown && !Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) { setInsertPointAfterDef(B, I); Type *TypeToAssign = Ty; if (auto *II = dyn_cast<IntrinsicInst>(I)) { @@ -1230,6 +1457,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, TypeToAssign = It->second; } } + TypeToAssign = restoreMutatedType(GR, I, TypeToAssign); buildAssignType(B, TypeToAssign, I); } for (const auto &Op : I->operands()) { @@ -1343,7 +1571,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType( return KnownTy; // try to deduce from the operand itself Visited.clear(); - if (Type *Ty = deduceElementTypeHelper(OpArg, Visited)) + if (Type *Ty = deduceElementTypeHelper(OpArg, Visited, false)) return Ty; // search in actual parameter's users for (User *OpU : OpArg->users()) { @@ -1351,7 +1579,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType( if (!Inst || Inst == CI) continue; Visited.clear(); - if (Type *Ty = deduceElementTypeHelper(Inst, Visited)) + if (Type *Ty = deduceElementTypeHelper(Inst, Visited, false)) return Ty; } // check if it's a formal parameter of the outer function @@ -1480,12 +1708,39 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { return true; } +// Try to deduce a better type for pointers to untyped ptr. +bool SPIRVEmitIntrinsics::postprocessTypes() { + bool Changed = false; + if (!GR) + return Changed; + for (auto IB = PostprocessWorklist.rbegin(), IE = PostprocessWorklist.rend(); + IB != IE; ++IB) { + CallInst *AssignCI = GR->findAssignPtrTypeInstr(*IB); + Type *KnownTy = GR->findDeducedElementType(*IB); + if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0))) + continue; + Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0)); + for (User *U : I->users()) { + Instruction *Inst = dyn_cast<Instruction>(U); + if (!Inst || isa<IntrinsicInst>(Inst)) + continue; + deduceOperandElementType(Inst, I, KnownTy, AssignCI); + if (KnownTy != GR->findDeducedElementType(I)) { + Changed = true; + break; + } + } + } + return Changed; +} + bool SPIRVEmitIntrinsics::runOnModule(Module &M) { bool Changed = false; - for (auto &F : M) { + UncompleteTypeInfo.clear(); + PostprocessWorklist.clear(); + for (auto &F : M) Changed |= runOnFunction(F); - } for (auto &F : M) { // check if function parameter types are set @@ -1497,6 +1752,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) { } } + Changed |= postprocessTypes(); + return Changed; } diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a45e1cc..0e26b38 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -51,6 +51,10 @@ class SPIRVGlobalRegistry { // Maps Functions to their calls (in a form of the machine instruction, // OpFunctionCall) that happened before the definition is available DenseMap<const Function *, SmallPtrSet<MachineInstr *, 8>> ForwardCalls; + // map a Function to its original return type before the clone function was + // created during substitution of aggregate arguments + // (see `SPIRVPrepareFunctions::removeAggregateTypesFromSignature()`) + DenseMap<Value *, Type *> MutatedAggRet; // Look for an equivalent of the newType in the map. Return the equivalent // if it's found, otherwise insert newType to the map and return the type. @@ -163,6 +167,16 @@ public: return It == AssignPtrTypeInstr.end() ? nullptr : It->second; } + // A registry of mutated values + // (see `SPIRVPrepareFunctions::removeAggregateTypesFromSignature()`): + // - Add a record. + void addMutated(Value *Val, Type *Ty) { MutatedAggRet[Val] = Ty; } + // - Find a record. + Type *findMutated(const Value *Val) { + auto It = MutatedAggRet.find(Val); + return It == MutatedAggRet.end() ? nullptr : It->second; + } + // Deduced element types of untyped pointers and composites: // - Add a record to the map of deduced element types. void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 9be736c..04def5e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2190,7 +2190,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( // FIXME: don't use MachineIRBuilder here, replace it with BuildMI. MachineIRBuilder MIRBuilder(I); const GlobalValue *GV = I.getOperand(1).getGlobal(); - Type *GVType = GR.getDeducedGlobalValueType(GV); + Type *GVType = toTypedPointer(GR.getDeducedGlobalValueType(GV)); SPIRVType *PointerBaseType; if (GVType->isArrayTy()) { SPIRVType *ArrayElementType = diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 6c7c3af..e775f8c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -138,7 +138,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; - auto allFloatAndIntScalars = allIntScalars; + auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, + p2, p3, p4, p5, p6}; auto allPtrs = {p0, p1, p2, p3, p4, p5, p6}; auto allWritablePtrs = {p0, p1, p3, p4, p5, p6}; @@ -238,7 +239,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { .legalForCartesianProduct(allFloatScalars, allWritablePtrs); getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) - .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); + .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs); getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); // TODO: add proper legalization rules. diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 0ea2f17..099557a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -151,6 +151,20 @@ foldConstantsIntoIntrinsics(MachineFunction &MF, MI->eraseFromParent(); } +static MachineInstr *findAssignTypeInstr(Register Reg, + MachineRegisterInfo *MRI) { + for (MachineRegisterInfo::use_instr_iterator I = MRI->use_instr_begin(Reg), + IE = MRI->use_instr_end(); + I != IE; ++I) { + MachineInstr *UseMI = &*I; + if ((isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_ptr_type) || + isSpvIntrinsic(*UseMI, Intrinsic::spv_assign_type)) && + UseMI->getOperand(1).getReg() == Reg) + return UseMI; + } + return nullptr; +} + static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB) { // Get access to information about available extensions @@ -177,9 +191,12 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(), addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST)); - // If the bitcast would be redundant, replace all uses with the source + // If the ptrcast would be redundant, replace all uses with the source // register. if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) { + // Erase Def's assign type instruction if we are going to replace Def. + if (MachineInstr *AssignMI = findAssignTypeInstr(Def, MIB.getMRI())) + ToErase.push_back(AssignMI); MIB.getMRI()->replaceRegWith(Def, Source); } else { GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF); @@ -224,8 +241,8 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, case TargetOpcode::G_GLOBAL_VALUE: { MIB.setInsertPt(*MI->getParent(), MI); const GlobalValue *Global = MI->getOperand(1).getGlobal(); - Type *ElementTy = GR->getDeducedGlobalValueType(Global); - auto *Ty = TypedPointerType::get(toTypedPointer(ElementTy), + Type *ElementTy = toTypedPointer(GR->getDeducedGlobalValueType(Global)); + auto *Ty = TypedPointerType::get(ElementTy, Global->getType()->getAddressSpace()); SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB); break; diff --git a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp index 7bee87d..29b8f8f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp @@ -536,6 +536,11 @@ SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { CI->mutateFunctionType(NewF->getFunctionType()); U->replaceUsesOfWith(F, NewF); } + + // register the mutation + if (RetType != F->getReturnType()) + TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated( + NewF, F->getReturnType()); return NewF; } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 12725d6..c757af6 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -108,7 +108,7 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx); // True if this is an instance of TypedPointerType. inline bool isTypedPointerTy(const Type *T) { - return T->getTypeID() == Type::TypedPointerTyID; + return T && T->getTypeID() == Type::TypedPointerTyID; } // True if this is an instance of PointerType. @@ -153,7 +153,61 @@ inline Type *reconstructFunctionType(Function *F) { return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg()); } +#define TYPED_PTR_TARGET_EXT_NAME "spirv.$TypedPointerType" +inline Type *getTypedPointerWrapper(Type *ElemTy, unsigned AS) { + return TargetExtType::get(ElemTy->getContext(), TYPED_PTR_TARGET_EXT_NAME, + {ElemTy}, {AS}); +} + +inline bool isTypedPointerWrapper(TargetExtType *ExtTy) { + return ExtTy->getName() == TYPED_PTR_TARGET_EXT_NAME && + ExtTy->getNumIntParameters() == 1 && + ExtTy->getNumTypeParameters() == 1; +} + +inline Type *applyWrappers(Type *Ty) { + if (auto *ExtTy = dyn_cast<TargetExtType>(Ty)) { + if (isTypedPointerWrapper(ExtTy)) + return TypedPointerType::get(applyWrappers(ExtTy->getTypeParameter(0)), + ExtTy->getIntParameter(0)); + } else if (auto *VecTy = dyn_cast<VectorType>(Ty)) { + Type *ElemTy = VecTy->getElementType(); + Type *NewElemTy = ElemTy->isTargetExtTy() ? applyWrappers(ElemTy) : ElemTy; + if (NewElemTy != ElemTy) + return VectorType::get(NewElemTy, VecTy->getElementCount()); + } + return Ty; +} + +inline Type *getPointeeType(Type *Ty) { + if (auto PType = dyn_cast<TypedPointerType>(Ty)) + return PType->getElementType(); + else if (auto *ExtTy = dyn_cast<TargetExtType>(Ty)) + if (isTypedPointerWrapper(ExtTy)) + return applyWrappers(ExtTy->getTypeParameter(0)); + return nullptr; +} + +inline bool isUntypedEquivalentToTyExt(Type *Ty1, Type *Ty2) { + if (!isUntypedPointerTy(Ty1) || !Ty2) + return false; + if (auto *ExtTy = dyn_cast<TargetExtType>(Ty2)) + if (isTypedPointerWrapper(ExtTy) && + ExtTy->getTypeParameter(0) == + IntegerType::getInt8Ty(Ty1->getContext()) && + ExtTy->getIntParameter(0) == cast<PointerType>(Ty1)->getAddressSpace()) + return true; + return false; +} + +inline bool isEquivalentTypes(Type *Ty1, Type *Ty2) { + return isUntypedEquivalentToTyExt(Ty1, Ty2) || + isUntypedEquivalentToTyExt(Ty2, Ty1); +} + inline Type *toTypedPointer(Type *Ty) { + if (Type *NewTy = applyWrappers(Ty); NewTy != Ty) + return NewTy; return isUntypedPointerTy(Ty) ? TypedPointerType::get(IntegerType::getInt8Ty(Ty->getContext()), getPointerAddressSpace(Ty)) |