diff options
author | Noah Goldstein <goldstein.w.n@gmail.com> | 2023-09-19 12:18:51 -0500 |
---|---|---|
committer | Noah Goldstein <goldstein.w.n@gmail.com> | 2023-09-20 13:28:24 -0500 |
commit | 47c642f9a0e936822ce23bdb834bcc4c29ae6484 (patch) | |
tree | 6e1ea0215567b21fbdb160c66f011c4fb05c9d84 /llvm/lib | |
parent | 32a46919a2f3009d19a2de75d1dbb0f530aa19ce (diff) | |
download | llvm-47c642f9a0e936822ce23bdb834bcc4c29ae6484.zip llvm-47c642f9a0e936822ce23bdb834bcc4c29ae6484.tar.gz llvm-47c642f9a0e936822ce23bdb834bcc4c29ae6484.tar.bz2 |
[DAGCombiner] Fold IEEE `fmul`/`fdiv` by Pow2 to `add`/`sub` of exp
Note: This is moving D154678 which previously implemented this in
InstCombine. Concerns where brought up that this was de-canonicalizing
and really targeting a codegen improvement, so placing in DAGCombiner.
This implements:
```
(fmul C, (uitofp Pow2))
-> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa))
(fdiv C, (uitofp Pow2))
-> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa))
```
The motivation is mostly fdiv where 2^(-p) is a fairly common
expression.
The patch is intentionally conservative about the transform, only
doing so if we:
1) have IEEE floats
2) C is normal
3) add/sub of max(Log2(Pow2)) stays in the min/max exponent
bounds.
Alive2 can't realistically prove this, but did test float16/float32
cases (within the bounds of the above rules) exhaustively.
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D154805
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 303 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 18 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.h | 3 |
3 files changed, 288 insertions, 36 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 5088d59..b69bada 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -611,6 +611,7 @@ namespace { SDValue CombineExtLoad(SDNode *N); SDValue CombineZExtLogicopShiftLoad(SDNode *N); SDValue combineRepeatedFPDivisors(SDNode *N); + SDValue combineFMulOrFDivWithIntPow2(SDNode *N); SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex); SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex); SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex); @@ -620,7 +621,10 @@ namespace { SDValue BuildUDIV(SDNode *N); SDValue BuildSREMPow2(SDNode *N); SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N); - SDValue BuildLogBase2(SDValue V, const SDLoc &DL); + SDValue BuildLogBase2(SDValue V, const SDLoc &DL, + bool KnownNeverZero = false, + bool InexpensiveOnly = false, + std::optional<EVT> OutVT = std::nullopt); SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags); SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags Flags); @@ -4389,12 +4393,12 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { // fold (mul x, (1 << c)) -> x << c if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1) && (!VT.isVector() || Level <= AfterLegalizeVectorOps)) { - SDValue LogBase2 = BuildLogBase2(N1, DL); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); - return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); + if (SDValue LogBase2 = BuildLogBase2(N1, DL)) { + EVT ShiftVT = getShiftAmountTy(N0.getValueType()); + SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); + return DAG.getNode(ISD::SHL, DL, VT, N0, Trunc); + } } // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c @@ -4916,31 +4920,31 @@ SDValue DAGCombiner::visitUDIVLike(SDValue N0, SDValue N1, SDNode *N) { EVT VT = N->getValueType(0); // fold (udiv x, (1 << c)) -> x >>u c - if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1)) { - SDValue LogBase2 = BuildLogBase2(N1, DL); - AddToWorklist(LogBase2.getNode()); + if (isConstantOrConstantVector(N1, /*NoOpaques*/ true)) { + if (SDValue LogBase2 = BuildLogBase2(N1, DL)) { + AddToWorklist(LogBase2.getNode()); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); - AddToWorklist(Trunc.getNode()); - return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); + EVT ShiftVT = getShiftAmountTy(N0.getValueType()); + SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT); + AddToWorklist(Trunc.getNode()); + return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); + } } // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2 if (N1.getOpcode() == ISD::SHL) { SDValue N10 = N1.getOperand(0); - if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N10)) { - SDValue LogBase2 = BuildLogBase2(N10, DL); - AddToWorklist(LogBase2.getNode()); - - EVT ADDVT = N1.getOperand(1).getValueType(); - SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT); - AddToWorklist(Trunc.getNode()); - SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::SRL, DL, VT, N0, Add); + if (isConstantOrConstantVector(N10, /*NoOpaques*/ true)) { + if (SDValue LogBase2 = BuildLogBase2(N10, DL)) { + AddToWorklist(LogBase2.getNode()); + + EVT ADDVT = N1.getOperand(1).getValueType(); + SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT); + AddToWorklist(Trunc.getNode()); + SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), Trunc); + AddToWorklist(Add.getNode()); + return DAG.getNode(ISD::SRL, DL, VT, N0, Add); + } } } @@ -5158,14 +5162,15 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) { // fold (mulhu x, (1 << c)) -> x >> (bitwidth - c) if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) && - DAG.isKnownToBeAPowerOfTwo(N1) && hasOperation(ISD::SRL, VT)) { - unsigned NumEltBits = VT.getScalarSizeInBits(); - SDValue LogBase2 = BuildLogBase2(N1, DL); - SDValue SRLAmt = DAG.getNode( - ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2); - EVT ShiftVT = getShiftAmountTy(N0.getValueType()); - SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT); - return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); + hasOperation(ISD::SRL, VT)) { + if (SDValue LogBase2 = BuildLogBase2(N1, DL)) { + unsigned NumEltBits = VT.getScalarSizeInBits(); + SDValue SRLAmt = DAG.getNode( + ISD::SUB, DL, VT, DAG.getConstant(NumEltBits, DL, VT), LogBase2); + EVT ShiftVT = getShiftAmountTy(N0.getValueType()); + SDValue Trunc = DAG.getZExtOrTrunc(SRLAmt, DL, ShiftVT); + return DAG.getNode(ISD::SRL, DL, VT, N0, Trunc); + } } // If the type twice as wide is legal, transform the mulhu to a wider multiply @@ -16328,6 +16333,105 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { return SDValue(); } +// Transform IEEE Floats: +// (fmul C, (uitofp Pow2)) +// -> (bitcast_to_FP (add (bitcast_to_INT C), Log2(Pow2) << mantissa)) +// (fdiv C, (uitofp Pow2)) +// -> (bitcast_to_FP (sub (bitcast_to_INT C), Log2(Pow2) << mantissa)) +// +// The rationale is fmul/fdiv by a power of 2 is just change the exponent, so +// there is no need for more than an add/sub. +// +// This is valid under the following circumstances: +// 1) We are dealing with IEEE floats +// 2) C is normal +// 3) The fmul/fdiv add/sub will not go outside of min/max exponent bounds. +// TODO: Much of this could also be used for generating `ldexp` on targets the +// prefer it. +SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) { + EVT VT = N->getValueType(0); + SDValue ConstOp, Pow2Op; + + int Mantissa = -1; + auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) { + if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV) + return false; + + ConstOp = peekThroughBitcasts(N->getOperand(ConstOpIdx)); + Pow2Op = N->getOperand(1 - ConstOpIdx); + if (Pow2Op.getOpcode() != ISD::UINT_TO_FP && + (Pow2Op.getOpcode() != ISD::SINT_TO_FP || + !DAG.computeKnownBits(Pow2Op).isNonNegative())) + return false; + + Pow2Op = Pow2Op.getOperand(0); + + // TODO(1): We may be able to include undefs. + // TODO(2): We could also handle non-splat vector types. + ConstantFPSDNode *CFP = + isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false); + if (CFP == nullptr) + return false; + const APFloat &APF = CFP->getValueAPF(); + + // Make sure we have normal/ieee constant. + if (!APF.isNormal() || !APF.isIEEE()) + return false; + + // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`. + // TODO: We could use knownbits to make this bound more precise. + int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits(); + + // Make sure the floats exponent is within the bounds that this transform + // produces bitwise equals value. + int CurExp = ilogb(APF); + // FMul by pow2 will only increase exponent. + int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange); + // FDiv by pow2 will only decrease exponent. + int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange); + if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) || + MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics())) + return false; + + // Finally make sure we actually know the mantissa for the float type. + Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1; + return Mantissa > 0; + }; + + if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1)) + return SDValue(); + + if (!TLI.optimizeFMulOrFDivAsShiftAddBitcast(N, ConstOp, Pow2Op)) + return SDValue(); + + // Get log2 after all other checks have taken place. This is because + // BuildLogBase2 may create a new node. + SDLoc DL(N); + // Get Log2 type with same bitwidth as the float type (VT). + EVT NewIntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits()); + if (VT.isVector()) + NewIntVT = EVT::getVectorVT(*DAG.getContext(), NewIntVT, + VT.getVectorNumElements()); + + SDValue Log2 = BuildLogBase2(Pow2Op, DL, DAG.isKnownNeverZero(Pow2Op), + /*InexpensiveOnly*/ true, NewIntVT); + if (!Log2) + return SDValue(); + + // Perform actual transform. + SDValue MantissaShiftCnt = + DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT)); + // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to + // `(X << C1) + (C << C1)`, but that isn't always the case because of the + // cast. We could implement that by handle here to handle the casts. + SDValue Shift = DAG.getNode(ISD::SHL, DL, NewIntVT, Log2, MantissaShiftCnt); + SDValue ResAsInt = + DAG.getNode(N->getOpcode() == ISD::FMUL ? ISD::ADD : ISD::SUB, DL, + NewIntVT, DAG.getBitcast(NewIntVT, ConstOp), Shift); + SDValue ResAsFP = DAG.getBitcast(VT, ResAsInt); + return ResAsFP; +} + SDValue DAGCombiner::visitFMUL(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -16468,6 +16572,11 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { return Fused; } + // Don't do `combineFMulOrFDivWithIntPow2` until after FMUL -> FMA has been + // able to run. + if (SDValue R = combineFMulOrFDivWithIntPow2(N)) + return R; + return SDValue(); } @@ -16819,6 +16928,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { return DAG.getNode(ISD::FDIV, SDLoc(N), VT, NegN0, NegN1); } + if (SDValue R = combineFMulOrFDivWithIntPow2(N)) + return R; + return SDValue(); } @@ -21861,7 +21973,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { if (DAG.isKnownNeverZero(Index)) return DAG.getUNDEF(ScalarVT); - // Check if the result type doesn't match the inserted element type. + // Check if the result type doesn't match the inserted element type. // The inserted element and extracted element may have mismatched bitwidth. // As a result, EXTRACT_VECTOR_ELT may extend or truncate the extracted vector. SDValue InOp = VecOp.getOperand(0); @@ -27142,10 +27254,129 @@ SDValue DAGCombiner::BuildSREMPow2(SDNode *N) { return SDValue(); } +// This is basically just a port of takeLog2 from InstCombineMulDivRem.cpp +// +// Returns the node that represents `Log2(Op)`. This may create a new node. If +// we are unable to compute `Log2(Op)` its return `SDValue()`. +// +// All nodes will be created at `DL` and the output will be of type `VT`. +// +// This will only return `Log2(Op)` if we can prove `Op` is non-zero. Set +// `AssumeNonZero` if this function should simply assume (not require proving +// `Op` is non-zero). +static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT, + SDValue Op, unsigned Depth, + bool AssumeNonZero) { + assert(VT.isInteger() && "Only integer types are supported!"); + + auto PeekThroughCastsAndTrunc = [](SDValue V) { + while (true) { + switch (V.getOpcode()) { + case ISD::TRUNCATE: + case ISD::ZERO_EXTEND: + V = V.getOperand(0); + break; + default: + return V; + } + } + }; + + if (VT.isScalableVector()) + return SDValue(); + + Op = PeekThroughCastsAndTrunc(Op); + + // Helper for determining whether a value is a power-2 constant scalar or a + // vector of such elements. + SmallVector<APInt> Pow2Constants; + auto IsPowerOfTwo = [&Pow2Constants](ConstantSDNode *C) { + if (C->isZero() || C->isOpaque()) + return false; + // TODO: We may also be able to support negative powers of 2 here. + if (C->getAPIntValue().isPowerOf2()) { + Pow2Constants.emplace_back(C->getAPIntValue()); + return true; + } + return false; + }; + + if (ISD::matchUnaryPredicate(Op, IsPowerOfTwo)) { + if (!VT.isVector()) + return DAG.getConstant(Pow2Constants.back().logBase2(), DL, VT); + // We need to create a build vector + SmallVector<SDValue> Log2Ops; + for (const APInt &Pow2 : Pow2Constants) + Log2Ops.emplace_back( + DAG.getConstant(Pow2.logBase2(), DL, VT.getScalarType())); + return DAG.getBuildVector(VT, DL, Log2Ops); + } + + if (Depth >= DAG.MaxRecursionDepth) + return SDValue(); + + auto CastToVT = [&](EVT NewVT, SDValue ToCast) { + EVT CurVT = ToCast.getValueType(); + ToCast = PeekThroughCastsAndTrunc(ToCast); + if (NewVT == CurVT) + return ToCast; + + if (NewVT.getSizeInBits() == CurVT.getSizeInBits()) + return DAG.getBitcast(NewVT, ToCast); + + return DAG.getZExtOrTrunc(ToCast, DL, NewVT); + }; + + // log2(X << Y) -> log2(X) + Y + if (Op.getOpcode() == ISD::SHL) { + // 1 << Y and X nuw/nsw << Y are all non-zero. + if (AssumeNonZero || Op->getFlags().hasNoUnsignedWrap() || + Op->getFlags().hasNoSignedWrap() || isOneConstant(Op.getOperand(0))) + if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), + Depth + 1, AssumeNonZero)) + return DAG.getNode(ISD::ADD, DL, VT, LogX, + CastToVT(VT, Op.getOperand(1))); + } + + // c ? X : Y -> c ? Log2(X) : Log2(Y) + if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) && + Op.hasOneUse()) { + if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), + Depth + 1, AssumeNonZero)) + if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2), + Depth + 1, AssumeNonZero)) + return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY); + } + + // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) + // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) + if ((Op.getOpcode() == ISD::UMIN || Op.getOpcode() == ISD::UMAX) && + Op.hasOneUse()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (SDValue LogX = + takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(0), Depth + 1, + /*AssumeNonZero*/ false)) + if (SDValue LogY = + takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1), Depth + 1, + /*AssumeNonZero*/ false)) + return DAG.getNode(Op.getOpcode(), DL, VT, LogX, LogY); + } + + return SDValue(); +} + /// Determines the LogBase2 value for a non-null input value using the /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). -SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) { - EVT VT = V.getValueType(); +SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL, + bool KnownNonZero, bool InexpensiveOnly, + std::optional<EVT> OutVT) { + EVT VT = OutVT ? *OutVT : V.getValueType(); + SDValue InexpensiveLogBase2 = + takeInexpensiveLog2(DAG, DL, VT, V, /*Depth*/ 0, KnownNonZero); + if (InexpensiveLogBase2 || InexpensiveOnly || !DAG.isKnownToBeAPowerOfTwo(V)) + return InexpensiveLogBase2; + SDValue Ctlz = DAG.getNode(ISD::CTLZ, DL, VT, V); SDValue Base = DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT); SDValue LogBase2 = DAG.getNode(ISD::SUB, DL, VT, Base, Ctlz); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0724879..5d5040e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -22445,6 +22445,24 @@ bool X86TargetLowering::isXAndYEqZeroPreferableToXAndYEqY(ISD::CondCode Cond, return !VT.isVector() || Cond != ISD::CondCode::SETEQ; } +bool X86TargetLowering::optimizeFMulOrFDivAsShiftAddBitcast( + SDNode *N, SDValue, SDValue IntPow2) const { + if (N->getOpcode() == ISD::FDIV) + return true; + + EVT FPVT = N->getValueType(0); + EVT IntVT = IntPow2.getValueType(); + + // This indicates a non-free bitcast. + // TODO: This is probably overly conservative as we will need to scale the + // integer vector anyways for the int->fp cast. + if (FPVT.isVector() && + FPVT.getScalarSizeInBits() != IntVT.getScalarSizeInBits()) + return false; + + return true; +} + /// Check if replacement of SQRT with RSQRT should be disabled. bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index 4d45a0a5..1c51a37 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1808,6 +1808,9 @@ namespace llvm { const SDLoc &dl, SelectionDAG &DAG, SDValue &X86CC) const; + bool optimizeFMulOrFDivAsShiftAddBitcast(SDNode *N, SDValue FPConst, + SDValue IntPow2) const override; + /// Check if replacement of SQRT with RSQRT should be disabled. bool isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const override; |