//===- CombinerHelperCasts.cpp---------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and // G_ZEXT // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/GlobalISel/Utils.h" #include "llvm/CodeGen/LowLevelTypeUtils.h" #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/Support/Casting.h" #define DEBUG_TYPE "gi-combiner" using namespace llvm; bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo) const { GSext *Sext = cast(getDefIgnoringCopies(MO.getReg(), MRI)); GTrunc *Trunc = cast(getDefIgnoringCopies(Sext->getSrcReg(), MRI)); Register Dst = Sext->getReg(0); Register Src = Trunc->getSrcReg(); LLT DstTy = MRI.getType(Dst); LLT SrcTy = MRI.getType(Src); // Combines without nsw trunc. if (!Trunc->getFlag(MachineInstr::NoSWrap)) { if (DstTy != SrcTy || !isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT_INREG, {DstTy, SrcTy}})) return false; // Do this for 8 bit values and up. We don't want to do it for e.g. G_TRUNC // to i1. unsigned TruncWidth = MRI.getType(Trunc->getReg(0)).getScalarSizeInBits(); if (TruncWidth < 8) return false; MatchInfo = [=](MachineIRBuilder &B) { B.buildSExtInReg(Dst, Src, TruncWidth); }; return true; } // Combines for nsw trunc. if (DstTy == SrcTy) { MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; return true; } if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap); }; return true; } if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; return true; } return false; } bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo) const { GZext *Zext = cast(getDefIgnoringCopies(MO.getReg(), MRI)); GTrunc *Trunc = cast(getDefIgnoringCopies(Zext->getSrcReg(), MRI)); Register Dst = Zext->getReg(0); Register Src = Trunc->getSrcReg(); LLT DstTy = MRI.getType(Dst); LLT SrcTy = MRI.getType(Src); if (DstTy == SrcTy) { MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; return true; } if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap); }; return true; } if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg); }; return true; } return false; } bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, BuildFnTy &MatchInfo) const { GZext *Zext = cast(MRI.getVRegDef(MO.getReg())); Register Dst = Zext->getReg(0); Register Src = Zext->getSrcReg(); LLT DstTy = MRI.getType(Dst); LLT SrcTy = MRI.getType(Src); const auto &TLI = getTargetLowering(); // Convert zext nneg to sext if sext is the preferred form for the target. if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) { MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; return true; } return false; } bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root, const MachineInstr &ExtMI, BuildFnTy &MatchInfo) const { const GTrunc *Trunc = cast(&Root); const GExtOp *Ext = cast(&ExtMI); if (!MRI.hasOneNonDBGUse(Ext->getReg(0))) return false; Register Dst = Trunc->getReg(0); Register Src = Ext->getSrcReg(); LLT DstTy = MRI.getType(Dst); LLT SrcTy = MRI.getType(Src); if (SrcTy == DstTy) { // The source and the destination are equally sized. We need to copy. MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; return true; } if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) { // If the source is smaller than the destination, we need to extend. if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}})) return false; MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Ext->getOpcode(), {Dst}, {Src}); }; return true; } if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) { // If the source is larger than the destination, then we need to truncate. if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) return false; MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); }; return true; } return false; } bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const { const TargetLowering &TLI = getTargetLowering(); LLVMContext &Ctx = getContext(); switch (Opcode) { case TargetOpcode::G_ANYEXT: case TargetOpcode::G_ZEXT: return TLI.isZExtFree(FromTy, ToTy, Ctx); case TargetOpcode::G_TRUNC: return TLI.isTruncateFree(FromTy, ToTy, Ctx); default: return false; } } bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI, const MachineInstr &SelectMI, BuildFnTy &MatchInfo) const { const GExtOrTruncOp *Cast = cast(&CastMI); const GSelect *Select = cast(&SelectMI); if (!MRI.hasOneNonDBGUse(Select->getReg(0))) return false; Register Dst = Cast->getReg(0); LLT DstTy = MRI.getType(Dst); LLT CondTy = MRI.getType(Select->getCondReg()); Register TrueReg = Select->getTrueReg(); Register FalseReg = Select->getFalseReg(); LLT SrcTy = MRI.getType(TrueReg); Register Cond = Select->getCondReg(); if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}})) return false; if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy)) return false; MatchInfo = [=](MachineIRBuilder &B) { auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg}); auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg}); B.buildSelect(Dst, Cond, True, False); }; return true; } bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI, const MachineInstr &SecondMI, BuildFnTy &MatchInfo) const { const GExtOp *First = cast(&FirstMI); const GExtOp *Second = cast(&SecondMI); Register Dst = First->getReg(0); Register Src = Second->getSrcReg(); LLT DstTy = MRI.getType(Dst); LLT SrcTy = MRI.getType(Src); if (!MRI.hasOneNonDBGUse(Second->getReg(0))) return false; // ext of ext -> later ext if (First->getOpcode() == Second->getOpcode() && isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { if (Second->getOpcode() == TargetOpcode::G_ZEXT) { MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) Flag = MachineInstr::MIFlag::NonNeg; MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; return true; } // not zext -> no flags MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Second->getOpcode(), {Dst}, {Src}); }; return true; } // anyext of sext/zext -> sext/zext // -> pick anyext as second ext, then ext of ext if (First->getOpcode() == TargetOpcode::G_ANYEXT && isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { if (Second->getOpcode() == TargetOpcode::G_ZEXT) { MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) Flag = MachineInstr::MIFlag::NonNeg; MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; return true; } MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; return true; } // sext/zext of anyext -> sext/zext // -> pick anyext as first ext, then ext of ext if (Second->getOpcode() == TargetOpcode::G_ANYEXT && isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) { if (First->getOpcode() == TargetOpcode::G_ZEXT) { MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; if (First->getFlag(MachineInstr::MIFlag::NonNeg)) Flag = MachineInstr::MIFlag::NonNeg; MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; return true; } MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; return true; } return false; } bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI, const MachineInstr &BVMI, BuildFnTy &MatchInfo) const { const GExtOrTruncOp *Cast = cast(&CastMI); const GBuildVector *BV = cast(&BVMI); if (!MRI.hasOneNonDBGUse(BV->getReg(0))) return false; Register Dst = Cast->getReg(0); // The type of the new build vector. LLT DstTy = MRI.getType(Dst); // The scalar or element type of the new build vector. LLT ElemTy = DstTy.getScalarType(); // The scalar or element type of the old build vector. LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType(); // Check legality of new build vector, the scalar casts, and profitability of // the many casts. if (!isLegalOrBeforeLegalizer( {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) || !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) || !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy)) return false; MatchInfo = [=](MachineIRBuilder &B) { SmallVector Casts; unsigned Elements = BV->getNumSources(); for (unsigned I = 0; I < Elements; ++I) { auto CastI = B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)}); Casts.push_back(CastI.getReg(0)); } B.buildBuildVector(Dst, Casts); }; return true; } bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI, const MachineInstr &BinopMI, BuildFnTy &MatchInfo) const { const GTrunc *Trunc = cast(&TruncMI); const GBinOp *BinOp = cast(&BinopMI); if (!MRI.hasOneNonDBGUse(BinOp->getReg(0))) return false; Register Dst = Trunc->getReg(0); LLT DstTy = MRI.getType(Dst); // Is narrow binop legal? if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}})) return false; MatchInfo = [=](MachineIRBuilder &B) { auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg()); auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg()); B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS}); }; return true; } bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI, APInt &MatchInfo) const { const GExtOrTruncOp *Cast = cast(&CastMI); APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI); LLT DstTy = MRI.getType(Cast->getReg(0)); if (!isConstantLegalOrBeforeLegalizer(DstTy)) return false; switch (Cast->getOpcode()) { case TargetOpcode::G_TRUNC: { MatchInfo = Input.trunc(DstTy.getScalarSizeInBits()); return true; } default: return false; } } bool CombinerHelper::matchRedundantSextInReg(MachineInstr &Root, MachineInstr &Other, BuildFnTy &MatchInfo) const { assert(Root.getOpcode() == TargetOpcode::G_SEXT_INREG && Other.getOpcode() == TargetOpcode::G_SEXT_INREG); unsigned RootWidth = Root.getOperand(2).getImm(); unsigned OtherWidth = Other.getOperand(2).getImm(); Register Dst = Root.getOperand(0).getReg(); Register OtherDst = Other.getOperand(0).getReg(); Register Src = Other.getOperand(1).getReg(); if (RootWidth >= OtherWidth) { // The root sext_inreg is entirely redundant because the other one // is narrower. if (!canReplaceReg(Dst, OtherDst, MRI)) return false; MatchInfo = [=](MachineIRBuilder &B) { Observer.changingAllUsesOfReg(MRI, Dst); MRI.replaceRegWith(Dst, OtherDst); Observer.finishedChangingAllUsesOfReg(); }; } else { // RootWidth < OtherWidth, rewrite this G_SEXT_INREG with the source of the // other G_SEXT_INREG. MatchInfo = [=](MachineIRBuilder &B) { B.buildSExtInReg(Dst, Src, RootWidth); }; } return true; }