aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp')
-rw-r--r--llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp184
1 files changed, 177 insertions, 7 deletions
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
index 7f35107..38c1f9868 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
@@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.clampScalar(0, s32, sXLen)
.minScalarSameAs(1, 0);
+ auto &ExtActions =
+ getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
+ .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
+ typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
if (ST.is64Bit()) {
- getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
- .legalFor({{sXLen, s32}})
- .maxScalar(0, sXLen);
-
+ ExtActions.legalFor({{sXLen, s32}});
getActionDefinitionsBuilder(G_SEXT_INREG)
.customFor({sXLen})
.maxScalar(0, sXLen)
.lower();
} else {
- getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar(0, sXLen);
-
getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower();
}
+ ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST))
+ .maxScalar(0, sXLen);
// Merge/Unmerge
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -235,7 +236,9 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
getActionDefinitionsBuilder(G_ICMP)
.legalFor({{sXLen, sXLen}, {sXLen, p0}})
- .widenScalarToNextPow2(1)
+ .legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
+ typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)))
+ .widenScalarOrEltToNextPow2OrMinSize(1, 8)
.clampScalar(1, sXLen, sXLen)
.clampScalar(0, sXLen, sXLen);
@@ -418,6 +421,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.clampScalar(0, sXLen, sXLen)
.customFor({sXLen});
+ auto &SplatActions =
+ getActionDefinitionsBuilder(G_SPLAT_VECTOR)
+ .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
+ typeIs(1, sXLen)))
+ .customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1)));
+ // Handle case of s64 element vectors on RV32. If the subtarget does not have
+ // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
+ // does have f64, then we don't know whether the type is an f64 or an i64,
+ // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
+ // depending on how the instructions it consumes are legalized. They are not
+ // legalized yet since legalization is in reverse postorder, so we cannot
+ // make the decision at this moment.
+ if (XLen == 32) {
+ if (ST.hasVInstructionsF64() && ST.hasStdExtD())
+ SplatActions.legalIf(all(
+ typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
+ else if (ST.hasVInstructionsI64())
+ SplatActions.customIf(all(
+ typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
+ }
+
+ SplatActions.clampScalar(1, sXLen, sXLen);
+
getLegacyLegalizerInfo().computeTables();
}
@@ -576,7 +602,145 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
}
+ MI.eraseFromParent();
+ return true;
+}
+
+// Custom-lower extensions from mask vectors by using a vselect either with 1
+// for zero/any-extension or -1 for sign-extension:
+// (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
+// Note that any-extension is lowered identically to zero-extension.
+bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
+ MachineIRBuilder &MIB) const {
+
+ unsigned Opc = MI.getOpcode();
+ assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
+ Opc == TargetOpcode::G_ANYEXT);
+
+ MachineRegisterInfo &MRI = *MIB.getMRI();
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+
+ LLT DstTy = MRI.getType(Dst);
+ int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1;
+ LLT DstEltTy = DstTy.getElementType();
+ auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
+ auto SplatTrue =
+ MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
+ MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);
+
+ MI.eraseFromParent();
+ return true;
+}
+
+/// Return the type of the mask type suitable for masking the provided
+/// vector type. This is simply an i1 element type vector of the same
+/// (possibly scalable) length.
+static LLT getMaskTypeFor(LLT VecTy) {
+ assert(VecTy.isVector());
+ ElementCount EC = VecTy.getElementCount();
+ return LLT::vector(EC, LLT::scalar(1));
+}
+
+/// Creates an all ones mask suitable for masking a vector of type VecTy with
+/// vector length VL.
+static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
+ MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ LLT MaskTy = getMaskTypeFor(VecTy);
+ return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
+}
+
+/// Gets the two common "VL" operands: an all-ones mask and the vector length.
+/// VecTy is a scalable vector type.
+static std::pair<MachineInstrBuilder, Register>
+buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ LLT VecTy = Dst.getLLTTy(MRI);
+ assert(VecTy.isScalableVector() && "Expecting scalable container type");
+ Register VL(RISCV::X0);
+ MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
+ return {Mask, VL};
+}
+
+static MachineInstrBuilder
+buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
+ Register Hi, Register VL, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI) {
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
+ // splat Lo even if it might be sign extended. I don't think we have
+ // introduced a case where we're build a s64 where the upper bits are undef
+ // yet.
+
+ // Fall back to a stack store and stride x0 vector load.
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
+ // preprocessDAG in SDAG.
+ return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
+ {Passthru, Lo, Hi, VL});
+}
+
+static MachineInstrBuilder
+buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
+ const SrcOp &Scalar, Register VL,
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
+ assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
+ auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
+ return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
+ Unmerge.getReg(1), VL, MIB, MRI);
+}
+
+// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
+// legal equivalently-sized i8 type, so we can use that as a go-between.
+// Splats of s1 types that have constant value can be legalized as VMSET_VL or
+// VMCLR_VL.
+bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
+ MachineIRBuilder &MIB) const {
+ assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
+
+ MachineRegisterInfo &MRI = *MIB.getMRI();
+
+ Register Dst = MI.getOperand(0).getReg();
+ Register SplatVal = MI.getOperand(1).getReg();
+
+ LLT VecTy = MRI.getType(Dst);
+ LLT XLenTy(STI.getXLenVT());
+
+ // Handle case of s64 element vectors on rv32
+ if (XLenTy.getSizeInBits() == 32 &&
+ VecTy.getElementType().getSizeInBits() == 64) {
+ auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
+ buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
+ MRI);
+ MI.eraseFromParent();
+ return true;
+ }
+
+ // All-zeros or all-ones splats are handled specially.
+ MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal);
+ if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) {
+ auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
+ MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL});
+ MI.eraseFromParent();
+ return true;
+ }
+ if (isNullOrNullSplat(SplatValMI, MRI)) {
+ auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
+ MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL});
+ MI.eraseFromParent();
+ return true;
+ }
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
+ // ones) by promoting it to an s8 splat.
+ LLT InterEltTy = LLT::scalar(8);
+ LLT InterTy = VecTy.changeElementType(InterEltTy);
+ auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal);
+ auto And =
+ MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1));
+ auto LHS = MIB.buildSplatVector(InterTy, And);
+ auto ZeroSplat =
+ MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0));
+ MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
MI.eraseFromParent();
return true;
}
@@ -640,6 +804,12 @@ bool RISCVLegalizerInfo::legalizeCustom(
return legalizeVAStart(MI, MIRBuilder);
case TargetOpcode::G_VSCALE:
return legalizeVScale(MI, MIRBuilder);
+ case TargetOpcode::G_ZEXT:
+ case TargetOpcode::G_SEXT:
+ case TargetOpcode::G_ANYEXT:
+ return legalizeExt(MI, MIRBuilder);
+ case TargetOpcode::G_SPLAT_VECTOR:
+ return legalizeSplatVector(MI, MIRBuilder);
}
llvm_unreachable("expected switch to return");