aboutsummaryrefslogtreecommitdiff
path: root/clang/lib/CIR/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/CIR/Dialect')
-rw-r--r--clang/lib/CIR/Dialect/IR/CIRDialect.cpp25
-rw-r--r--clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp125
2 files changed, 140 insertions, 10 deletions
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b4c3704..fa180f5 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1978,13 +1978,19 @@ void cir::TernaryOp::build(
result.addOperands(cond);
OpBuilder::InsertionGuard guard(builder);
Region *trueRegion = result.addRegion();
- Block *block = builder.createBlock(trueRegion);
+ builder.createBlock(trueRegion);
trueBuilder(builder, result.location);
Region *falseRegion = result.addRegion();
builder.createBlock(falseRegion);
falseBuilder(builder, result.location);
- auto yield = dyn_cast<YieldOp>(block->getTerminator());
+ // Get result type from whichever branch has a yield (the other may have
+ // unreachable from a throw expression)
+ auto yield =
+ dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator());
+ if (!yield)
+ yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator());
+
assert((yield && yield.getNumOperands() <= 1) &&
"expected zero or one result type");
if (yield.getNumOperands() == 1)
@@ -2935,6 +2941,21 @@ mlir::LogicalResult cir::ThrowOp::verify() {
}
//===----------------------------------------------------------------------===//
+// AtomicFetchOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::AtomicFetchOp::verify() {
+ if (getBinop() != cir::AtomicFetchKind::Add &&
+ getBinop() != cir::AtomicFetchKind::Sub &&
+ getBinop() != cir::AtomicFetchKind::Max &&
+ getBinop() != cir::AtomicFetchKind::Min &&
+ !mlir::isa<cir::IntType>(getVal().getType()))
+ return emitError("only atomic add, sub, max, and min operation could "
+ "operate on floating-point values");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// TypeInfoAttr
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 26e5c05..46bd186 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -505,10 +505,19 @@ public:
Block *trueBlock = &trueRegion.front();
mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
rewriter.setInsertionPointToEnd(&trueRegion.back());
- auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
- rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
- continueBlock);
+ // Handle both yield and unreachable terminators (throw expressions)
+ if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
+ rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+ continueBlock);
+ } else if (isa<cir::UnreachableOp>(trueTerminator)) {
+ // Terminator is unreachable (e.g., from throw), just keep it
+ } else {
+ trueTerminator->emitError("unexpected terminator in ternary true region, "
+ "expected yield or unreachable, got: ")
+ << trueTerminator->getName();
+ return mlir::failure();
+ }
rewriter.inlineRegionBefore(trueRegion, continueBlock);
Block *falseBlock = continueBlock;
@@ -517,9 +526,19 @@ public:
falseBlock = &falseRegion.front();
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
rewriter.setInsertionPointToEnd(&falseRegion.back());
- auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
- rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
- continueBlock);
+
+ // Handle both yield and unreachable terminators (throw expressions)
+ if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
+ rewriter.replaceOpWithNewOp<cir::BrOp>(
+ falseYieldOp, falseYieldOp.getArgs(), continueBlock);
+ } else if (isa<cir::UnreachableOp>(falseTerminator)) {
+ // Terminator is unreachable (e.g., from throw), just keep it
+ } else {
+ falseTerminator->emitError("unexpected terminator in ternary false "
+ "region, expected yield or unreachable, got: ")
+ << falseTerminator->getName();
+ return mlir::failure();
+ }
rewriter.inlineRegionBefore(falseRegion, continueBlock);
rewriter.setInsertionPointToEnd(condBlock);
@@ -532,10 +551,100 @@ public:
}
};
+class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
+public:
+ using OpRewritePattern<cir::TryOp>::OpRewritePattern;
+
+ mlir::Block *buildTryBody(cir::TryOp tryOp,
+ mlir::PatternRewriter &rewriter) const {
+ // Split the current block before the TryOp to create the inlining
+ // point.
+ mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock();
+ mlir::Block *afterTry =
+ rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint());
+
+ // Inline body region.
+ mlir::Block *beforeBody = &tryOp.getTryRegion().front();
+ rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry);
+
+ // Branch into the body of the region.
+ rewriter.setInsertionPointToEnd(beforeTryScopeBlock);
+ cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody);
+ return afterTry;
+ }
+
+ void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter,
+ mlir::Block *afterBody, mlir::Block *afterTry,
+ SmallVectorImpl<cir::CallOp> &callsToRewrite,
+ SmallVectorImpl<mlir::Block *> &landingPads) const {
+ // Replace the tryOp return with a branch that jumps out of the body.
+ rewriter.setInsertionPointToEnd(afterBody);
+
+ mlir::Block *beforeCatch = rewriter.getInsertionBlock();
+ rewriter.setInsertionPointToEnd(beforeCatch);
+
+ // Check if the terminator is a YieldOp because there could be another
+ // terminator, e.g. unreachable
+ if (auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator()))
+ rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry);
+
+ mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr();
+ if (!handlers || handlers.empty())
+ return;
+
+ llvm_unreachable("TryOpFlattening buildHandlers with CallsOp is NYI");
+ }
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::TryOp tryOp,
+ mlir::PatternRewriter &rewriter) const override {
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ mlir::Block *afterBody = &tryOp.getTryRegion().back();
+
+ // Grab the collection of `cir.call exception`s to rewrite to
+ // `cir.try_call`.
+ llvm::SmallVector<cir::CallOp, 4> callsToRewrite;
+ tryOp.getTryRegion().walk([&](CallOp op) {
+ // Only grab calls within immediate closest TryOp scope.
+ if (op->getParentOfType<cir::TryOp>() != tryOp)
+ return;
+ assert(!cir::MissingFeatures::opCallExceptionAttr());
+ callsToRewrite.push_back(op);
+ });
+
+ if (!callsToRewrite.empty())
+ llvm_unreachable(
+ "TryOpFlattening with try block that contains CallOps is NYI");
+
+ // Build try body.
+ mlir::Block *afterTry = buildTryBody(tryOp, rewriter);
+
+ // Build handlers.
+ llvm::SmallVector<mlir::Block *, 4> landingPads;
+ buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite,
+ landingPads);
+
+ rewriter.eraseOp(tryOp);
+
+ assert((landingPads.size() == callsToRewrite.size()) &&
+ "expected matching number of entries");
+
+ // Quick block cleanup: no indirection to the post try block.
+ auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator());
+ if (brOp && brOp.getDest()->hasNoPredecessors()) {
+ mlir::Block *srcBlock = brOp.getDest();
+ rewriter.eraseOp(brOp);
+ rewriter.mergeBlocks(srcBlock, afterTry);
+ }
+
+ return mlir::success();
+ }
+};
+
void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
patterns
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
- CIRSwitchOpFlattening, CIRTernaryOpFlattening>(
+ CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
patterns.getContext());
}
@@ -549,7 +658,7 @@ void CIRFlattenCFGPass::runOnOperation() {
assert(!cir::MissingFeatures::ifOp());
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))
+ if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
ops.push_back(op);
});