aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/RISCV
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/RISCV')
-rw-r--r--llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp39
1 files changed, 39 insertions, 0 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index 563f3bb..d4124ae 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -167,6 +167,42 @@ static bool canUseShiftPair(Instruction *Inst, const APInt &Imm) {
return false;
}
+// If this is i64 AND is part of (X & -(1 << C1) & 0xffffffff) == C2 << C1),
+// DAGCombiner can convert this to (sraiw X, C1) == sext(C2) for RV64. On RV32,
+// the type will be split so only the lower 32 bits need to be compared using
+// (srai/srli X, C) == C2.
+static bool canUseShiftCmp(Instruction *Inst, const APInt &Imm) {
+ if (!Inst->hasOneUse())
+ return false;
+
+ // Look for equality comparison.
+ auto *Cmp = dyn_cast<ICmpInst>(*Inst->user_begin());
+ if (!Cmp || !Cmp->isEquality())
+ return false;
+
+ // Right hand side of comparison should be a constant.
+ auto *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
+ if (!C)
+ return false;
+
+ uint64_t Mask = Imm.getZExtValue();
+
+ // Mask should be of the form -(1 << C) in the lower 32 bits.
+ if (!isUInt<32>(Mask) || !isPowerOf2_32(-uint32_t(Mask)))
+ return false;
+
+ // Comparison constant should be a subset of Mask.
+ uint64_t CmpC = C->getZExtValue();
+ if ((CmpC & Mask) != CmpC)
+ return false;
+
+ // We'll need to sign extend the comparison constant and shift it right. Make
+ // sure the new constant can use addi/xori+seqz/snez.
+ unsigned ShiftBits = llvm::countr_zero(Mask);
+ int64_t NewCmpC = SignExtend64<32>(CmpC) >> ShiftBits;
+ return NewCmpC >= -2048 && NewCmpC <= 2048;
+}
+
InstructionCost RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
const APInt &Imm, Type *Ty,
TTI::TargetCostKind CostKind,
@@ -224,6 +260,9 @@ InstructionCost RISCVTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
if (Inst && Idx == 1 && Imm.getBitWidth() <= ST->getXLen() &&
canUseShiftPair(Inst, Imm))
return TTI::TCC_Free;
+ if (Inst && Idx == 1 && Imm.getBitWidth() == 64 &&
+ canUseShiftCmp(Inst, Imm))
+ return TTI::TCC_Free;
Takes12BitImm = true;
break;
case Instruction::Add: