diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 42 |
1 files changed, 35 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index ed4e2b1..3487e81 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false)); +static cl::opt<unsigned> SplitMatmulRemainderOverThreshold( + "matrix-split-matmul-remainder-over-threshold", cl::Hidden, + cl::desc("Illegal remainder vectors over this size in bits should be split " + "in the inner loop of matmul"), + cl::init(0)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -1720,6 +1726,31 @@ public: ToRemove.push_back(MatMul); } + /// Given \p Remainder iterations of the the matmul inner loop, + /// potentially lower \p Blocksize that is used for the underlying + /// vector. + unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) { + if (BlockSize <= Remainder) + return BlockSize; + + // If the remainder is also a legal type just use it. + auto *VecTy = FixedVectorType::get(EltType, Remainder); + if (TTI.isTypeLegal(VecTy)) + return Remainder; + + // Similarly, if the vector is small enough that we don't want + // to split further. + if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold) + return Remainder; + + // Gradually lower the vectorization factor to cover the + // remainder. + do { + BlockSize /= 2; + } while (BlockSize > Remainder); + return BlockSize; + } + /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. /// @@ -1757,10 +1788,8 @@ public: bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); for (unsigned I = 0; I < R; I += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (I + BlockSize > R) - BlockSize /= 2; - + // Lower block size to make sure we stay within bounds. + BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType()); Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) : nullptr; for (unsigned K = 0; K < M; ++K) { @@ -1785,9 +1814,8 @@ public: unsigned BlockSize = VF; bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); for (unsigned J = 0; J < C; J += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (J + BlockSize > C) - BlockSize /= 2; + // Lower the vectorization factor to cover the remainder. + BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType()); Value *Sum = nullptr; for (unsigned K = 0; K < M; ++K) { |