diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 141 |
1 files changed, 137 insertions, 4 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3918dd2..43e4f8e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } + // Customize load and store operation for bf16 if zfh isn't enabled. + if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) { + setOperationAction(ISD::LOAD, MVT::bf16, Custom); + setOperationAction(ISD::STORE, MVT::bf16, Custom); + } + // Function alignments. const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4); setMinFunctionAlignment(FunctionAlignment); @@ -2733,6 +2739,27 @@ bool RISCVTargetLowering::isLegalElementTypeForRVV(EVT ScalarTy) const { } } +bool RISCVTargetLowering::isLegalLoadStoreElementTypeForRVV( + EVT ScalarTy) const { + if (!ScalarTy.isSimple()) + return false; + switch (ScalarTy.getSimpleVT().SimpleTy) { + case MVT::iPTR: + return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true; + case MVT::i8: + case MVT::i16: + case MVT::i32: + case MVT::f16: + case MVT::bf16: + case MVT::f32: + return true; + case MVT::i64: + case MVT::f64: + return Subtarget.hasVInstructionsI64(); + default: + return false; + } +} unsigned RISCVTargetLowering::combineRepeatedFPDivisors() const { return NumRepeatedDivisors; @@ -7216,6 +7243,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) { return DAG.getMergeValues({V, HiRes.getValue(1)}, DL); } +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 load lowering"); + + SDLoc DL(Op); + LoadSDNode *LD = cast<LoadSDNode>(Op.getNode()); + EVT MemVT = LD->getMemoryVT(); + SDValue Load = DAG.getExtLoad( + ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(), + LD->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), + LD->getMemOperand()); + // Using mask to make bf16 nan-boxing valid when we don't have flh + // instruction. -65536 would be treat as a small number and thus it can be + // directly used lui to get the constant. + SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT()); + SDValue OrSixteenOne = + DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask}); + SDValue ConvertedResult = + DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne); + return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL); +} + +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 store lowering"); + + StoreSDNode *ST = cast<StoreSDNode>(Op.getNode()); + SDLoc DL(Op); + SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL, + Subtarget.getXLenVT(), ST->getValue()); + return DAG.getTruncStore( + ST->getChain(), DL, FMV, ST->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()), + ST->getMemOperand()); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -7914,6 +7982,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getMergeValues({Pair, Chain}, DL); } + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Load(Op, DAG); + // Handle normal vector tuple load. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -7998,6 +8069,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64, Store->getMemOperand()); } + + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Store(Op, DAG); + // Handle normal vector tuple store. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -16079,7 +16154,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt = CNode->getZExtValue(); // Don't do this if the Xqciac extension is enabled and the MulAmt in simm12. - if (Subtarget.hasVendorXqciac() && isInt<12>(MulAmt)) + if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); const bool HasShlAdd = Subtarget.hasStdExtZba() || @@ -16184,10 +16259,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { + unsigned ShAmt = Log2_64(MulAmt + Offset); + if (ShAmt >= VT.getSizeInBits()) + continue; SDLoc DL(N); SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT)); + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT)); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getConstant(Log2_64(Offset - 1), DL, VT), X); @@ -20766,6 +20843,62 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } break; } + case RISCVISD::TUPLE_EXTRACT: { + EVT VT = N->getValueType(0); + SDValue Tuple = N->getOperand(0); + unsigned Idx = N->getConstantOperandVal(1); + if (!Tuple.hasOneUse() || Tuple.getOpcode() != ISD::INTRINSIC_W_CHAIN) + break; + + unsigned NF = 0; + switch (Tuple.getConstantOperandVal(1)) { + default: + break; + case Intrinsic::riscv_vlseg2_mask: + case Intrinsic::riscv_vlseg3_mask: + case Intrinsic::riscv_vlseg4_mask: + case Intrinsic::riscv_vlseg5_mask: + case Intrinsic::riscv_vlseg6_mask: + case Intrinsic::riscv_vlseg7_mask: + case Intrinsic::riscv_vlseg8_mask: + NF = Tuple.getValueType().getRISCVVectorTupleNumFields(); + break; + } + + if (!NF || Subtarget.hasOptimizedSegmentLoadStore(NF)) + break; + + unsigned SEW = VT.getScalarSizeInBits(); + assert(Log2_64(SEW) == Tuple.getConstantOperandVal(7) && + "Type mismatch without bitcast?"); + unsigned Stride = SEW / 8 * NF; + unsigned Offset = SEW / 8 * Idx; + + SDValue Ops[] = { + /*Chain=*/Tuple.getOperand(0), + /*IntID=*/DAG.getTargetConstant(Intrinsic::riscv_vlse_mask, DL, XLenVT), + /*Passthru=*/Tuple.getOperand(2), + /*Ptr=*/ + DAG.getNode(ISD::ADD, DL, XLenVT, Tuple.getOperand(3), + DAG.getConstant(Offset, DL, XLenVT)), + /*Stride=*/DAG.getConstant(Stride, DL, XLenVT), + /*Mask=*/Tuple.getOperand(4), + /*VL=*/Tuple.getOperand(5), + /*Policy=*/Tuple.getOperand(6)}; + + auto TupleMemSD = cast<MemIntrinsicSDNode>(Tuple); + // Match getTgtMemIntrinsic for non-unit stride case + EVT MemVT = TupleMemSD->getMemoryVT().getScalarType(); + MachineFunction &MF = DAG.getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + TupleMemSD->getMemOperand(), Offset, MemoryLocation::UnknownSize); + + SDVTList VTs = DAG.getVTList({VT, MVT::Other}); + SDValue Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, + Ops, MemVT, MMO); + DAG.ReplaceAllUsesOfValueWith(Tuple.getValue(1), Result.getValue(1)); + return Result.getValue(0); + } } return SDValue(); @@ -24183,7 +24316,7 @@ bool RISCVTargetLowering::isLegalStridedLoadStore(EVT DataType, return false; EVT ScalarType = DataType.getScalarType(); - if (!isLegalElementTypeForRVV(ScalarType)) + if (!isLegalLoadStoreElementTypeForRVV(ScalarType)) return false; if (!Subtarget.enableUnalignedVectorMem() && |