diff options
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 11 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/invalid.mlir | 75 |
2 files changed, 83 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d1266..0754bd9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1356,8 +1356,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); if (payloadOpName.has_value()) { - addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, - ArrayRef(result.operands).drop_back()); + if (!result.operands.empty()) + addBodyWithPayloadOp(parser, result, payloadOpName.value(), + payloadOpAttrs, + ArrayRef(result.operands).drop_back()); + else + result.addRegion(); } else { SmallVector<OpAsmParser::Argument> regionArgs; if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, @@ -1739,7 +1743,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc, ValueRange outputs) { buildGenericRegion(builder, loc, region, inputs, outputs, [](OpBuilder &b, Location loc, ValueRange args) { - b.create<linalg::YieldOp>(loc, args[0]); + if (!args.empty()) + b.create<linalg::YieldOp>(loc, args[0]); }); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 44c81c3..213ef6c 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -455,6 +455,32 @@ func.func @map_input_output_shape_mismatch( // ----- +func.func @map_no_operands1() { + // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}} + linalg.map { arith.addf } +} + +// ----- + +func.func @map_no_operands2() { + // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}} + "linalg.map"() ({ + ^bb0: + }) : () -> () +} + +// ----- + +func.func @map_no_operands3( + %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) + -> tensor<64xf32> { + // expected-error @+1 {{cannot name an operation with no results}} + %add = linalg.map { arith.addf } + func.return %add : tensor<64xf32> +} + +// ----- + func.func @reduce_input_vs_init_dimension_mismatch( %input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { @@ -676,6 +702,30 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>, // ----- +func.func @transpose_no_operands1() { + // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}} + linalg.transpose permutation = [1, 0, 2] +} + +// ----- + +func.func @transpose_no_operands2() { + // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}} + "linalg.transpose"() <{permutation = array<i64: 1, 0, 2>}> ({ + ^bb0: + }) : () -> () +} + +// ----- + +func.func @transpose_no_operands3() -> tensor<32x64x16xf32> { + // expected-error @+1 {{cannot name an operation with no results}} + %transpose = linalg.transpose permutation = [1, 0, 2] + func.return %transpose : tensor<32x64x16xf32> +} + +// ----- + func.func @broadcast_input_dims_rank_mismatch( %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { @@ -728,6 +778,31 @@ func.func @broadcast_size_1_extension_not_supported( // ----- +func.func @broadcast_no_operands1() { + // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}} + linalg.broadcast dimensions = [1] +} + +// ----- + +func.func @broadcast_no_operands2() { + // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}} + "linalg.broadcast"() <{dimensions = array<i64: 1>}> ({ + ^bb0: + }) : () -> () +} + +// ----- + +func.func @broadcast_no_operands3() + -> tensor<4x?x16xf32> { + // expected-error @+1 {{cannot name an operation with no results}} + %broadcast = linalg.broadcast dimensions = [1] + func.return %broadcast : tensor<32x64x16xf32> +} + +// ----- + func.func @missing_iterator_types() { // expected-error @below {{expected "iterator_types" array attribute}} linalg.generic {} ins() outs() |
