diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64Combine.td | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 111 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrFormats.td | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 10 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.td | 379 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td | 6 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 64 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp | 10 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp | 16 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp | 18 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp | 21 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h | 19 |
13 files changed, 444 insertions, 214 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td index ecaeff7..b3ec65c 100644 --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -71,7 +71,6 @@ def AArch64PreLegalizerCombiner: GICombiner< "AArch64PreLegalizerCombinerImpl", [all_combines, icmp_redundant_trunc, fold_global_offset, - shuffle_to_extract, ext_addv_to_udot_addv, ext_uaddv_to_uaddlv, push_sub_through_zext, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a81de5c..d16b116 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9002,12 +9002,12 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI, } static SMECallAttrs -getSMECallAttrs(const Function &Caller, const AArch64TargetLowering &TLI, +getSMECallAttrs(const Function &Caller, const RTLIB::RuntimeLibcallsInfo &RTLCI, const TargetLowering::CallLoweringInfo &CLI) { if (CLI.CB) - return SMECallAttrs(*CLI.CB, &TLI); + return SMECallAttrs(*CLI.CB, &RTLCI); if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee)) - return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), TLI)); + return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(ES->getSymbol(), RTLCI)); return SMECallAttrs(SMEAttrs(Caller), SMEAttrs(SMEAttrs::Normal)); } @@ -9029,7 +9029,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // SME Streaming functions are not eligible for TCO as they may require // the streaming mode or ZA to be restored after returning from the call. - SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, *this, CLI); + SMECallAttrs CallAttrs = + getSMECallAttrs(CallerF, getRuntimeLibcallsInfo(), CLI); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingAllZAState() || CallAttrs.caller().hasStreamingBody()) @@ -9454,7 +9455,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Determine whether we need any streaming mode changes. - SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI); + SMECallAttrs CallAttrs = + getSMECallAttrs(MF.getFunction(), getRuntimeLibcallsInfo(), CLI); std::optional<unsigned> ZAMarkerNode; bool UseNewSMEABILowering = getTM().useNewSMEABILowering(); @@ -19476,6 +19478,61 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { Op1 ? Op1 : Mul->getOperand(1)); } +// Multiplying an RDSVL value by a constant can sometimes be done cheaper by +// folding a power-of-two factor of the constant into the RDSVL immediate and +// compensating with an extra shift. +// +// We rewrite: +// (mul (srl (rdsvl 1), w), x) +// to one of: +// (shl (rdsvl y), z) if z > 0 +// (srl (rdsvl y), abs(z)) if z < 0 +// where integers y, z satisfy x = y * 2^(w + z) and y ∈ [-32, 31]. +static SDValue performMulRdsvlCombine(SDNode *Mul, SelectionDAG &DAG) { + SDLoc DL(Mul); + EVT VT = Mul->getValueType(0); + SDValue MulOp0 = Mul->getOperand(0); + int ConstMultiplier = + cast<ConstantSDNode>(Mul->getOperand(1))->getSExtValue(); + if ((MulOp0->getOpcode() != ISD::SRL) || + (MulOp0->getOperand(0).getOpcode() != AArch64ISD::RDSVL)) + return SDValue(); + + unsigned AbsConstValue = abs(ConstMultiplier); + unsigned OperandShift = + cast<ConstantSDNode>(MulOp0->getOperand(1))->getZExtValue(); + + // z ≤ ctz(|x|) - w (largest extra shift we can take while keeping y + // integral) + int UpperBound = llvm::countr_zero(AbsConstValue) - OperandShift; + + // To keep y in range, with B = 31 for x > 0 and B = 32 for x < 0, we need: + // 2^(w + z) ≥ ceil(x / B) ⇒ z ≥ ceil_log2(ceil(x / B)) - w (LowerBound). + unsigned B = ConstMultiplier < 0 ? 32 : 31; + unsigned CeilAxOverB = (AbsConstValue + (B - 1)) / B; // ceil(|x|/B) + int LowerBound = llvm::Log2_32_Ceil(CeilAxOverB) - OperandShift; + + // No valid solution found. + if (LowerBound > UpperBound) + return SDValue(); + + // Any value of z in [LowerBound, UpperBound] is valid. Prefer no extra + // shift if possible. + int Shift = std::min(std::max(/*prefer*/ 0, LowerBound), UpperBound); + + // y = x / 2^(w + z) + int32_t RdsvlMul = (AbsConstValue >> (OperandShift + Shift)) * + (ConstMultiplier < 0 ? -1 : 1); + auto Rdsvl = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64, + DAG.getSignedConstant(RdsvlMul, DL, MVT::i32)); + + if (Shift == 0) + return Rdsvl; + return DAG.getNode(Shift < 0 ? ISD::SRL : ISD::SHL, DL, VT, Rdsvl, + DAG.getConstant(abs(Shift), DL, MVT::i32), + SDNodeFlags::Exact); +} + // Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz // Same for other types with equivalent constants. static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) { @@ -19604,6 +19661,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, if (!isa<ConstantSDNode>(N1)) return SDValue(); + if (SDValue Ext = performMulRdsvlCombine(N, DAG)) + return Ext; + ConstantSDNode *C = cast<ConstantSDNode>(N1); const APInt &ConstValue = C->getAPIntValue(); @@ -26665,11 +26725,34 @@ static SDValue performDUPCombine(SDNode *N, } if (N->getOpcode() == AArch64ISD::DUP) { + SDValue Op = N->getOperand(0); + + // Optimize DUP(extload/zextload i8/i16/i32) to avoid GPR->FPR transfer. + // For example: + // v4i32 = DUP (i32 (zextloadi8 addr)) + // => + // v4i32 = SCALAR_TO_VECTOR (i32 (zextloadi8 addr)) ; Matches to ldr b0 + // v4i32 = DUPLANE32 (v4i32), 0 + if (auto *LD = dyn_cast<LoadSDNode>(Op)) { + ISD::LoadExtType ExtType = LD->getExtensionType(); + EVT MemVT = LD->getMemoryVT(); + EVT ElemVT = VT.getVectorElementType(); + if ((ExtType == ISD::EXTLOAD || ExtType == ISD::ZEXTLOAD) && + (MemVT == MVT::i8 || MemVT == MVT::i16 || MemVT == MVT::i32) && + ElemVT != MemVT && LD->hasOneUse()) { + EVT Vec128VT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT, + 128 / ElemVT.getSizeInBits()); + SDValue ScalarToVec = + DCI.DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, Vec128VT, Op); + return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, ScalarToVec, + DCI.DAG.getConstant(0, DL, MVT::i64)); + } + } + // If the instruction is known to produce a scalar in SIMD registers, we can // duplicate it across the vector lanes using DUPLANE instead of moving it // to a GPR first. For example, this allows us to handle: // v4i32 = DUP (i32 (FCMGT (f32, f32))) - SDValue Op = N->getOperand(0); // FIXME: Ideally, we should be able to handle all instructions that // produce a scalar value in FPRs. if (Op.getOpcode() == AArch64ISD::FCMEQ || @@ -29430,15 +29513,6 @@ void AArch64TargetLowering::insertSSPDeclarations(Module &M) const { TargetLowering::insertSSPDeclarations(M); } -Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const { - // MSVC CRT has a function to validate security cookie. - RTLIB::LibcallImpl SecurityCheckCookieLibcall = - getLibcallImpl(RTLIB::SECURITY_CHECK_COOKIE); - if (SecurityCheckCookieLibcall != RTLIB::Unsupported) - return M.getFunction(getLibcallImplName(SecurityCheckCookieLibcall)); - return TargetLowering::getSSPStackGuardCheck(M); -} - Value * AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const { // Android provides a fixed TLS slot for the SafeStack pointer. See the @@ -29447,11 +29521,6 @@ AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const { if (Subtarget->isTargetAndroid()) return UseTlsOffset(IRB, 0x48); - // Fuchsia is similar. - // <zircon/tls.h> defines ZX_TLS_UNSAFE_SP_OFFSET with this value. - if (Subtarget->isTargetFuchsia()) - return UseTlsOffset(IRB, -0x8); - return TargetLowering::getSafeStackPointerLocation(IRB); } @@ -29769,7 +29838,7 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { // Checks to allow the use of SME instructions if (auto *Base = dyn_cast<CallBase>(&Inst)) { - auto CallAttrs = SMECallAttrs(*Base, this); + auto CallAttrs = SMECallAttrs(*Base, &getRuntimeLibcallsInfo()); if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() || CallAttrs.requiresPreservingZT0() || CallAttrs.requiresPreservingAllZAState()) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 9495c9f..2cb8ed2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -366,7 +366,6 @@ public: Value *getIRStackGuard(IRBuilderBase &IRB) const override; void insertSSPDeclarations(Module &M) const override; - Function *getSSPStackGuardCheck(const Module &M) const override; /// If the target has a standard location for the unsafe stack pointer, /// returns the address of that location. Otherwise, returns nullptr. diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index eab1627..58a53af 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -5298,7 +5298,7 @@ multiclass FPToIntegerUnscaled<bits<2> rmode, bits<3> opcode, string asm, } multiclass FPToIntegerSIMDScalar<bits<2> rmode, bits<3> opcode, string asm, - SDPatternOperator OpN = null_frag> { + SDPatternOperator OpN> { // double-precision to 32-bit SIMD/FPR def SDr : BaseFPToIntegerUnscaled<0b01, rmode, opcode, FPR64, FPR32, asm, [(set FPR32:$Rd, (i32 (OpN (f64 FPR64:$Rn))))]> { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index d5117da..457e540 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -5151,7 +5151,15 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, // GPR32 zeroing if (AArch64::GPR32spRegClass.contains(DestReg) && SrcReg == AArch64::WZR) { - if (Subtarget.hasZeroCycleZeroingGPR32()) { + if (Subtarget.hasZeroCycleZeroingGPR64() && + !Subtarget.hasZeroCycleZeroingGPR32()) { + MCRegister DestRegX = RI.getMatchingSuperReg(DestReg, AArch64::sub_32, + &AArch64::GPR64spRegClass); + assert(DestRegX.isValid() && "Destination super-reg not valid"); + BuildMI(MBB, I, DL, get(AArch64::MOVZXi), DestRegX) + .addImm(0) + .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); + } else if (Subtarget.hasZeroCycleZeroingGPR32()) { BuildMI(MBB, I, DL, get(AArch64::MOVZWi), DestReg) .addImm(0) .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index b74ca79..b9e299e 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -4022,22 +4022,6 @@ defm LDRSW : LoadUI<0b10, 0, 0b10, GPR64, uimm12s4, "ldrsw", def : Pat<(i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))), (SUBREG_TO_REG (i64 0), (LDRWui GPR64sp:$Rn, uimm12s4:$offset), sub_32)>; -// load zero-extended i32, bitcast to f64 -def : Pat<(f64 (bitconvert (i64 (zextloadi32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; -// load zero-extended i16, bitcast to f64 -def : Pat<(f64 (bitconvert (i64 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; -// load zero-extended i8, bitcast to f64 -def : Pat<(f64 (bitconvert (i64 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; -// load zero-extended i16, bitcast to f32 -def : Pat<(f32 (bitconvert (i32 (zextloadi16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; -// load zero-extended i8, bitcast to f32 -def : Pat<(f32 (bitconvert (i32 (zextloadi8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), - (SUBREG_TO_REG (i32 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; - // Pre-fetch. def PRFMui : PrefetchUI<0b11, 0, 0b10, "prfm", [(AArch64Prefetch timm:$Rt, @@ -4389,6 +4373,64 @@ def : Pat <(v1i64 (scalar_to_vector (i64 (load (ro64.Xpat GPR64sp:$Rn, GPR64:$Rm, ro64.Xext:$extend))))), (LDRDroX GPR64sp:$Rn, GPR64:$Rm, ro64.Xext:$extend)>; +// Patterns for bitconvert or scalar_to_vector of load operations. +// Enables direct SIMD register loads for small integer types (i8/i16) that are +// naturally zero-extended to i32/i64. +multiclass ExtLoad8_16AllModes<ValueType OutTy, ValueType InnerTy, + SDPatternOperator OuterOp, + PatFrags LoadOp8, PatFrags LoadOp16> { + // 8-bit loads. + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp8 (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRBui GPR64sp:$Rn, uimm12s1:$offset), bsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp8 (am_unscaled8 GPR64sp:$Rn, simm9:$offset))))), + (SUBREG_TO_REG (i64 0), (LDURBi GPR64sp:$Rn, simm9:$offset), bsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp8 (ro8.Wpat GPR64sp:$Rn, GPR32:$Rm, ro8.Wext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRBroW GPR64sp:$Rn, GPR32:$Rm, ro8.Wext:$extend), bsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp8 (ro8.Xpat GPR64sp:$Rn, GPR64:$Rm, ro8.Xext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRBroX GPR64sp:$Rn, GPR64:$Rm, ro8.Xext:$extend), bsub)>; + + // 16-bit loads. + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp16 (am_indexed16 GPR64sp:$Rn, uimm12s2:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRHui GPR64sp:$Rn, uimm12s2:$offset), hsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp16 (am_unscaled16 GPR64sp:$Rn, simm9:$offset))))), + (SUBREG_TO_REG (i64 0), (LDURHi GPR64sp:$Rn, simm9:$offset), hsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp16 (ro16.Wpat GPR64sp:$Rn, GPR32:$Rm, ro16.Wext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRHroW GPR64sp:$Rn, GPR32:$Rm, ro16.Wext:$extend), hsub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp16 (ro16.Xpat GPR64sp:$Rn, GPR64:$Rm, ro16.Xext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRHroX GPR64sp:$Rn, GPR64:$Rm, ro16.Xext:$extend), hsub)>; +} + +// Extended multiclass that includes 32-bit loads in addition to 8-bit and 16-bit. +multiclass ExtLoad8_16_32AllModes<ValueType OutTy, ValueType InnerTy, + SDPatternOperator OuterOp, + PatFrags LoadOp8, PatFrags LoadOp16, PatFrags LoadOp32> { + defm : ExtLoad8_16AllModes<OutTy, InnerTy, OuterOp, LoadOp8, LoadOp16>; + + // 32-bit loads. + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp32 (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))))), + (SUBREG_TO_REG (i64 0), (LDRSui GPR64sp:$Rn, uimm12s4:$offset), ssub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp32 (am_unscaled32 GPR64sp:$Rn, simm9:$offset))))), + (SUBREG_TO_REG (i64 0), (LDURSi GPR64sp:$Rn, simm9:$offset), ssub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp32 (ro32.Wpat GPR64sp:$Rn, GPR32:$Rm, ro32.Wext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRSroW GPR64sp:$Rn, GPR32:$Rm, ro32.Wext:$extend), ssub)>; + def : Pat<(OutTy (OuterOp (InnerTy (LoadOp32 (ro32.Xpat GPR64sp:$Rn, GPR64:$Rm, ro32.Xext:$extend))))), + (SUBREG_TO_REG (i64 0), (LDRSroX GPR64sp:$Rn, GPR64:$Rm, ro32.Xext:$extend), ssub)>; +} + +// Instantiate bitconvert patterns for floating-point types. +defm : ExtLoad8_16AllModes<f32, i32, bitconvert, zextloadi8, zextloadi16>; +defm : ExtLoad8_16_32AllModes<f64, i64, bitconvert, zextloadi8, zextloadi16, zextloadi32>; + +// Instantiate scalar_to_vector patterns for all vector types. +defm : ExtLoad8_16AllModes<v16i8, i32, scalar_to_vector, zextloadi8, zextloadi16>; +defm : ExtLoad8_16AllModes<v16i8, i32, scalar_to_vector, extloadi8, extloadi16>; +defm : ExtLoad8_16AllModes<v8i16, i32, scalar_to_vector, zextloadi8, zextloadi16>; +defm : ExtLoad8_16AllModes<v8i16, i32, scalar_to_vector, extloadi8, extloadi16>; +defm : ExtLoad8_16AllModes<v4i32, i32, scalar_to_vector, zextloadi8, zextloadi16>; +defm : ExtLoad8_16AllModes<v4i32, i32, scalar_to_vector, extloadi8, extloadi16>; +defm : ExtLoad8_16_32AllModes<v2i64, i64, scalar_to_vector, zextloadi8, zextloadi16, zextloadi32>; +defm : ExtLoad8_16_32AllModes<v2i64, i64, scalar_to_vector, extloadi8, extloadi16, extloadi32>; + // Pre-fetch. defm PRFUM : PrefetchUnscaled<0b11, 0, 0b10, "prfum", [(AArch64Prefetch timm:$Rt, @@ -5253,113 +5295,10 @@ let Predicates = [HasNEON, HasFPRCVT] in{ defm FCVTNU : FPToIntegerSIMDScalar<0b01, 0b011, "fcvtnu", int_aarch64_neon_fcvtnu>; defm FCVTPS : FPToIntegerSIMDScalar<0b10, 0b010, "fcvtps", int_aarch64_neon_fcvtps>; defm FCVTPU : FPToIntegerSIMDScalar<0b10, 0b011, "fcvtpu", int_aarch64_neon_fcvtpu>; - defm FCVTZS : FPToIntegerSIMDScalar<0b10, 0b110, "fcvtzs">; - defm FCVTZU : FPToIntegerSIMDScalar<0b10, 0b111, "fcvtzu">; -} - - -// AArch64's FCVT instructions saturate when out of range. -multiclass FPToIntegerSatPats<SDNode to_int_sat, SDNode to_int_sat_gi, string INST> { - let Predicates = [HasFullFP16] in { - def : Pat<(i32 (to_int_sat f16:$Rn, i32)), - (!cast<Instruction>(INST # UWHr) f16:$Rn)>; - def : Pat<(i64 (to_int_sat f16:$Rn, i64)), - (!cast<Instruction>(INST # UXHr) f16:$Rn)>; - } - def : Pat<(i32 (to_int_sat f32:$Rn, i32)), - (!cast<Instruction>(INST # UWSr) f32:$Rn)>; - def : Pat<(i64 (to_int_sat f32:$Rn, i64)), - (!cast<Instruction>(INST # UXSr) f32:$Rn)>; - def : Pat<(i32 (to_int_sat f64:$Rn, i32)), - (!cast<Instruction>(INST # UWDr) f64:$Rn)>; - def : Pat<(i64 (to_int_sat f64:$Rn, i64)), - (!cast<Instruction>(INST # UXDr) f64:$Rn)>; - - let Predicates = [HasFullFP16] in { - def : Pat<(i32 (to_int_sat_gi f16:$Rn)), - (!cast<Instruction>(INST # UWHr) f16:$Rn)>; - def : Pat<(i64 (to_int_sat_gi f16:$Rn)), - (!cast<Instruction>(INST # UXHr) f16:$Rn)>; - } - def : Pat<(i32 (to_int_sat_gi f32:$Rn)), - (!cast<Instruction>(INST # UWSr) f32:$Rn)>; - def : Pat<(i64 (to_int_sat_gi f32:$Rn)), - (!cast<Instruction>(INST # UXSr) f32:$Rn)>; - def : Pat<(i32 (to_int_sat_gi f64:$Rn)), - (!cast<Instruction>(INST # UWDr) f64:$Rn)>; - def : Pat<(i64 (to_int_sat_gi f64:$Rn)), - (!cast<Instruction>(INST # UXDr) f64:$Rn)>; - - let Predicates = [HasFullFP16] in { - def : Pat<(i32 (to_int_sat (fmul f16:$Rn, fixedpoint_f16_i32:$scale), i32)), - (!cast<Instruction>(INST # SWHri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat (fmul f16:$Rn, fixedpoint_f16_i64:$scale), i64)), - (!cast<Instruction>(INST # SXHri) $Rn, $scale)>; - } - def : Pat<(i32 (to_int_sat (fmul f32:$Rn, fixedpoint_f32_i32:$scale), i32)), - (!cast<Instruction>(INST # SWSri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat (fmul f32:$Rn, fixedpoint_f32_i64:$scale), i64)), - (!cast<Instruction>(INST # SXSri) $Rn, $scale)>; - def : Pat<(i32 (to_int_sat (fmul f64:$Rn, fixedpoint_f64_i32:$scale), i32)), - (!cast<Instruction>(INST # SWDri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat (fmul f64:$Rn, fixedpoint_f64_i64:$scale), i64)), - (!cast<Instruction>(INST # SXDri) $Rn, $scale)>; - - let Predicates = [HasFullFP16] in { - def : Pat<(i32 (to_int_sat_gi (fmul f16:$Rn, fixedpoint_f16_i32:$scale))), - (!cast<Instruction>(INST # SWHri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat_gi (fmul f16:$Rn, fixedpoint_f16_i64:$scale))), - (!cast<Instruction>(INST # SXHri) $Rn, $scale)>; - } - def : Pat<(i32 (to_int_sat_gi (fmul f32:$Rn, fixedpoint_f32_i32:$scale))), - (!cast<Instruction>(INST # SWSri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat_gi (fmul f32:$Rn, fixedpoint_f32_i64:$scale))), - (!cast<Instruction>(INST # SXSri) $Rn, $scale)>; - def : Pat<(i32 (to_int_sat_gi (fmul f64:$Rn, fixedpoint_f64_i32:$scale))), - (!cast<Instruction>(INST # SWDri) $Rn, $scale)>; - def : Pat<(i64 (to_int_sat_gi (fmul f64:$Rn, fixedpoint_f64_i64:$scale))), - (!cast<Instruction>(INST # SXDri) $Rn, $scale)>; -} - -defm : FPToIntegerSatPats<fp_to_sint_sat, fp_to_sint_sat_gi, "FCVTZS">; -defm : FPToIntegerSatPats<fp_to_uint_sat, fp_to_uint_sat_gi, "FCVTZU">; - -multiclass FPToIntegerPats<SDNode to_int, SDNode to_int_sat, SDNode round, string INST> { - def : Pat<(i32 (to_int (round f32:$Rn))), - (!cast<Instruction>(INST # UWSr) f32:$Rn)>; - def : Pat<(i64 (to_int (round f32:$Rn))), - (!cast<Instruction>(INST # UXSr) f32:$Rn)>; - def : Pat<(i32 (to_int (round f64:$Rn))), - (!cast<Instruction>(INST # UWDr) f64:$Rn)>; - def : Pat<(i64 (to_int (round f64:$Rn))), - (!cast<Instruction>(INST # UXDr) f64:$Rn)>; - - // These instructions saturate like fp_to_[su]int_sat. - let Predicates = [HasFullFP16] in { - def : Pat<(i32 (to_int_sat (round f16:$Rn), i32)), - (!cast<Instruction>(INST # UWHr) f16:$Rn)>; - def : Pat<(i64 (to_int_sat (round f16:$Rn), i64)), - (!cast<Instruction>(INST # UXHr) f16:$Rn)>; - } - def : Pat<(i32 (to_int_sat (round f32:$Rn), i32)), - (!cast<Instruction>(INST # UWSr) f32:$Rn)>; - def : Pat<(i64 (to_int_sat (round f32:$Rn), i64)), - (!cast<Instruction>(INST # UXSr) f32:$Rn)>; - def : Pat<(i32 (to_int_sat (round f64:$Rn), i32)), - (!cast<Instruction>(INST # UWDr) f64:$Rn)>; - def : Pat<(i64 (to_int_sat (round f64:$Rn), i64)), - (!cast<Instruction>(INST # UXDr) f64:$Rn)>; + defm FCVTZS : FPToIntegerSIMDScalar<0b10, 0b110, "fcvtzs", any_fp_to_sint>; + defm FCVTZU : FPToIntegerSIMDScalar<0b10, 0b111, "fcvtzu", any_fp_to_uint>; } -defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fceil, "FCVTPS">; -defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fceil, "FCVTPU">; -defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, ffloor, "FCVTMS">; -defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, ffloor, "FCVTMU">; -defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, ftrunc, "FCVTZS">; -defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, ftrunc, "FCVTZU">; -defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fround, "FCVTAS">; -defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fround, "FCVTAU">; - let Predicates = [HasFullFP16] in { @@ -6567,8 +6506,8 @@ defm FCVTNU : SIMDFPTwoScalar< 1, 0, 0b11010, "fcvtnu", int_aarch64_neon_fcvtn defm FCVTPS : SIMDFPTwoScalar< 0, 1, 0b11010, "fcvtps", int_aarch64_neon_fcvtps>; defm FCVTPU : SIMDFPTwoScalar< 1, 1, 0b11010, "fcvtpu", int_aarch64_neon_fcvtpu>; def FCVTXNv1i64 : SIMDInexactCvtTwoScalar<0b10110, "fcvtxn">; -defm FCVTZS : SIMDFPTwoScalar< 0, 1, 0b11011, "fcvtzs">; -defm FCVTZU : SIMDFPTwoScalar< 1, 1, 0b11011, "fcvtzu">; +defm FCVTZS : SIMDFPTwoScalar< 0, 1, 0b11011, "fcvtzs", any_fp_to_sint>; +defm FCVTZU : SIMDFPTwoScalar< 1, 1, 0b11011, "fcvtzu", any_fp_to_uint>; defm FRECPE : SIMDFPTwoScalar< 0, 1, 0b11101, "frecpe">; defm FRECPX : SIMDFPTwoScalar< 0, 1, 0b11111, "frecpx">; defm FRSQRTE : SIMDFPTwoScalar< 1, 1, 0b11101, "frsqrte">; @@ -6588,6 +6527,7 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd", // Floating-point conversion patterns. multiclass FPToIntegerSIMDScalarPatterns<SDPatternOperator OpN, string INST> { + let Predicates = [HasFPRCVT] in { def : Pat<(f32 (bitconvert (i32 (OpN (f64 FPR64:$Rn))))), (!cast<Instruction>(INST # SDr) FPR64:$Rn)>; def : Pat<(f32 (bitconvert (i32 (OpN (f16 FPR16:$Rn))))), @@ -6596,6 +6536,7 @@ multiclass FPToIntegerSIMDScalarPatterns<SDPatternOperator OpN, string INST> { (!cast<Instruction>(INST # DHr) FPR16:$Rn)>; def : Pat<(f64 (bitconvert (i64 (OpN (f32 FPR32:$Rn))))), (!cast<Instruction>(INST # DSr) FPR32:$Rn)>; + } def : Pat<(f32 (bitconvert (i32 (OpN (f32 FPR32:$Rn))))), (!cast<Instruction>(INST # v1i32) FPR32:$Rn)>; def : Pat<(f64 (bitconvert (i64 (OpN (f64 FPR64:$Rn))))), @@ -6610,6 +6551,8 @@ defm: FPToIntegerSIMDScalarPatterns<int_aarch64_neon_fcvtns, "FCVTNS">; defm: FPToIntegerSIMDScalarPatterns<int_aarch64_neon_fcvtnu, "FCVTNU">; defm: FPToIntegerSIMDScalarPatterns<int_aarch64_neon_fcvtps, "FCVTPS">; defm: FPToIntegerSIMDScalarPatterns<int_aarch64_neon_fcvtpu, "FCVTPU">; +defm: FPToIntegerSIMDScalarPatterns<any_fp_to_sint, "FCVTZS">; +defm: FPToIntegerSIMDScalarPatterns<any_fp_to_uint, "FCVTZU">; multiclass FPToIntegerIntPats<Intrinsic round, string INST> { let Predicates = [HasFullFP16] in { @@ -6666,6 +6609,196 @@ multiclass FPToIntegerIntPats<Intrinsic round, string INST> { defm : FPToIntegerIntPats<int_aarch64_neon_fcvtzs, "FCVTZS">; defm : FPToIntegerIntPats<int_aarch64_neon_fcvtzu, "FCVTZU">; +// AArch64's FCVT instructions saturate when out of range. +multiclass FPToIntegerSatPats<SDNode to_int_sat, SDNode to_int_sat_gi, string INST> { + let Predicates = [HasFullFP16] in { + def : Pat<(i32 (to_int_sat f16:$Rn, i32)), + (!cast<Instruction>(INST # UWHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat f16:$Rn, i64)), + (!cast<Instruction>(INST # UXHr) f16:$Rn)>; + } + def : Pat<(i32 (to_int_sat f32:$Rn, i32)), + (!cast<Instruction>(INST # UWSr) f32:$Rn)>; + def : Pat<(i64 (to_int_sat f32:$Rn, i64)), + (!cast<Instruction>(INST # UXSr) f32:$Rn)>; + def : Pat<(i32 (to_int_sat f64:$Rn, i32)), + (!cast<Instruction>(INST # UWDr) f64:$Rn)>; + def : Pat<(i64 (to_int_sat f64:$Rn, i64)), + (!cast<Instruction>(INST # UXDr) f64:$Rn)>; + + let Predicates = [HasFullFP16] in { + def : Pat<(i32 (to_int_sat_gi f16:$Rn)), + (!cast<Instruction>(INST # UWHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f16:$Rn)), + (!cast<Instruction>(INST # UXHr) f16:$Rn)>; + } + def : Pat<(i32 (to_int_sat_gi f32:$Rn)), + (!cast<Instruction>(INST # UWSr) f32:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f32:$Rn)), + (!cast<Instruction>(INST # UXSr) f32:$Rn)>; + def : Pat<(i32 (to_int_sat_gi f64:$Rn)), + (!cast<Instruction>(INST # UWDr) f64:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f64:$Rn)), + (!cast<Instruction>(INST # UXDr) f64:$Rn)>; + + // For global-isel we can use register classes to determine + // which FCVT instruction to use. + let Predicates = [HasFPRCVT] in { + def : Pat<(i32 (to_int_sat_gi f16:$Rn)), + (!cast<Instruction>(INST # SHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f16:$Rn)), + (!cast<Instruction>(INST # DHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f32:$Rn)), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(i32 (to_int_sat_gi f64:$Rn)), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(i32 (to_int_sat_gi f32:$Rn)), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(i64 (to_int_sat_gi f64:$Rn)), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; + + let Predicates = [HasFPRCVT] in { + def : Pat<(f32 (bitconvert (i32 (to_int_sat f16:$Rn, i32)))), + (!cast<Instruction>(INST # SHr) f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat f16:$Rn, i64)))), + (!cast<Instruction>(INST # DHr) f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat f32:$Rn, i64)))), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(f32 (bitconvert (i32 (to_int_sat f64:$Rn, i32)))), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(f32 (bitconvert (i32 (to_int_sat f32:$Rn, i32)))), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat f64:$Rn, i64)))), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; + + let Predicates = [HasFullFP16] in { + def : Pat<(i32 (to_int_sat (fmul f16:$Rn, fixedpoint_f16_i32:$scale), i32)), + (!cast<Instruction>(INST # SWHri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat (fmul f16:$Rn, fixedpoint_f16_i64:$scale), i64)), + (!cast<Instruction>(INST # SXHri) $Rn, $scale)>; + } + def : Pat<(i32 (to_int_sat (fmul f32:$Rn, fixedpoint_f32_i32:$scale), i32)), + (!cast<Instruction>(INST # SWSri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat (fmul f32:$Rn, fixedpoint_f32_i64:$scale), i64)), + (!cast<Instruction>(INST # SXSri) $Rn, $scale)>; + def : Pat<(i32 (to_int_sat (fmul f64:$Rn, fixedpoint_f64_i32:$scale), i32)), + (!cast<Instruction>(INST # SWDri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat (fmul f64:$Rn, fixedpoint_f64_i64:$scale), i64)), + (!cast<Instruction>(INST # SXDri) $Rn, $scale)>; + + let Predicates = [HasFullFP16] in { + def : Pat<(i32 (to_int_sat_gi (fmul f16:$Rn, fixedpoint_f16_i32:$scale))), + (!cast<Instruction>(INST # SWHri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat_gi (fmul f16:$Rn, fixedpoint_f16_i64:$scale))), + (!cast<Instruction>(INST # SXHri) $Rn, $scale)>; + } + def : Pat<(i32 (to_int_sat_gi (fmul f32:$Rn, fixedpoint_f32_i32:$scale))), + (!cast<Instruction>(INST # SWSri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat_gi (fmul f32:$Rn, fixedpoint_f32_i64:$scale))), + (!cast<Instruction>(INST # SXSri) $Rn, $scale)>; + def : Pat<(i32 (to_int_sat_gi (fmul f64:$Rn, fixedpoint_f64_i32:$scale))), + (!cast<Instruction>(INST # SWDri) $Rn, $scale)>; + def : Pat<(i64 (to_int_sat_gi (fmul f64:$Rn, fixedpoint_f64_i64:$scale))), + (!cast<Instruction>(INST # SXDri) $Rn, $scale)>; +} + +defm : FPToIntegerSatPats<fp_to_sint_sat, fp_to_sint_sat_gi, "FCVTZS">; +defm : FPToIntegerSatPats<fp_to_uint_sat, fp_to_uint_sat_gi, "FCVTZU">; + +multiclass FPToIntegerPats<SDNode to_int, SDNode to_int_sat, SDNode to_int_sat_gi, SDNode round, string INST> { + def : Pat<(i32 (to_int (round f32:$Rn))), + (!cast<Instruction>(INST # UWSr) f32:$Rn)>; + def : Pat<(i64 (to_int (round f32:$Rn))), + (!cast<Instruction>(INST # UXSr) f32:$Rn)>; + def : Pat<(i32 (to_int (round f64:$Rn))), + (!cast<Instruction>(INST # UWDr) f64:$Rn)>; + def : Pat<(i64 (to_int (round f64:$Rn))), + (!cast<Instruction>(INST # UXDr) f64:$Rn)>; + + // For global-isel we can use register classes to determine + // which FCVT instruction to use. + let Predicates = [HasFPRCVT] in { + def : Pat<(i64 (to_int (round f32:$Rn))), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(i32 (to_int (round f64:$Rn))), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(i32 (to_int (round f32:$Rn))), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(i64 (to_int (round f64:$Rn))), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; + + let Predicates = [HasFPRCVT] in { + def : Pat<(f64 (bitconvert (i64 (to_int (round f32:$Rn))))), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(f32 (bitconvert (i32 (to_int (round f64:$Rn))))), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(f32 (bitconvert (i32 (to_int (round f32:$Rn))))), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int (round f64:$Rn))))), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; + + // These instructions saturate like fp_to_[su]int_sat. + let Predicates = [HasFullFP16] in { + def : Pat<(i32 (to_int_sat (round f16:$Rn), i32)), + (!cast<Instruction>(INST # UWHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat (round f16:$Rn), i64)), + (!cast<Instruction>(INST # UXHr) f16:$Rn)>; + } + def : Pat<(i32 (to_int_sat (round f32:$Rn), i32)), + (!cast<Instruction>(INST # UWSr) f32:$Rn)>; + def : Pat<(i64 (to_int_sat (round f32:$Rn), i64)), + (!cast<Instruction>(INST # UXSr) f32:$Rn)>; + def : Pat<(i32 (to_int_sat (round f64:$Rn), i32)), + (!cast<Instruction>(INST # UWDr) f64:$Rn)>; + def : Pat<(i64 (to_int_sat (round f64:$Rn), i64)), + (!cast<Instruction>(INST # UXDr) f64:$Rn)>; + + // For global-isel we can use register classes to determine + // which FCVT instruction to use. + let Predicates = [HasFPRCVT] in { + def : Pat<(i32 (to_int_sat_gi (round f16:$Rn))), + (!cast<Instruction>(INST # SHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat_gi (round f16:$Rn))), + (!cast<Instruction>(INST # DHr) f16:$Rn)>; + def : Pat<(i64 (to_int_sat_gi (round f32:$Rn))), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(i32 (to_int_sat_gi (round f64:$Rn))), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(i32 (to_int_sat_gi (round f32:$Rn))), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(i64 (to_int_sat_gi (round f64:$Rn))), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; + + let Predicates = [HasFPRCVT] in { + def : Pat<(f32 (bitconvert (i32 (to_int_sat (round f16:$Rn), i32)))), + (!cast<Instruction>(INST # SHr) f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat (round f16:$Rn), i64)))), + (!cast<Instruction>(INST # DHr) f16:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat (round f32:$Rn), i64)))), + (!cast<Instruction>(INST # DSr) f32:$Rn)>; + def : Pat<(f32 (bitconvert (i32 (to_int_sat (round f64:$Rn), i32)))), + (!cast<Instruction>(INST # SDr) f64:$Rn)>; + } + def : Pat<(f32 (bitconvert (i32 (to_int_sat (round f32:$Rn), i32)))), + (!cast<Instruction>(INST # v1i32) f32:$Rn)>; + def : Pat<(f64 (bitconvert (i64 (to_int_sat (round f64:$Rn), i64)))), + (!cast<Instruction>(INST # v1i64) f64:$Rn)>; +} + +defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fp_to_sint_sat_gi, fceil, "FCVTPS">; +defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, fceil, "FCVTPU">; +defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fp_to_sint_sat_gi, ffloor, "FCVTMS">; +defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, ffloor, "FCVTMU">; +defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fp_to_sint_sat_gi, ftrunc, "FCVTZS">; +defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, ftrunc, "FCVTZU">; +defm : FPToIntegerPats<fp_to_sint, fp_to_sint_sat, fp_to_sint_sat_gi, fround, "FCVTAS">; +defm : FPToIntegerPats<fp_to_uint, fp_to_uint_sat, fp_to_uint_sat_gi, fround, "FCVTAU">; + // f16 -> s16 conversions let Predicates = [HasFullFP16] in { def : Pat<(i16(fp_to_sint_sat_gi f16:$Rn)), (FCVTZSv1f16 f16:$Rn)>; diff --git a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td index bdde8e3..2387f17 100644 --- a/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td +++ b/llvm/lib/Target/AArch64/AArch64SchedNeoverseV2.td @@ -2762,11 +2762,11 @@ def : InstRW<[V2Write_11c_18L01_18V01], (instregex "^ST4[BHWD]_IMM$")>; def : InstRW<[V2Write_11c_18L01_18S_18V01], (instregex "^ST4[BHWD]$")>; // Non temporal store, scalar + imm -def : InstRW<[V2Write_2c_1L01_1V], (instregex "^STNT1[BHWD]_ZRI$")>; +def : InstRW<[V2Write_2c_1L01_1V01], (instregex "^STNT1[BHWD]_ZRI$")>; // Non temporal store, scalar + scalar -def : InstRW<[V2Write_2c_1L01_1S_1V], (instrs STNT1H_ZRR)>; -def : InstRW<[V2Write_2c_1L01_1V], (instregex "^STNT1[BWD]_ZRR$")>; +def : InstRW<[V2Write_2c_1L01_1S_1V01], (instrs STNT1H_ZRR)>; +def : InstRW<[V2Write_2c_1L01_1V01], (instregex "^STNT1[BWD]_ZRR$")>; // Scatter non temporal store, vector + scalar 32-bit element size def : InstRW<[V2Write_4c_4L01_4V01], (instregex "^STNT1[BHW]_ZZR_S")>; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 2053fc4..fede586 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -224,7 +224,8 @@ static cl::opt<bool> EnableScalableAutovecInStreamingMode( static bool isSMEABIRoutineCall(const CallInst &CI, const AArch64TargetLowering &TLI) { const auto *F = CI.getCalledFunction(); - return F && SMEAttrs(F->getName(), TLI).isSMEABIRoutine(); + return F && + SMEAttrs(F->getName(), TLI.getRuntimeLibcallsInfo()).isSMEABIRoutine(); } /// Returns true if the function has explicit operations that can only be @@ -355,7 +356,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, // change only once and avoid inlining of G into F. SMEAttrs FAttrs(*F); - SMECallAttrs CallAttrs(Call, getTLI()); + SMECallAttrs CallAttrs(Call, &getTLI()->getRuntimeLibcallsInfo()); if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) { if (F == Call.getCaller()) // (1) @@ -957,23 +958,50 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, return TyL.first + ExtraCost; } case Intrinsic::get_active_lane_mask: { - auto *RetTy = dyn_cast<FixedVectorType>(ICA.getReturnType()); - if (RetTy) { - EVT RetVT = getTLI()->getValueType(DL, RetTy); - EVT OpVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]); - if (!getTLI()->shouldExpandGetActiveLaneMask(RetVT, OpVT) && - !getTLI()->isTypeLegal(RetVT)) { - // We don't have enough context at this point to determine if the mask - // is going to be kept live after the block, which will force the vXi1 - // type to be expanded to legal vectors of integers, e.g. v4i1->v4i32. - // For now, we just assume the vectorizer created this intrinsic and - // the result will be the input for a PHI. In this case the cost will - // be extremely high for fixed-width vectors. - // NOTE: getScalarizationOverhead returns a cost that's far too - // pessimistic for the actual generated codegen. In reality there are - // two instructions generated per lane. - return RetTy->getNumElements() * 2; + auto RetTy = cast<VectorType>(ICA.getReturnType()); + EVT RetVT = getTLI()->getValueType(DL, RetTy); + EVT OpVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]); + if (getTLI()->shouldExpandGetActiveLaneMask(RetVT, OpVT)) + break; + + if (RetTy->isScalableTy()) { + if (TLI->getTypeAction(RetTy->getContext(), RetVT) != + TargetLowering::TypeSplitVector) + break; + + auto LT = getTypeLegalizationCost(RetTy); + InstructionCost Cost = LT.first; + // When SVE2p1 or SME2 is available, we can halve getTypeLegalizationCost + // as get_active_lane_mask may lower to the sve_whilelo_x2 intrinsic, e.g. + // nxv32i1 = get_active_lane_mask(base, idx) -> + // {nxv16i1, nxv16i1} = sve_whilelo_x2(base, idx) + if (ST->hasSVE2p1() || ST->hasSME2()) { + Cost /= 2; + if (Cost == 1) + return Cost; } + + // If more than one whilelo intrinsic is required, include the extra cost + // required by the saturating add & select required to increment the + // start value after the first intrinsic call. + Type *OpTy = ICA.getArgTypes()[0]; + IntrinsicCostAttributes AddAttrs(Intrinsic::uadd_sat, OpTy, {OpTy, OpTy}); + InstructionCost SplitCost = getIntrinsicInstrCost(AddAttrs, CostKind); + Type *CondTy = OpTy->getWithNewBitWidth(1); + SplitCost += getCmpSelInstrCost(Instruction::Select, OpTy, CondTy, + CmpInst::ICMP_UGT, CostKind); + return Cost + (SplitCost * (Cost - 1)); + } else if (!getTLI()->isTypeLegal(RetVT)) { + // We don't have enough context at this point to determine if the mask + // is going to be kept live after the block, which will force the vXi1 + // type to be expanded to legal vectors of integers, e.g. v4i1->v4i32. + // For now, we just assume the vectorizer created this intrinsic and + // the result will be the input for a PHI. In this case the cost will + // be extremely high for fixed-width vectors. + // NOTE: getScalarizationOverhead returns a cost that's far too + // pessimistic for the actual generated codegen. In reality there are + // two instructions generated per lane. + return cast<FixedVectorType>(RetTy)->getNumElements() * 2; } break; } diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp index 3e55b76..14b0f9a 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp @@ -5126,23 +5126,13 @@ bool AArch64InstructionSelector::selectShuffleVector( MachineInstr &I, MachineRegisterInfo &MRI) { const LLT DstTy = MRI.getType(I.getOperand(0).getReg()); Register Src1Reg = I.getOperand(1).getReg(); - const LLT Src1Ty = MRI.getType(Src1Reg); Register Src2Reg = I.getOperand(2).getReg(); - const LLT Src2Ty = MRI.getType(Src2Reg); ArrayRef<int> Mask = I.getOperand(3).getShuffleMask(); MachineBasicBlock &MBB = *I.getParent(); MachineFunction &MF = *MBB.getParent(); LLVMContext &Ctx = MF.getFunction().getContext(); - // G_SHUFFLE_VECTOR is weird in that the source operands can be scalars, if - // it's originated from a <1 x T> type. Those should have been lowered into - // G_BUILD_VECTOR earlier. - if (!Src1Ty.isVector() || !Src2Ty.isVector()) { - LLVM_DEBUG(dbgs() << "Could not select a \"scalar\" G_SHUFFLE_VECTOR\n"); - return false; - } - unsigned BytesPerElt = DstTy.getElementType().getSizeInBits() / 8; SmallVector<Constant *, 64> CstIdxs; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp index 05a4313..5f93847 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp @@ -1201,25 +1201,17 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST) return llvm::is_contained( {v8s8, v16s8, v4s16, v8s16, v2s32, v4s32, v2s64}, DstTy); }) - // G_SHUFFLE_VECTOR can have scalar sources (from 1 x s vectors) or scalar - // destinations, we just want those lowered into G_BUILD_VECTOR or - // G_EXTRACT_ELEMENT. - .lowerIf([=](const LegalityQuery &Query) { - return !Query.Types[0].isVector() || !Query.Types[1].isVector(); - }) .moreElementsIf( [](const LegalityQuery &Query) { - return Query.Types[0].isVector() && Query.Types[1].isVector() && - Query.Types[0].getNumElements() > - Query.Types[1].getNumElements(); + return Query.Types[0].getNumElements() > + Query.Types[1].getNumElements(); }, changeTo(1, 0)) .moreElementsToNextPow2(0) .moreElementsIf( [](const LegalityQuery &Query) { - return Query.Types[0].isVector() && Query.Types[1].isVector() && - Query.Types[0].getNumElements() < - Query.Types[1].getNumElements(); + return Query.Types[0].getNumElements() < + Query.Types[1].getNumElements(); }, changeTo(0, 1)) .widenScalarOrEltToNextPow2OrMinSize(0, 8) diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp index 830a35bb..6d2d705 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -856,7 +856,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { break; } case TargetOpcode::G_FPTOSI_SAT: - case TargetOpcode::G_FPTOUI_SAT: { + case TargetOpcode::G_FPTOUI_SAT: + case TargetOpcode::G_FPTOSI: + case TargetOpcode::G_FPTOUI: { LLT DstType = MRI.getType(MI.getOperand(0).getReg()); if (DstType.isVector()) break; @@ -864,11 +866,19 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR}; break; } - OpRegBankIdx = {PMI_FirstGPR, PMI_FirstFPR}; + TypeSize DstSize = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI); + TypeSize SrcSize = getSizeInBits(MI.getOperand(1).getReg(), MRI, TRI); + if (((DstSize == SrcSize) || STI.hasFeature(AArch64::FeatureFPRCVT)) && + all_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()), + [&](const MachineInstr &UseMI) { + return onlyUsesFP(UseMI, MRI, TRI) || + prefersFPUse(UseMI, MRI, TRI); + })) + OpRegBankIdx = {PMI_FirstFPR, PMI_FirstFPR}; + else + OpRegBankIdx = {PMI_FirstGPR, PMI_FirstFPR}; break; } - case TargetOpcode::G_FPTOSI: - case TargetOpcode::G_FPTOUI: case TargetOpcode::G_INTRINSIC_LRINT: case TargetOpcode::G_INTRINSIC_LLRINT: if (MRI.getType(MI.getOperand(0).getReg()).isVector()) diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp index d71f728..085c8588 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp @@ -75,8 +75,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) { } void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName, - const AArch64TargetLowering &TLI) { - RTLIB::LibcallImpl Impl = TLI.getSupportedLibcallImpl(FuncName); + const RTLIB::RuntimeLibcallsInfo &RTLCI) { + RTLIB::LibcallImpl Impl = RTLCI.getSupportedLibcallImpl(FuncName); if (Impl == RTLIB::Unsupported) return; unsigned KnownAttrs = SMEAttrs::Normal; @@ -124,21 +124,22 @@ bool SMECallAttrs::requiresSMChange() const { return true; } -SMECallAttrs::SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI) +SMECallAttrs::SMECallAttrs(const CallBase &CB, + const RTLIB::RuntimeLibcallsInfo *RTLCI) : CallerFn(*CB.getFunction()), CalledFn(SMEAttrs::Normal), Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) { if (auto *CalledFunction = CB.getCalledFunction()) - CalledFn = SMEAttrs(*CalledFunction, TLI); - - // An `invoke` of an agnostic ZA function may not return normally (it may - // resume in an exception block). In this case, it acts like a private ZA - // callee and may require a ZA save to be set up before it is called. - if (isa<InvokeInst>(CB)) - CalledFn.set(SMEAttrs::ZA_State_Agnostic, /*Enable=*/false); + CalledFn = SMEAttrs(*CalledFunction, RTLCI); // FIXME: We probably should not allow SME attributes on direct calls but // clang duplicates streaming mode attributes at each callsite. assert((IsIndirect || ((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) && "SME attributes at callsite do not match declaration"); + + // An `invoke` of an agnostic ZA function may not return normally (it may + // resume in an exception block). In this case, it acts like a private ZA + // callee and may require a ZA save to be set up before it is called. + if (isa<InvokeInst>(CB)) + CalledFn.set(SMEAttrs::ZA_State_Agnostic, /*Enable=*/false); } diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h index d26e3cd..28c397e 100644 --- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h @@ -12,8 +12,9 @@ #include "llvm/IR/Function.h" namespace llvm { - -class AArch64TargetLowering; +namespace RTLIB { +struct RuntimeLibcallsInfo; +} class Function; class CallBase; @@ -52,14 +53,14 @@ public: SMEAttrs() = default; SMEAttrs(unsigned Mask) { set(Mask); } - SMEAttrs(const Function &F, const AArch64TargetLowering *TLI = nullptr) + SMEAttrs(const Function &F, const RTLIB::RuntimeLibcallsInfo *RTLCI = nullptr) : SMEAttrs(F.getAttributes()) { - if (TLI) - addKnownFunctionAttrs(F.getName(), *TLI); + if (RTLCI) + addKnownFunctionAttrs(F.getName(), *RTLCI); } SMEAttrs(const AttributeList &L); - SMEAttrs(StringRef FuncName, const AArch64TargetLowering &TLI) { - addKnownFunctionAttrs(FuncName, TLI); + SMEAttrs(StringRef FuncName, const RTLIB::RuntimeLibcallsInfo &RTLCI) { + addKnownFunctionAttrs(FuncName, RTLCI); }; void set(unsigned M, bool Enable = true) { @@ -157,7 +158,7 @@ public: private: void addKnownFunctionAttrs(StringRef FuncName, - const AArch64TargetLowering &TLI); + const RTLIB::RuntimeLibcallsInfo &RTLCI); void validate() const; }; @@ -175,7 +176,7 @@ public: SMEAttrs Callsite = SMEAttrs::Normal) : CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {} - SMECallAttrs(const CallBase &CB, const AArch64TargetLowering *TLI); + SMECallAttrs(const CallBase &CB, const RTLIB::RuntimeLibcallsInfo *RTLCI); SMEAttrs &caller() { return CallerFn; } SMEAttrs &callee() { return IsIndirect ? Callsite : CalledFn; } |
