diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-03-22 09:37:43 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 09:37:43 +0000 |
commit | 5f1b2cffe5fab0aa733fc8d5f1546c1c800faac4 (patch) | |
tree | 445850252858e318e984a43d7653a569049159bf /mlir | |
parent | c5f839bd58e7f888acc4cb39a18e9e5bbaa9fb0a (diff) | |
download | llvm-5f1b2cffe5fab0aa733fc8d5f1546c1c800faac4.zip llvm-5f1b2cffe5fab0aa733fc8d5f1546c1c800faac4.tar.gz llvm-5f1b2cffe5fab0aa733fc8d5f1546c1c800faac4.tar.bz2 |
[mlir][vector] Add support for masks in castAwayContractionLeadingOneDim (#81906)
Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.
Builds on top of #83827
Diffstat (limited to 'mlir')
4 files changed, 114 insertions, 56 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h index 08d3bb1..1f7d641 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -110,8 +110,10 @@ void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp); /// Cast away the leading unit dim, if exists, for the given contract op. /// Return success if the transformation applies; return failure otherwise. -LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, - RewriterBase &rewriter); +FailureOr<Value> +castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, + MaskingOpInterface maskingOp, + RewriterBase &rewriter); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 35e76a8..2c548fb 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -127,8 +127,8 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics, /// responsible for providing an updated ("rewritten") version of: /// a. the source Op when mask _is not_ present, /// b. the source Op and the masking Op when mask _is_ present. -/// Note that the return value from `matchAndRewriteMaskableOp` depends on the -/// case above. +/// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that +/// the return value will depend on the case above. template <class SourceOp> struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> { using OpRewritePattern<SourceOp>::OpRewritePattern; @@ -162,9 +162,9 @@ private: } public: - // Matches SourceOp that can potentially be masked with `maskingOp`. If the - // latter is present, returns an updated masking op (with a replacement for - // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`. + // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the + // latter is present, returns a replacement for `maskingOp`. Otherwise, + // returns a replacement for `sourceOp`. virtual FailureOr<Value> matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const = 0; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 74382b0..593c1e5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -329,12 +329,10 @@ struct CastAwayTransferWriteLeadingOneDim } // namespace -LogicalResult +FailureOr<Value> mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, + MaskingOpInterface maskingOp, RewriterBase &rewriter) { - // TODO(#78787): Not supported masked op yet. - if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked()) - return failure(); VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); if (oldAccType == nullptr) return failure(); @@ -368,6 +366,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}; SmallVector<Value> newOperands; + auto loc = contractOp.getLoc(); for (const auto &it : llvm::enumerate(oldIndexingMaps)) { // Check if the dim to be dropped exists as a leading dim in the operand @@ -405,7 +404,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, map = AffineMap::get(map.getNumDims(), 0, transposeResults, contractOp.getContext()); operands[it.index()] = rewriter.create<vector::TransposeOp>( - contractOp.getLoc(), operands[it.index()], perm); + loc, operands[it.index()], perm); } } // We have taken care to have the dim to be dropped be @@ -429,18 +428,29 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, // Extract if its a valid extraction, otherwise use the operand // without extraction. newOperands.push_back( - validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(), - operands[it.index()], - splatZero(dropDim)) + validExtract ? rewriter.create<vector::ExtractOp>( + loc, operands[it.index()], splatZero(dropDim)) : operands[it.index()]); } - auto newContractOp = rewriter.create<vector::ContractionOp>( - contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], + + // Depending on whether this vector.contract is masked, the replacing Op + // should either be a new vector.contract Op or vector.mask Op. + Operation *newOp = rewriter.create<vector::ContractionOp>( + loc, newOperands[0], newOperands[1], newOperands[2], rewriter.getAffineMapArrayAttr(newIndexingMaps), rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); - rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - contractOp, contractOp->getResultTypes()[0], newContractOp); - return success(); + + if (maskingOp) { + auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(), + splatZero(dropDim)); + + newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); + } + + return rewriter + .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], + newOp->getResults()[0]) + .getResult(); } namespace { @@ -450,12 +460,14 @@ namespace { /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required /// prior to extract. struct CastAwayContractionLeadingOneDim - : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { - return castAwayContractionLeadingOneDim(contractOp, rewriter); + : public MaskableOpRewritePattern<vector::ContractionOp> { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + FailureOr<Value> + matchAndRewriteMaskableOp(vector::ContractionOp contractOp, + MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter); } }; diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index af6e636..4ba51c5 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar } // ----- +// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask +// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1> +// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32> +// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32> +// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] { +// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} +// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> +// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> +// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK: return %[[RES]] : vector<1x16x16xf32> + +#contraction_accesses0 = [ + affine_map<(l, i, j, k) -> (l, i, k)>, + affine_map<(l, i, j, k) -> (l, k, j)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + +func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { + %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1> + %0 = vector.mask %mask { + vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> + } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> + return %0 : vector<1x16x16xf32> +} + +// ----- +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask +// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32> +// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32> +// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1> +// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] { +// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} +// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> +// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32> + +#contraction_accesses0 = [ + affine_map<(l, i, j, k) -> (l, i, k)>, + affine_map<(l, i, j, k) -> (l, k, j)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + +func.func @cast_away_contraction_leading_one_dim_under_mask( + %arg0: vector<1x16x8xf32>, + %arg1: vector<1x8x16xf32>, + %arg2: vector<1x16x16xf32>, + %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> { + %0 = vector.mask %mask { + vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> + } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> + return %0: vector<1x16x16xf32> +} + +// ----- + // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)> @@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra return %0: vector<1x1x2x16xf32> } -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> - -// CHECK-LABEL: not_insert_cast_for_contraction_under_mask -// CHECK: %[[MASK:.+]] = vector.constant_mask -// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] -// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] { -// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> } -// CHECK: return %[[RET]] : vector<1x16x16xf32> - -#contraction_accesses0 = [ - affine_map<(l, i, j, k) -> (l, i, k)>, - affine_map<(l, i, j, k) -> (l, k, j)>, - affine_map<(l, i, j, k) -> (l, i, j)> -] -#contraction_trait0 = { - indexing_maps = #contraction_accesses0, - iterator_types = ["parallel", "parallel", "parallel", "reduction"] -} - -func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { - %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1> - %0 = vector.mask %mask { - vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> - } : vector<1x16x16x8xi1> -> vector<1x16x16xf32> - return %0 : vector<1x16x16xf32> -} // ----- // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims |