//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass implements regularization of LLVM IR for SPIR-V. The prototype of // the pass was taken from SPIRV-LLVM translator. // //===----------------------------------------------------------------------===// #include "SPIRV.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include #define DEBUG_TYPE "spirv-regularizer" using namespace llvm; namespace { struct SPIRVRegularizer : public FunctionPass { public: static char ID; SPIRVRegularizer() : FunctionPass(ID) {} bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "SPIR-V Regularizer"; } void getAnalysisUsage(AnalysisUsage &AU) const override { FunctionPass::getAnalysisUsage(AU); } private: void runLowerConstExpr(Function &F); void runLowerI1Comparisons(Function &F); }; } // namespace char SPIRVRegularizer::ID = 0; INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false, false) // Since SPIR-V cannot represent constant expression, constant expressions // in LLVM IR need to be lowered to instructions. For each function, // the constant expressions used by instructions of the function are replaced // by instructions placed in the entry block since it dominates all other BBs. // Each constant expression only needs to be lowered once in each function // and all uses of it by instructions in that function are replaced by // one instruction. // TODO: remove redundant instructions for common subexpression. void SPIRVRegularizer::runLowerConstExpr(Function &F) { LLVMContext &Ctx = F.getContext(); std::list WorkList; for (auto &II : instructions(F)) WorkList.push_back(&II); auto FBegin = F.begin(); while (!WorkList.empty()) { Instruction *II = WorkList.front(); auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * { if (isa(V)) return V; auto *CE = cast(V); LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE); auto ReplInst = CE->getAsInstruction(); auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back(); ReplInst->insertBefore(InsPoint->getIterator()); LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n'); std::vector Users; // Do not replace use during iteration of use. Do it in another loop. for (auto U : CE->users()) { LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n'); auto InstUser = dyn_cast(U); // Only replace users in scope of current function. if (InstUser && InstUser->getParent()->getParent() == &F) Users.push_back(InstUser); } for (auto &User : Users) { if (ReplInst->getParent() == User->getParent() && User->comesBefore(ReplInst)) ReplInst->moveBefore(User->getIterator()); User->replaceUsesOfWith(CE, ReplInst); } return ReplInst; }; WorkList.pop_front(); auto LowerConstantVec = [&II, &LowerOp, &WorkList, &Ctx](ConstantVector *Vec, unsigned NumOfOp) -> Value * { if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) { return isa(V) || isa(V); })) { // Expand a vector of constexprs and construct it back with // series of insertelement instructions. std::list OpList; std::transform(Vec->op_begin(), Vec->op_end(), std::back_inserter(OpList), [LowerOp](Value *V) { return LowerOp(V); }); Value *Repl = nullptr; unsigned Idx = 0; auto *PhiII = dyn_cast(II); Instruction *InsPoint = PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II; std::list ReplList; for (auto V : OpList) { if (auto *Inst = dyn_cast(V)) ReplList.push_back(Inst); Repl = InsertElementInst::Create( (Repl ? Repl : PoisonValue::get(Vec->getType())), V, ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint->getIterator()); } WorkList.splice(WorkList.begin(), ReplList); return Repl; } return nullptr; }; for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) { auto *Op = II->getOperand(OI); if (auto *Vec = dyn_cast(Op)) { Value *ReplInst = LowerConstantVec(Vec, OI); if (ReplInst) II->replaceUsesOfWith(Op, ReplInst); } else if (auto CE = dyn_cast(Op)) { WorkList.push_front(cast(LowerOp(CE))); } else if (auto MDAsVal = dyn_cast(Op)) { auto ConstMD = dyn_cast(MDAsVal->getMetadata()); if (!ConstMD) continue; Constant *C = ConstMD->getValue(); Value *ReplInst = nullptr; if (auto *Vec = dyn_cast(C)) ReplInst = LowerConstantVec(Vec, OI); if (auto *CE = dyn_cast(C)) ReplInst = LowerOp(CE); if (!ReplInst) continue; Metadata *RepMD = ValueAsMetadata::get(ReplInst); Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD); II->setOperand(OI, RepMDVal); WorkList.push_front(cast(ReplInst)); } } } } // Lower i1 comparisons with certain predicates to logical operations. // The backend treats i1 as boolean values, and SPIR-V only allows logical // operations for boolean values. This function lowers i1 comparisons with // certain predicates to logical operations to generate valid SPIR-V. void SPIRVRegularizer::runLowerI1Comparisons(Function &F) { for (auto &I : make_early_inc_range(instructions(F))) { auto *Cmp = dyn_cast(&I); if (!Cmp) continue; bool IsI1 = Cmp->getOperand(0)->getType()->isIntegerTy(1); if (!IsI1) continue; auto Pred = Cmp->getPredicate(); bool IsTargetPred = Pred >= ICmpInst::ICMP_UGT && Pred <= ICmpInst::ICMP_SLE; if (!IsTargetPred) continue; Value *P = Cmp->getOperand(0); Value *Q = Cmp->getOperand(1); IRBuilder<> Builder(Cmp); Value *Result = nullptr; switch (Pred) { case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SLT: // Result = p & !q Result = Builder.CreateAnd(P, Builder.CreateNot(Q)); break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SGT: // Result = q & !p Result = Builder.CreateAnd(Q, Builder.CreateNot(P)); break; case ICmpInst::ICMP_ULE: case ICmpInst::ICMP_SGE: // Result = q | !p Result = Builder.CreateOr(Q, Builder.CreateNot(P)); break; case ICmpInst::ICMP_UGE: case ICmpInst::ICMP_SLE: // Result = p | !q Result = Builder.CreateOr(P, Builder.CreateNot(Q)); break; default: llvm_unreachable("Unexpected predicate"); } Result->takeName(Cmp); Cmp->replaceAllUsesWith(Result); Cmp->eraseFromParent(); } } bool SPIRVRegularizer::runOnFunction(Function &F) { runLowerI1Comparisons(F); runLowerConstExpr(F); return true; } FunctionPass *llvm::createSPIRVRegularizerPass() { return new SPIRVRegularizer(); }