diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 56e0569..7cae94eb 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1295,6 +1295,24 @@ public: return commonAlignment(InitialAlign, ElementSizeInBits / 8); } + IntegerType *getIndexType(Value *Ptr) const { + return cast<IntegerType>(DL.getIndexType(Ptr->getType())); + } + + Value *getIndex(Value *Ptr, uint64_t V) const { + return ConstantInt::get(getIndexType(Ptr), V); + } + + Value *castToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const { + assert(isa<IntegerType>(V->getType()) && + "Attempted to cast non-integral type to integer index"); + // In case the data layout's index type differs in width from the type of + // the value we're given, truncate or zero extend to the appropriate width. + // We zero extend here as indices are unsigned. + return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr), + V->getName() + ".cast"); + } + /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between /// vectors. MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, @@ -1304,6 +1322,7 @@ public: Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride()); Value *EltPtr = Ptr; MatrixTy Result; + Stride = castToIndexType(Ptr, Stride, Builder); for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { Value *GEP = computeVectorAddr( EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I), @@ -1325,14 +1344,14 @@ public: ShapeInfo ResultShape, Type *EltTy, IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( - Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I); Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); return loadMatrix(TileTy, TileStart, Align, - Builder.getInt64(MatrixShape.getStride()), IsVolatile, + getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile, ResultShape, Builder); } @@ -1363,14 +1382,15 @@ public: MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( - Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I); Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset); auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * StoreVal.getNumColumns()); storeMatrix(TileTy, StoreVal, TileStart, MAlign, - Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); + getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile, + Builder); } /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between @@ -1380,6 +1400,7 @@ public: IRBuilder<> &Builder) { auto *VType = cast<FixedVectorType>(Ty); Value *EltPtr = Ptr; + Stride = castToIndexType(Ptr, Stride, Builder); for (auto Vec : enumerate(StoreVal.vectors())) { Value *GEP = computeVectorAddr( EltPtr, @@ -2011,18 +2032,17 @@ public: const unsigned TileM = std::min(M - K, unsigned(TileSize)); MatrixTy A = loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), - LShape, Builder.getInt64(I), Builder.getInt64(K), + LShape, getIndex(APtr, I), getIndex(APtr, K), {TileR, TileM}, EltType, Builder); MatrixTy B = loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), - RShape, Builder.getInt64(K), Builder.getInt64(J), + RShape, getIndex(BPtr, K), getIndex(BPtr, J), {TileM, TileC}, EltType, Builder); emitMatrixMultiply(Res, A, B, Builder, true, false, getFastMathFlags(MatMul)); } storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, - Builder.getInt64(I), Builder.getInt64(J), EltType, - Builder); + getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder); } } @@ -2254,15 +2274,14 @@ public: /// Lower load instructions. MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) { - return LowerLoad(Inst, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, - Builder); + return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()), + Inst->isVolatile(), SI, Builder); } MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr, IRBuilder<> &Builder) { return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), - Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI, + getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI, Builder); } |