diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-12-06 21:35:03 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-06 21:35:03 +0000 |
commit | 03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1 (patch) | |
tree | 80775ae1467c72091b3b1a282733db24b2aa85ab /mlir | |
parent | 98ce2debc6ff3f6d31d7b63eb54e10e88a84ee78 (diff) | |
download | llvm-03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1.zip llvm-03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1.tar.gz llvm-03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1.tar.bz2 |
[mlir][linalg][conv] Flatten the channel dimension when vectorizing (#71918)
The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:
```mlir
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<1> : vector<1xi64>,
strides = dense<1> : vector<1xi64>}
ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```
That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.
Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/filter/output using
`vector.shape_cast` operation:
```mlir
%sc_input = vector.shape_cast %input : vector<1x8x3xi8> to vector<1x24xi8>
%sc_output = vector.shape_cast %output : vector<1x8x3xi8> to vector<1x24xi8>
%b_filter = vector.broadcast %filter : vector<3xi8> to vector<1x8x3xi8>
%sc_filter = vector.shape_cast %b_filter : vector<1x8x3xi8> to vector<1x24xi8>
```
This new vectorization mode is implemented in `depthwiseConv` by
inserting `vector.shape_cast` Ops before and after
`depthwiseConv1dSliceAsMulAcc` is invoked. It can be selected through
e.g. a transform dialect attribute:
```mlir
transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```
A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.
Co-authored by: Bradley Smith <bradley.smith@arm.com>
Diffstat (limited to 'mlir')
5 files changed, 388 insertions, 29 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 002926f..de65f31 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2038,6 +2038,7 @@ def VectorizeChildrenAndApplyPatternsOp : let arguments = (ins TransformHandleTypeInterface:$target, UnitAttr:$vectorize_padding, UnitAttr:$vectorize_nd_extract, + UnitAttr:$flatten_1d_depthwise_conv, UnitAttr:$disable_multi_reduction_to_contract_patterns, UnitAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs TransformHandleTypeInterface:$transformed); @@ -2049,7 +2050,8 @@ def VectorizeChildrenAndApplyPatternsOp : let builders = [ OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding, - CArg<"bool", "false">:$vectorizeNDExtract)>, + CArg<"bool", "false">:$vectorizeNDExtract, + CArg<"bool", "false">:$flatten1DDepthwise)> ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 6c4e16b..3f4dfe4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/); LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes = {}, ArrayRef<bool> inputScalableVecDims = {}, - bool vectorizeNDExtract = false); + bool vectorizeNDExtract = false, + bool flatten1DDepthwiseConv = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 18ee36e..e371345 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2946,7 +2946,7 @@ LogicalResult TileUsingForallOp::verify() { void transform::VectorizeChildrenAndApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - bool vectorizePadding, bool vectorizeExtract) { + bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { result.addOperands(target); if (vectorizePadding) { result.addAttribute( @@ -2960,6 +2960,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build( result.name), builder.getUnitAttr()); } + if (flatten1DDepthwiseConv) { + result.addAttribute( + VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName( + result.name), + builder.getUnitAttr()); + } result.addTypes(transform::AnyOpType::get(builder.getContext())); } @@ -2968,22 +2974,29 @@ namespace { /// VectorizeChildrenAndApplyPatternsOp::applyToOne. struct VectorizationPattern : public RewritePattern { explicit VectorizationPattern(MLIRContext *context, - bool vectorizeExtract = false) + bool vectorizeExtract = false, + bool flattenConv = false) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - vectorizeNDExtract(vectorizeExtract) {} + vectorizeNDExtract(vectorizeExtract), + flatten1DDepthwiseConv(flattenConv) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { LinalgOp linalgOp = dyn_cast<LinalgOp>(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, - /*scalableVecDims=*/{}, vectorizeNDExtract); + /*scalableVecDims=*/{}, vectorizeNDExtract, + flatten1DDepthwiseConv); } private: /// Controls whether to vectorize `tensor.extract` when the input tensor is /// rank >= 2. bool vectorizeNDExtract = false; + /// Controls whether to "flatten" the channel dimension when vectorising 1D + /// depthwise convolutions. This should lead to bette vectorization for + /// tensors with a low number of channel dimensions. + bool flatten1DDepthwiseConv = false; }; } // namespace @@ -3000,7 +3013,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract()); + patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(), + getFlatten_1dDepthwiseConv()); if (!getDisableTransferPermutationMapLoweringPatterns()) vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index f9a53a8..c21d007 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -44,8 +44,9 @@ using namespace mlir::linalg; #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") /// Try to vectorize `convOp` as a convolution. -static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter, - LinalgOp convOp); +static FailureOr<Operation *> +vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, + bool flatten1DDepthwiseConv = false); /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. @@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes, ArrayRef<bool> inputScalableVecDims, - bool vectorizeNDExtract) { + bool vectorizeNDExtract, + bool flatten1DDepthwiseConv) { LDBG("Attempting to vectorize:\n" << *op << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); @@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, // TODO: isaConvolutionOpInterface that can also infer from generic // features. Will require stride/dilation attributes inference. if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { - FailureOr<Operation *> convOr = - vectorizeConvolution(rewriter, linalgOp); + FailureOr<Operation *> convOr = vectorizeConvolution( + rewriter, linalgOp, flatten1DDepthwiseConv); if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); return success(); @@ -2822,7 +2824,7 @@ struct Conv1DGenerator /// kw is always unrolled. /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is /// > 1. - FailureOr<Operation *> depthwiseConv() { + FailureOr<Operation *> depthwiseConv(bool flatten) { if (!valid) return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv"); @@ -2869,6 +2871,9 @@ struct Conv1DGenerator //===------------------------------------------------------------------===// // Unroll along kw and read slices of lhs and rhs. SmallVector<Value> lhsVals, rhsVals, resVals; + auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize}; + auto inOutStrides = SmallVector<int64_t>{1, 1, 1}; + // Extract lhs slice of size {n, wSizeStep, c} // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { @@ -2876,8 +2881,7 @@ struct Conv1DGenerator lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( loc, lhs, /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, - /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef<int64_t>{1, 1, 1})); + inOutSliceSizes, inOutStrides)); } } // Extract rhs slice of size {c} @ [kw]. @@ -2889,21 +2893,39 @@ struct Conv1DGenerator for (int64_t w = 0; w < wSize; w += wSizeStep) { resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( loc, res, - /*offsets=*/ArrayRef<int64_t>{0, w, 0}, - /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize}, - /*strides=*/ArrayRef<int64_t>{1, 1, 1})); + /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes, + inOutStrides)); } auto linearIndex = [&](int64_t kw, int64_t w) { return kw * (wSize / wSizeStep) + w; }; + auto inOutFlattenSliceSizes = + SmallVector<int64_t>{nSize, wSizeStep * cSize}; + auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType); + auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType); // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, - lhsVals[linearIndex(kw, w)], - rhsVals[kw], resVals[w]); + Value lhsVal = lhsVals[linearIndex(kw, w)]; + Value resVal = resVals[w]; + ShapedType filterBCastTy = cast<ShapedType>(resVal.getType()); + if (flatten) { + // Flatten the input and filter vectors (collapse the channel + // dimension) + lhsVal = rewriter.create<vector::ShapeCastOp>( + loc, lhsCastType, lhsVals[linearIndex(kw, w)]); + resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType, + resVals[w]); + } + resVals[w] = depthwiseConv1dSliceAsMulAcc( + rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten); + if (flatten) { + // Un-flatten the output vector (restore the channel dimension) + resVals[w] = rewriter.create<vector::ShapeCastOp>( + loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); + } } } @@ -2936,9 +2958,13 @@ struct Conv1DGenerator .getOperation(); } - /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc + /// Lower: + /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false) + /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true) + /// to MulAcc. Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, - Value lhs, Value rhs, Value res) { + Value lhs, Value rhs, Value res, + ShapedType bcastTy, bool flatten) { auto rhsTy = cast<ShapedType>(rhs.getType()); auto resTy = cast<ShapedType>(res.getType()); @@ -2946,7 +2972,13 @@ struct Conv1DGenerator lhs = promote(rewriter, loc, lhs, resTy); rhs = rewriter.create<vector::BroadcastOp>( - loc, resTy.clone(rhsTy.getElementType()), rhs); + loc, bcastTy.clone(rhsTy.getElementType()), rhs); + if (flatten) { + // Flatten the channel dimension + rhs = rewriter.create<vector::ShapeCastOp>( + loc, resTy.clone(rhsTy.getElementType()), rhs); + } + rhs = promote(rewriter, loc, rhs, resTy); if (!lhs || !rhs) @@ -3049,7 +3081,7 @@ struct Conv1DGenerator /// Entry point that transposes into the common form: /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} - FailureOr<Operation *> generateDilatedConv() { + FailureOr<Operation *> generateDilatedConv(bool flatten = false) { AffineExpr n, w, c, kw; bindDims(ctx, n, w, c, kw); if (!iters({Par(), Par(), Par(), Red()})) @@ -3060,7 +3092,7 @@ struct Conv1DGenerator if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, /*rhsIndex*/ {kw, c}, /*resIndex*/ {n, w, c}})) - return depthwiseConv(); + return depthwiseConv(flatten); return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout"); } @@ -3125,8 +3157,9 @@ private: /// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter, - LinalgOp op) { +static FailureOr<Operation *> +vectorizeConvolution(RewriterBase &rewriter, LinalgOp op, + bool flatten1DDepthwiseConv) { // The ConvolutionOpInterface gives us guarantees of existence for // strides/dilations. However, we do not need to rely on those, we can simply // use them if present, otherwise use the default and let the generic conv. @@ -3151,7 +3184,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter, res = e.generateNcwPooling(); if (succeeded(res)) return res; - return e.generateDilatedConv(); + return e.generateDilatedConv(flatten1DDepthwiseConv); } struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> { diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir new file mode 100644 index 0000000..a242d09 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir @@ -0,0 +1,309 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s + +func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>, + %filter: tensor<1x3xi8>, + %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) { + %res = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<1> : vector<1xi64>, + strides = dense<1> : vector<1xi64>} + ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>) + outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8> + return %res : tensor<1x8x3xi8> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} +// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor +// CHECK-SAME: %[[INPUT:.*]]: tensor<1x8x3xi8>, +// CHECK-SAME: %[[FILTER:.*]]: tensor<1x3xi8>, +// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> { + +// CHECK-DAG: %[[C0_IDX:.*]] = arith.constant 0 : index + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] +// CHECK: %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]] +// CHECK: %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] + +// CHECK: %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<3xi8> from vector<1x3xi8> + +/// w == 0, kw = 0 +// CHECK: %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8> +// CHECK: %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8> +// CHECK: %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8> +// CHECK: %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8> +// CHECK: %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8> +// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8> + +// Write the result back in one shot. +// CHECK: %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8> +// CHECK: vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] + +//------ + +func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>, + %filter: memref<2x4xf32>, + %output: memref<3x2x4xf32>) { + linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) + outs(%output : memref<3x2x4xf32>) + return +} + +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2 +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32> + + +/// w == 0, kw = 0 +// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32> +// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32> +// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32> +// CHECK: %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32> + +/// w == 0, kw = 1 +// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32> +// CHECK: %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32> +// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32> +// CHECK: %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32> + +// Write the result back in one shot. +// CHECK: %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32> +// CHECK: vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5x4xi8>, + %filter: memref<2x4xi8>, + %output: memref<3x2x4xi32>) { + linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>) + outs(%output : memref<3x2x4xi32>) + return +} + +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2 +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + +/// Read the whole data in one shot. +// CHECK-DAG: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] +// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<2x4xi8> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<2x4xi8> + +/// w == 0, kw = 0 +// CHECK: %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8> +// CHECK: %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32> +// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8> +// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32> +// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32> +// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32> + +/// w == 0, kw = 1 +// CHECK: %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8> +// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8> +// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32> +// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32> +// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32> + +// Write the result back in one shot. +// CHECK: %[[SC_ADD_1:.*]] = vector.shape_cast %[[ADD_1]] : vector<3x8xi32> to vector<3x2x4xi32> +// CHECK: vector.transfer_write %[[SC_ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4xi8>, + %filter: tensor<3x4xi8>, + %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> { + %res = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins(%input, %filter : tensor<3x9x4xi8>, tensor<3x4xi8>) + outs(%output : tensor<3x3x4xi8>) -> tensor<3x3x4xi8> + return %res : tensor<3x3x4xi8> +} +// CHECK-LABEL: func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2 +// CHECK-SAME: %[[INPUT:.*]]: tensor<3x9x4xi8>, +// CHECK-SAME: %[[FILTER:.*]]: tensor<3x4xi8>, +// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> { + +// CHECK-DAG: %[[C0_IDX:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]] +// CHECK: %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]] +// CHECK: %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]] + +// CHECK: %[[V_INPUT_0:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_1:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_2:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_3:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_4:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 3, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_5:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 5, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_6:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_7:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_INPUT_8:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 6, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8> + +// CHECK: %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<3x4xi8> +// CHECK: %[[V_FILTER_1:.*]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<3x4xi8> +// CHECK: %[[V_FILTER_2:.*]] = vector.extract %[[V_FILTER_R]][2] : vector<4xi8> from vector<3x4xi8> + +// CHECK: %[[V_OUTPUT_0:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_OUTPUT_1:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8> +// CHECK: %[[V_OUTPUT_2:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8> + +/// w == 0, kw == 0 +// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8> +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8> + +/// w == 1, kw == 0 +// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8> +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8> + +/// w == 2, kw == 0 +// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8> +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8> + +/// w == 3, kw == 1 +// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8> +// CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8> + +/// w == 4, kw == 1 +// CHECK: %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8> +// CHECK: %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8> + +/// w == 5, kw == 1 +// CHECK: %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8> +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8> + +/// w == 6, kw == 2 +// CHECK: %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8> +// CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8> + +/// w == 7, kw == 2 +// CHECK: %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8> +// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8> + +/// w == 8, kw == 2 +// CHECK: %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8> +// CHECK: %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8> +// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8> + +// Write the result back. +// CHECK: %[[VAL_73:.*]] = vector.shape_cast %[[VAL_72]] : vector<3x4xi8> to vector<3x1x4xi8> +// CHECK: %[[VAL_74:.*]] = vector.insert_strided_slice %[[VAL_61]], %[[V_OUTPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8> +// CHECK: %[[VAL_75:.*]] = vector.insert_strided_slice %[[VAL_67]], %[[VAL_74]] +// CHECK-SAME: {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8> +// CHECK: %[[VAL_76:.*]] = vector.insert_strided_slice %[[VAL_73]], %[[VAL_75]] +// CHECK-SAME: {offsets = [0, 2, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8> +// CHECK: %[[VAL_77:.*]] = vector.transfer_write %[[VAL_76]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + |