aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp11
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir75
2 files changed, 83 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d1266..0754bd9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1356,8 +1356,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
if (payloadOpName.has_value()) {
- addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ if (!result.operands.empty())
+ addBodyWithPayloadOp(parser, result, payloadOpName.value(),
+ payloadOpAttrs,
+ ArrayRef(result.operands).drop_back());
+ else
+ result.addRegion();
} else {
SmallVector<OpAsmParser::Argument> regionArgs;
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1739,7 +1743,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
ValueRange outputs) {
buildGenericRegion(builder, loc, region, inputs, outputs,
[](OpBuilder &b, Location loc, ValueRange args) {
- b.create<linalg::YieldOp>(loc, args[0]);
+ if (!args.empty())
+ b.create<linalg::YieldOp>(loc, args[0]);
});
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 44c81c3..213ef6c 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -455,6 +455,32 @@ func.func @map_input_output_shape_mismatch(
// -----
+func.func @map_no_operands1() {
+ // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
+ linalg.map { arith.addf }
+}
+
+// -----
+
+func.func @map_no_operands2() {
+ // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
+ "linalg.map"() ({
+ ^bb0:
+ }) : () -> ()
+}
+
+// -----
+
+func.func @map_no_operands3(
+ %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+ -> tensor<64xf32> {
+ // expected-error @+1 {{cannot name an operation with no results}}
+ %add = linalg.map { arith.addf }
+ func.return %add : tensor<64xf32>
+}
+
+// -----
+
func.func @reduce_input_vs_init_dimension_mismatch(
%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
@@ -676,6 +702,30 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
// -----
+func.func @transpose_no_operands1() {
+ // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
+ linalg.transpose permutation = [1, 0, 2]
+}
+
+// -----
+
+func.func @transpose_no_operands2() {
+ // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
+ "linalg.transpose"() <{permutation = array<i64: 1, 0, 2>}> ({
+ ^bb0:
+ }) : () -> ()
+}
+
+// -----
+
+func.func @transpose_no_operands3() -> tensor<32x64x16xf32> {
+ // expected-error @+1 {{cannot name an operation with no results}}
+ %transpose = linalg.transpose permutation = [1, 0, 2]
+ func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
func.func @broadcast_input_dims_rank_mismatch(
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
-> tensor<4x8x16xf32> {
@@ -728,6 +778,31 @@ func.func @broadcast_size_1_extension_not_supported(
// -----
+func.func @broadcast_no_operands1() {
+ // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
+ linalg.broadcast dimensions = [1]
+}
+
+// -----
+
+func.func @broadcast_no_operands2() {
+ // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
+ "linalg.broadcast"() <{dimensions = array<i64: 1>}> ({
+ ^bb0:
+ }) : () -> ()
+}
+
+// -----
+
+func.func @broadcast_no_operands3()
+ -> tensor<4x?x16xf32> {
+ // expected-error @+1 {{cannot name an operation with no results}}
+ %broadcast = linalg.broadcast dimensions = [1]
+ func.return %broadcast : tensor<32x64x16xf32>
+}
+
+// -----
+
func.func @missing_iterator_types() {
// expected-error @below {{expected "iterator_types" array attribute}}
linalg.generic {} ins() outs()