aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Target/RISCV/RISCVISelLowering.cpp76
-rw-r--r--llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll32
2 files changed, 92 insertions, 16 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e74d184..19e4074 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13589,6 +13589,52 @@ static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
return ActiveLanes.all();
}
+/// Match the index of a gather or scatter operation as an operation
+/// with twice the element width and half the number of elements. This is
+/// generally profitable (if legal) because these operations are linear
+/// in VL, so even if we cause some extract VTYPE/VL toggles, we still
+/// come out ahead.
+static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
+ Align BaseAlign, const RISCVSubtarget &ST) {
+ if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
+ return false;
+ if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
+ return false;
+
+ // Attempt a doubling. If we can use a element type 4x or 8x in
+ // size, this will happen via multiply iterations of the transform.
+ const unsigned NumElems = VT.getVectorNumElements();
+ if (NumElems % 2 != 0)
+ return false;
+
+ const unsigned ElementSize = VT.getScalarStoreSize();
+ const unsigned WiderElementSize = ElementSize * 2;
+ if (WiderElementSize > ST.getELen()/8)
+ return false;
+
+ if (!ST.enableUnalignedVectorMem() && BaseAlign < WiderElementSize)
+ return false;
+
+ for (unsigned i = 0; i < Index->getNumOperands(); i++) {
+ // TODO: We've found an active bit of UB, and could be
+ // more aggressive here if desired.
+ if (Index->getOperand(i)->isUndef())
+ return false;
+ // TODO: This offset check is too strict if we support fully
+ // misaligned memory operations.
+ uint64_t C = Index->getConstantOperandVal(i);
+ if (C % ElementSize != 0)
+ return false;
+ if (i % 2 == 0)
+ continue;
+ uint64_t Last = Index->getConstantOperandVal(i-1);
+ if (C != Last + ElementSize)
+ return false;
+ }
+ return true;
+}
+
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -14020,6 +14066,36 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
}
+
+ if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
+ matchIndexAsWiderOp(VT, Index, MGN->getMask(),
+ MGN->getMemOperand()->getBaseAlign(), Subtarget)) {
+ SmallVector<SDValue> NewIndices;
+ for (unsigned i = 0; i < Index->getNumOperands(); i += 2)
+ NewIndices.push_back(Index.getOperand(i));
+ EVT IndexVT = Index.getValueType()
+ .getHalfNumVectorElementsVT(*DAG.getContext());
+ Index = DAG.getBuildVector(IndexVT, DL, NewIndices);
+
+ unsigned ElementSize = VT.getScalarStoreSize();
+ EVT WideScalarVT = MVT::getIntegerVT(ElementSize * 8 * 2);
+ auto EltCnt = VT.getVectorElementCount();
+ assert(EltCnt.isKnownEven() && "Splitting vector, but not in half!");
+ EVT WideVT = EVT::getVectorVT(*DAG.getContext(), WideScalarVT,
+ EltCnt.divideCoefficientBy(2));
+ SDValue Passthru = DAG.getBitcast(WideVT, MGN->getPassThru());
+ EVT MaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+ EltCnt.divideCoefficientBy(2));
+ SDValue Mask = DAG.getSplat(MaskVT, DL, DAG.getConstant(1, DL, MVT::i1));
+
+ SDValue Gather =
+ DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), WideVT, DL,
+ {MGN->getChain(), Passthru, Mask, MGN->getBasePtr(),
+ Index, ScaleOp},
+ MGN->getMemOperand(), IndexType, ISD::NON_EXTLOAD);
+ SDValue Result = DAG.getBitcast(VT, Gather.getValue(0));
+ return DAG.getMergeValues({Result, Gather.getValue(1)}, DL);
+ }
break;
}
case ISD::MSCATTER:{
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index ac5c11c..130d2c7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -13024,19 +13024,19 @@ define <4 x i32> @mgather_narrow_edge_case(ptr %base) {
define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
; RV32-LABEL: mgather_strided_2xSEW:
; RV32: # %bb.0:
-; RV32-NEXT: lui a1, %hi(.LCPI107_0)
-; RV32-NEXT: addi a1, a1, %lo(.LCPI107_0)
-; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; RV32-NEXT: vle8.v v9, (a1)
+; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; RV32-NEXT: vid.v v8
+; RV32-NEXT: vsll.vi v9, v8, 3
+; RV32-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; RV32-NEXT: vluxei8.v v8, (a0), v9
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_strided_2xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: lui a1, %hi(.LCPI107_0)
-; RV64V-NEXT: addi a1, a1, %lo(.LCPI107_0)
-; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; RV64V-NEXT: vle8.v v9, (a1)
+; RV64V-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
+; RV64V-NEXT: vid.v v8
+; RV64V-NEXT: vsll.vi v9, v8, 3
+; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; RV64V-NEXT: vluxei8.v v8, (a0), v9
; RV64V-NEXT: ret
;
@@ -13141,19 +13141,19 @@ define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
define <8 x i16> @mgather_gather_2xSEW(ptr %base) {
; RV32-LABEL: mgather_gather_2xSEW:
; RV32: # %bb.0:
-; RV32-NEXT: lui a1, %hi(.LCPI108_0)
-; RV32-NEXT: addi a1, a1, %lo(.LCPI108_0)
-; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; RV32-NEXT: vle8.v v9, (a1)
+; RV32-NEXT: lui a1, 82176
+; RV32-NEXT: addi a1, a1, 1024
+; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV32-NEXT: vmv.s.x v9, a1
; RV32-NEXT: vluxei8.v v8, (a0), v9
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_gather_2xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: lui a1, %hi(.LCPI108_0)
-; RV64V-NEXT: addi a1, a1, %lo(.LCPI108_0)
-; RV64V-NEXT: vsetivli zero, 8, e16, m1, ta, ma
-; RV64V-NEXT: vle8.v v9, (a1)
+; RV64V-NEXT: lui a1, 82176
+; RV64V-NEXT: addiw a1, a1, 1024
+; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV64V-NEXT: vmv.s.x v9, a1
; RV64V-NEXT: vluxei8.v v8, (a0), v9
; RV64V-NEXT: ret
;