diff options
author | Kojo Acquah <KoolJBlack@users.noreply.github.com> | 2024-03-19 10:09:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-19 13:09:33 -0400 |
commit | fe84369cc6759194e006f3f624a064bce13c84d4 (patch) | |
tree | 7b3c2cd4511a83862860db403f2bda113d5e4ef1 | |
parent | ab76052fa9331f418d7911cafefabd4dd0c1941e (diff) | |
download | llvm-fe84369cc6759194e006f3f624a064bce13c84d4.zip llvm-fe84369cc6759194e006f3f624a064bce13c84d4.tar.gz llvm-fe84369cc6759194e006f3f624a064bce13c84d4.tar.bz2 |
[mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (#84848)
This patch updates `LowerContractionToSMMLAPattern` to unroll larger vector contracts into multiple smmla instructions.
Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row.
-rw-r--r-- | mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp | 112 | ||||
-rw-r--r-- | mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir | 94 |
2 files changed, 174 insertions, 32 deletions
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 47c8470..1f48d27 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/ArmNeon/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) { return element; } -/// Lowering from a single vector::contractOp directly to the arm neon smmla -/// intrinsic. The shapes of the contract and intrinsic must match. +/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile +/// any vector.contract into multiple smmla instructions with unrolling so long +/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single +/// smmla instruction is emitted. class LowerContractionToSMMLAPattern : public OpRewritePattern<vector::ContractionOp> { public: @@ -45,10 +49,6 @@ public: LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - Value res = op.getAcc(); - // Check index maps that represent M N K in contract. auto indexingMaps = op.getIndexingMapsArray(); if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) { @@ -57,7 +57,6 @@ public: })) { return failure(); } - // Check iterator types for contract. auto iteratorTypes = op.getIteratorTypesArray(); if (iteratorTypes.size() != 3 || @@ -66,22 +65,24 @@ public: iteratorTypes[2] != vector::IteratorType::reduction) { return failure(); } - - // Check the tile size by mapping the dimensions of the contract. + // Infer tile sizes from operands; Note: RHS is not transposed. mlir::VectorType lhsType = op.getLhsType(); mlir::VectorType rhsType = op.getRhsType(); auto dimM = lhsType.getDimSize(0); auto dimN = rhsType.getDimSize(0); auto dimK = lhsType.getDimSize(1); - if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) { + + // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for + // tiling. + if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) { return failure(); } // Check two extsi inputs Rhs Lhs for contract. arith::ExtSIOp origLhsExtOp = - dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp()); + dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); arith::ExtSIOp origRhsExtOp = - dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp()); + dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); if (!origLhsExtOp || !origRhsExtOp) { return failure(); } @@ -113,26 +114,73 @@ public: return failure(); } - // Collapse to 1D vectors required by smmla intrinsic - auto collapsedInputType = VectorType::get( - {16}, extsiLhs.getType().cast<ShapedType>().getElementType()); - auto collapsedOutputType = - VectorType::get({4}, res.getType().cast<ShapedType>().getElementType()); - auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( - extsiLhs.getLoc(), collapsedInputType, extsiLhs); - auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( - extsiRhs.getLoc(), collapsedInputType, extsiRhs); - auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( - res.getLoc(), collapsedOutputType, res); - - // Replace the contract with a neon op - auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>( - op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, - collapsedRhs); - - // Reshape output back to 2D - rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), - smmlaOp); + // Initial accumulator for the final result. This is the un-tiled result if + // tiling is done. + Value result = rewriter.create<arith::ConstantOp>( + loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); + + SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); + SmallVector<int64_t> smmlaShape{2, 2, 8}; + SmallVector<int64_t> loopOrder{0, 1, 2}; + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { + + // Helper to compute the new shape of each operand and extract the slice. + auto extractOperand = [&](Value operand, AffineMap permutationMap, + ArrayRef<int64_t> operandOffsets) { + SmallVector<int64_t> operandShape = + applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); + SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); + return rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand, operandOffsets, operandShape, operandStrides); + }; + + // Extract tiled lhs, rhs, and acc + AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; + SmallVector<int64_t> lhsOffsets = + applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); + Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); + AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; + SmallVector<int64_t> rhsOffsets = + applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); + Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); + AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; + SmallVector<int64_t> accOffsets = + applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); + Value tiledAcc = + extractOperand(op.getAcc(), accPermutationMap, accOffsets); + + // Collapse tiled operands to 1D vectors required by smmla intrinsic + auto collapsedInputType = VectorType::get( + tiledLhs.getType().cast<ShapedType>().getNumElements(), + tiledLhs.getType().cast<ShapedType>().getElementType()); + auto collapsedOutputType = VectorType::get( + {4}, tiledAcc.getType().cast<ShapedType>().getElementType()); + auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( + tiledLhs.getLoc(), collapsedInputType, tiledLhs); + auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( + tiledRhs.getLoc(), collapsedInputType, tiledRhs); + auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( + tiledAcc.getLoc(), collapsedOutputType, tiledAcc); + + // Insert contract op + auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>( + op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, + collapsedRhs); + + // Reshape output back to 2D + Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( + smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp); + + // Insert the tiled result back into the non tiled result of the + // contract op. + SmallVector<int64_t> strides( + tiledRes.getType().cast<ShapedType>().getRank(), 1); + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( + loc, tiledRes, result, accOffsets, strides); + } + + rewriter.replaceOp(op, result); return success(); } }; diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir index cba7b00..e2be8745 100644 --- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir +++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir @@ -40,3 +40,97 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32> return %res : vector<2x2xi32> } + +// ----- + +// CHECK-LABEL: test_lower_vector_arm_neon_unroll +// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32> +// CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32> +// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32> +// CHECK-DAG: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32> +// CHECK-DAG: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32> +// CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32> +// CHECK-DAG: } +func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { + %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll( +// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32> +// CHECK-DAG: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32> +// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> +// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8> +// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32> +// CHECK-DAG: %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32> +// CHECK-DAG: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32> +// CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32> +// CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32> +// CHECK-DAG: } +func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> { + %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> + %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32> + return %res : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape( +// CHECK-DAG: %[[result:.*]] = vector.contract +func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { + %lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32> + %rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32> + return %res : vector<4x4xi32> +} |