aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp94
1 files changed, 78 insertions, 16 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4845a9c..54845e5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
}
+ // Customize load and store operation for bf16 if zfh isn't enabled.
+ if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+ setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+ setOperationAction(ISD::STORE, MVT::bf16, Custom);
+ }
+
// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
@@ -2319,6 +2325,10 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
if (getLegalZfaFPImm(Imm, VT) >= 0)
return true;
+ // Some constants can be produced by fli+fneg.
+ if (Imm.isNegative() && getLegalZfaFPImm(-Imm, VT) >= 0)
+ return true;
+
// Cannot create a 64 bit floating-point immediate value for rv32.
if (Subtarget.getXLen() < VT.getScalarSizeInBits()) {
// td can handle +0.0 or -0.0 already.
@@ -7212,6 +7222,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
}
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 load lowering");
+
+ SDLoc DL(Op);
+ LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+ EVT MemVT = LD->getMemoryVT();
+ SDValue Load = DAG.getExtLoad(
+ ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+ LD->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+ LD->getMemOperand());
+ // Using mask to make bf16 nan-boxing valid when we don't have flh
+ // instruction. -65536 would be treat as a small number and thus it can be
+ // directly used lui to get the constant.
+ SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+ SDValue OrSixteenOne =
+ DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+ SDValue ConvertedResult =
+ DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+ return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 store lowering");
+
+ StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+ SDLoc DL(Op);
+ SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+ Subtarget.getXLenVT(), ST->getValue());
+ return DAG.getTruncStore(
+ ST->getChain(), DL, FMV, ST->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+ ST->getMemOperand());
+}
+
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -7910,6 +7961,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getMergeValues({Pair, Chain}, DL);
}
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
// Handle normal vector tuple load.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -7936,7 +7990,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
BasePtr, MachinePointerInfo(Load->getAddressSpace()), Align(8));
OutChains.push_back(LoadVal.getValue(1));
Ret = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Ret, LoadVal,
- DAG.getVectorIdxConstant(i, DL));
+ DAG.getTargetConstant(i, DL, MVT::i32));
BasePtr = DAG.getNode(ISD::ADD, DL, XLenVT, BasePtr, VROffset, Flag);
}
return DAG.getMergeValues(
@@ -7994,6 +8048,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
Store->getMemOperand());
}
+
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
// Handle normal vector tuple store.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -8015,9 +8073,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
// Extract subregisters in a vector tuple and store them individually.
for (unsigned i = 0; i < NF; ++i) {
- auto Extract = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL,
- MVT::getScalableVectorVT(MVT::i8, NumElts),
- StoredVal, DAG.getVectorIdxConstant(i, DL));
+ auto Extract =
+ DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL,
+ MVT::getScalableVectorVT(MVT::i8, NumElts), StoredVal,
+ DAG.getTargetConstant(i, DL, MVT::i32));
Ret = DAG.getStore(Chain, DL, Extract, BasePtr,
MachinePointerInfo(Store->getAddressSpace()),
Store->getBaseAlign(),
@@ -10934,9 +10993,9 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
Load->getMemoryVT(), Load->getMemOperand());
SmallVector<SDValue, 9> Results;
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
- SDValue SubVec =
- DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
- Result.getValue(0), DAG.getVectorIdxConstant(RetIdx, DL));
+ SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
+ Result.getValue(0),
+ DAG.getTargetConstant(RetIdx, DL, MVT::i32));
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
}
Results.push_back(Result.getValue(1));
@@ -11023,7 +11082,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
convertToScalableVector(
ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget),
- DAG.getVectorIdxConstant(i, DL));
+ DAG.getTargetConstant(i, DL, MVT::i32));
SDValue Ops[] = {
FixedIntrinsic->getChain(),
@@ -12027,7 +12086,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
for (unsigned i = 0U; i < Factor; ++i)
Res[i] = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, VecVT, Load,
- DAG.getVectorIdxConstant(i, DL));
+ DAG.getTargetConstant(i, DL, MVT::i32));
return DAG.getMergeValues(Res, DL);
}
@@ -12124,8 +12183,9 @@ SDValue RISCVTargetLowering::lowerVECTOR_INTERLEAVE(SDValue Op,
SDValue StoredVal = DAG.getUNDEF(VecTupTy);
for (unsigned i = 0; i < Factor; i++)
- StoredVal = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
- Op.getOperand(i), DAG.getConstant(i, DL, XLenVT));
+ StoredVal =
+ DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
+ Op.getOperand(i), DAG.getTargetConstant(i, DL, MVT::i32));
SDValue Ops[] = {DAG.getEntryNode(),
DAG.getTargetConstant(IntrIds[Factor - 2], DL, XLenVT),
@@ -16073,7 +16133,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt = CNode->getZExtValue();
// Don't do this if the Xqciac extension is enabled and the MulAmt in simm12.
- if (Subtarget.hasVendorXqciac() && isInt<12>(MulAmt))
+ if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue()))
return SDValue();
const bool HasShlAdd = Subtarget.hasStdExtZba() ||
@@ -16178,10 +16238,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
+ unsigned ShAmt = Log2_64(MulAmt + Offset);
+ if (ShAmt >= VT.getSizeInBits())
+ continue;
SDLoc DL(N);
SDValue Shift1 =
- DAG.getNode(ISD::SHL, DL, VT, X,
- DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT));
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Offset - 1), DL, VT), X);
@@ -20690,7 +20752,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
SDValue Result = DAG.getUNDEF(VT);
for (unsigned i = 0; i < NF; ++i)
Result = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, VT, Result, Splat,
- DAG.getVectorIdxConstant(i, DL));
+ DAG.getTargetConstant(i, DL, MVT::i32));
return Result;
}
// If this is a bitcast between a MVT::v4i1/v2i1/v1i1 and an illegal integer
@@ -24014,7 +24076,7 @@ bool RISCVTargetLowering::splitValueIntoRegisterParts(
#endif
Val = DAG.getNode(RISCVISD::TUPLE_INSERT, DL, PartVT, DAG.getUNDEF(PartVT),
- Val, DAG.getVectorIdxConstant(0, DL));
+ Val, DAG.getTargetConstant(0, DL, MVT::i32));
Parts[0] = Val;
return true;
}