diff options
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 93 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/invalid.mlir | 32 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/roundtrip.mlir | 30 |
3 files changed, 105 insertions, 50 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 1f50f23..a9ca0fb 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2744,6 +2744,7 @@ LogicalResult WinogradFilterTransformOp::verify() { int64_t filterH = filterShape[1]; int64_t filterW = filterShape[2]; int64_t r = getR(); + int64_t m = getM(); if (filterH != r && filterH != 1) return emitOpError("expect filter height either equals to r or 1"); @@ -2752,6 +2753,17 @@ LogicalResult WinogradFilterTransformOp::verify() { if (filterH == 1 && filterW == 1) return emitOpError("expect either filter height or width equals to r"); + SmallVector<int64_t> expectedOutputShape; + expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1); + expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1); + expectedOutputShape.push_back(filterShape[3]); + expectedOutputShape.push_back(filterShape[0]); + + auto outputType = cast<ShapedType>(getOutput().getType()); + ArrayRef<int64_t> outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } return success(); } @@ -2764,40 +2776,35 @@ LogicalResult WinogradInputTransformOp::verify() { ArrayRef<int64_t> inputShape = inputType.getShape(); int64_t inputH = inputShape[1]; int64_t inputW = inputShape[2]; - auto outputType = cast<ShapedType>(getOutput().getType()); - ArrayRef<int64_t> outputShape = outputType.getShape(); - int64_t outputH = outputShape[0]; - int64_t outputW = outputShape[1]; - int64_t outputTileH = outputShape[2]; - int64_t outputTileW = outputShape[3]; int m = getM(); int r = getR(); + int64_t tileSize = m + r - 1; bool leftTransform = inputH != 1; bool rightTransform = inputW != 1; - if (!leftTransform && !rightTransform) - return failure(); - - if (leftTransform) { - int64_t tileH = (inputH - (r - 1)) / m; - if (inputH != tileH * m + (r - 1)) - return emitOpError("input height cannot be tiled in full tile size"); - if (tileH != outputTileH) - return emitOpError("number of output height tiles is not correct"); - if (outputH != m + r - 1) - return emitOpError("expect output height equals to tile size"); + SmallVector<int64_t> expectedOutputShape(6, inputH); + if (ShapedType::isDynamic(inputH)) { + expectedOutputShape[0] = tileSize; + expectedOutputShape[2] = -1; + } else { + expectedOutputShape[0] = leftTransform ? tileSize : 1; + expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1; } - - if (rightTransform) { - int64_t tileW = (inputW - (r - 1)) / m; - if (inputW != tileW * m + (r - 1)) - return emitOpError("input width cannot be tiled in full tile size"); - if (tileW != outputTileW) - return emitOpError("number of output width tiles is not correct"); - if (outputW != m + r - 1) - return emitOpError("expect output width equals to tile size"); + if (ShapedType::isDynamic(inputW)) { + expectedOutputShape[1] = tileSize; + expectedOutputShape[3] = -1; + } else { + expectedOutputShape[1] = rightTransform ? tileSize : 1; + expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1; } + expectedOutputShape[4] = inputShape[0]; + expectedOutputShape[5] = inputShape[3]; + auto outputType = cast<ShapedType>(getOutput().getType()); + ArrayRef<int64_t> outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } return success(); } @@ -2812,32 +2819,34 @@ LogicalResult WinogradOutputTransformOp::verify() { int64_t valueW = valueShape[1]; int64_t valueTileH = valueShape[2]; int64_t valueTileW = valueShape[3]; - auto outputType = cast<ShapedType>(getOutput().getType()); - ArrayRef<int64_t> outputShape = outputType.getShape(); - int64_t outputH = outputShape[1]; - int64_t outputW = outputShape[2]; int m = getM(); int r = getR(); bool leftTransform = valueH != 1; bool rightTransform = valueW != 1; - if (!leftTransform && !rightTransform) - return failure(); - - if (leftTransform) { - if (valueH != m + r - 1) + SmallVector<int64_t> expectedOutputShape(4, valueH); + if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) { + expectedOutputShape[1] = -1; + } else { + if (valueH != (leftTransform ? m + r - 1 : 1)) return emitOpError("expect input height equals to input tile size"); - if (outputH != m * valueTileH) - return emitOpError("expect output height aligned to output tile size"); + expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH; } - - if (rightTransform) { - if (valueW != m + r - 1) + if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) { + expectedOutputShape[2] = -1; + } else { + if (valueW != (rightTransform ? m + r - 1 : 1)) return emitOpError("expect input width equals to input tile size"); - if (outputW != m * valueTileW) - return emitOpError("expect output width aligned to output tile size"); + expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW; } + expectedOutputShape[0] = valueShape[4]; + expectedOutputShape[3] = valueShape[5]; + auto outputType = cast<ShapedType>(getOutput().getType()); + ArrayRef<int64_t> outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index e54060d..b0cf274 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -806,8 +806,16 @@ func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6 // ----- +func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> + return %0 : tensor<6x5x?x?xf32> +} + +// ----- + func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> { - // expected-error @+1 {{input height cannot be tiled in full tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> return %0 : tensor<6x6x3x3x2x5xf32> } @@ -815,7 +823,7 @@ func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: // ----- func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> { - // expected-error @+1 {{input width cannot be tiled in full tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> return %0 : tensor<6x6x3x3x2x5xf32> } @@ -823,7 +831,7 @@ func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: t // ----- func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> { - // expected-error @+1 {{number of output height tiles is not correct}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> return %0 : tensor<6x6x2x3x2x5xf32> } @@ -831,7 +839,7 @@ func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, % // ----- func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> { - // expected-error @+1 {{number of output width tiles is not correct}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> return %0 : tensor<6x6x3x2x2x5xf32> } @@ -839,7 +847,7 @@ func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, % // ----- func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> { - // expected-error @+1 {{expect output height equals to tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> return %0 : tensor<5x6x3x3x2x5xf32> } @@ -847,13 +855,21 @@ func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, // ----- func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> { - // expected-error @+1 {{expect output width equals to tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> return %0 : tensor<6x5x3x3x2x5xf32> } // ----- +func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> + return %0 : tensor<6x5x?x?x?x?xf32> +} + +// ----- + func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> { // expected-error @+1 {{expect input height equals to input tile size}} %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> @@ -871,7 +887,7 @@ func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>, // ----- func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> { - // expected-error @+1 {{expect output height aligned to output tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> return %0 : tensor<2x11x12x2xf32> } @@ -879,7 +895,7 @@ func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32 // ----- func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> { - // expected-error @+1 {{expect output width aligned to output tile size}} + // expected-error @+1 {{the output shape is not expected}} %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> return %0 : tensor<2x12x11x2xf32> } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 49fbe13..146e978 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -634,3 +634,33 @@ func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg // CHECK: linalg.winograd_filter_transform m(4) r(3) // CHECK: linalg.winograd_input_transform m(4) r(3) // CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> { + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> + return %0 : tensor<6x6x?x?xf32> +} + +// CHECK-LABEL: func @winograd_filter_dyn +// CHECK: linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> + +// ----- + +func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> + return %0 : tensor<6x6x?x?x?x?xf32> +} + +// CHECK-LABEL: func @winograd_input_dyn +// CHECK: linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> + +// ----- + +func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> { + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> + return %0 : tensor<?x?x?x?xf32> +} + +// CHECK-LABEL: func @winograd_output_dyn +// CHECK: linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> |