aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp93
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir32
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir30
3 files changed, 105 insertions, 50 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1f50f23..a9ca0fb 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2744,6 +2744,7 @@ LogicalResult WinogradFilterTransformOp::verify() {
int64_t filterH = filterShape[1];
int64_t filterW = filterShape[2];
int64_t r = getR();
+ int64_t m = getM();
if (filterH != r && filterH != 1)
return emitOpError("expect filter height either equals to r or 1");
@@ -2752,6 +2753,17 @@ LogicalResult WinogradFilterTransformOp::verify() {
if (filterH == 1 && filterW == 1)
return emitOpError("expect either filter height or width equals to r");
+ SmallVector<int64_t> expectedOutputShape;
+ expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
+ expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
+ expectedOutputShape.push_back(filterShape[3]);
+ expectedOutputShape.push_back(filterShape[0]);
+
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
+ return emitOpError("the output shape is not expected");
+ }
return success();
}
@@ -2764,40 +2776,35 @@ LogicalResult WinogradInputTransformOp::verify() {
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputH = inputShape[1];
int64_t inputW = inputShape[2];
- auto outputType = cast<ShapedType>(getOutput().getType());
- ArrayRef<int64_t> outputShape = outputType.getShape();
- int64_t outputH = outputShape[0];
- int64_t outputW = outputShape[1];
- int64_t outputTileH = outputShape[2];
- int64_t outputTileW = outputShape[3];
int m = getM();
int r = getR();
+ int64_t tileSize = m + r - 1;
bool leftTransform = inputH != 1;
bool rightTransform = inputW != 1;
- if (!leftTransform && !rightTransform)
- return failure();
-
- if (leftTransform) {
- int64_t tileH = (inputH - (r - 1)) / m;
- if (inputH != tileH * m + (r - 1))
- return emitOpError("input height cannot be tiled in full tile size");
- if (tileH != outputTileH)
- return emitOpError("number of output height tiles is not correct");
- if (outputH != m + r - 1)
- return emitOpError("expect output height equals to tile size");
+ SmallVector<int64_t> expectedOutputShape(6, inputH);
+ if (ShapedType::isDynamic(inputH)) {
+ expectedOutputShape[0] = tileSize;
+ expectedOutputShape[2] = -1;
+ } else {
+ expectedOutputShape[0] = leftTransform ? tileSize : 1;
+ expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
}
-
- if (rightTransform) {
- int64_t tileW = (inputW - (r - 1)) / m;
- if (inputW != tileW * m + (r - 1))
- return emitOpError("input width cannot be tiled in full tile size");
- if (tileW != outputTileW)
- return emitOpError("number of output width tiles is not correct");
- if (outputW != m + r - 1)
- return emitOpError("expect output width equals to tile size");
+ if (ShapedType::isDynamic(inputW)) {
+ expectedOutputShape[1] = tileSize;
+ expectedOutputShape[3] = -1;
+ } else {
+ expectedOutputShape[1] = rightTransform ? tileSize : 1;
+ expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
}
+ expectedOutputShape[4] = inputShape[0];
+ expectedOutputShape[5] = inputShape[3];
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
+ return emitOpError("the output shape is not expected");
+ }
return success();
}
@@ -2812,32 +2819,34 @@ LogicalResult WinogradOutputTransformOp::verify() {
int64_t valueW = valueShape[1];
int64_t valueTileH = valueShape[2];
int64_t valueTileW = valueShape[3];
- auto outputType = cast<ShapedType>(getOutput().getType());
- ArrayRef<int64_t> outputShape = outputType.getShape();
- int64_t outputH = outputShape[1];
- int64_t outputW = outputShape[2];
int m = getM();
int r = getR();
bool leftTransform = valueH != 1;
bool rightTransform = valueW != 1;
- if (!leftTransform && !rightTransform)
- return failure();
-
- if (leftTransform) {
- if (valueH != m + r - 1)
+ SmallVector<int64_t> expectedOutputShape(4, valueH);
+ if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
+ expectedOutputShape[1] = -1;
+ } else {
+ if (valueH != (leftTransform ? m + r - 1 : 1))
return emitOpError("expect input height equals to input tile size");
- if (outputH != m * valueTileH)
- return emitOpError("expect output height aligned to output tile size");
+ expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
}
-
- if (rightTransform) {
- if (valueW != m + r - 1)
+ if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
+ expectedOutputShape[2] = -1;
+ } else {
+ if (valueW != (rightTransform ? m + r - 1 : 1))
return emitOpError("expect input width equals to input tile size");
- if (outputW != m * valueTileW)
- return emitOpError("expect output width aligned to output tile size");
+ expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
}
+ expectedOutputShape[0] = valueShape[4];
+ expectedOutputShape[3] = valueShape[5];
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
+ return emitOpError("the output shape is not expected");
+ }
return success();
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e54060d..b0cf274 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -806,8 +806,16 @@ func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6
// -----
+func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> {
+ // expected-error @+1 {{the output shape is not expected}}
+ %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32>
+ return %0 : tensor<6x5x?x?xf32>
+}
+
+// -----
+
func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
- // expected-error @+1 {{input height cannot be tiled in full tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
return %0 : tensor<6x6x3x3x2x5xf32>
}
@@ -815,7 +823,7 @@ func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1:
// -----
func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> {
- // expected-error @+1 {{input width cannot be tiled in full tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32>
return %0 : tensor<6x6x3x3x2x5xf32>
}
@@ -823,7 +831,7 @@ func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: t
// -----
func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> {
- // expected-error @+1 {{number of output height tiles is not correct}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32>
return %0 : tensor<6x6x2x3x2x5xf32>
}
@@ -831,7 +839,7 @@ func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %
// -----
func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> {
- // expected-error @+1 {{number of output width tiles is not correct}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32>
return %0 : tensor<6x6x3x2x2x5xf32>
}
@@ -839,7 +847,7 @@ func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %
// -----
func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> {
- // expected-error @+1 {{expect output height equals to tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32>
return %0 : tensor<5x6x3x3x2x5xf32>
}
@@ -847,13 +855,21 @@ func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>,
// -----
func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> {
- // expected-error @+1 {{expect output width equals to tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32>
return %0 : tensor<6x5x3x3x2x5xf32>
}
// -----
+func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> {
+ // expected-error @+1 {{the output shape is not expected}}
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32>
+ return %0 : tensor<6x5x?x?x?x?xf32>
+}
+
+// -----
+
func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> {
// expected-error @+1 {{expect input height equals to input tile size}}
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
@@ -871,7 +887,7 @@ func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>,
// -----
func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> {
- // expected-error @+1 {{expect output height aligned to output tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32>
return %0 : tensor<2x11x12x2xf32>
}
@@ -879,7 +895,7 @@ func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32
// -----
func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> {
- // expected-error @+1 {{expect output width aligned to output tile size}}
+ // expected-error @+1 {{the output shape is not expected}}
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
return %0 : tensor<2x12x11x2xf32>
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 49fbe13..146e978 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -634,3 +634,33 @@ func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
// CHECK: linalg.winograd_filter_transform m(4) r(3)
// CHECK: linalg.winograd_input_transform m(4) r(3)
// CHECK: linalg.winograd_output_transform m(4) r(3)
+
+// -----
+
+func.func @winograd_filter_dyn(%arg0: tensor<?x3x3x?xf32>, %arg1: tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> {
+ %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+ return %0 : tensor<6x6x?x?xf32>
+}
+
+// CHECK-LABEL: func @winograd_filter_dyn
+// CHECK: linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<?x3x3x?xf32>) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32>
+
+// -----
+
+func.func @winograd_input_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> {
+ %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+ return %0 : tensor<6x6x?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func @winograd_input_dyn
+// CHECK: linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<?x?x?x?xf32>) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32>
+
+// -----
+
+func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+
+// CHECK-LABEL: func @winograd_output_dyn
+// CHECK: linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>