From fc5c5a934d2560559221bcb334b14ef4aa96a2dd Mon Sep 17 00:00:00 2001 From: jyli0116 Date: Thu, 17 Jul 2025 14:43:58 +0100 Subject: [GlobalISel] Allow expansion of srem by constant in prelegalizer (#148845) This patch allows srem by a constant to be expanded more efficiently to avoid the need for expensive sdiv instructions. This is the last part of the patches which fixes #118090 --- llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp | 45 +++++++++++++++++--------- 1 file changed, 29 insertions(+), 16 deletions(-) (limited to 'llvm/lib/CodeGen') diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 3922eba..e8f513a 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -5300,7 +5300,7 @@ bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, return false; } -MachineInstr *CombinerHelper::buildUDivorURemUsingMul(MachineInstr &MI) const { +MachineInstr *CombinerHelper::buildUDivOrURemUsingMul(MachineInstr &MI) const { unsigned Opcode = MI.getOpcode(); assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); auto &UDivorRem = cast(MI); @@ -5468,7 +5468,7 @@ MachineInstr *CombinerHelper::buildUDivorURemUsingMul(MachineInstr &MI) const { return ret; } -bool CombinerHelper::matchUDivorURemByConst(MachineInstr &MI) const { +bool CombinerHelper::matchUDivOrURemByConst(MachineInstr &MI) const { unsigned Opcode = MI.getOpcode(); assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); Register Dst = MI.getOperand(0).getReg(); @@ -5517,13 +5517,14 @@ bool CombinerHelper::matchUDivorURemByConst(MachineInstr &MI) const { MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); } -void CombinerHelper::applyUDivorURemByConst(MachineInstr &MI) const { - auto *NewMI = buildUDivorURemUsingMul(MI); +void CombinerHelper::applyUDivOrURemByConst(MachineInstr &MI) const { + auto *NewMI = buildUDivOrURemUsingMul(MI); replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); } -bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { - assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); +bool CombinerHelper::matchSDivOrSRemByConst(MachineInstr &MI) const { + unsigned Opcode = MI.getOpcode(); + assert(Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM); Register Dst = MI.getOperand(0).getReg(); Register RHS = MI.getOperand(2).getReg(); LLT DstTy = MRI.getType(Dst); @@ -5543,7 +5544,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { return false; // If the sdiv has an 'exact' flag we can use a simpler lowering. - if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { + if (Opcode == TargetOpcode::G_SDIV && + MI.getFlag(MachineInstr::MIFlag::IsExact)) { return matchUnaryPredicate( MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); } @@ -5559,23 +5561,28 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) && !isLegalOrHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}})) return false; + if (Opcode == TargetOpcode::G_SREM && + !isLegalOrBeforeLegalizer({TargetOpcode::G_SUB, {DstTy, DstTy}})) + return false; } return matchUnaryPredicate( MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); } -void CombinerHelper::applySDivByConst(MachineInstr &MI) const { - auto *NewMI = buildSDivUsingMul(MI); +void CombinerHelper::applySDivOrSRemByConst(MachineInstr &MI) const { + auto *NewMI = buildSDivOrSRemUsingMul(MI); replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); } -MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const { - assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); - auto &SDiv = cast(MI); - Register Dst = SDiv.getReg(0); - Register LHS = SDiv.getReg(1); - Register RHS = SDiv.getReg(2); +MachineInstr *CombinerHelper::buildSDivOrSRemUsingMul(MachineInstr &MI) const { + unsigned Opcode = MI.getOpcode(); + assert(MI.getOpcode() == TargetOpcode::G_SDIV || + Opcode == TargetOpcode::G_SREM); + auto &SDivorRem = cast(MI); + Register Dst = SDivorRem.getReg(0); + Register LHS = SDivorRem.getReg(1); + Register RHS = SDivorRem.getReg(2); LLT Ty = MRI.getType(Dst); LLT ScalarTy = Ty.getScalarType(); const unsigned EltBits = ScalarTy.getScalarSizeInBits(); @@ -5705,7 +5712,13 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const { auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1); auto T = MIB.buildLShr(Ty, Q, SignShift); T = MIB.buildAnd(Ty, T, ShiftMask); - return MIB.buildAdd(Ty, Q, T); + auto ret = MIB.buildAdd(Ty, Q, T); + + if (Opcode == TargetOpcode::G_SREM) { + auto Prod = MIB.buildMul(Ty, ret, RHS); + return MIB.buildSub(Ty, LHS, Prod); + } + return ret; } bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const { -- cgit v1.1