diff options
Diffstat (limited to 'llvm/lib/Target/RISCV')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 563f3bb..d4124ae 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -167,6 +167,42 @@ static bool canUseShiftPair(Instruction *Inst, const APInt &Imm) { return false; } +// If this is i64 AND is part of (X & -(1 << C1) & 0xffffffff) == C2 << C1), +// DAGCombiner can convert this to (sraiw X, C1) == sext(C2) for RV64. On RV32, +// the type will be split so only the lower 32 bits need to be compared using +// (srai/srli X, C) == C2. +static bool canUseShiftCmp(Instruction *Inst, const APInt &Imm) { + if (!Inst->hasOneUse()) + return false; + + // Look for equality comparison. + auto *Cmp = dyn_cast<ICmpInst>(*Inst->user_begin()); + if (!Cmp || !Cmp->isEquality()) + return false; + + // Right hand side of comparison should be a constant. + auto *C = dyn_cast<ConstantInt>(Cmp->getOperand(1)); + if (!C) + return false; + + uint64_t Mask = Imm.getZExtValue(); + + // Mask should be of the form -(1 << C) in the lower 32 bits. + if (!isUInt<32>(Mask) || !isPowerOf2_32(-uint32_t(Mask))) + return false; + + // Comparison constant should be a subset of Mask. + uint64_t CmpC = C->getZExtValue(); + if ((CmpC & Mask) != CmpC) + return false; + + // We'll need to sign extend the comparison constant and shift it right. Make + // sure the new constant can use addi/xori+seqz/snez. + unsigned ShiftBits = llvm::countr_zero(Mask); + int64_t NewCmpC = SignExtend64<32>(CmpC) >> ShiftBits; + return NewCmpC >= -2048 && NewCmpC <= 2048; +} + InstructionCost RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind, @@ -224,6 +260,9 @@ InstructionCost RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, if (Inst && Idx == 1 && Imm.getBitWidth() <= ST->getXLen() && canUseShiftPair(Inst, Imm)) return TTI::TCC_Free; + if (Inst && Idx == 1 && Imm.getBitWidth() == 64 && + canUseShiftCmp(Inst, Imm)) + return TTI::TCC_Free; Takes12BitImm = true; break; case Instruction::Add: |