From b859a9fc152f83d657b1b59a99ede2478e285b27 Mon Sep 17 00:00:00 2001 From: Kojo Acquah Date: Fri, 19 Apr 2024 07:49:25 -0700 Subject: [mlir][ArmNeon] Update `LowerContractionToSMMLAPattern` to support proper unrolling for k dimension (#88591) Fixes correctness issue with current smmla unrolling patterns whereby unrolling K dimension would only include the result from the last tile along K. Updates patterns to feed previous smmla output of the previous tile into the next one along K. --- .../Transforms/LowerContractionToSMMLAPattern.cpp | 20 +++- mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir | 121 ++++++++++++++++++--- 2 files changed, 118 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp index 7e390aa..3635cd3 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp @@ -133,6 +133,9 @@ public: smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); loopOrder.push_back(2); } + + // Keep track of the previous accumulator when tiling over K. + Value kAcc; for (SmallVector offsets : StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { // Helper to compute the new shape of each operand and extract the slice. @@ -194,19 +197,26 @@ public: tiledRhs.getLoc(), collapsedInputType, tiledRhs); auto collapsedOutputType = VectorType::get(outputExpandedType.getNumElements(), accElementType); - auto collapsedRes = rewriter.createOrFold( - tiledAcc.getLoc(), collapsedOutputType, tiledAcc); + + bool initialKAcc = offsets.back() == 0; + Value collapsedRes; + if (!initialKAcc) { + collapsedRes = kAcc; + } else { + collapsedRes = rewriter.createOrFold( + tiledAcc.getLoc(), collapsedOutputType, tiledAcc); + } // Insert contract op - auto smmlaOp = rewriter.createOrFold( + kAcc = rewriter.createOrFold( op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, collapsedRhs); // Reshape output back to 2D Value tiledRes = rewriter.createOrFold( - smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp); + kAcc.getLoc(), tiledAcc.getType(), kAcc); - // With vecmat, only one row of tiled ACC can be inserted inot file result + // With vecmat, only one row of tiled ACC can be inserted into file result if (isVecmat) { tiledRes = rewriter.createOrFold(loc, tiledRes, 0); } diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir index c276a5b..297be91 100644 --- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir +++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt -test-lower-to-arm-neon -verify-diagnostics -split-input-file %s | FileCheck %s -// CHECK-LABEL: test_lower_vector_arm_neon_mixed_types +// CHECK-LABEL: vector_arm_neon_mixed_types // CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi4>, %[[A2:.*]]: vector<2x2xi32> // CHECK-DAG: %[[D0:.*]] = arith.extsi %[[A1]] : vector<2x8xi4> to vector<2x8xi8> // CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8> @@ -8,7 +8,7 @@ // CHECK-DAG: %[[D3:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32> // CHECK-DAG: %[[D4:.*]] = arm_neon.intr.smmla %[[D3]], %[[D1]], %[[D2]] : vector<16xi8> to vector<4xi32> // CHECK-DAG: %[[D5:.*]] = vector.shape_cast %[[D4]] : vector<4xi32> to vector<2x2xi32> -func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { +func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32> %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32> @@ -17,14 +17,14 @@ func.func @test_lower_vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: ve // ----- -// CHECK-LABEL: test_lower_vector_arm_neon_same_types +// CHECK-LABEL: vector_arm_neon_same_types // CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32> // CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8> // CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8> // CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32> // CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32> // CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32> -func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { +func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32> %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32> @@ -33,17 +33,17 @@ func.func @test_lower_vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vec // ----- -// CHECK-LABEL: test_lower_vector_arm_neon_without_extsi +// CHECK-LABEL: vector_arm_neon_without_extsi // CHECK-SAME: %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32> // CHECK-DAG: %[[D0:.*]] = vector.contract -func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { +func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32> return %res : vector<2x2xi32> } // ----- -// CHECK-LABEL: test_lower_vector_arm_neon_unroll +// CHECK-LABEL: vector_arm_neon_unroll // CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32> // CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32> // CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8> @@ -84,7 +84,7 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: // CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32> // CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32> // CHECK-DAG: } -func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { +func.func @vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32> @@ -93,7 +93,7 @@ func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector< // ----- -// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll( +// CHECK-LABEL: func.func @vector_arm_neon_mixed_unroll( // CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, // CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>, // CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> { @@ -117,7 +117,7 @@ func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector< // CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32> // CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32> // CHECK-DAG: } -func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> { +func.func @vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> { %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32> %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32> @@ -126,9 +126,9 @@ func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: v // ----- -// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape( +// CHECK-LABEL: func.func @vector_arm_neon_unroll_incompatible_shape( // CHECK-DAG: %[[result:.*]] = vector.contract -func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { +func.func @vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> { %lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32> %rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32> @@ -137,7 +137,7 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1 // ----- -// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll( +// CHECK-LABEL: func.func @vector_arm_neon_vecmat_unroll( // CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>, // CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>, // CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> { @@ -190,7 +190,7 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1 // CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32> // CHECK: return %[[VAL_49]] : vector<8xi32> // CHECK: } -func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> { +func.func @vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> { %lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32> %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32> @@ -199,7 +199,7 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: ve // ----- -// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim( +// CHECK-LABEL: func.func @vector_arm_neon_vecmat_unroll_leading_dim( // CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi8>, // CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>, // CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> { @@ -252,7 +252,7 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: ve // CHECK: %[[VAL_49:.*]] = vector.insert_strided_slice %[[VAL_48]], %[[VAL_38]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32> // CHECK: return %[[VAL_49]] : vector<1x8xi32> // CHECK: } -func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> { +func.func @vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> { %lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32> %rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32> @@ -261,11 +261,96 @@ func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8 // ----- -// CHECK-LABEL: func.func @test_lower_vector_arm_neon_matvec +// CHECK-LABEL: func.func @vector_arm_neon_matvec // CHECK-NOT: arm_neon.intr.smmla -func.func @test_lower_vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> { +func.func @vector_arm_neon_matvec(%lhs: vector<8x8xi8>, %rhs: vector<8xi8>, %acc : vector<8xi32>) -> vector<8xi32> { %rhs_extsi= arith.extsi %rhs : vector<8xi8> to vector<8xi32> %lhs_extsi = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32> %res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<8x8xi32>, vector<8xi32> into vector<8xi32> return %res : vector<8xi32> } + + +// ----- + +// CHECK-LABEL: func.func @vector_arm_neon_k_unroll( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x16xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x16xi4>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<2x2xi32>) -> vector<2x2xi32> { +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_1]] : vector<2x16xi4> to vector<2x16xi8> +// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8> +// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8> +// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_2]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[KACC_0:.*]] = arm_neon.intr.smmla %[[VAL_8]], %[[VAL_6]], %[[VAL_7]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8> +// CHECK: %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_3]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x16xi8> to vector<2x8xi8> +// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_11]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[KACC_1:.*]] = arm_neon.intr.smmla %[[KACC_0]], %[[VAL_12]], %[[VAL_13]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[KACC_1]] : vector<4xi32> to vector<2x2xi32> +// CHECK: return %[[VAL_15]] : vector<2x2xi32> +// CHECK: } +func.func @vector_arm_neon_k_unroll(%lhs: vector<2x16xi8>, %rhs: vector<2x16xi4>, %acc : vector<2x2xi32>) -> vector<2x2xi32> { + %lhs_extsi = arith.extsi %lhs : vector<2x16xi8> to vector<2x16xi32> + %rhs_extsi = arith.extsi %rhs : vector<2x16xi4> to vector<2x16xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<2x16xi32>, vector<2x16xi32> into vector<2x2xi32> + return %res : vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @vector_arm_neon_k_unroll_vecmat( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x32xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x32xi4>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<1x2xi32>) -> vector<1x2xi32> { +// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<2x2xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<2x8xi8> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : vector<1x2xi32> +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_1]] : vector<2x32xi4> to vector<2x32xi8> +// CHECK: %[[VAL_7:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8> +// CHECK: %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8> +// CHECK: %[[VAL_9:.*]] = vector.insert_strided_slice %[[VAL_7]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_10:.*]] = vector.insert_strided_slice %[[VAL_2]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<1x2xi32> into vector<2x2xi32> +// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_9]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_13:.*]] = vector.shape_cast %[[VAL_10]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[KACC_0:.*]] = arm_neon.intr.smmla %[[VAL_13]], %[[VAL_11]], %[[VAL_12]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_15:.*]] = vector.shape_cast %[[KACC_0]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_15]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_17:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[VAL_5]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32> +// CHECK: %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 8], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8> +// CHECK: %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 8], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8> +// CHECK: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_18]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_22:.*]] = vector.shape_cast %[[VAL_19]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[KACC_1:.*]] = arm_neon.intr.smmla %[[KACC_0]], %[[VAL_21]], %[[VAL_22]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[KACC_1]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_25:.*]] = vector.extract %[[VAL_24]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_26:.*]] = vector.insert_strided_slice %[[VAL_25]], %[[VAL_17]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32> +// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 16], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8> +// CHECK: %[[VAL_28:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 16], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8> +// CHECK: %[[VAL_29:.*]] = vector.insert_strided_slice %[[VAL_27]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[KACC_2:.*]] = arm_neon.intr.smmla %[[KACC_1]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[KACC_2]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_34:.*]] = vector.extract %[[VAL_33]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_35:.*]] = vector.insert_strided_slice %[[VAL_34]], %[[VAL_26]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32> +// CHECK: %[[VAL_36:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 24], sizes = [1, 8], strides = [1, 1]} : vector<1x32xi8> to vector<1x8xi8> +// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 24], sizes = [2, 8], strides = [1, 1]} : vector<2x32xi8> to vector<2x8xi8> +// CHECK: %[[VAL_38:.*]] = vector.insert_strided_slice %[[VAL_36]], %[[VAL_4]] {offsets = [0, 0], strides = [1, 1]} : vector<1x8xi8> into vector<2x8xi8> +// CHECK: %[[VAL_39:.*]] = vector.shape_cast %[[VAL_38]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[VAL_40:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8> +// CHECK: %[[KACC_3:.*]] = arm_neon.intr.smmla %[[KACC_2]], %[[VAL_39]], %[[VAL_40]] : vector<16xi8> to vector<4xi32> +// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[KACC_3]] : vector<4xi32> to vector<2x2xi32> +// CHECK: %[[VAL_43:.*]] = vector.extract %[[VAL_42]][0] : vector<2xi32> from vector<2x2xi32> +// CHECK: %[[VAL_44:.*]] = vector.insert_strided_slice %[[VAL_43]], %[[VAL_35]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x2xi32> +// CHECK: return %[[VAL_44]] : vector<1x2xi32> +func.func @vector_arm_neon_k_unroll_vecmat(%lhs: vector<1x32xi8>, %rhs: vector<2x32xi4>, %acc : vector<1x2xi32>) -> vector<1x2xi32> { + %lhs_extsi = arith.extsi %lhs : vector<1x32xi8> to vector<1x32xi32> + %rhs_extsi = arith.extsi %rhs : vector<2x32xi4> to vector<2x32xi32> + %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_extsi, %rhs_extsi, %acc : vector<1x32xi32>, vector<2x32xi32> into vector<1x2xi32> + return %res : vector<1x2xi32> +} -- cgit v1.1