aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp42
1 files changed, 35 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index ed4e2b1..3487e81 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));
+static cl::opt<unsigned> SplitMatmulRemainderOverThreshold(
+ "matrix-split-matmul-remainder-over-threshold", cl::Hidden,
+ cl::desc("Illegal remainder vectors over this size in bits should be split "
+ "in the inner loop of matmul"),
+ cl::init(0));
+
/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -1720,6 +1726,31 @@ public:
ToRemove.push_back(MatMul);
}
+ /// Given \p Remainder iterations of the the matmul inner loop,
+ /// potentially lower \p Blocksize that is used for the underlying
+ /// vector.
+ unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
+ if (BlockSize <= Remainder)
+ return BlockSize;
+
+ // If the remainder is also a legal type just use it.
+ auto *VecTy = FixedVectorType::get(EltType, Remainder);
+ if (TTI.isTypeLegal(VecTy))
+ return Remainder;
+
+ // Similarly, if the vector is small enough that we don't want
+ // to split further.
+ if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold)
+ return Remainder;
+
+ // Gradually lower the vectorization factor to cover the
+ // remainder.
+ do {
+ BlockSize /= 2;
+ } while (BlockSize > Remainder);
+ return BlockSize;
+ }
+
/// Compute \p Result += \p A * \p B for input matrices with left-associating
/// addition.
///
@@ -1757,10 +1788,8 @@ public:
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
+ // Lower block size to make sure we stay within bounds.
+ BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
: nullptr;
for (unsigned K = 0; K < M; ++K) {
@@ -1785,9 +1814,8 @@ public:
unsigned BlockSize = VF;
bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
for (unsigned J = 0; J < C; J += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (J + BlockSize > C)
- BlockSize /= 2;
+ // Lower the vectorization factor to cover the remainder.
+ BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());
Value *Sum = nullptr;
for (unsigned K = 0; K < M; ++K) {