diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp | 151 |
1 files changed, 120 insertions, 31 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index 5a2416de..dfaa145 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -2020,6 +2020,22 @@ bool AMDGPUDAGToDAGISel::SelectGlobalSAddr(SDNode *N, SDValue Addr, return true; } +bool AMDGPUDAGToDAGISel::SelectGlobalSAddrCPol(SDNode *N, SDValue Addr, + SDValue &SAddr, SDValue &VOffset, + SDValue &Offset, + SDValue &CPol) const { + bool ScaleOffset; + if (!SelectGlobalSAddr(N, Addr, SAddr, VOffset, Offset, ScaleOffset)) + return false; + + // We are assuming CPol is always the last operand of the intrinsic. + auto PassedCPol = + N->getConstantOperandVal(N->getNumOperands() - 1) & ~AMDGPU::CPol::SCAL; + CPol = CurDAG->getTargetConstant( + (ScaleOffset ? AMDGPU::CPol::SCAL : 0) | PassedCPol, SDLoc(), MVT::i32); + return true; +} + bool AMDGPUDAGToDAGISel::SelectGlobalSAddrGLC(SDNode *N, SDValue Addr, SDValue &SAddr, SDValue &VOffset, SDValue &Offset, @@ -3861,58 +3877,114 @@ bool AMDGPUDAGToDAGISel::SelectVOP3OpSelMods(SDValue In, SDValue &Src, return SelectVOP3Mods(In, Src, SrcMods); } +// Match lowered fpext from bf16 to f32. This is a bit operation extending +// a 16-bit value with 16-bit of zeroes at LSB: +// +// 1. (f32 (bitcast (build_vector (i16 0), (i16 (bitcast bf16:val))))) +// 2. (f32 (bitcast (and i32:val, 0xffff0000))) -> IsExtractHigh = true +// 3. (f32 (bitcast (shl i32:va, 16) -> IsExtractHigh = false +static SDValue matchBF16FPExtendLike(SDValue Op, bool &IsExtractHigh) { + if (Op.getValueType() != MVT::f32 || Op.getOpcode() != ISD::BITCAST) + return SDValue(); + Op = Op.getOperand(0); + + IsExtractHigh = false; + if (Op.getValueType() == MVT::v2i16 && Op.getOpcode() == ISD::BUILD_VECTOR) { + auto Low16 = dyn_cast<ConstantSDNode>(Op.getOperand(0)); + if (!Low16 || !Low16->isZero()) + return SDValue(); + Op = stripBitcast(Op.getOperand(1)); + if (Op.getValueType() != MVT::bf16) + return SDValue(); + return Op; + } + + if (Op.getValueType() != MVT::i32) + return SDValue(); + + if (Op.getOpcode() == ISD::AND) { + if (auto Mask = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { + if (Mask->getZExtValue() == 0xffff0000) { + IsExtractHigh = true; + return Op.getOperand(0); + } + } + return SDValue(); + } + + if (Op.getOpcode() == ISD::SHL) { + if (auto Amt = dyn_cast<ConstantSDNode>(Op.getOperand(1))) { + if (Amt->getZExtValue() == 16) + return Op.getOperand(0); + } + } + + return SDValue(); +} + // The return value is not whether the match is possible (which it always is), // but whether or not it a conversion is really used. bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src, - unsigned &Mods) const { + unsigned &Mods, + MVT VT) const { Mods = 0; SelectVOP3ModsImpl(In, Src, Mods); + bool IsExtractHigh = false; if (Src.getOpcode() == ISD::FP_EXTEND) { Src = Src.getOperand(0); - assert(Src.getValueType() == MVT::f16); - Src = stripBitcast(Src); + } else if (VT == MVT::bf16) { + SDValue B16 = matchBF16FPExtendLike(Src, IsExtractHigh); + if (!B16) + return false; + Src = B16; + } else + return false; - // Be careful about folding modifiers if we already have an abs. fneg is - // applied last, so we don't want to apply an earlier fneg. - if ((Mods & SISrcMods::ABS) == 0) { - unsigned ModsTmp; - SelectVOP3ModsImpl(Src, Src, ModsTmp); + if (Src.getValueType() != VT && + (VT != MVT::bf16 || Src.getValueType() != MVT::i32)) + return false; - if ((ModsTmp & SISrcMods::NEG) != 0) - Mods ^= SISrcMods::NEG; + Src = stripBitcast(Src); - if ((ModsTmp & SISrcMods::ABS) != 0) - Mods |= SISrcMods::ABS; - } + // Be careful about folding modifiers if we already have an abs. fneg is + // applied last, so we don't want to apply an earlier fneg. + if ((Mods & SISrcMods::ABS) == 0) { + unsigned ModsTmp; + SelectVOP3ModsImpl(Src, Src, ModsTmp); - // op_sel/op_sel_hi decide the source type and source. - // If the source's op_sel_hi is set, it indicates to do a conversion from fp16. - // If the sources's op_sel is set, it picks the high half of the source - // register. + if ((ModsTmp & SISrcMods::NEG) != 0) + Mods ^= SISrcMods::NEG; - Mods |= SISrcMods::OP_SEL_1; - if (isExtractHiElt(Src, Src)) { - Mods |= SISrcMods::OP_SEL_0; + if ((ModsTmp & SISrcMods::ABS) != 0) + Mods |= SISrcMods::ABS; + } - // TODO: Should we try to look for neg/abs here? - } + // op_sel/op_sel_hi decide the source type and source. + // If the source's op_sel_hi is set, it indicates to do a conversion from + // fp16. If the sources's op_sel is set, it picks the high half of the source + // register. - // Prevent unnecessary subreg COPY to VGPR_16 - if (Src.getOpcode() == ISD::TRUNCATE && - Src.getOperand(0).getValueType() == MVT::i32) { - Src = Src.getOperand(0); - } - return true; + Mods |= SISrcMods::OP_SEL_1; + if (IsExtractHigh || + (Src.getValueSizeInBits() == 16 && isExtractHiElt(Src, Src))) { + Mods |= SISrcMods::OP_SEL_0; + + // TODO: Should we try to look for neg/abs here? } - return false; + // Prevent unnecessary subreg COPY to VGPR_16 + if (Src.getOpcode() == ISD::TRUNCATE && + Src.getOperand(0).getValueType() == MVT::i32) { + Src = Src.getOperand(0); + } + return true; } bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src, SDValue &SrcMods) const { unsigned Mods = 0; - if (!SelectVOP3PMadMixModsImpl(In, Src, Mods)) + if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16)) return false; SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); return true; @@ -3921,7 +3993,24 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src, bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const { unsigned Mods = 0; - SelectVOP3PMadMixModsImpl(In, Src, Mods); + SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16); + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16ModsExt(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + unsigned Mods = 0; + if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16)) + return false; + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + +bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16Mods(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + unsigned Mods = 0; + SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16); SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); return true; } |