diff options
Diffstat (limited to 'clang/lib/CIR/CodeGen/CIRGenFunction.h')
| -rw-r--r-- | clang/lib/CIR/CodeGen/CIRGenFunction.h | 81 |
1 files changed, 53 insertions, 28 deletions
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index c3fcd1a6..e5cecaa5 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1103,44 +1103,69 @@ public: // --- private: - // `returnBlock`, `returnLoc`, and all the functions that deal with them - // will change and become more complicated when `switch` statements are - // upstreamed. `case` statements within the `switch` are in the same scope - // but have their own regions. Therefore the LexicalScope will need to - // keep track of multiple return blocks. - mlir::Block *returnBlock = nullptr; - std::optional<mlir::Location> returnLoc; - - // See the comment on `getOrCreateRetBlock`. + // On switches we need one return block per region, since cases don't + // have their own scopes but are distinct regions nonetheless. + + // TODO: This implementation should change once we have support for early + // exits in MLIR structured control flow (llvm-project#161575) + llvm::SmallVector<mlir::Block *> retBlocks; + llvm::DenseMap<mlir::Block *, mlir::Location> retLocs; + llvm::DenseMap<cir::CaseOp, unsigned> retBlockInCaseIndex; + std::optional<unsigned> normalRetBlockIndex; + + // There's usually only one ret block per scope, but this needs to be + // get or create because of potential unreachable return statements, note + // that for those, all source location maps to the first one found. mlir::Block *createRetBlock(CIRGenFunction &cgf, mlir::Location loc) { - assert(returnBlock == nullptr && "only one return block per scope"); - // Create the cleanup block but don't hook it up just yet. + assert((isa_and_nonnull<cir::CaseOp>( + cgf.builder.getBlock()->getParentOp()) || + retBlocks.size() == 0) && + "only switches can hold more than one ret block"); + + // Create the return block but don't hook it up just yet. mlir::OpBuilder::InsertionGuard guard(cgf.builder); - returnBlock = - cgf.builder.createBlock(cgf.builder.getBlock()->getParent()); - updateRetLoc(returnBlock, loc); - return returnBlock; + auto *b = cgf.builder.createBlock(cgf.builder.getBlock()->getParent()); + retBlocks.push_back(b); + updateRetLoc(b, loc); + return b; } cir::ReturnOp emitReturn(mlir::Location loc); void emitImplicitReturn(); public: - mlir::Block *getRetBlock() { return returnBlock; } - mlir::Location getRetLoc(mlir::Block *b) { return *returnLoc; } - void updateRetLoc(mlir::Block *b, mlir::Location loc) { returnLoc = loc; } - - // Create the return block for this scope, or return the existing one. - // This get-or-create logic is necessary to handle multiple return - // statements within the same scope, which can happen if some of them are - // dead code or if there is a `goto` into the middle of the scope. + llvm::ArrayRef<mlir::Block *> getRetBlocks() { return retBlocks; } + mlir::Location getRetLoc(mlir::Block *b) { return retLocs.at(b); } + void updateRetLoc(mlir::Block *b, mlir::Location loc) { + retLocs.insert_or_assign(b, loc); + } + mlir::Block *getOrCreateRetBlock(CIRGenFunction &cgf, mlir::Location loc) { - if (returnBlock == nullptr) { - returnBlock = createRetBlock(cgf, loc); - return returnBlock; + // Check if we're inside a case region + if (auto caseOp = mlir::dyn_cast_if_present<cir::CaseOp>( + cgf.builder.getBlock()->getParentOp())) { + auto iter = retBlockInCaseIndex.find(caseOp); + if (iter != retBlockInCaseIndex.end()) { + // Reuse existing return block + mlir::Block *ret = retBlocks[iter->second]; + updateRetLoc(ret, loc); + return ret; + } + // Create new return block + mlir::Block *ret = createRetBlock(cgf, loc); + retBlockInCaseIndex[caseOp] = retBlocks.size() - 1; + return ret; } - updateRetLoc(returnBlock, loc); - return returnBlock; + + if (normalRetBlockIndex) { + mlir::Block *ret = retBlocks[*normalRetBlockIndex]; + updateRetLoc(ret, loc); + return ret; + } + + mlir::Block *ret = createRetBlock(cgf, loc); + normalRetBlockIndex = retBlocks.size() - 1; + return ret; } mlir::Block *getEntryBlock() { return entryBlock; } |
