aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2025-02-19 11:15:44 +0100
committerMatthias Springer <mspringer@nvidia.com>2025-02-19 11:15:44 +0100
commit0cf2efbc18adf32edc19193a98cdec1c1b401c5f (patch)
tree7828f5cb9246d49ff0d49a7e9dea11bcbcc6d888
parent6812fc02fbb81d679f95d5c3e15768ae11e1bad8 (diff)
downloadllvm-users/matthias-springer/fix_scf_for_parser.zip
llvm-users/matthias-springer/fix_scf_for_parser.tar.gz
llvm-users/matthias-springer/fix_scf_for_parser.tar.bz2
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp21
1 files changed, 13 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4481417..1f70ad5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
else if (parser.parseType(type))
return failure();
- // Resolve input operands.
+ // Set block argument types, so that they are known when parsing the region.
regionArgs.front().type = type;
+ for (auto [iterArg, type] :
+ llvm::zip(llvm::drop_begin(regionArgs), result.types))
+ iterArg.type = type;
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ // Resolve input operands. This should be done after parsing the region to
+ // catch invalid IR where operands were defined inside of the region.
if (parser.resolveOperand(lb, type, result.operands) ||
parser.resolveOperand(ub, type, result.operands) ||
parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
}
}
- // Parse the body region.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
-
- ForOp::ensureTerminator(*body, builder, result.location);
-
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();