diff options
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp | 9 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir | 66 |
2 files changed, 75 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 84294e4..e1ed5d8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { + // TODO(#78787): Not supported masked op yet. + if (cast<MaskableOpInterface>(read.getOperation()).isMasked()) + return failure(); // TODO: support 0-d corner case. if (read.getTransferRank() == 0) return failure(); @@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { + // TODO(#78787): Not supported masked op yet. + if (cast<MaskableOpInterface>(write.getOperation()).isMasked()) + return failure(); // TODO: support 0-d corner case. if (write.getTransferRank() == 0) return failure(); @@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim LogicalResult mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, RewriterBase &rewriter) { + // TODO(#78787): Not supported masked op yet. + if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked()) + return failure(); VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); if (oldAccType == nullptr) return failure(); diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index 71dffca..f601be0 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -165,6 +165,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra } // ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + +// CHECK-LABEL: not_insert_cast_for_contraction_under_mask +// CHECK: %[[MASK:.+]] = vector.constant_mask +// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] +// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] { +// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> } +// CHECK: return %[[RET]] : vector<1x16x16xf32> + +#contraction_accesses0 = [ + affine_map<(l, i, j, k) -> (l, i, k)>, + affine_map<(l, i, j, k) -> (l, k, j)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + +func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { + %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1> + %0 = vector.mask %mask { + vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> + } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> + return %0 : vector<1x16x16xf32> +} + +// ----- // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16> @@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16 // ----- +// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask +// CHECK: %[[MASK:.+]] = vector.constant_mask +// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] +// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] { +// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> } +// CHECK: return %[[RET]] : vector<1x4xf16> +func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0. : f16 + %mask = vector.constant_mask [1, 3] : vector<1x4xi1> + %ret = vector.mask %mask { + vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16> + } : vector<1x4xi1> -> vector<1x4xf16> + return %ret: vector<1x4xf16> +} + +// ----- + // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index @@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1 // ----- +// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask +// CHECK: %[[MASK:.+]] = vector.constant_mask +// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] +// CHECK: vector.mask %[[CASTED_MASK]] { +// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> } +// CHECK: return +func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) { + %c0 = arith.constant 0 : index + %mask = vector.constant_mask [1, 3] : vector<1x4xi1> + vector.mask %mask { + vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16> + } : vector<1x4xi1> + return +} + +// ----- + // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) { |