diff options
4 files changed, 72 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index abd0243..e6a1b07 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -754,20 +754,19 @@ private: /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D /// convolution ops. +template <typename Conv2DOp, typename Conv1DOp> struct DownscaleSizeOneWindowed2DConvolution final - : public OpRewritePattern<Conv2DNhwcHwcfOp> { + : public OpRewritePattern<Conv2DOp> { DownscaleSizeOneWindowed2DConvolution( MLIRContext *context, LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) - : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), - filter(std::move(f)) {} + : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {} - FailureOr<Conv1DNwcWcfOp> - returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const; + FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp, + PatternRewriter &rewriter) const; - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + LogicalResult matchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 80e4d31..f4241f4 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -76,10 +76,18 @@ DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, SmallVectorImpl<Operation *> &results, transform::TransformState &state) { - FailureOr<LinalgOp> windowed = - tryApply<DownscaleSizeOneWindowed2DConvolution>(target); - if (succeeded(windowed)) { - results.push_back(*windowed); + FailureOr<LinalgOp> windowedNhwc = + tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, + Conv1DNwcWcfOp>>(target); + if (succeeded(windowedNhwc)) { + results.push_back(*windowedNhwc); + return DiagnosedSilenceableFailure(success()); + } + FailureOr<LinalgOp> windowedNchw = + tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, + Conv1DNcwFcwOp>>(target); + if (succeeded(windowedNchw)) { + results.push_back(*windowedNchw); return DiagnosedSilenceableFailure(success()); } FailureOr<LinalgOp> depthwise = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 2fcbe68..1c4ceaa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -828,9 +828,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. -FailureOr<Conv1DNwcWcfOp> -DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( - linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { +template <typename Conv2DOp, typename Conv1DOp> +FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: + returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, convOp))) return failure(); if (convOp.hasBufferSemantics()) @@ -847,10 +847,30 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( auto kernelShape = kernelType.getShape(); auto outputShape = outputType.getShape(); + // Get domain indices based on conv2D layout. + int khIndex, kwIndex, ohIndex, owIndex; + + TypeSwitch<Operation *>(convOp) + .Case([&](linalg::Conv2DNhwcHwcfOp op) { + khIndex = 0; + kwIndex = 1; + ohIndex = 1; + owIndex = 2; + }) + .Case([&](linalg::Conv2DNchwFchwOp op) { + khIndex = 2; + kwIndex = 3; + ohIndex = 2; + owIndex = 3; + }) + .Default([&](Operation *op) { + llvm_unreachable("unexpected conv2d operation."); + }); + // Only handle the case where at least one of the window dimensions is // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; - int64_t ohSize = outputShape[1], owSize = outputShape[2]; + int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; + int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; bool removeH = (khSize == 1 && ohSize == 1); bool removeW = (kwSize == 1 && owSize == 1); if (!removeH && !removeW) @@ -860,11 +880,11 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( // dimension. using RTTBuilder = RankedTensorType::Builder; RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); // Rank-reduce operands. Location loc = convOp.getLoc(); @@ -877,16 +897,17 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( // Rank-reduce strides and dilations too. // TODO: dropDim 1-liner helper. - auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); + auto strides = + llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); strides.erase(strides.begin() + (removeH ? 0 : 1)); auto stridesAttr = rewriter.getI64VectorAttr(strides); auto dilations = - llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); + llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>( + auto conv1DOp = rewriter.create<Conv1DOp>( loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); @@ -973,7 +994,10 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( void linalg::populateDecomposeConvolutionPatterns( RewritePatternSet &patterns, const LinalgTransformationFilter &filter, PatternBenefit benefit) { - patterns.add<DownscaleSizeOneWindowed2DConvolution, + patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, + Conv1DNwcWcfOp>, + DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, + Conv1DNcwFcwOp>, DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, benefit); } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir index 988d706..7348289 100644 --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -18,6 +18,24 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x return %0 : tensor<?x1x?x?xf32> } +// CHECK-LABEL: @conv_2d_nchw_fchw +// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>, +// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>) +func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) + outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> + // CHECK: return %[[RES]] + return %0 : tensor<?x?x1x?xf32> +} + // CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32> // CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32> |