aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiego Caballero <diegocaballero@google.com>2023-04-21 18:36:47 +0000
committerDiego Caballero <diegocaballero@google.com>2023-04-21 19:19:01 +0000
commitc8557c7c3e98dddc0abb28913e06c8dae60a6c01 (patch)
treee9b114acab7695b805666c5ce7d1da7b6fd3bfd5
parentfa2a8c2e1cdf188ce985d69fca6f78866390b715 (diff)
downloadllvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.zip
llvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.tar.gz
llvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.tar.bz2
[MLIR][Vector] Enable masked vectorizaton of contraction ops
This patch enables the vectorization of contraction ops using vector masking. Support for vectorizing contractions is already there so this is just adding contraction ops to the list of supported ops in `vectorizeDynamicLinalgOpPrecondition` and adding a test. Reviewed By: hanchung, awarzynski Differential Revision: https://reviews.llvm.org/D148865
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp3
-rw-r--r--mlir/test/Dialect/Linalg/vectorization.mlir50
2 files changed, 47 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c7776f4..55beb2f8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1303,7 +1303,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
// TODO: Masking only supports dynamic generic ops for now.
- if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(op))
+ if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
+ linalg::ContractionOpInterface>(op.getOperation()))
return failure();
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 910c815..f2f6fed 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -2832,7 +2832,7 @@ transform.sequence failures(propagate) {
// CHECK-LABEL: func @test_masked_vectorize_pad
func.func @test_masked_vectorize_pad(
- %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
+ %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)
-> tensor<2x4xf32>
{
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
@@ -2841,9 +2841,9 @@ func.func @test_masked_vectorize_pad(
// CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32>
// CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1>
- // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
- // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]]
- // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
+ // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+ // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]]
+ // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32>
// CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32>
// CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]]
// CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32>
@@ -2857,7 +2857,47 @@ func.func @test_masked_vectorize_pad(
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
: (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [2, 4]
}
+
+// -----
+
+func.func @vectorize_dynamic_matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?x?xf32>, %[[VAL_1:.*]]: memref<?x?xf32>, %[[VAL_2:.*]]: memref<?x?xf32>) {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32>
+// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
+// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
+// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
+// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32>
+// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
+// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
+// CHECK: return
+// CHECK: }
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4]
+}