diff options
Diffstat (limited to 'llvm/lib/Target/AArch64')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64FrameLowering.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 30 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 121 |
3 files changed, 101 insertions, 52 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index 8d6eb91..4357264d 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -282,7 +282,7 @@ static cl::opt<bool> OrderFrameObjects("aarch64-order-frame-objects", static cl::opt<bool> SplitSVEObjects("aarch64-split-sve-objects", cl::desc("Split allocation of ZPR & PPR objects"), - cl::init(false), cl::Hidden); + cl::init(true), cl::Hidden); cl::opt<bool> EnableHomogeneousPrologEpilog( "homogeneous-prolog-epilog", cl::Hidden, diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 36c9cb6..bc6b931 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1010,6 +1010,36 @@ let Predicates = [HasSVE_or_SME] in { defm SEL_ZPZZ : sve_int_sel_vvv<"sel", vselect>; defm SPLICE_ZPZ : sve_int_perm_splice<"splice", AArch64splice>; + + // mul x (splat -1) -> neg x + def : Pat<(nxv16i8 (AArch64mul_m1 nxv16i1:$Op1, nxv16i8:$Op2, (nxv16i8 (splat_vector (i32 -1))))), + (NEG_ZPmZ_B $Op2, $Op1, $Op2)>; + def : Pat<(nxv8i16 (AArch64mul_m1 nxv8i1:$Op1, nxv8i16:$Op2, (nxv8i16 (splat_vector (i32 -1))))), + (NEG_ZPmZ_H $Op2, $Op1, $Op2)>; + def : Pat<(nxv4i32 (AArch64mul_m1 nxv4i1:$Op1, nxv4i32:$Op2, (nxv4i32 (splat_vector (i32 -1))))), + (NEG_ZPmZ_S $Op2, $Op1, $Op2)>; + def : Pat<(nxv2i64 (AArch64mul_m1 nxv2i1:$Op1, nxv2i64:$Op2, (nxv2i64 (splat_vector (i64 -1))))), + (NEG_ZPmZ_D $Op2, $Op1, $Op2)>; + + let AddedComplexity = 5 in { + def : Pat<(nxv16i8 (AArch64mul_p nxv16i1:$Op1, nxv16i8:$Op2, (nxv16i8 (splat_vector (i32 -1))))), + (NEG_ZPmZ_B_UNDEF $Op2, $Op1, $Op2)>; + def : Pat<(nxv8i16 (AArch64mul_p nxv8i1:$Op1, nxv8i16:$Op2, (nxv8i16 (splat_vector (i32 -1))))), + (NEG_ZPmZ_H_UNDEF $Op2, $Op1, $Op2)>; + def : Pat<(nxv4i32 (AArch64mul_p nxv4i1:$Op1, nxv4i32:$Op2, (nxv4i32 (splat_vector (i32 -1))))), + (NEG_ZPmZ_S_UNDEF $Op2, $Op1, $Op2)>; + def : Pat<(nxv2i64 (AArch64mul_p nxv2i1:$Op1, nxv2i64:$Op2, (nxv2i64 (splat_vector (i64 -1))))), + (NEG_ZPmZ_D_UNDEF $Op2, $Op1, $Op2)>; + } + + def : Pat<(nxv16i8 (AArch64mul_m1 nxv16i1:$Op1, (nxv16i8 (splat_vector (i32 -1))), nxv16i8:$Op2)), + (NEG_ZPmZ_B (DUP_ZI_B -1, 0), $Op1, $Op2)>; + def : Pat<(nxv8i16 (AArch64mul_m1 nxv8i1:$Op1, (nxv8i16 (splat_vector (i32 -1))), nxv8i16:$Op2)), + (NEG_ZPmZ_H (DUP_ZI_H -1, 0), $Op1, $Op2)>; + def : Pat<(nxv4i32 (AArch64mul_m1 nxv4i1:$Op1, (nxv4i32 (splat_vector (i32 -1))), nxv4i32:$Op2)), + (NEG_ZPmZ_S (DUP_ZI_S -1, 0), $Op1, $Op2)>; + def : Pat<(nxv2i64 (AArch64mul_m1 nxv2i1:$Op1, (nxv2i64 (splat_vector (i64 -1))), nxv2i64:$Op2)), + (NEG_ZPmZ_D (DUP_ZI_D -1, 0), $Op1, $Op2)>; } // End HasSVE_or_SME // COMPACT - word and doubleword diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 8c4b4f6..50a8754 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const { InstructionCost Invalid = InstructionCost::getInvalid(); - InstructionCost Cost(TTI::TCC_Basic); if (CostKind != TTI::TCK_RecipThroughput) return Invalid; - // Sub opcodes currently only occur in chained cases. - // Independent partial reduction subtractions are still costed as an add + if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() && + (!ST->isNeonAvailable() || !ST->hasDotProd())) + return Invalid; + if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) || OpAExtend == TTI::PR_None) return Invalid; + assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && + (!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) && + "Unexpected values for OpBExtend or InputTypeB"); + // We only support multiply binary operations for now, and for muls we // require the types being extended to be the same. - // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but - // only if the i8mm or sve/streaming features are available. - if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB || - OpBExtend == TTI::PR_None || - (OpAExtend != OpBExtend && !ST->hasMatMulInt8() && - !ST->isSVEorStreamingSVEAvailable()))) + if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB)) return Invalid; - assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && - "Unexpected values for OpBExtend or InputTypeB"); - EVT InputEVT = EVT::getEVT(InputTypeA); - EVT AccumEVT = EVT::getEVT(AccumType); + bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend; + if (IsUSDot && !ST->hasMatMulInt8()) + return Invalid; + + unsigned Ratio = + AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits(); + if (VF.getKnownMinValue() <= Ratio) + return Invalid; + + VectorType *InputVectorType = VectorType::get(InputTypeA, VF); + VectorType *AccumVectorType = + VectorType::get(AccumType, VF.divideCoefficientBy(Ratio)); + // We don't yet support all kinds of legalization. + auto TA = TLI->getTypeAction(AccumVectorType->getContext(), + EVT::getEVT(AccumVectorType)); + switch (TA) { + default: + return Invalid; + case TargetLowering::TypeLegal: + case TargetLowering::TypePromoteInteger: + case TargetLowering::TypeSplitVector: + break; + } + + // Check what kind of type-legalisation happens. + std::pair<InstructionCost, MVT> AccumLT = + getTypeLegalizationCost(AccumVectorType); + std::pair<InstructionCost, MVT> InputLT = + getTypeLegalizationCost(InputVectorType); - unsigned VFMinValue = VF.getKnownMinValue(); + InstructionCost Cost = InputLT.first * TTI::TCC_Basic; - if (VF.isScalable()) { - if (!ST->isSVEorStreamingSVEAvailable()) - return Invalid; + // Prefer using full types by costing half-full input types as more expensive. + if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(), + TypeSize::getScalable(128))) + // FIXME: This can be removed after the cost of the extends are folded into + // the dot-product expression in VPlan, after landing: + // https://github.com/llvm/llvm-project/pull/147302 + Cost *= 2; - // Don't accept a partial reduction if the scaled accumulator is vscale x 1, - // since we can't lower that type. - unsigned Scale = - AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits(); - if (VFMinValue == Scale) - return Invalid; + if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) { + // i16 -> i64 is natively supported for udot/sdot + if (AccumLT.second.getScalarType() == MVT::i64 && + InputLT.second.getScalarType() == MVT::i16) + return Cost; + // i8 -> i64 is supported with an extra level of extends + if (AccumLT.second.getScalarType() == MVT::i64 && + InputLT.second.getScalarType() == MVT::i8) + // FIXME: This cost should probably be a little higher, e.g. Cost + 2 + // because it requires two extra extends on the inputs. But if we'd change + // that now, a regular reduction would be cheaper because the costs of + // the extends in the IR are still counted. This can be fixed + // after https://github.com/llvm/llvm-project/pull/147302 has landed. + return Cost; } - if (VF.isFixed() && - (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64)) - return Invalid; - if (InputEVT == MVT::i8) { - switch (VFMinValue) { - default: - return Invalid; - case 8: - if (AccumEVT == MVT::i32) - Cost *= 2; - else if (AccumEVT != MVT::i64) - return Invalid; - break; - case 16: - if (AccumEVT == MVT::i64) - Cost *= 2; - else if (AccumEVT != MVT::i32) - return Invalid; - break; - } - } else if (InputEVT == MVT::i16) { - // FIXME: Allow i32 accumulator but increase cost, as we would extend - // it to i64. - if (VFMinValue != 8 || AccumEVT != MVT::i64) - return Invalid; - } else - return Invalid; + // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE. + if (ST->isSVEorStreamingSVEAvailable() || + (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() && + ST->hasDotProd())) { + if (AccumLT.second.getScalarType() == MVT::i32 && + InputLT.second.getScalarType() == MVT::i8) + return Cost; + } - return Cost; + // Add additional cost for the extends that would need to be inserted. + return Cost + 4; } InstructionCost |