diff options
author | Jon Roelofs <jonathan_roelofs@apple.com> | 2025-06-10 15:36:37 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-10 12:36:37 -0700 |
commit | 8345d62478054d4ab97c6f28cfea6d1ecca837da (patch) | |
tree | 234a323abd2b52204dd44061be763e838c19e2f4 | |
parent | 77da1257b61c728a4d35dc518bfb758d0b1ddf26 (diff) | |
download | llvm-8345d62478054d4ab97c6f28cfea6d1ecca837da.zip llvm-8345d62478054d4ab97c6f28cfea6d1ecca837da.tar.gz llvm-8345d62478054d4ab97c6f28cfea6d1ecca837da.tar.bz2 |
[Matrix] Hoist finalizeLowering into caller. NFC (#143038)
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 144 |
1 files changed, 66 insertions, 78 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 2168308..eb81d2e 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1134,26 +1134,28 @@ public: if (FusedInsts.count(Inst)) continue; - IRBuilder<> Builder(Inst); - const ShapeInfo &SI = ShapeMap.at(Inst); Value *Op1; Value *Op2; + MatrixTy Result; if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) - VisitBinaryOperator(BinOp, SI); + Result = VisitBinaryOperator(BinOp, SI); else if (auto *Cast = dyn_cast<CastInst>(Inst)) - VisitCastInstruction(Cast, SI); + Result = VisitCastInstruction(Cast, SI); else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst)) - VisitUnaryOperator(UnOp, SI); - else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst)) - VisitIntrinsicInst(Intr, SI); + Result = VisitUnaryOperator(UnOp, SI); + else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst)) + Result = VisitIntrinsicInst(Intr, SI); else if (match(Inst, m_Load(m_Value(Op1)))) - VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder); + Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) - VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder); + Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2); else continue; + + IRBuilder<> Builder(Inst); + finalizeLowering(Inst, Result, Builder); Changed = true; } @@ -1193,25 +1195,24 @@ public: } /// Replace intrinsic calls. - void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) { - switch (Inst->getIntrinsicID()) { + MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) { + assert(Inst->getCalledFunction() && + Inst->getCalledFunction()->isIntrinsic()); + + switch (Inst->getCalledFunction()->getIntrinsicID()) { case Intrinsic::matrix_multiply: - LowerMultiply(Inst); - return; + return LowerMultiply(Inst); case Intrinsic::matrix_transpose: - LowerTranspose(Inst); - return; + return LowerTranspose(Inst); case Intrinsic::matrix_column_major_load: - LowerColumnMajorLoad(Inst); - return; + return LowerColumnMajorLoad(Inst); case Intrinsic::matrix_column_major_store: - LowerColumnMajorStore(Inst); - return; + return LowerColumnMajorStore(Inst); case Intrinsic::abs: case Intrinsic::fabs: { IRBuilder<> Builder(Inst); MatrixTy Result; - MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder); + MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder); Builder.setFastMathFlags(getFastMathFlags(Inst)); for (auto &Vector : M.vectors()) { @@ -1229,16 +1230,14 @@ public: } } - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); - return; + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } default: - llvm_unreachable( - "only intrinsics supporting shape info should be seen here"); + break; } + llvm_unreachable( + "only intrinsics supporting shape info should be seen here"); } /// Compute the alignment for a column/row \p Idx with \p Stride between them. @@ -1304,26 +1303,24 @@ public: } /// Lower a load instruction with shape information. - void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, - bool IsVolatile, ShapeInfo Shape) { + MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, + Value *Stride, bool IsVolatile, ShapeInfo Shape) { IRBuilder<> Builder(Inst); - finalizeLowering(Inst, - loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, - Shape, Builder), - Builder); + return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape, + Builder); } /// Lowers llvm.matrix.column.major.load. /// /// The intrinsic loads a matrix from memory using a stride between columns. - void LowerColumnMajorLoad(CallInst *Inst) { + MatrixTy LowerColumnMajorLoad(CallInst *Inst) { assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && "Intrinsic only supports column-major layout!"); Value *Ptr = Inst->getArgOperand(0); Value *Stride = Inst->getArgOperand(1); - LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, - cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), - {Inst->getArgOperand(3), Inst->getArgOperand(4)}); + return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, + cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), + {Inst->getArgOperand(3), Inst->getArgOperand(4)}); } /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p @@ -1366,28 +1363,27 @@ public: } /// Lower a store instruction with shape information. - void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, - Value *Stride, bool IsVolatile, ShapeInfo Shape) { + MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, + MaybeAlign A, Value *Stride, bool IsVolatile, + ShapeInfo Shape) { IRBuilder<> Builder(Inst); auto StoreVal = getMatrix(Matrix, Shape, Builder); - finalizeLowering(Inst, - storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, - IsVolatile, Builder), - Builder); + return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile, + Builder); } /// Lowers llvm.matrix.column.major.store. /// /// The intrinsic store a matrix back memory using a stride between columns. - void LowerColumnMajorStore(CallInst *Inst) { + MatrixTy LowerColumnMajorStore(CallInst *Inst) { assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && "Intrinsic only supports column-major layout!"); Value *Matrix = Inst->getArgOperand(0); Value *Ptr = Inst->getArgOperand(1); Value *Stride = Inst->getArgOperand(2); - LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, - cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), - {Inst->getArgOperand(4), Inst->getArgOperand(5)}); + return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, + cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), + {Inst->getArgOperand(4), Inst->getArgOperand(5)}); } // Set elements I..I+NumElts-1 to Block @@ -2162,7 +2158,7 @@ public: } /// Lowers llvm.matrix.multiply. - void LowerMultiply(CallInst *MatMul) { + MatrixTy LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType(); ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); @@ -2184,11 +2180,11 @@ public: emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, getFastMathFlags(MatMul)); - finalizeLowering(MatMul, Result, Builder); + return Result; } /// Lowers llvm.matrix.transpose. - void LowerTranspose(CallInst *Inst) { + MatrixTy LowerTranspose(CallInst *Inst) { MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); @@ -2218,28 +2214,26 @@ public: // TODO: Improve estimate of operations needed for transposes. Currently we // just count the insertelement/extractelement instructions, but do not // account for later simplifications/combines. - finalizeLowering( - Inst, - Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) - .addNumExposedTransposes(1), - Builder); + return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) + .addNumExposedTransposes(1); } /// Lower load instructions. - void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, - IRBuilder<> &Builder) { - LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()), - Inst->isVolatile(), SI); + MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) { + IRBuilder<> Builder(Inst); + return LowerLoad(Inst, Ptr, Inst->getAlign(), + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); } - void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, - Value *Ptr, IRBuilder<> &Builder) { - LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); + MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, + Value *Ptr) { + IRBuilder<> Builder(Inst); + return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), + Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI); } /// Lower binary operators. - void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) { + MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) { Value *Lhs = Inst->getOperand(0); Value *Rhs = Inst->getOperand(1); @@ -2258,14 +2252,12 @@ public: Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I), B.getVector(I))); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Lower unary operators. - void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) { + MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) { Value *Op = Inst->getOperand(0); IRBuilder<> Builder(Inst); @@ -2288,14 +2280,12 @@ public: for (unsigned I = 0; I < SI.getNumVectors(); ++I) Result.addVector(BuildVectorOp(M.getVector(I))); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Lower cast instructions. - void VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) { + MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) { Value *Op = Inst->getOperand(0); IRBuilder<> Builder(Inst); @@ -2312,10 +2302,8 @@ public: for (auto &Vector : M.vectors()) Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy)); - finalizeLowering(Inst, - Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * - Result.getNumVectors()), - Builder); + return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()); } /// Helper to linearize a matrix expression tree into a string. Currently |