//===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===---------------------------------------------------------------------===// /// /// \file This file contains a pass to flatten arrays for the DirectX Backend. /// //===----------------------------------------------------------------------===// #include "DXILFlattenArrays.h" #include "DirectX.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include #include #include #include #define DEBUG_TYPE "dxil-flatten-arrays" using namespace llvm; namespace { class DXILFlattenArraysLegacy : public ModulePass { public: bool runOnModule(Module &M) override; DXILFlattenArraysLegacy() : ModulePass(ID) {} static char ID; // Pass identification. }; struct GEPInfo { ArrayType *RootFlattenedArrayType; Value *RootPointerOperand; SmallMapVector VariableOffsets; APInt ConstantOffset; }; class DXILFlattenArraysVisitor : public InstVisitor { public: DXILFlattenArraysVisitor( SmallDenseMap &GlobalMap) : GlobalMap(GlobalMap) {} bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. bool visitGetElementPtrInst(GetElementPtrInst &GEPI); bool visitAllocaInst(AllocaInst &AI); bool visitInstruction(Instruction &I) { return false; } bool visitSelectInst(SelectInst &SI) { return false; } bool visitICmpInst(ICmpInst &ICI) { return false; } bool visitFCmpInst(FCmpInst &FCI) { return false; } bool visitUnaryOperator(UnaryOperator &UO) { return false; } bool visitBinaryOperator(BinaryOperator &BO) { return false; } bool visitCastInst(CastInst &CI) { return false; } bool visitBitCastInst(BitCastInst &BCI) { return false; } bool visitInsertElementInst(InsertElementInst &IEI) { return false; } bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } bool visitPHINode(PHINode &PHI) { return false; } bool visitLoadInst(LoadInst &LI); bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI) { return false; } bool visitFreezeInst(FreezeInst &FI) { return false; } static bool isMultiDimensionalArray(Type *T); static std::pair getElementCountAndType(Type *ArrayTy); private: SmallVector PotentiallyDeadInstrs; SmallDenseMap GEPChainInfoMap; SmallDenseMap &GlobalMap; bool finish(); ConstantInt *genConstFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); Value *genInstructionFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); }; } // namespace bool DXILFlattenArraysVisitor::finish() { GEPChainInfoMap.clear(); RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); return true; } bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) { if (ArrayType *ArrType = dyn_cast(T)) return isa(ArrType->getElementType()); return false; } std::pair DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) { unsigned TotalElements = 1; Type *CurrArrayTy = ArrayTy; while (auto *InnerArrayTy = dyn_cast(CurrArrayTy)) { TotalElements *= InnerArrayTy->getNumElements(); CurrArrayTy = InnerArrayTy->getElementType(); } return std::make_pair(TotalElements, CurrArrayTy); } ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { assert(Indices.size() == Dims.size() && "Indicies and dimmensions should be the same"); unsigned FlatIndex = 0; unsigned Multiplier = 1; for (int I = Indices.size() - 1; I >= 0; --I) { unsigned DimSize = Dims[I]; ConstantInt *CIndex = dyn_cast(Indices[I]); assert(CIndex && "This function expects all indicies to be ConstantInt"); FlatIndex += CIndex->getZExtValue() * Multiplier; Multiplier *= DimSize; } return Builder.getInt32(FlatIndex); } Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices( ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder) { if (Indices.size() == 1) return Indices[0]; Value *FlatIndex = Builder.getInt32(0); unsigned Multiplier = 1; for (int I = Indices.size() - 1; I >= 0; --I) { unsigned DimSize = Dims[I]; Value *VMultiplier = Builder.getInt32(Multiplier); Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier); FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex); Multiplier *= DimSize; } return FlatIndex; } bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) { unsigned NumOperands = LI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = LI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(LI.getIterator()); IRBuilder<> Builder(&LI); LoadInst *NewLoad = Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); NewLoad->setAlignment(LI.getAlign()); LI.replaceAllUsesWith(NewLoad); LI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } } return false; } bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) { unsigned NumOperands = SI.getNumOperands(); for (unsigned I = 0; I < NumOperands; ++I) { Value *CurrOpperand = SI.getOperand(I); ConstantExpr *CE = dyn_cast(CurrOpperand); if (CE && CE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEP = cast(CE->getAsInstruction()); OldGEP->insertBefore(SI.getIterator()); IRBuilder<> Builder(&SI); StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); NewStore->setAlignment(SI.getAlign()); SI.replaceAllUsesWith(NewStore); SI.eraseFromParent(); visitGetElementPtrInst(*OldGEP); return true; } } return false; } bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { if (!isMultiDimensionalArray(AI.getAllocatedType())) return false; ArrayType *ArrType = cast(AI.getAllocatedType()); IRBuilder<> Builder(&AI); auto [TotalElements, BaseType] = getElementCountAndType(ArrType); ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); AllocaInst *FlatAlloca = Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim"); FlatAlloca->setAlignment(AI.getAlign()); AI.replaceAllUsesWith(FlatAlloca); AI.eraseFromParent(); return true; } bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Do not visit GEPs more than once if (GEPChainInfoMap.contains(cast(&GEP))) return false; Value *PtrOperand = GEP.getPointerOperand(); // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets // pointer types, then the handling of this case can be implemented later. assert(!isa(PtrOperand) && "Pointer operand of GEP should not be a PHI Node"); // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that // it can be visited if (auto *PtrOpGEPCE = dyn_cast(PtrOperand); PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { GetElementPtrInst *OldGEPI = cast(PtrOpGEPCE->getAsInstruction()); OldGEPI->insertBefore(GEP.getIterator()); IRBuilder<> Builder(&GEP); SmallVector Indices(GEP.indices()); Value *NewGEP = Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, GEP.getName(), GEP.getNoWrapFlags()); assert(isa(NewGEP) && "Expected newly-created GEP to be an instruction"); GetElementPtrInst *NewGEPI = cast(NewGEP); GEP.replaceAllUsesWith(NewGEPI); GEP.eraseFromParent(); visitGetElementPtrInst(*OldGEPI); visitGetElementPtrInst(*NewGEPI); return true; } // Construct GEPInfo for this GEP GEPInfo Info; // Obtain the variable and constant byte offsets computed by this GEP const DataLayout &DL = GEP.getDataLayout(); unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); Info.ConstantOffset = {BitWidth, 0}; [[maybe_unused]] bool Success = GEP.collectOffset( DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset); assert(Success && "Failed to collect offsets for GEP"); // If there is a parent GEP, inherit the root array type and pointer, and // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP // chain and we need to deterine the root array type if (auto *PtrOpGEP = dyn_cast(PtrOperand)) { assert(GEPChainInfoMap.contains(PtrOpGEP) && "Expected parent GEP to be visited before this GEP"); GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP]; Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType; Info.RootPointerOperand = PGEPInfo.RootPointerOperand; for (auto &VariableOffset : PGEPInfo.VariableOffsets) Info.VariableOffsets.insert(VariableOffset); Info.ConstantOffset += PGEPInfo.ConstantOffset; } else { Info.RootPointerOperand = PtrOperand; // We should try to determine the type of the root from the pointer rather // than the GEP's source element type because this could be a scalar GEP // into an array-typed pointer from an Alloca or Global Variable. Type *RootTy = GEP.getSourceElementType(); if (auto *GlobalVar = dyn_cast(PtrOperand)) { if (GlobalMap.contains(GlobalVar)) GlobalVar = GlobalMap[GlobalVar]; Info.RootPointerOperand = GlobalVar; RootTy = GlobalVar->getValueType(); } else if (auto *Alloca = dyn_cast(PtrOperand)) RootTy = Alloca->getAllocatedType(); assert(!isMultiDimensionalArray(RootTy) && "Expected root array type to be flattened"); // If the root type is not an array, we don't need to do any flattening if (!isa(RootTy)) return false; Info.RootFlattenedArrayType = cast(RootTy); } // GEPs without users or GEPs with non-GEP users should be replaced such that // the chain of GEPs they are a part of are collapsed to a single GEP into a // flattened array. bool ReplaceThisGEP = GEP.users().empty(); for (Value *User : GEP.users()) if (!isa(User)) ReplaceThisGEP = true; if (ReplaceThisGEP) { unsigned BytesPerElem = DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType()); assert(isPowerOf2_32(BytesPerElem) && "Bytes per element should be a power of 2"); // Compute the 32-bit index for this flattened GEP from the constant and // variable byte offsets in the GEPInfo IRBuilder<> Builder(&GEP); Value *ZeroIndex = Builder.getInt32(0); uint64_t ConstantOffset = Info.ConstantOffset.udiv(BytesPerElem).getZExtValue(); assert(ConstantOffset < UINT32_MAX && "Constant byte offset for flat GEP index must fit within 32 bits"); Value *FlattenedIndex = Builder.getInt32(ConstantOffset); for (auto [VarIndex, Multiplier] : Info.VariableOffsets) { assert(Multiplier.getActiveBits() <= 32 && "The multiplier for a flat GEP index must fit within 32 bits"); assert(VarIndex->getType()->isIntegerTy(32) && "Expected i32-typed GEP indices"); Value *VI; if (Multiplier.getZExtValue() % BytesPerElem != 0) { // This can happen, e.g., with i8 GEPs. To handle this we just divide // by BytesPerElem using an instruction after multiplying VarIndex by // Multiplier. VI = Builder.CreateMul(VarIndex, Builder.getInt32(Multiplier.getZExtValue())); VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem))); } else VI = Builder.CreateMul( VarIndex, Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem)); FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI); } // Construct a new GEP for the flattened array to replace the current GEP Value *NewGEP = Builder.CreateGEP( Info.RootFlattenedArrayType, Info.RootPointerOperand, {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags()); // Replace the current GEP with the new GEP. Store GEPInfo into the map // for later use in case this GEP was not the end of the chain GEPChainInfoMap.insert({cast(NewGEP), std::move(Info)}); GEP.replaceAllUsesWith(NewGEP); GEP.eraseFromParent(); return true; } // This GEP is potentially dead at the end of the pass since it may not have // any users anymore after GEP chains have been collapsed. We retain store // GEPInfo for GEPs down the chain to use to compute their indices. GEPChainInfoMap.insert({cast(&GEP), std::move(Info)}); PotentiallyDeadInstrs.emplace_back(&GEP); return false; } bool DXILFlattenArraysVisitor::visit(Function &F) { bool MadeChange = false; ReversePostOrderTraversal RPOT(&F); for (BasicBlock *BB : make_early_inc_range(RPOT)) { for (Instruction &I : make_early_inc_range(*BB)) MadeChange |= InstVisitor::visit(I); } finish(); return MadeChange; } static void collectElements(Constant *Init, SmallVectorImpl &Elements) { // Base case: If Init is not an array, add it directly to the vector. auto *ArrayTy = dyn_cast(Init->getType()); if (!ArrayTy) { Elements.push_back(Init); return; } unsigned ArrSize = ArrayTy->getNumElements(); if (isa(Init)) { for (unsigned I = 0; I < ArrSize; ++I) Elements.push_back(Constant::getNullValue(ArrayTy->getElementType())); return; } // Recursive case: Process each element in the array. if (auto *ArrayConstant = dyn_cast(Init)) { for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) { collectElements(ArrayConstant->getOperand(I), Elements); } } else if (auto *DataArrayConstant = dyn_cast(Init)) { for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) { collectElements(DataArrayConstant->getElementAsConstant(I), Elements); } } else { llvm_unreachable( "Expected a ConstantArray or ConstantDataArray for array initializer!"); } } static Constant *transformInitializer(Constant *Init, Type *OrigType, ArrayType *FlattenedType, LLVMContext &Ctx) { // Handle ConstantAggregateZero (zero-initialized constants) if (isa(Init)) return ConstantAggregateZero::get(FlattenedType); // Handle UndefValue (undefined constants) if (isa(Init)) return UndefValue::get(FlattenedType); if (!isa(OrigType)) return Init; SmallVector FlattenedElements; collectElements(Init, FlattenedElements); assert(FlattenedType->getNumElements() == FlattenedElements.size() && "The number of collected elements should match the FlattenedType"); return ConstantArray::get(FlattenedType, FlattenedElements); } static void flattenGlobalArrays( Module &M, SmallDenseMap &GlobalMap) { LLVMContext &Ctx = M.getContext(); for (GlobalVariable &G : M.globals()) { Type *OrigType = G.getValueType(); if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType)) continue; ArrayType *ArrType = cast(OrigType); auto [TotalElements, BaseType] = DXILFlattenArraysVisitor::getElementCountAndType(ArrType); ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); // Create a new global variable with the updated type // Note: Initializer is set via transformInitializer GlobalVariable *NewGlobal = new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(), /*Initializer=*/nullptr, G.getName() + ".1dim", &G, G.getThreadLocalMode(), G.getAddressSpace(), G.isExternallyInitialized()); // Copy relevant attributes NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); if (G.getAlignment() > 0) { NewGlobal->setAlignment(G.getAlign()); } if (G.hasInitializer()) { Constant *Init = G.getInitializer(); Constant *NewInit = transformInitializer(Init, OrigType, FattenedArrayType, Ctx); NewGlobal->setInitializer(NewInit); } GlobalMap[&G] = NewGlobal; } } static bool flattenArrays(Module &M) { bool MadeChange = false; SmallDenseMap GlobalMap; flattenGlobalArrays(M, GlobalMap); DXILFlattenArraysVisitor Impl(GlobalMap); for (auto &F : make_early_inc_range(M.functions())) { if (F.isDeclaration()) continue; MadeChange |= Impl.visit(F); } for (auto &[Old, New] : GlobalMap) { Old->replaceAllUsesWith(New); Old->eraseFromParent(); MadeChange = true; } return MadeChange; } PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) { bool MadeChanges = flattenArrays(M); if (!MadeChanges) return PreservedAnalyses::all(); PreservedAnalyses PA; return PA; } bool DXILFlattenArraysLegacy::runOnModule(Module &M) { return flattenArrays(M); } char DXILFlattenArraysLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", false, false) INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", false, false) ModulePass *llvm::createDXILFlattenArraysLegacyPass() { return new DXILFlattenArraysLegacy(); }