aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp141
1 files changed, 137 insertions, 4 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 3918dd2..43e4f8e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1618,6 +1618,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
}
+ // Customize load and store operation for bf16 if zfh isn't enabled.
+ if (Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh()) {
+ setOperationAction(ISD::LOAD, MVT::bf16, Custom);
+ setOperationAction(ISD::STORE, MVT::bf16, Custom);
+ }
+
// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
@@ -2733,6 +2739,27 @@ bool RISCVTargetLowering::isLegalElementTypeForRVV(EVT ScalarTy) const {
}
}
+bool RISCVTargetLowering::isLegalLoadStoreElementTypeForRVV(
+ EVT ScalarTy) const {
+ if (!ScalarTy.isSimple())
+ return false;
+ switch (ScalarTy.getSimpleVT().SimpleTy) {
+ case MVT::iPTR:
+ return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true;
+ case MVT::i8:
+ case MVT::i16:
+ case MVT::i32:
+ case MVT::f16:
+ case MVT::bf16:
+ case MVT::f32:
+ return true;
+ case MVT::i64:
+ case MVT::f64:
+ return Subtarget.hasVInstructionsI64();
+ default:
+ return false;
+ }
+}
unsigned RISCVTargetLowering::combineRepeatedFPDivisors() const {
return NumRepeatedDivisors;
@@ -7216,6 +7243,47 @@ static SDValue SplitStrictFPVectorOp(SDValue Op, SelectionDAG &DAG) {
return DAG.getMergeValues({V, HiRes.getValue(1)}, DL);
}
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Load(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 load lowering");
+
+ SDLoc DL(Op);
+ LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+ EVT MemVT = LD->getMemoryVT();
+ SDValue Load = DAG.getExtLoad(
+ ISD::ZEXTLOAD, DL, Subtarget.getXLenVT(), LD->getChain(),
+ LD->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits()),
+ LD->getMemOperand());
+ // Using mask to make bf16 nan-boxing valid when we don't have flh
+ // instruction. -65536 would be treat as a small number and thus it can be
+ // directly used lui to get the constant.
+ SDValue mask = DAG.getSignedConstant(-65536, DL, Subtarget.getXLenVT());
+ SDValue OrSixteenOne =
+ DAG.getNode(ISD::OR, DL, Load.getValueType(), {Load, mask});
+ SDValue ConvertedResult =
+ DAG.getNode(RISCVISD::NDS_FMV_BF16_X, DL, MVT::bf16, OrSixteenOne);
+ return DAG.getMergeValues({ConvertedResult, Load.getValue(1)}, DL);
+}
+
+SDValue
+RISCVTargetLowering::lowerXAndesBfHCvtBFloat16Store(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Subtarget.hasVendorXAndesBFHCvt() && !Subtarget.hasStdExtZfh() &&
+ "Unexpected bfloat16 store lowering");
+
+ StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+ SDLoc DL(Op);
+ SDValue FMV = DAG.getNode(RISCVISD::NDS_FMV_X_ANYEXTBF16, DL,
+ Subtarget.getXLenVT(), ST->getValue());
+ return DAG.getTruncStore(
+ ST->getChain(), DL, FMV, ST->getBasePtr(),
+ EVT::getIntegerVT(*DAG.getContext(), ST->getMemoryVT().getSizeInBits()),
+ ST->getMemOperand());
+}
+
SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -7914,6 +7982,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getMergeValues({Pair, Chain}, DL);
}
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Load(Op, DAG);
+
// Handle normal vector tuple load.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -7998,6 +8069,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
{Store->getChain(), Lo, Hi, Store->getBasePtr()}, MVT::i64,
Store->getMemOperand());
}
+
+ if (VT == MVT::bf16)
+ return lowerXAndesBfHCvtBFloat16Store(Op, DAG);
+
// Handle normal vector tuple store.
if (VT.isRISCVVectorTuple()) {
SDLoc DL(Op);
@@ -16079,7 +16154,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
uint64_t MulAmt = CNode->getZExtValue();
// Don't do this if the Xqciac extension is enabled and the MulAmt in simm12.
- if (Subtarget.hasVendorXqciac() && isInt<12>(MulAmt))
+ if (Subtarget.hasVendorXqciac() && isInt<12>(CNode->getSExtValue()))
return SDValue();
const bool HasShlAdd = Subtarget.hasStdExtZba() ||
@@ -16184,10 +16259,12 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
for (uint64_t Offset : {3, 5, 9}) {
if (isPowerOf2_64(MulAmt + Offset)) {
+ unsigned ShAmt = Log2_64(MulAmt + Offset);
+ if (ShAmt >= VT.getSizeInBits())
+ continue;
SDLoc DL(N);
SDValue Shift1 =
- DAG.getNode(ISD::SHL, DL, VT, X,
- DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShAmt, DL, VT));
SDValue Mul359 =
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
DAG.getConstant(Log2_64(Offset - 1), DL, VT), X);
@@ -20766,6 +20843,62 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
break;
}
+ case RISCVISD::TUPLE_EXTRACT: {
+ EVT VT = N->getValueType(0);
+ SDValue Tuple = N->getOperand(0);
+ unsigned Idx = N->getConstantOperandVal(1);
+ if (!Tuple.hasOneUse() || Tuple.getOpcode() != ISD::INTRINSIC_W_CHAIN)
+ break;
+
+ unsigned NF = 0;
+ switch (Tuple.getConstantOperandVal(1)) {
+ default:
+ break;
+ case Intrinsic::riscv_vlseg2_mask:
+ case Intrinsic::riscv_vlseg3_mask:
+ case Intrinsic::riscv_vlseg4_mask:
+ case Intrinsic::riscv_vlseg5_mask:
+ case Intrinsic::riscv_vlseg6_mask:
+ case Intrinsic::riscv_vlseg7_mask:
+ case Intrinsic::riscv_vlseg8_mask:
+ NF = Tuple.getValueType().getRISCVVectorTupleNumFields();
+ break;
+ }
+
+ if (!NF || Subtarget.hasOptimizedSegmentLoadStore(NF))
+ break;
+
+ unsigned SEW = VT.getScalarSizeInBits();
+ assert(Log2_64(SEW) == Tuple.getConstantOperandVal(7) &&
+ "Type mismatch without bitcast?");
+ unsigned Stride = SEW / 8 * NF;
+ unsigned Offset = SEW / 8 * Idx;
+
+ SDValue Ops[] = {
+ /*Chain=*/Tuple.getOperand(0),
+ /*IntID=*/DAG.getTargetConstant(Intrinsic::riscv_vlse_mask, DL, XLenVT),
+ /*Passthru=*/Tuple.getOperand(2),
+ /*Ptr=*/
+ DAG.getNode(ISD::ADD, DL, XLenVT, Tuple.getOperand(3),
+ DAG.getConstant(Offset, DL, XLenVT)),
+ /*Stride=*/DAG.getConstant(Stride, DL, XLenVT),
+ /*Mask=*/Tuple.getOperand(4),
+ /*VL=*/Tuple.getOperand(5),
+ /*Policy=*/Tuple.getOperand(6)};
+
+ auto TupleMemSD = cast<MemIntrinsicSDNode>(Tuple);
+ // Match getTgtMemIntrinsic for non-unit stride case
+ EVT MemVT = TupleMemSD->getMemoryVT().getScalarType();
+ MachineFunction &MF = DAG.getMachineFunction();
+ MachineMemOperand *MMO = MF.getMachineMemOperand(
+ TupleMemSD->getMemOperand(), Offset, MemoryLocation::UnknownSize);
+
+ SDVTList VTs = DAG.getVTList({VT, MVT::Other});
+ SDValue Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
+ Ops, MemVT, MMO);
+ DAG.ReplaceAllUsesOfValueWith(Tuple.getValue(1), Result.getValue(1));
+ return Result.getValue(0);
+ }
}
return SDValue();
@@ -24183,7 +24316,7 @@ bool RISCVTargetLowering::isLegalStridedLoadStore(EVT DataType,
return false;
EVT ScalarType = DataType.getScalarType();
- if (!isLegalElementTypeForRVV(ScalarType))
+ if (!isLegalLoadStoreElementTypeForRVV(ScalarType))
return false;
if (!Subtarget.enableUnalignedVectorMem() &&