diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 59 |
1 files changed, 44 insertions, 15 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 7cae94eb..3487e81 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout( static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt", cl::init(false)); +static cl::opt<unsigned> SplitMatmulRemainderOverThreshold( + "matrix-split-matmul-remainder-over-threshold", cl::Hidden, + cl::desc("Illegal remainder vectors over this size in bits should be split " + "in the inner loop of matmul"), + cl::init(0)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -115,18 +121,16 @@ static bool isSplat(Value *V) { /// Match any mul operation (fp or integer). template <typename LTy, typename RTy> -auto m_AnyMul(const LTy &L, const RTy &R) { +static auto m_AnyMul(const LTy &L, const RTy &R) { return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); } /// Match any add operation (fp or integer). template <typename LTy, typename RTy> -auto m_AnyAdd(const LTy &L, const RTy &R) { +static auto m_AnyAdd(const LTy &L, const RTy &R) { return m_CombineOr(m_Add(L, R), m_FAdd(L, R)); } -namespace { - // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute // the start address of vector \p VecIdx with type (\p EltType x \p NumElements) // assuming \p Stride elements between start two consecutive vectors. @@ -167,9 +171,9 @@ namespace { // v_2_0 |v_2_1 |v_2_2 |v_2_3 // v_3_0 {v_3_1 {v_3_2 v_3_3 // -Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, - unsigned NumElements, Type *EltType, - IRBuilder<> &Builder) { +static Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, + unsigned NumElements, Type *EltType, + IRBuilder<> &Builder) { assert((!isa<ConstantInt>(Stride) || cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && @@ -338,6 +342,8 @@ computeShapeInfoForInst(Instruction *I, return std::nullopt; } +namespace { + /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. /// /// Currently, the lowering for each matrix intrinsic is done as follows: @@ -371,7 +377,8 @@ class LowerMatrixIntrinsics { LoopInfo *LI = nullptr; OptimizationRemarkEmitter *ORE = nullptr; - /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. + /// Contains estimates of the number of operations (loads, stores, compute) + /// required to lower a matrix operation. struct OpInfoTy { /// Number of stores emitted to generate this matrix. unsigned NumStores = 0; @@ -1719,6 +1726,31 @@ public: ToRemove.push_back(MatMul); } + /// Given \p Remainder iterations of the the matmul inner loop, + /// potentially lower \p Blocksize that is used for the underlying + /// vector. + unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) { + if (BlockSize <= Remainder) + return BlockSize; + + // If the remainder is also a legal type just use it. + auto *VecTy = FixedVectorType::get(EltType, Remainder); + if (TTI.isTypeLegal(VecTy)) + return Remainder; + + // Similarly, if the vector is small enough that we don't want + // to split further. + if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold) + return Remainder; + + // Gradually lower the vectorization factor to cover the + // remainder. + do { + BlockSize /= 2; + } while (BlockSize > Remainder); + return BlockSize; + } + /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. /// @@ -1756,10 +1788,8 @@ public: bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); for (unsigned I = 0; I < R; I += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (I + BlockSize > R) - BlockSize /= 2; - + // Lower block size to make sure we stay within bounds. + BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType()); Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder) : nullptr; for (unsigned K = 0; K < M; ++K) { @@ -1784,9 +1814,8 @@ public: unsigned BlockSize = VF; bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); for (unsigned J = 0; J < C; J += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (J + BlockSize > C) - BlockSize /= 2; + // Lower the vectorization factor to cover the remainder. + BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType()); Value *Sum = nullptr; for (unsigned K = 0; K < M; ++K) { |