diff options
author | Diego Caballero <diegocaballero@google.com> | 2023-04-21 18:36:47 +0000 |
---|---|---|
committer | Diego Caballero <diegocaballero@google.com> | 2023-04-21 19:19:01 +0000 |
commit | c8557c7c3e98dddc0abb28913e06c8dae60a6c01 (patch) | |
tree | e9b114acab7695b805666c5ce7d1da7b6fd3bfd5 | |
parent | fa2a8c2e1cdf188ce985d69fca6f78866390b715 (diff) | |
download | llvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.zip llvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.tar.gz llvm-c8557c7c3e98dddc0abb28913e06c8dae60a6c01.tar.bz2 |
[MLIR][Vector] Enable masked vectorizaton of contraction ops
This patch enables the vectorization of contraction ops using vector
masking. Support for vectorizing contractions is already there so this
is just adding contraction ops to the list of supported ops in
`vectorizeDynamicLinalgOpPrecondition` and adding a test.
Reviewed By: hanchung, awarzynski
Differential Revision: https://reviews.llvm.org/D148865
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/vectorization.mlir | 50 |
2 files changed, 47 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index c7776f4..55beb2f8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1303,7 +1303,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) { static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { // TODO: Masking only supports dynamic generic ops for now. - if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp>(op)) + if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp, + linalg::ContractionOpInterface>(op.getOperation())) return failure(); LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 910c815..f2f6fed 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2832,7 +2832,7 @@ transform.sequence failures(propagate) { // CHECK-LABEL: func @test_masked_vectorize_pad func.func @test_masked_vectorize_pad( - %0 : tensor<?x?xf32>, %h0 : index, %h1 : index) + %0 : tensor<?x?xf32>, %h0 : index, %h1 : index) -> tensor<2x4xf32> { // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index @@ -2841,9 +2841,9 @@ func.func @test_masked_vectorize_pad( // CHECK: %[[d0:.*]] = tensor.dim {{.*}} : tensor<?x?xf32> // CHECK: %[[d1:.*]] = tensor.dim {{.*}} : tensor<?x?xf32> // CHECK: %[[mask:.*]] = vector.create_mask %[[d0]], %[[d1]] : vector<2x4xi1> - // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { - // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] - // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32> + // CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] { + // CHECK-SAME: vector.transfer_read %{{.*}}[%[[c0]], %[[c0]]], %[[c42]] + // CHECK-SAME: {in_bounds = [true, true]} : tensor<?x?xf32>, vector<2x4xf32> // CHECK-SAME: } : vector<2x4xi1> -> vector<2x4xf32> // CHECK: vector.transfer_write %[[masked_read]], %[[empty]][%[[c0]], %[[c0]]] // CHECK-SAME: {in_bounds = [true, true]} : vector<2x4xf32>, tensor<2x4xf32> @@ -2857,7 +2857,47 @@ func.func @test_masked_vectorize_pad( transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.masked_vectorize %0 vector_sizes [2, 4] } + +// ----- + +func.func @vectorize_dynamic_matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) { + linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>) + outs(%C: memref<?x?xf32>) + return +} + +// CHECK-LABEL: func.func @vectorize_dynamic_matmul( +// CHECK-SAME: %[[VAL_0:.*]]: memref<?x?xf32>, %[[VAL_1:.*]]: memref<?x?xf32>, %[[VAL_2:.*]]: memref<?x?xf32>) { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?x?xf32> +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?x?xf32> +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32> +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32> +// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32> +// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1> +// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32> +// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1> +// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1> +// CHECK: return +// CHECK: } + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] +} |