diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 121 |
1 files changed, 70 insertions, 51 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 8c4b4f6..50a8754 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5632,75 +5632,94 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp, TTI::TargetCostKind CostKind) const { InstructionCost Invalid = InstructionCost::getInvalid(); - InstructionCost Cost(TTI::TCC_Basic); if (CostKind != TTI::TCK_RecipThroughput) return Invalid; - // Sub opcodes currently only occur in chained cases. - // Independent partial reduction subtractions are still costed as an add + if (VF.isFixed() && !ST->isSVEorStreamingSVEAvailable() && + (!ST->isNeonAvailable() || !ST->hasDotProd())) + return Invalid; + if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) || OpAExtend == TTI::PR_None) return Invalid; + assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && + (!BinOp || (OpBExtend != TTI::PR_None && InputTypeB)) && + "Unexpected values for OpBExtend or InputTypeB"); + // We only support multiply binary operations for now, and for muls we // require the types being extended to be the same. - // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but - // only if the i8mm or sve/streaming features are available. - if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB || - OpBExtend == TTI::PR_None || - (OpAExtend != OpBExtend && !ST->hasMatMulInt8() && - !ST->isSVEorStreamingSVEAvailable()))) + if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB)) return Invalid; - assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && - "Unexpected values for OpBExtend or InputTypeB"); - EVT InputEVT = EVT::getEVT(InputTypeA); - EVT AccumEVT = EVT::getEVT(AccumType); + bool IsUSDot = OpBExtend != TTI::PR_None && OpAExtend != OpBExtend; + if (IsUSDot && !ST->hasMatMulInt8()) + return Invalid; + + unsigned Ratio = + AccumType->getScalarSizeInBits() / InputTypeA->getScalarSizeInBits(); + if (VF.getKnownMinValue() <= Ratio) + return Invalid; + + VectorType *InputVectorType = VectorType::get(InputTypeA, VF); + VectorType *AccumVectorType = + VectorType::get(AccumType, VF.divideCoefficientBy(Ratio)); + // We don't yet support all kinds of legalization. + auto TA = TLI->getTypeAction(AccumVectorType->getContext(), + EVT::getEVT(AccumVectorType)); + switch (TA) { + default: + return Invalid; + case TargetLowering::TypeLegal: + case TargetLowering::TypePromoteInteger: + case TargetLowering::TypeSplitVector: + break; + } + + // Check what kind of type-legalisation happens. + std::pair<InstructionCost, MVT> AccumLT = + getTypeLegalizationCost(AccumVectorType); + std::pair<InstructionCost, MVT> InputLT = + getTypeLegalizationCost(InputVectorType); - unsigned VFMinValue = VF.getKnownMinValue(); + InstructionCost Cost = InputLT.first * TTI::TCC_Basic; - if (VF.isScalable()) { - if (!ST->isSVEorStreamingSVEAvailable()) - return Invalid; + // Prefer using full types by costing half-full input types as more expensive. + if (TypeSize::isKnownLT(InputVectorType->getPrimitiveSizeInBits(), + TypeSize::getScalable(128))) + // FIXME: This can be removed after the cost of the extends are folded into + // the dot-product expression in VPlan, after landing: + // https://github.com/llvm/llvm-project/pull/147302 + Cost *= 2; - // Don't accept a partial reduction if the scaled accumulator is vscale x 1, - // since we can't lower that type. - unsigned Scale = - AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits(); - if (VFMinValue == Scale) - return Invalid; + if (ST->isSVEorStreamingSVEAvailable() && !IsUSDot) { + // i16 -> i64 is natively supported for udot/sdot + if (AccumLT.second.getScalarType() == MVT::i64 && + InputLT.second.getScalarType() == MVT::i16) + return Cost; + // i8 -> i64 is supported with an extra level of extends + if (AccumLT.second.getScalarType() == MVT::i64 && + InputLT.second.getScalarType() == MVT::i8) + // FIXME: This cost should probably be a little higher, e.g. Cost + 2 + // because it requires two extra extends on the inputs. But if we'd change + // that now, a regular reduction would be cheaper because the costs of + // the extends in the IR are still counted. This can be fixed + // after https://github.com/llvm/llvm-project/pull/147302 has landed. + return Cost; } - if (VF.isFixed() && - (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64)) - return Invalid; - if (InputEVT == MVT::i8) { - switch (VFMinValue) { - default: - return Invalid; - case 8: - if (AccumEVT == MVT::i32) - Cost *= 2; - else if (AccumEVT != MVT::i64) - return Invalid; - break; - case 16: - if (AccumEVT == MVT::i64) - Cost *= 2; - else if (AccumEVT != MVT::i32) - return Invalid; - break; - } - } else if (InputEVT == MVT::i16) { - // FIXME: Allow i32 accumulator but increase cost, as we would extend - // it to i64. - if (VFMinValue != 8 || AccumEVT != MVT::i64) - return Invalid; - } else - return Invalid; + // i8 -> i32 is natively supported for udot/sdot/usdot, both for NEON and SVE. + if (ST->isSVEorStreamingSVEAvailable() || + (AccumLT.second.isFixedLengthVector() && ST->isNeonAvailable() && + ST->hasDotProd())) { + if (AccumLT.second.getScalarType() == MVT::i32 && + InputLT.second.getScalarType() == MVT::i8) + return Cost; + } - return Cost; + // Add additional cost for the extends that would need to be inserted. + return Cost + 4; } InstructionCost |