diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 94 |
1 files changed, 78 insertions, 16 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 4845a9c..54845e5 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } } + // Customize load and store operation for bf16 if zfh isn't enabled. + if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) { + setOperationAction(ISD::LOAD, MVT::bf16, Custom); + setOperationAction(ISD::STORE, MVT::bf16, Custom); + } + // Function alignments. const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4); setMinFunctionAlignment(FunctionAlignment); @@ -2319,6 +2325,10 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, if (getLegalZfaFPImm(Imm, VT) >= 0) return true; + // Some constants can be produced by fli+fneg. + if (Imm.isNegative() && getLegalZfaFPImm(-Imm, VT) >= 0) + return true; + // Cannot create a 64 bit floating-point immediate value for rv32. if (Subtarget.getXLen() < VT.getScalarSizeInBits()) { // td can handle +0.0 or -0.0 already. @@ -7212,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) { return DAG.getMergeValues({V, HiRes.getValue(1)}, DL); } +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 load lowering"); + + SDLoc DL(Op); + LoadSDNode *LD = cast<LoadSDNode>(Op.getNode()); + EVT MemVT = LD->getMemoryVT(); + SDValue Load = DAG.getExtLoad( + ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(), + LD->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()), + LD->getMemOperand()); + // Using mask to make bf16 nan-boxing valid when we don't have flh + // instruction. -65536 would be treat as a small number and thus it can be + // directly used lui to get the constant. + SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT()); + SDValue OrSixteenOne = + DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask}); + SDValue ConvertedResult = + DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne); + return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL); +} + +SDValue +RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op, + SelectionDAG &DAG) const { + assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() && + "Unexpected bfloat16 store lowering"); + + StoreSDNode *ST = cast<StoreSDNode>(Op.getNode()); + SDLoc DL(Op); + SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL, + Subtarget.getXLenVT(), ST->getValue()); + return DAG.getTruncStore( + ST->getChain(), DL, FMV, ST->getBasePtr(), + EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()), + ST->getMemOperand()); +} + SDValue RISCVTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -7910,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getMergeValues({Pair, Chain}, DL); } + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Load(Op, DAG); + // Handle normal vector tuple load. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -7936,7 +7990,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, BasePtr, MachinePointerInfo(Load->getAddressSpace()), Align(8)); OutChains.push_back(LoadVal.getValue(1)); Ret = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Ret, LoadVal, - DAG.getVectorIdxConstant(i, DL)); + DAG.getTargetConstant(i, DL, MVT::i32)); BasePtr = DAG.getNode(ISD::ADD, DL, XLenVT, BasePtr, VROffset, Flag); } return DAG.getMergeValues( @@ -7994,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, {Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64, Store->getMemOperand()); } + + if (VT == MVT::bf16) + return lowerXAndesBfHCvtBFloat16Store(Op, DAG); + // Handle normal vector tuple store. if (VT.isRISCVVectorTuple()) { SDLoc DL(Op); @@ -8015,9 +8073,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, // Extract subregisters in a vector tuple and store them individually. for (unsigned i = 0; i < NF; ++i) { - auto Extract = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, - MVT::getScalableVectorVT(MVT::i8, NumElts), - StoredVal, DAG.getVectorIdxConstant(i, DL)); + auto Extract = + DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, + MVT::getScalableVectorVT(MVT::i8, NumElts), StoredVal, + DAG.getTargetConstant(i, DL, MVT::i32)); Ret = DAG.getStore(Chain, DL, Extract, BasePtr, MachinePointerInfo(Store->getAddressSpace()), Store->getBaseAlign(), @@ -10934,9 +10993,9 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, Load->getMemoryVT(), Load->getMemOperand()); SmallVector<SDValue, 9> Results; for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) { - SDValue SubVec = - DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, - Result.getValue(0), DAG.getVectorIdxConstant(RetIdx, DL)); + SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT, + Result.getValue(0), + DAG.getTargetConstant(RetIdx, DL, MVT::i32)); Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget)); } Results.push_back(Result.getValue(1)); @@ -11023,7 +11082,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op, RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal, convertToScalableVector( ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget), - DAG.getVectorIdxConstant(i, DL)); + DAG.getTargetConstant(i, DL, MVT::i32)); SDValue Ops[] = { FixedIntrinsic->getChain(), @@ -12027,7 +12086,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op, for (unsigned i = 0U; i < Factor; ++i) Res[i] = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, VecVT, Load, - DAG.getVectorIdxConstant(i, DL)); + DAG.getTargetConstant(i, DL, MVT::i32)); return DAG.getMergeValues(Res, DL); } @@ -12124,8 +12183,9 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op, SDValue StoredVal = DAG.getUNDEF(VecTupTy); for (unsigned i = 0; i < Factor; i++) - StoredVal = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal, - Op.getOperand(i), DAG.getConstant(i, DL, XLenVT)); + StoredVal = + DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal, + Op.getOperand(i), DAG.getTargetConstant(i, DL, MVT::i32)); SDValue Ops[] = {DAG.getEntryNode(), DAG.getTargetConstant(IntrIds[Factor - 2], DL, XLenVT), @@ -16073,7 +16133,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, uint64_t MulAmt = CNode->getZExtValue(); // Don't do this if the Xqciac extension is enabled and the MulAmt in simm12. - if (Subtarget.hasVendorXqciac() && isInt<12>(MulAmt)) + if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue())) return SDValue(); const bool HasShlAdd = Subtarget.hasStdExtZba() || @@ -16178,10 +16238,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG, // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x)) for (uint64_t Offset : {3, 5, 9}) { if (isPowerOf2_64(MulAmt + Offset)) { + unsigned ShAmt = Log2_64(MulAmt + Offset); + if (ShAmt >= VT.getSizeInBits()) + continue; SDLoc DL(N); SDValue Shift1 = - DAG.getNode(ISD::SHL, DL, VT, X, - DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT)); + DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT)); SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X, DAG.getConstant(Log2_64(Offset - 1), DL, VT), X); @@ -20690,7 +20752,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, SDValue Result = DAG.getUNDEF(VT); for (unsigned i = 0; i < NF; ++i) Result = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result, Splat, - DAG.getVectorIdxConstant(i, DL)); + DAG.getTargetConstant(i, DL, MVT::i32)); return Result; } // If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer @@ -24014,7 +24076,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts( #endif Val = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, PartVT, DAG.getUNDEF(PartVT), - Val, DAG.getVectorIdxConstant(0, DL)); + Val, DAG.getTargetConstant(0, DL, MVT::i32)); Parts[0] = Val; return true; } |