diff options
author | Iman Hosseini <imanhosseini@google.com> | 2025-01-17 14:40:31 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-17 14:40:31 +0000 |
commit | 8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23 (patch) | |
tree | 434c76e3876d25161af7c99928acc9fc5b56e216 | |
parent | a18f4bdb18d59858e384540a62c9145c888cc9b2 (diff) | |
download | llvm-8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23.zip llvm-8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23.tar.gz llvm-8ae1cb2bcb55293cce31bb75c38d6b4e8a13cc23.tar.bz2 |
add power function to APInt (#122788)
I am trying to calculate power function for APFloat, APInt to constant
fold vector reductions: https://github.com/llvm/llvm-project/pull/122450
I need this utility to fold N `mul`s into power.
---------
Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
-rw-r--r-- | llvm/include/llvm/ADT/APInt.h | 4 | ||||
-rw-r--r-- | llvm/lib/Support/APInt.cpp | 18 | ||||
-rw-r--r-- | llvm/unittests/ADT/APIntTest.cpp | 67 |
3 files changed, 89 insertions, 0 deletions
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 225390f..02d58d8 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2263,6 +2263,10 @@ APInt mulhs(const APInt &C1, const APInt &C2); /// Returns the high N bits of the multiplication result. APInt mulhu(const APInt &C1, const APInt &C2); +/// Compute X^N for N>=0. +/// 0^0 is supported and returns 1. +APInt pow(const APInt &X, int64_t N); + /// Compute GCD of two unsigned APInt values. /// /// This function returns the greatest common divisor of the two APInt values diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index ea8295f..38cf485 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -3108,3 +3108,21 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { APInt C2Ext = C2.zext(FullWidth); return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); } + +APInt APIntOps::pow(const APInt &X, int64_t N) { + assert(N >= 0 && "negative exponents not supported."); + APInt Acc = APInt(X.getBitWidth(), 1); + if (N == 0) + return Acc; + APInt Base = X; + int64_t RemainingExponent = N; + while (RemainingExponent > 0) { + while (RemainingExponent % 2 == 0) { + Base *= Base; + RemainingExponent /= 2; + } + --RemainingExponent; + Acc *= Base; + } + return Acc; +}; diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 4d5553f..b14366e 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -29,6 +29,73 @@ TEST(APIntTest, ValueInit) { EXPECT_TRUE(!Zero.sext(64)); } +// Test that 0^5 == 0 +TEST(APIntTest, PowZeroTo5) { + APInt Zero = APInt::getZero(32); + EXPECT_TRUE(!Zero); + APInt ZeroTo5 = APIntOps::pow(Zero, 5); + EXPECT_TRUE(!ZeroTo5); +} + +// Test that 1^16 == 1 +TEST(APIntTest, PowOneTo16) { + APInt One(32, 1); + APInt OneTo16 = APIntOps::pow(One, 16); + EXPECT_EQ(One, OneTo16); +} + +// Test that 2^10 == 1024 +TEST(APIntTest, PowerTwoTo10) { + APInt Two(32, 2); + APInt TwoTo20 = APIntOps::pow(Two, 10); + APInt V_1024(32, 1024); + EXPECT_EQ(TwoTo20, V_1024); +} + +// Test that 3^3 == 27 +TEST(APIntTest, PowerThreeTo3) { + APInt Three(32, 3); + APInt ThreeTo3 = APIntOps::pow(Three, 3); + APInt V_27(32, 27); + EXPECT_EQ(ThreeTo3, V_27); +} + +// Test that SignedMaxValue^3 == SignedMaxValue +TEST(APIntTest, PowerSignedMaxValue) { + APInt SignedMaxValue = APInt::getSignedMaxValue(32); + APInt MaxTo3 = APIntOps::pow(SignedMaxValue, 3); + EXPECT_EQ(MaxTo3, SignedMaxValue); +} + +// Test that MaxValue^3 == MaxValue +TEST(APIntTest, PowerMaxValue) { + APInt MaxValue = APInt::getMaxValue(32); + APInt MaxTo3 = APIntOps::pow(MaxValue, 3); + EXPECT_EQ(MaxValue, MaxTo3); +} + +// Test that SignedMinValue^3 == 0 +TEST(APIntTest, PowerSignedMinValueTo3) { + APInt SignedMinValue = APInt::getSignedMinValue(32); + APInt MinTo3 = APIntOps::pow(SignedMinValue, 3); + EXPECT_TRUE(MinTo3.isZero()); +} + +// Test that SignedMinValue^1 == SignedMinValue +TEST(APIntTest, PowerSignedMinValueTo1) { + APInt SignedMinValue = APInt::getSignedMinValue(32); + APInt MinTo1 = APIntOps::pow(SignedMinValue, 1); + EXPECT_EQ(SignedMinValue, MinTo1); +} + +// Test that MaxValue^3 == MaxValue +TEST(APIntTest, ZeroToZero) { + APInt Zero = APInt::getZero(32); + APInt One(32, 1); + APInt ZeroToZero = APIntOps::pow(Zero, 0); + EXPECT_EQ(ZeroToZero, One); +} + // Test that APInt shift left works when bitwidth > 64 and shiftamt == 0 TEST(APIntTest, ShiftLeftByZero) { APInt One = APInt::getZero(65) + 1; |