diff options
author | Igor Kirillov <igor.kirillov@arm.com> | 2023-04-17 18:24:45 +0000 |
---|---|---|
committer | Igor Kirillov <igor.kirillov@arm.com> | 2023-05-31 18:31:38 +0000 |
commit | 1a1e76100e3f99c2bf0babcab52da333c12631e2 (patch) | |
tree | 7ee4a15b3c41deb4d457a823d9a0296b177d0407 /llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | |
parent | 1ca458f78e26e785b6eca2946a7558d8c39c7490 (diff) | |
download | llvm-1a1e76100e3f99c2bf0babcab52da333c12631e2.zip llvm-1a1e76100e3f99c2bf0babcab52da333c12631e2.tar.gz llvm-1a1e76100e3f99c2bf0babcab52da333c12631e2.tar.bz2 |
[CodeGen] Improve handling -Ofast generated code by ComplexDeinterleaving pass
Code generated with -Ofast and -O3 -ffp-contract=fast (add
-ffinite-math-only to enable vectorization) can differ significantly.
Code compiled with -O3 can be deinterleaved using patterns as the
instruction order is preserved. However, with the -Ofast flag, there
can be multiple changes in the computation sequence, and even the real
and imaginary parts may not be calculated in parallel.
For more details, refer to
llvm/test/CodeGen/AArch64/complex-deinterleaving-*-fast.ll and
llvm/test/CodeGen/AArch64/complex-deinterleaving-*-contract.ll tests.
This patch implements a more general approach and enables handling most
-Ofast cases.
Differential Revision: https://reviews.llvm.org/D148558
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 582 |
1 files changed, 547 insertions, 35 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index 4351d68..ec7abb2 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -143,6 +143,11 @@ public: Instruction *Real; Instruction *Imag; + // This two members are required exclusively for generating + // ComplexDeinterleavingOperation::Symmetric operations. + unsigned Opcode; + FastMathFlags Flags; + ComplexDeinterleavingRotation Rotation = ComplexDeinterleavingRotation::Rotation_0; SmallVector<RawNodePtr> Operands; @@ -186,8 +191,26 @@ public: class ComplexDeinterleavingGraph { public: + struct Product { + Instruction *Multiplier; + Instruction *Multiplicand; + bool IsPositive; + }; + + using Addend = std::pair<Instruction *, bool>; using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; + + // Helper struct for holding info about potential partial multiplication + // candidates + struct PartialMulCandidate { + Instruction *Common; + NodePtr Node; + unsigned RealIdx; + unsigned ImagIdx; + bool IsNodeInverted; + }; + explicit ComplexDeinterleavingGraph(const TargetLowering *TL, const TargetLibraryInfo *TLI) : TL(TL), TLI(TLI) {} @@ -256,6 +279,40 @@ private: NodePtr identifyNode(Instruction *I, Instruction *J); + /// Determine if a sum of complex numbers can be formed from \p RealAddends + /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. + /// Return nullptr if it is not possible to construct a complex number. + /// \p Flags are needed to generate symmetric Add and Sub operations. + NodePtr identifyAdditions(std::list<Addend> &RealAddends, + std::list<Addend> &ImagAddends, FastMathFlags Flags, + NodePtr Accumulator); + + /// Extract one addend that have both real and imaginary parts positive. + NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, + std::list<Addend> &ImagAddends); + + /// Determine if sum of multiplications of complex numbers can be formed from + /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result + /// to it. Return nullptr if it is not possible to construct a complex number. + NodePtr identifyMultiplications(std::vector<Product> &RealMuls, + std::vector<Product> &ImagMuls, + NodePtr Accumulator); + + /// Go through pairs of multiplication (one Real and one Imag) and find all + /// possible candidates for partial multiplication and put them into \p + /// Candidates. Returns true if all Product has pair with common operand + bool collectPartialMuls(const std::vector<Product> &RealMuls, + const std::vector<Product> &ImagMuls, + std::vector<PartialMulCandidate> &Candidates); + + /// If the code is compiled with -Ofast or expressions have `reassoc` flag, + /// the order of complex computation operations may be significantly altered, + /// and the real and imaginary parts may not be executed in parallel. This + /// function takes this into consideration and employs a more general approach + /// to identify complex computations. Initially, it gathers all the addends + /// and multiplicands and then constructs a complex expression from them. + NodePtr identifyReassocNodes(Instruction *I, Instruction *J); + NodePtr identifyRoot(Instruction *I); /// Identifies the Deinterleave operation applied to a vector containing @@ -737,8 +794,16 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, return nullptr; } + if (isa<FPMathOperator>(Real) && + Real->getFastMathFlags() != Imag->getFastMathFlags()) + return nullptr; + auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Real, Imag); + Node->Opcode = Real->getOpcode(); + if (isa<FPMathOperator>(Real)) + Node->Flags = Real->getFastMathFlags(); + Node->addOperand(Op0); if (Real->isBinaryOp()) Node->addOperand(Op1); @@ -754,29 +819,477 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { return CN; } - NodePtr Node = identifyDeinterleave(Real, Imag); - if (Node) - return Node; + if (NodePtr CN = identifyDeinterleave(Real, Imag)) + return CN; auto *VTy = cast<VectorType>(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); - if (TL->isComplexDeinterleavingOperationSupported( - ComplexDeinterleavingOperation::CMulPartial, NewVTy) && - isInstructionPairMul(Real, Imag)) { - return identifyPartialMul(Real, Imag); + bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CMulPartial, NewVTy); + bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CAdd, NewVTy); + + if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { + if (NodePtr CN = identifyPartialMul(Real, Imag)) + return CN; + } + + if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { + if (NodePtr CN = identifyAdd(Real, Imag)) + return CN; + } + + if (HasCMulSupport && HasCAddSupport) { + if (NodePtr CN = identifyReassocNodes(Real, Imag)) + return CN; + } + + if (NodePtr CN = identifySymmetricOperation(Real, Imag)) + return CN; + + LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); + return nullptr; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, + Instruction *Imag) { + if ((Real->getOpcode() != Instruction::FAdd && + Real->getOpcode() != Instruction::FSub && + Real->getOpcode() != Instruction::FNeg) || + (Imag->getOpcode() != Instruction::FAdd && + Imag->getOpcode() != Instruction::FSub && + Imag->getOpcode() != Instruction::FNeg)) + return nullptr; + + if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { + LLVM_DEBUG( + dbgs() + << "The flags in Real and Imaginary instructions are not identical\n"); + return nullptr; + } + + FastMathFlags Flags = Real->getFastMathFlags(); + if (!Flags.allowReassoc()) { + LLVM_DEBUG( + dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); + return nullptr; + } + + // Collect multiplications and addend instructions from the given instruction + // while traversing it operands. Additionally, verify that all instructions + // have the same fast math flags. + auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, + std::list<Addend> &Addends) -> bool { + SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; + SmallPtrSet<Value *, 8> Visited; + while (!Worklist.empty()) { + auto [V, IsPositive] = Worklist.back(); + Worklist.pop_back(); + if (!Visited.insert(V).second) + continue; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return false; + + // If an instruction has more than one user, it indicates that it either + // has an external user, which will be later checked by the checkNodes + // function, or it is a subexpression utilized by multiple expressions. In + // the latter case, we will attempt to separately identify the complex + // operation from here in order to create a shared + // ComplexDeinterleavingCompositeNode. + if (I != Insn && I->getNumUses() > 1) { + LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); + Addends.emplace_back(I, IsPositive); + continue; + } + + if (I->getOpcode() == Instruction::FAdd) { + Worklist.emplace_back(I->getOperand(1), IsPositive); + Worklist.emplace_back(I->getOperand(0), IsPositive); + } else if (I->getOpcode() == Instruction::FSub) { + Worklist.emplace_back(I->getOperand(1), !IsPositive); + Worklist.emplace_back(I->getOperand(0), IsPositive); + } else if (I->getOpcode() == Instruction::FMul) { + auto *A = dyn_cast<Instruction>(I->getOperand(0)); + if (A && A->getOpcode() == Instruction::FNeg) { + A = dyn_cast<Instruction>(A->getOperand(0)); + IsPositive = !IsPositive; + } + if (!A) + return false; + auto *B = dyn_cast<Instruction>(I->getOperand(1)); + if (B && B->getOpcode() == Instruction::FNeg) { + B = dyn_cast<Instruction>(B->getOperand(0)); + IsPositive = !IsPositive; + } + if (!B) + return false; + Muls.push_back(Product{A, B, IsPositive}); + } else if (I->getOpcode() == Instruction::FNeg) { + Worklist.emplace_back(I->getOperand(0), !IsPositive); + } else { + Addends.emplace_back(I, IsPositive); + continue; + } + + if (I->getFastMathFlags() != Flags) { + LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " + "inconsistent with the root instructions' flags: " + << *I << "\n"); + return false; + } + } + return true; + }; + + std::vector<Product> RealMuls, ImagMuls; + std::list<Addend> RealAddends, ImagAddends; + if (!Collect(Real, RealMuls, RealAddends) || + !Collect(Imag, ImagMuls, ImagAddends)) + return nullptr; + + if (RealAddends.size() != ImagAddends.size()) + return nullptr; + + NodePtr FinalNode; + if (!RealMuls.empty() || !ImagMuls.empty()) { + // If there are multiplicands, extract positive addend and use it as an + // accumulator + FinalNode = extractPositiveAddend(RealAddends, ImagAddends); + FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); + if (!FinalNode) + return nullptr; } - if (TL->isComplexDeinterleavingOperationSupported( - ComplexDeinterleavingOperation::CAdd, NewVTy) && - isInstructionPairAdd(Real, Imag)) { - return identifyAdd(Real, Imag); + // Identify and process remaining additions + if (!RealAddends.empty() || !ImagAddends.empty()) { + FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); + if (!FinalNode) + return nullptr; } - auto Symmetric = identifySymmetricOperation(Real, Imag); - LLVM_DEBUG(if (Symmetric == nullptr) dbgs() - << " - Not recognised as a valid pattern.\n"); - return Symmetric; + // Set the Real and Imag fields of the final node and submit it + FinalNode->Real = Real; + FinalNode->Imag = Imag; + submitCompositeNode(FinalNode); + return FinalNode; +} + +bool ComplexDeinterleavingGraph::collectPartialMuls( + const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, + std::vector<PartialMulCandidate> &PartialMulCandidates) { + // Helper function to extract a common operand from two products + auto FindCommonInstruction = [](const Product &Real, + const Product &Imag) -> Instruction * { + if (Real.Multiplicand == Imag.Multiplicand || + Real.Multiplicand == Imag.Multiplier) + return Real.Multiplicand; + + if (Real.Multiplier == Imag.Multiplicand || + Real.Multiplier == Imag.Multiplier) + return Real.Multiplier; + + return nullptr; + }; + + // Iterating over real and imaginary multiplications to find common operands + // If a common operand is found, a partial multiplication candidate is created + // and added to the candidates vector The function returns false if no common + // operands are found for any product + for (unsigned i = 0; i < RealMuls.size(); ++i) { + bool FoundCommon = false; + for (unsigned j = 0; j < ImagMuls.size(); ++j) { + auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); + if (!Common) + continue; + + auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier + : RealMuls[i].Multiplicand; + auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier + : ImagMuls[j].Multiplicand; + + bool Inverted = false; + auto Node = identifyNode(A, B); + if (!Node) { + std::swap(A, B); + Inverted = true; + Node = identifyNode(A, B); + } + if (!Node) + continue; + + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); + } + if (!FoundCommon) + return false; + } + return true; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyMultiplications( + std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, + NodePtr Accumulator = nullptr) { + if (RealMuls.size() != ImagMuls.size()) + return nullptr; + + std::vector<PartialMulCandidate> Info; + if (!collectPartialMuls(RealMuls, ImagMuls, Info)) + return nullptr; + + // Map to store common instruction to node pointers + std::map<Instruction *, NodePtr> CommonToNode; + std::vector<bool> Processed(Info.size(), false); + for (unsigned I = 0; I < Info.size(); ++I) { + if (Processed[I]) + continue; + + PartialMulCandidate &InfoA = Info[I]; + for (unsigned J = I + 1; J < Info.size(); ++J) { + if (Processed[J]) + continue; + + PartialMulCandidate &InfoB = Info[J]; + auto *InfoReal = &InfoA; + auto *InfoImag = &InfoB; + + auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); + if (!NodeFromCommon) { + std::swap(InfoReal, InfoImag); + NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); + } + if (!NodeFromCommon) + continue; + + CommonToNode[InfoReal->Common] = NodeFromCommon; + CommonToNode[InfoImag->Common] = NodeFromCommon; + Processed[I] = true; + Processed[J] = true; + } + } + + std::vector<bool> ProcessedReal(RealMuls.size(), false); + std::vector<bool> ProcessedImag(ImagMuls.size(), false); + NodePtr Result = Accumulator; + for (auto &PMI : Info) { + if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) + continue; + + auto It = CommonToNode.find(PMI.Common); + // TODO: Process independent complex multiplications. Cases like this: + // A.real() * B where both A and B are complex numbers. + if (It == CommonToNode.end()) { + LLVM_DEBUG({ + dbgs() << "Unprocessed independent partial multiplication:\n"; + for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) + dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier + << " multiplied by " << *Mul->Multiplicand << "\n"; + }); + return nullptr; + } + + auto &RealMul = RealMuls[PMI.RealIdx]; + auto &ImagMul = ImagMuls[PMI.ImagIdx]; + + auto NodeA = It->second; + auto NodeB = PMI.Node; + auto IsMultiplicandReal = PMI.Common == NodeA->Real; + // The following table illustrates the relationship between multiplications + // and rotations. If we consider the multiplication (X + iY) * (U + iV), we + // can see: + // + // Rotation | Real | Imag | + // ---------+--------+--------+ + // 0 | x * u | x * v | + // 90 | -y * v | y * u | + // 180 | -x * u | -x * v | + // 270 | y * v | -y * u | + // + // Check if the candidate can indeed be represented by partial + // multiplication + // TODO: Add support for multiplication by complex one + if ((IsMultiplicandReal && PMI.IsNodeInverted) || + (!IsMultiplicandReal && !PMI.IsNodeInverted)) + continue; + + // Determine the rotation based on the multiplications + ComplexDeinterleavingRotation Rotation; + if (IsMultiplicandReal) { + // Detect 0 and 180 degrees rotation + if (RealMul.IsPositive && ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; + else if (!RealMul.IsPositive && !ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; + else + continue; + + } else { + // Detect 90 and 270 degrees rotation + if (!RealMul.IsPositive && ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; + else if (RealMul.IsPositive && !ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; + else + continue; + } + + LLVM_DEBUG({ + dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; + dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; + dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; + dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; + dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; + dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; + }); + + NodePtr NodeMul = prepareCompositeNode( + ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); + NodeMul->Rotation = Rotation; + NodeMul->addOperand(NodeA); + NodeMul->addOperand(NodeB); + if (Result) + NodeMul->addOperand(Result); + submitCompositeNode(NodeMul); + Result = NodeMul; + ProcessedReal[PMI.RealIdx] = true; + ProcessedImag[PMI.ImagIdx] = true; + } + + // Ensure all products have been processed, if not return nullptr. + if (!all_of(ProcessedReal, [](bool V) { return V; }) || + !all_of(ProcessedImag, [](bool V) { return V; })) { + + // Dump debug information about which partial multiplications are not + // processed. + LLVM_DEBUG({ + dbgs() << "Unprocessed products (Real):\n"; + for (size_t i = 0; i < ProcessedReal.size(); ++i) { + if (!ProcessedReal[i]) + dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") + << *RealMuls[i].Multiplier << " multiplied by " + << *RealMuls[i].Multiplicand << "\n"; + } + dbgs() << "Unprocessed products (Imag):\n"; + for (size_t i = 0; i < ProcessedImag.size(); ++i) { + if (!ProcessedImag[i]) + dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") + << *ImagMuls[i].Multiplier << " multiplied by " + << *ImagMuls[i].Multiplicand << "\n"; + } + }); + return nullptr; + } + + return Result; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends, + std::list<Addend> &ImagAddends, + FastMathFlags Flags, + NodePtr Accumulator = nullptr) { + if (RealAddends.size() != ImagAddends.size()) + return nullptr; + + NodePtr Result; + // If we have accumulator use it as first addend + if (Accumulator) + Result = Accumulator; + // Otherwise find an element with both positive real and imaginary parts. + else + Result = extractPositiveAddend(RealAddends, ImagAddends); + + if (!Result) + return nullptr; + + while (!RealAddends.empty()) { + auto ItR = RealAddends.begin(); + auto [R, IsPositiveR] = *ItR; + + bool FoundImag = false; + for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { + auto [I, IsPositiveI] = *ItI; + ComplexDeinterleavingRotation Rotation; + if (IsPositiveR && IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_0; + else if (!IsPositiveR && IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_90; + else if (!IsPositiveR && !IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_180; + else + Rotation = ComplexDeinterleavingRotation::Rotation_270; + + NodePtr AddNode; + if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || + Rotation == ComplexDeinterleavingRotation::Rotation_180) { + AddNode = identifyNode(R, I); + } else { + AddNode = identifyNode(I, R); + } + if (AddNode) { + LLVM_DEBUG({ + dbgs() << "Identified addition:\n"; + dbgs().indent(4) << "X: " << *R << "\n"; + dbgs().indent(4) << "Y: " << *I << "\n"; + dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; + }); + + NodePtr TmpNode; + if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { + TmpNode = prepareCompositeNode( + ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); + TmpNode->Opcode = Instruction::FAdd; + TmpNode->Flags = Flags; + } else if (Rotation == + llvm::ComplexDeinterleavingRotation::Rotation_180) { + TmpNode = prepareCompositeNode( + ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); + TmpNode->Opcode = Instruction::FSub; + TmpNode->Flags = Flags; + } else { + TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, + nullptr, nullptr); + TmpNode->Rotation = Rotation; + } + + TmpNode->addOperand(Result); + TmpNode->addOperand(AddNode); + submitCompositeNode(TmpNode); + Result = TmpNode; + RealAddends.erase(ItR); + ImagAddends.erase(ItI); + FoundImag = true; + break; + } + } + if (!FoundImag) + return nullptr; + } + return Result; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::extractPositiveAddend( + std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { + for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { + for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { + auto [R, IsPositiveR] = *ItR; + auto [I, IsPositiveI] = *ItI; + if (IsPositiveR && IsPositiveI) { + auto Result = identifyNode(R, I); + if (Result) { + RealAddends.erase(ItR); + ImagAddends.erase(ItI); + return Result; + } + } + } + } + return nullptr; } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { @@ -1011,29 +1524,28 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, return submitCompositeNode(PlaceholderNode); } -static Value *replaceSymmetricNode(IRBuilderBase &B, - ComplexDeinterleavingGraph::RawNodePtr Node, - Value *InputA, Value *InputB) { - Instruction *I = Node->Real; - if (I->isUnaryOp()) - assert(!InputB && - "Unary symmetric operations need one input, but two were provided."); - else if (I->isBinaryOp()) - assert(InputB && "Binary symmetric operations need two inputs, only one " - "was provided."); - - switch (I->getOpcode()) { +static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, + FastMathFlags Flags, Value *InputA, + Value *InputB) { + Value *I; + switch (Opcode) { case Instruction::FNeg: - return B.CreateFNegFMF(InputA, I); + I = B.CreateFNeg(InputA); + break; case Instruction::FAdd: - return B.CreateFAddFMF(InputA, InputB, I); + I = B.CreateFAdd(InputA, InputB); + break; case Instruction::FSub: - return B.CreateFSubFMF(InputA, InputB, I); + I = B.CreateFSub(InputA, InputB); + break; case Instruction::FMul: - return B.CreateFMulFMF(InputA, InputB, I); + I = B.CreateFMul(InputA, InputB); + break; + default: + llvm_unreachable("Incorrect symmetric opcode"); } - - return nullptr; + cast<Instruction>(I)->setFastMathFlags(Flags); + return I; } Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, @@ -1048,13 +1560,13 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, Value *Accumulator = Node->Operands.size() > 2 ? replaceNode(Builder, Node->Operands[2]) : nullptr; - if (Input1) assert(Input0->getType() == Input1->getType() && "Node inputs need to be of the same type"); if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) - Node->ReplacementNode = replaceSymmetricNode(Builder, Node, Input0, Input1); + Node->ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, + Node->Flags, Input0, Input1); else Node->ReplacementNode = TL->createComplexDeinterleavingIR( Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); |