diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 342 |
1 files changed, 289 insertions, 53 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3918dd2..adbfbeb 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } + // Customize load and store operation for bf16 if zfh isn't enabled. + if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) { + setOperationAction(ISD::LOAD, MVT::bf16, Custom); + setOperationAction(ISD::STORE, MVT::bf16, Custom); + } + // Function alignments. const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4); setMinFunctionAlignment(FunctionAlignment); @@ -1813,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false, /*IsUnitStrided*/ false, /*UsePtrVal*/ true); case Intrinsic::riscv_seg2_store_mask: @@ -7216,6 +7229,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) { return DAG.getMergeValues({V, HiRes.getValue(1)}, DL); } +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 load lowering"); + + SDLoc DL(Op); + LoadSDNode *LD = cast<LoadSDNode>(Op.getNode()); + EVT MemVT = LD->getMemoryVT(); + SDValue Load = DAG.getExtLoad( + ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(), + LD->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), + LD->getMemOperand()); + // Using mask to make bf16 nan-boxing valid when we don't have flh + // instruction. -65536 would be treat as a small number and thus it can be + // directly used lui to get the constant. + SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT()); + SDValue OrSixteenOne = + DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask}); + SDValue ConvertedResult = + DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne); + return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL); +} + +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 store lowering"); + + StoreSDNode *ST = cast<StoreSDNode>(Op.getNode()); + SDLoc DL(Op); + SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL, + Subtarget.getXLenVT(), ST->getValue()); + return DAG.getTruncStore( + ST->getChain(), DL, FMV, ST->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()), + ST->getMemOperand()); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -7914,6 +7968,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getMergeValues({Pair, Chain}, DL); } + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Load(Op, DAG); + // Handle normal vector tuple load. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -7998,6 +8055,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64, Store->getMemOperand()); } + + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Store(Op, DAG); + // Handle normal vector tuple store. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -10884,6 +10945,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG, return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands); } +static SDValue +lowerFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op, + const RISCVSubtarget &Subtarget, + SelectionDAG &DAG) { + bool IsStrided; + switch (IntNo) { + case Intrinsic::riscv_seg2_load_mask: + case Intrinsic::riscv_seg3_load_mask: + case Intrinsic::riscv_seg4_load_mask: + case Intrinsic::riscv_seg5_load_mask: + case Intrinsic::riscv_seg6_load_mask: + case Intrinsic::riscv_seg7_load_mask: + case Intrinsic::riscv_seg8_load_mask: + IsStrided = false; + break; + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + IsStrided = true; + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + }; + + static const Intrinsic::ID VlsegInts[7] = { + Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, + Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, + Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, + Intrinsic::riscv_vlseg8_mask}; + static const Intrinsic::ID VlssegInts[7] = { + Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask, + Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask, + Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask, + Intrinsic::riscv_vlsseg8_mask}; + + SDLoc DL(Op); + unsigned NF = Op->getNumValues() - 1; + assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); + MVT XLenVT = Subtarget.getXLenVT(); + MVT VT = Op->getSimpleValueType(0); + MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget); + unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * + ContainerVT.getScalarSizeInBits(); + EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); + + // Operands: (chain, int_id, pointer, mask, vl) or + // (chain, int_id, pointer, offset, mask, vl) + SDValue VL = Op.getOperand(Op.getNumOperands() - 1); + SDValue Mask = Op.getOperand(Op.getNumOperands() - 2); + MVT MaskVT = Mask.getSimpleValueType(); + MVT MaskContainerVT = + ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); + Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); + + SDValue IntID = DAG.getTargetConstant( + IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT); + auto *Load = cast<MemIntrinsicSDNode>(Op); + + SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); + SmallVector<SDValue, 9> Ops = { + Load->getChain(), + IntID, + DAG.getUNDEF(VecTupTy), + Op.getOperand(2), + Mask, + VL, + DAG.getTargetConstant( + RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), + DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; + // Insert the stride operand. + if (IsStrided) + Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3)); + + SDValue Result = + DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, + Load->getMemoryVT(), Load->getMemOperand()); + SmallVector<SDValue, 9> Results; + for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { + SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, + Result.getValue(0), + DAG.getTargetConstant(RetIdx, DL, MVT::i32)); + Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); + } + Results.push_back(Result.getValue(1)); + return DAG.getMergeValues(Results, DL); +} + SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = Op.getConstantOperandVal(1); @@ -10896,57 +11048,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, case Intrinsic::riscv_seg5_load_mask: case Intrinsic::riscv_seg6_load_mask: case Intrinsic::riscv_seg7_load_mask: - case Intrinsic::riscv_seg8_load_mask: { - SDLoc DL(Op); - static const Intrinsic::ID VlsegInts[7] = { - Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask, - Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask, - Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask, - Intrinsic::riscv_vlseg8_mask}; - unsigned NF = Op->getNumValues() - 1; - assert(NF >= 2 && NF <= 8 && "Unexpected seg number"); - MVT XLenVT = Subtarget.getXLenVT(); - MVT VT = Op->getSimpleValueType(0); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - unsigned Sz = NF * ContainerVT.getVectorMinNumElements() * - ContainerVT.getScalarSizeInBits(); - EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF); - - // Operands: (chain, int_id, pointer, mask, vl) - SDValue VL = Op.getOperand(Op.getNumOperands() - 1); - SDValue Mask = Op.getOperand(3); - MVT MaskVT = Mask.getSimpleValueType(); - MVT MaskContainerVT = - ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget); - Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget); - - SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT); - auto *Load = cast<MemIntrinsicSDNode>(Op); + case Intrinsic::riscv_seg8_load_mask: + case Intrinsic::riscv_sseg2_load_mask: + case Intrinsic::riscv_sseg3_load_mask: + case Intrinsic::riscv_sseg4_load_mask: + case Intrinsic::riscv_sseg5_load_mask: + case Intrinsic::riscv_sseg6_load_mask: + case Intrinsic::riscv_sseg7_load_mask: + case Intrinsic::riscv_sseg8_load_mask: + return lowerFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG); - SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other}); - SDValue Ops[] = { - Load->getChain(), - IntID, - DAG.getUNDEF(VecTupTy), - Op.getOperand(2), - Mask, - VL, - DAG.getTargetConstant( - RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT), - DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)}; - SDValue Result = - DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, - Load->getMemoryVT(), Load->getMemOperand()); - SmallVector<SDValue, 9> Results; - for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { - SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, - Result.getValue(0), - DAG.getTargetConstant(RetIdx, DL, MVT::i32)); - Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); - } - Results.push_back(Result.getValue(1)); - return DAG.getMergeValues(Results, DL); - } case Intrinsic::riscv_sf_vc_v_x_se: return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE); case Intrinsic::riscv_sf_vc_v_i_se: @@ -16079,7 +16190,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt = CNode->getZExtValue(); // Don't do this if the Xqciac extension is enabled and the MulAmt in simm12. - if (Subtarget.hasVendorXqciac() && isInt<12>(MulAmt)) + if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); const bool HasShlAdd = Subtarget.hasStdExtZba() || @@ -16184,10 +16295,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { + unsigned ShAmt = Log2_64(MulAmt + Offset); + if (ShAmt >= VT.getSizeInBits()) + continue; SDLoc DL(N); SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT)); + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT)); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getConstant(Log2_64(Offset - 1), DL, VT), X); @@ -20674,6 +20787,53 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return DAG.getAllOnesConstant(DL, VT); return DAG.getConstant(0, DL, VT); } + case Intrinsic::riscv_vsseg2_mask: + case Intrinsic::riscv_vsseg3_mask: + case Intrinsic::riscv_vsseg4_mask: + case Intrinsic::riscv_vsseg5_mask: + case Intrinsic::riscv_vsseg6_mask: + case Intrinsic::riscv_vsseg7_mask: + case Intrinsic::riscv_vsseg8_mask: { + SDValue Tuple = N->getOperand(2); + unsigned NF = Tuple.getValueType().getRISCVVectorTupleNumFields(); + + if (Subtarget.hasOptimizedSegmentLoadStore(NF) || !Tuple.hasOneUse() || + Tuple.getOpcode() != RISCVISD::TUPLE_INSERT || + !Tuple.getOperand(0).isUndef()) + return SDValue(); + + SDValue Val = Tuple.getOperand(1); + unsigned Idx = Tuple.getConstantOperandVal(2); + + unsigned SEW = Val.getValueType().getScalarSizeInBits(); + assert(Log2_64(SEW) == N->getConstantOperandVal(6) && + "Type mismatch without bitcast?"); + unsigned Stride = SEW / 8 * NF; + unsigned Offset = SEW / 8 * Idx; + + SDValue Ops[] = { + /*Chain=*/N->getOperand(0), + /*IntID=*/ + DAG.getTargetConstant(Intrinsic::riscv_vsse_mask, DL, XLenVT), + /*StoredVal=*/Val, + /*Ptr=*/ + DAG.getNode(ISD::ADD, DL, XLenVT, N->getOperand(3), + DAG.getConstant(Offset, DL, XLenVT)), + /*Stride=*/DAG.getConstant(Stride, DL, XLenVT), + /*Mask=*/N->getOperand(4), + /*VL=*/N->getOperand(5)}; + + auto *OldMemSD = cast<MemIntrinsicSDNode>(N); + // Match getTgtMemIntrinsic for non-unit stride case + EVT MemVT = OldMemSD->getMemoryVT().getScalarType(); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + OldMemSD->getMemOperand(), Offset, MemoryLocation::UnknownSize); + + SDVTList VTs = DAG.getVTList(MVT::Other); + return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, VTs, Ops, MemVT, + MMO); + } } } case ISD::EXPERIMENTAL_VP_REVERSE: @@ -20766,6 +20926,68 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } break; } + case RISCVISD::TUPLE_EXTRACT: { + EVT VT = N->getValueType(0); + SDValue Tuple = N->getOperand(0); + unsigned Idx = N->getConstantOperandVal(1); + if (!Tuple.hasOneUse() || Tuple.getOpcode() != ISD::INTRINSIC_W_CHAIN) + break; + + unsigned NF = 0; + switch (Tuple.getConstantOperandVal(1)) { + default: + break; + case Intrinsic::riscv_vlseg2_mask: + case Intrinsic::riscv_vlseg3_mask: + case Intrinsic::riscv_vlseg4_mask: + case Intrinsic::riscv_vlseg5_mask: + case Intrinsic::riscv_vlseg6_mask: + case Intrinsic::riscv_vlseg7_mask: + case Intrinsic::riscv_vlseg8_mask: + NF = Tuple.getValueType().getRISCVVectorTupleNumFields(); + break; + } + + if (!NF || Subtarget.hasOptimizedSegmentLoadStore(NF)) + break; + + unsigned SEW = VT.getScalarSizeInBits(); + assert(Log2_64(SEW) == Tuple.getConstantOperandVal(7) && + "Type mismatch without bitcast?"); + unsigned Stride = SEW / 8 * NF; + unsigned Offset = SEW / 8 * Idx; + + SDValue Ops[] = { + /*Chain=*/Tuple.getOperand(0), + /*IntID=*/DAG.getTargetConstant(Intrinsic::riscv_vlse_mask, DL, XLenVT), + /*Passthru=*/Tuple.getOperand(2), + /*Ptr=*/ + DAG.getNode(ISD::ADD, DL, XLenVT, Tuple.getOperand(3), + DAG.getConstant(Offset, DL, XLenVT)), + /*Stride=*/DAG.getConstant(Stride, DL, XLenVT), + /*Mask=*/Tuple.getOperand(4), + /*VL=*/Tuple.getOperand(5), + /*Policy=*/Tuple.getOperand(6)}; + + auto *TupleMemSD = cast<MemIntrinsicSDNode>(Tuple); + // Match getTgtMemIntrinsic for non-unit stride case + EVT MemVT = TupleMemSD->getMemoryVT().getScalarType(); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + TupleMemSD->getMemOperand(), Offset, MemoryLocation::UnknownSize); + + SDVTList VTs = DAG.getVTList({VT, MVT::Other}); + SDValue Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, + Ops, MemVT, MMO); + DAG.ReplaceAllUsesOfValueWith(Tuple.getValue(1), Result.getValue(1)); + return Result.getValue(0); + } + case RISCVISD::TUPLE_INSERT: { + // tuple_insert tuple, undef, idx -> tuple + if (N->getOperand(1).isUndef()) + return N->getOperand(0); + break; + } } return SDValue(); @@ -22290,6 +22512,7 @@ SDValue RISCVTargetLowering::LowerFormalArguments( case CallingConv::C: case CallingConv::Fast: case CallingConv::SPIR_KERNEL: + case CallingConv::PreserveMost: case CallingConv::GRAAL: case CallingConv::RISCV_VectorCall: #define CC_VLS_CASE(ABI_VLEN) case CallingConv::RISCV_VLSCall_##ABI_VLEN: @@ -22559,8 +22782,14 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, bool IsVarArg = CLI.IsVarArg; EVT PtrVT = getPointerTy(DAG.getDataLayout()); MVT XLenVT = Subtarget.getXLenVT(); + const CallBase *CB = CLI.CB; MachineFunction &MF = DAG.getMachineFunction(); + MachineFunction::CallSiteInfo CSInfo; + + // Set type id for call site info. + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + CSInfo = MachineFunction::CallSiteInfo(*CB); // Analyze the operands of the call, assigning locations to each operand. SmallVector<CCValAssign, 16> ArgLocs; @@ -22818,6 +23047,9 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, if (CLI.CFIType) Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue()); DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge); + if (MF.getTarget().Options.EmitCallGraphSection && CB && + CB->isIndirectCall()) + DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo)); return Ret; } @@ -22825,6 +23057,10 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI, Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops); if (CLI.CFIType) Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue()); + + if (MF.getTarget().Options.EmitCallGraphSection && CB && CB->isIndirectCall()) + DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); + DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); Glue = Chain.getValue(1); |