diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 12fdc1e..70d1d6a 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -390,6 +390,111 @@ static Optional<Instruction *> instCombineSVEDup(InstCombiner &IC, return IC.replaceInstUsesWith(II, Insert); } +static Optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, + IntrinsicInst &II) { + LLVMContext &Ctx = II.getContext(); + IRBuilder<> Builder(Ctx); + Builder.SetInsertPoint(&II); + + // Check that the predicate is all active + auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0)); + if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) + return None; + + const auto PTruePattern = + cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); + if (PTruePattern != AArch64SVEPredPattern::all) + return None; + + // Check that we have a compare of zero.. + auto *DupX = dyn_cast<IntrinsicInst>(II.getArgOperand(2)); + if (!DupX || DupX->getIntrinsicID() != Intrinsic::aarch64_sve_dup_x) + return None; + + auto *DupXArg = dyn_cast<ConstantInt>(DupX->getArgOperand(0)); + if (!DupXArg || !DupXArg->isZero()) + return None; + + // ..against a dupq + auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); + if (!DupQLane || + DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane) + return None; + + // Where the dupq is a lane 0 replicate of a vector insert + if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero()) + return None; + + auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0)); + if (!VecIns || + VecIns->getIntrinsicID() != Intrinsic::experimental_vector_insert) + return None; + + // Where the vector insert is a fixed constant vector insert into undef at + // index zero + if (!isa<UndefValue>(VecIns->getArgOperand(0))) + return None; + + if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero()) + return None; + + auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1)); + if (!ConstVec) + return None; + + auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType()); + auto *OutTy = dyn_cast<ScalableVectorType>(II.getType()); + if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements()) + return None; + + unsigned NumElts = VecTy->getNumElements(); + unsigned PredicateBits = 0; + + // Expand intrinsic operands to a 16-bit byte level predicate + for (unsigned I = 0; I < NumElts; ++I) { + auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I)); + if (!Arg) + return None; + if (!Arg->isZero()) + PredicateBits |= 1 << (I * (16 / NumElts)); + } + + // If all bits are zero bail early with an empty predicate + if (PredicateBits == 0) { + auto *PFalse = Constant::getNullValue(II.getType()); + PFalse->takeName(&II); + return IC.replaceInstUsesWith(II, PFalse); + } + + // Calculate largest predicate type used (where byte predicate is largest) + unsigned Mask = 8; + for (unsigned I = 0; I < 16; ++I) + if ((PredicateBits & (1 << I)) != 0) + Mask |= (I % 8); + + unsigned PredSize = Mask & -Mask; + auto *PredType = ScalableVectorType::get( + Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8)); + + // Ensure all relevant bits are set + for (unsigned I = 0; I < 16; I += PredSize) + if ((PredicateBits & (1 << I)) == 0) + return None; + + auto *PTruePat = + ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); + auto *PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, + {PredType}, {PTruePat}); + auto *ConvertToSVBool = Builder.CreateIntrinsic( + Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue}); + auto *ConvertFromSVBool = + Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, + {II.getType()}, {ConvertToSVBool}); + + ConvertFromSVBool->takeName(&II); + return IC.replaceInstUsesWith(II, ConvertFromSVBool); +} + static Optional<Instruction *> instCombineSVELast(InstCombiner &IC, IntrinsicInst &II) { Value *Pg = II.getArgOperand(0); @@ -498,6 +603,9 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, return instCombineConvertFromSVBool(IC, II); case Intrinsic::aarch64_sve_dup: return instCombineSVEDup(IC, II); + case Intrinsic::aarch64_sve_cmpne: + case Intrinsic::aarch64_sve_cmpne_wide: + return instCombineSVECmpNE(IC, II); case Intrinsic::aarch64_sve_rdffr: return instCombineRDFFR(IC, II); case Intrinsic::aarch64_sve_lasta: |