diff options
author | Fraser Cormack <fraser@codeplay.com> | 2021-02-22 16:51:24 +0000 |
---|---|---|
committer | Fraser Cormack <fraser@codeplay.com> | 2021-02-25 12:11:34 +0000 |
commit | 84413e1947427a917a3e55abfc1f66c42adc751b (patch) | |
tree | df55c68fbeb08c22cfe8afac0ea2114d74458fc6 | |
parent | 3bc5ed38750c6a6daff39ad524b75e40c8c09183 (diff) | |
download | llvm-84413e1947427a917a3e55abfc1f66c42adc751b.zip llvm-84413e1947427a917a3e55abfc1f66c42adc751b.tar.gz llvm-84413e1947427a917a3e55abfc1f66c42adc751b.tar.bz2 |
[RISCV] Support fixed-length vector truncates
This patch extends support for our custom-lowering of scalable-vector
truncates to include those of fixed-length vectors. It does this by
co-opting the custom RISCVISD::TRUNCATE_VECTOR node and adding mask and
VL operands. This avoids unnecessary duplication of patterns and
inflation of the ISel table.
Some truncates go through CONCAT_VECTORS which currently isn't
efficiently handled, as it goes through the stack. This can be improved
upon in the future.
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D97202
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 39 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.h | 8 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td | 13 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 18 | ||||
-rw-r--r-- | llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll | 77 |
5 files changed, 130 insertions, 25 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index abd9a3d..f189667 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -446,7 +446,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::FP_TO_SINT, VT, Custom); setOperationAction(ISD::FP_TO_UINT, VT, Custom); - // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR" + // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL" // nodes which truncate by one power of two at a time. setOperationAction(ISD::TRUNCATE, VT, Custom); @@ -526,6 +526,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // By default everything must be expanded. for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) setOperationAction(Op, VT, Expand); + for (MVT OtherVT : MVT::fixedlen_vector_valuetypes()) + setTruncStoreAction(VT, OtherVT, Expand); // We use EXTRACT_SUBVECTOR as a "cast" from scalable to fixed. setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); @@ -571,6 +573,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::VSELECT, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::ANY_EXTEND, VT, Custom); setOperationAction(ISD::SIGN_EXTEND, VT, Custom); setOperationAction(ISD::ZERO_EXTEND, VT, Custom); @@ -1171,7 +1174,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, } case ISD::TRUNCATE: { SDLoc DL(Op); - EVT VT = Op.getValueType(); + MVT VT = Op.getSimpleValueType(); // Only custom-lower vector truncates if (!VT.isVector()) return Op; @@ -1181,28 +1184,42 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return lowerVectorMaskTrunc(Op, DAG); // RVV only has truncates which operate from SEW*2->SEW, so lower arbitrary - // truncates as a series of "RISCVISD::TRUNCATE_VECTOR" nodes which + // truncates as a series of "RISCVISD::TRUNCATE_VECTOR_VL" nodes which // truncate by one power of two at a time. - EVT DstEltVT = VT.getVectorElementType(); + MVT DstEltVT = VT.getVectorElementType(); SDValue Src = Op.getOperand(0); - EVT SrcVT = Src.getValueType(); - EVT SrcEltVT = SrcVT.getVectorElementType(); + MVT SrcVT = Src.getSimpleValueType(); + MVT SrcEltVT = SrcVT.getVectorElementType(); assert(DstEltVT.bitsLT(SrcEltVT) && isPowerOf2_64(DstEltVT.getSizeInBits()) && isPowerOf2_64(SrcEltVT.getSizeInBits()) && "Unexpected vector truncate lowering"); + MVT ContainerVT = SrcVT; + if (SrcVT.isFixedLengthVector()) { + ContainerVT = RISCVTargetLowering::getContainerForFixedLengthVector( + DAG, SrcVT, Subtarget); + Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); + } + SDValue Result = Src; + SDValue Mask, VL; + std::tie(Mask, VL) = + getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); LLVMContext &Context = *DAG.getContext(); - const ElementCount Count = SrcVT.getVectorElementCount(); + const ElementCount Count = ContainerVT.getVectorElementCount(); do { - SrcEltVT = EVT::getIntegerVT(Context, SrcEltVT.getSizeInBits() / 2); + SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); EVT ResultVT = EVT::getVectorVT(Context, SrcEltVT, Count); - Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR, DL, ResultVT, Result); + Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, + Mask, VL); } while (SrcEltVT != DstEltVT); + if (SrcVT.isFixedLengthVector()) + Result = convertFromScalableVector(VT, Result, DAG, Subtarget); + return Result; } case ISD::ANY_EXTEND: @@ -5437,7 +5454,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VMV_X_S) NODE_NAME_CASE(SPLAT_VECTOR_I64) NODE_NAME_CASE(READ_VLENB) - NODE_NAME_CASE(TRUNCATE_VECTOR) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL) + NODE_NAME_CASE(VLEFF) + NODE_NAME_CASE(VLEFF_MASK) NODE_NAME_CASE(VSLIDEUP_VL) NODE_NAME_CASE(VSLIDEDOWN_VL) NODE_NAME_CASE(VID_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index dc7e05e..a75ebc3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -105,8 +105,12 @@ enum NodeType : unsigned { SPLAT_VECTOR_I64, // Read VLENB CSR READ_VLENB, - // Truncates a RVV integer vector by one power-of-two. - TRUNCATE_VECTOR, + // Truncates a RVV integer vector by one power-of-two. Carries both an extra + // mask and VL operand. + TRUNCATE_VECTOR_VL, + // Unit-stride fault-only-first load + VLEFF, + VLEFF_MASK, // Matches the semantics of vslideup/vslidedown. The first operand is the // pass-thru operand, the second is the source vector, the third is the // XLenVT index (either constant or non-constant), the fourth is the mask diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index ea0e5f1..c552865 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -28,10 +28,6 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [ def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>; -def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR", - SDTypeProfile<1, 1, - [SDTCisVec<0>, SDTCisVec<1>]>>; - // Give explicit Complexity to prefer simm5/uimm5. def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>; def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>; @@ -433,15 +429,6 @@ defm "" : VPatBinarySDNode_VV_VX_VI<shl, "PseudoVSLL", uimm5>; defm "" : VPatBinarySDNode_VV_VX_VI<srl, "PseudoVSRL", uimm5>; defm "" : VPatBinarySDNode_VV_VX_VI<sra, "PseudoVSRA", uimm5>; -// 12.7. Vector Narrowing Integer Right Shift Instructions -foreach vtiTofti = AllFractionableVF2IntVectors in { - defvar vti = vtiTofti.Vti; - defvar fti = vtiTofti.Fti; - def : Pat<(fti.Vector (riscv_trunc_vector (vti.Vector vti.RegClass:$rs1))), - (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX) - vti.RegClass:$rs1, 0, fti.AVL, fti.SEW)>; -} - // 12.8. Vector Integer Comparison Instructions defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETEQ, "PseudoVMSEQ">; defm "" : VPatIntegerSetCCSDNode_VV_VX_VI<SETNE, "PseudoVMSNE">; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 76eb5f6..2d5f8fa 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -148,6 +148,13 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>, def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>; def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>; +def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL", + SDTypeProfile<1, 3, [SDTCisVec<0>, + SDTCisVec<1>, + SDTCisSameNumEltsAs<0, 2>, + SDTCVecEltisVT<2, i1>, + SDTCisVT<3, XLenVT>]>>; + // Ignore the vl operand. def SplatFPOp : PatFrag<(ops node:$op), (riscv_vfmv_v_f_vl node:$op, srcvalue)>; @@ -443,6 +450,17 @@ defm "" : VPatBinaryVL_VV_VX_VI<riscv_shl_vl, "PseudoVSLL", uimm5>; defm "" : VPatBinaryVL_VV_VX_VI<riscv_srl_vl, "PseudoVSRL", uimm5>; defm "" : VPatBinaryVL_VV_VX_VI<riscv_sra_vl, "PseudoVSRA", uimm5>; +// 12.7. Vector Narrowing Integer Right Shift Instructions +foreach vtiTofti = AllFractionableVF2IntVectors in { + defvar vti = vtiTofti.Vti; + defvar fti = vtiTofti.Fti; + def : Pat<(fti.Vector (riscv_trunc_vector_vl (vti.Vector vti.RegClass:$rs1), + (vti.Mask true_mask), + (XLenVT (VLOp GPR:$vl)))), + (!cast<Instruction>("PseudoVNSRL_WI_"#fti.LMul.MX) + vti.RegClass:$rs1, 0, GPR:$vl, fti.SEW)>; +} + // 12.8. Vector Integer Comparison Instructions foreach vti = AllIntegerVectors in { defm "" : VPatIntegerSetCCVL_VV<vti, "PseudoVMSEQ", SETEQ>; diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll index ad04c1a..e4e033a 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-exttrunc.ll @@ -165,3 +165,80 @@ define void @sext_v32i8_v32i32(<32 x i8>* %x, <32 x i32>* %z) { store <32 x i32> %b, <32 x i32>* %z ret void } + +define void @trunc_v4i8_v4i32(<4 x i32>* %x, <4 x i8>* %z) { +; CHECK-LABEL: trunc_v4i8_v4i32: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli a2, 4, e32,m1,ta,mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: vsetivli a0, 4, e16,mf2,ta,mu +; CHECK-NEXT: vnsrl.wi v26, v25, 0 +; CHECK-NEXT: vsetivli a0, 4, e8,mf4,ta,mu +; CHECK-NEXT: vnsrl.wi v25, v26, 0 +; CHECK-NEXT: vsetivli a0, 4, e8,m1,ta,mu +; CHECK-NEXT: vse8.v v25, (a1) +; CHECK-NEXT: ret + %a = load <4 x i32>, <4 x i32>* %x + %b = trunc <4 x i32> %a to <4 x i8> + store <4 x i8> %b, <4 x i8>* %z + ret void +} + +define void @trunc_v8i8_v8i32(<8 x i32>* %x, <8 x i8>* %z) { +; LMULMAX8-LABEL: trunc_v8i8_v8i32: +; LMULMAX8: # %bb.0: +; LMULMAX8-NEXT: vsetivli a2, 8, e32,m2,ta,mu +; LMULMAX8-NEXT: vle32.v v26, (a0) +; LMULMAX8-NEXT: vsetivli a0, 8, e16,m1,ta,mu +; LMULMAX8-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX8-NEXT: vsetivli a0, 8, e8,mf2,ta,mu +; LMULMAX8-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX8-NEXT: vsetivli a0, 8, e8,m1,ta,mu +; LMULMAX8-NEXT: vse8.v v26, (a1) +; LMULMAX8-NEXT: ret +; +; LMULMAX2-LABEL: trunc_v8i8_v8i32: +; LMULMAX2: # %bb.0: +; LMULMAX2-NEXT: vsetivli a2, 8, e32,m2,ta,mu +; LMULMAX2-NEXT: vle32.v v26, (a0) +; LMULMAX2-NEXT: vsetivli a0, 8, e16,m1,ta,mu +; LMULMAX2-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX2-NEXT: vsetivli a0, 8, e8,mf2,ta,mu +; LMULMAX2-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX2-NEXT: vsetivli a0, 8, e8,m1,ta,mu +; LMULMAX2-NEXT: vse8.v v26, (a1) +; LMULMAX2-NEXT: ret +; +; LMULMAX1-LABEL: trunc_v8i8_v8i32: +; LMULMAX1: # %bb.0: +; LMULMAX1-NEXT: addi sp, sp, -16 +; LMULMAX1-NEXT: .cfi_def_cfa_offset 16 +; LMULMAX1-NEXT: vsetivli a2, 4, e32,m1,ta,mu +; LMULMAX1-NEXT: addi a2, a0, 16 +; LMULMAX1-NEXT: vle32.v v25, (a2) +; LMULMAX1-NEXT: vle32.v v26, (a0) +; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v27, v25, 0 +; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v25, v27, 0 +; LMULMAX1-NEXT: addi a0, sp, 12 +; LMULMAX1-NEXT: vsetivli a2, 4, e8,m1,ta,mu +; LMULMAX1-NEXT: vse8.v v25, (a0) +; LMULMAX1-NEXT: vsetivli a0, 4, e16,mf2,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v25, v26, 0 +; LMULMAX1-NEXT: vsetivli a0, 4, e8,mf4,ta,mu +; LMULMAX1-NEXT: vnsrl.wi v26, v25, 0 +; LMULMAX1-NEXT: vsetivli a0, 4, e8,m1,ta,mu +; LMULMAX1-NEXT: addi a0, sp, 8 +; LMULMAX1-NEXT: vse8.v v26, (a0) +; LMULMAX1-NEXT: vsetivli a0, 8, e8,m1,ta,mu +; LMULMAX1-NEXT: addi a0, sp, 8 +; LMULMAX1-NEXT: vle8.v v25, (a0) +; LMULMAX1-NEXT: vse8.v v25, (a1) +; LMULMAX1-NEXT: addi sp, sp, 16 +; LMULMAX1-NEXT: ret + %a = load <8 x i32>, <8 x i32>* %x + %b = trunc <8 x i32> %a to <8 x i8> + store <8 x i8> %b, <8 x i8>* %z + ret void +} |