aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp121
1 files changed, 70 insertions, 51 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8c4b4f6..50a8754 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid();
- InstructionCost Cost(TTI::TCC_Basic);
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
- // Sub opcodes currently only occur in chained cases.
- // Independent partial reduction subtractions are still costed as an add
+ if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() &&
+ (!ST->isNeonAvailable() || !ST->hasDotProd()))
+ return Invalid;
+
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
OpAExtend == TTI::PR_None)
return Invalid;
+ assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+ (!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) &&
+ "Unexpected values for OpBExtend or InputTypeB");
+
// We only support multiply binary operations for now, and for muls we
// require the types being extended to be the same.
- // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
- // only if the i8mm or sve/streaming features are available.
- if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
- OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable())))
+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB))
return Invalid;
- assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
- "Unexpected values for OpBExtend or InputTypeB");
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
+ bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend;
+ if (IsUSDot && !ST->hasMatMulInt8())
+ return Invalid;
+
+ unsigned Ratio =
+ AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits();
+ if (VF.getKnownMinValue() <= Ratio)
+ return Invalid;
+
+ VectorType *InputVectorType = VectorType::get(InputTypeA, VF);
+ VectorType *AccumVectorType =
+ VectorType::get(AccumType, VF.divideCoefficientBy(Ratio));
+ // We don't yet support all kinds of legalization.
+ auto TA = TLI->getTypeAction(AccumVectorType->getContext(),
+ EVT::getEVT(AccumVectorType));
+ switch (TA) {
+ default:
+ return Invalid;
+ case TargetLowering::TypeLegal:
+ case TargetLowering::TypePromoteInteger:
+ case TargetLowering::TypeSplitVector:
+ break;
+ }
+
+ // Check what kind of type-legalisation happens.
+ std::pair<InstructionCost, MVT> AccumLT =
+ getTypeLegalizationCost(AccumVectorType);
+ std::pair<InstructionCost, MVT> InputLT =
+ getTypeLegalizationCost(InputVectorType);
- unsigned VFMinValue = VF.getKnownMinValue();
+ InstructionCost Cost = InputLT.first * TTI::TCC_Basic;
- if (VF.isScalable()) {
- if (!ST->isSVEorStreamingSVEAvailable())
- return Invalid;
+ // Prefer using full types by costing half-full input types as more expensive.
+ if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(),
+ TypeSize::getScalable(128)))
+ // FIXME: This can be removed after the cost of the extends are folded into
+ // the dot-product expression in VPlan, after landing:
+ // https://github.com/llvm/llvm-project/pull/147302
+ Cost *= 2;
- // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
- // since we can't lower that type.
- unsigned Scale =
- AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
- if (VFMinValue == Scale)
- return Invalid;
+ if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) {
+ // i16 -> i64 is natively supported for udot/sdot
+ if (AccumLT.second.getScalarType() == MVT::i64 &&
+ InputLT.second.getScalarType() == MVT::i16)
+ return Cost;
+ // i8 -> i64 is supported with an extra level of extends
+ if (AccumLT.second.getScalarType() == MVT::i64 &&
+ InputLT.second.getScalarType() == MVT::i8)
+ // FIXME: This cost should probably be a little higher, e.g. Cost + 2
+ // because it requires two extra extends on the inputs. But if we'd change
+ // that now, a regular reduction would be cheaper because the costs of
+ // the extends in the IR are still counted. This can be fixed
+ // after https://github.com/llvm/llvm-project/pull/147302 has landed.
+ return Cost;
}
- if (VF.isFixed() &&
- (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
- return Invalid;
- if (InputEVT == MVT::i8) {
- switch (VFMinValue) {
- default:
- return Invalid;
- case 8:
- if (AccumEVT == MVT::i32)
- Cost *= 2;
- else if (AccumEVT != MVT::i64)
- return Invalid;
- break;
- case 16:
- if (AccumEVT == MVT::i64)
- Cost *= 2;
- else if (AccumEVT != MVT::i32)
- return Invalid;
- break;
- }
- } else if (InputEVT == MVT::i16) {
- // FIXME: Allow i32 accumulator but increase cost, as we would extend
- // it to i64.
- if (VFMinValue != 8 || AccumEVT != MVT::i64)
- return Invalid;
- } else
- return Invalid;
+ // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE.
+ if (ST->isSVEorStreamingSVEAvailable() ||
+ (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() &&
+ ST->hasDotProd())) {
+ if (AccumLT.second.getScalarType() == MVT::i32 &&
+ InputLT.second.getScalarType() == MVT::i8)
+ return Cost;
+ }
- return Cost;
+ // Add additional cost for the extends that would need to be inserted.
+ return Cost + 4;
}
InstructionCost