aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp')
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 7e390aa..3635cd3 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -133,6 +133,9 @@ public:
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
+
+ // Keep track of the previous accumulator when tiling over K.
+ Value kAcc;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
// Helper to compute the new shape of each operand and extract the slice.
@@ -194,19 +197,26 @@ public:
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
auto collapsedOutputType =
VectorType::get(outputExpandedType.getNumElements(), accElementType);
- auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
- tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+ bool initialKAcc = offsets.back() == 0;
+ Value collapsedRes;
+ if (!initialKAcc) {
+ collapsedRes = kAcc;
+ } else {
+ collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+ }
// Insert contract op
- auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+ kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
- smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ kAcc.getLoc(), tiledAcc.getType(), kAcc);
- // With vecmat, only one row of tiled ACC can be inserted inot file result
+ // With vecmat, only one row of tiled ACC can be inserted into file result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}