diff options
Diffstat (limited to 'clang/lib/CIR/Dialect')
-rw-r--r-- | clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 33 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp | 4 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp | 178 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp | 41 | ||||
-rw-r--r-- | clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp | 37 |
5 files changed, 212 insertions, 81 deletions
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index b4c3704..2d2ef42 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -95,8 +95,8 @@ Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Attribute value, mlir::Type type, mlir::Location loc) { - return builder.create<cir::ConstantOp>(loc, type, - mlir::cast<mlir::TypedAttr>(value)); + return cir::ConstantOp::create(builder, loc, type, + mlir::cast<mlir::TypedAttr>(value)); } //===----------------------------------------------------------------------===// @@ -184,7 +184,7 @@ static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region ®ion, // Terminator was omitted correctly: recreate it. builder.setInsertionPointToEnd(&block); - builder.create<cir::YieldOp>(eLoc); + cir::YieldOp::create(builder, eLoc); return success(); } @@ -977,7 +977,7 @@ void cir::IfOp::print(OpAsmPrinter &p) { /// Default callback for IfOp builders. void cir::buildTerminatedBody(OpBuilder &builder, Location loc) { // add cir.yield to end of the block - builder.create<cir::YieldOp>(loc); + cir::YieldOp::create(builder, loc); } /// Given the region at `index`, or the parent operation if `index` is None, @@ -1978,13 +1978,19 @@ void cir::TernaryOp::build( result.addOperands(cond); OpBuilder::InsertionGuard guard(builder); Region *trueRegion = result.addRegion(); - Block *block = builder.createBlock(trueRegion); + builder.createBlock(trueRegion); trueBuilder(builder, result.location); Region *falseRegion = result.addRegion(); builder.createBlock(falseRegion); falseBuilder(builder, result.location); - auto yield = dyn_cast<YieldOp>(block->getTerminator()); + // Get result type from whichever branch has a yield (the other may have + // unreachable from a throw expression) + auto yield = + dyn_cast_or_null<cir::YieldOp>(trueRegion->back().getTerminator()); + if (!yield) + yield = dyn_cast_or_null<cir::YieldOp>(falseRegion->back().getTerminator()); + assert((yield && yield.getNumOperands() <= 1) && "expected zero or one result type"); if (yield.getNumOperands() == 1) @@ -2935,6 +2941,21 @@ mlir::LogicalResult cir::ThrowOp::verify() { } //===----------------------------------------------------------------------===// +// AtomicFetchOp +//===----------------------------------------------------------------------===// + +LogicalResult cir::AtomicFetchOp::verify() { + if (getBinop() != cir::AtomicFetchKind::Add && + getBinop() != cir::AtomicFetchKind::Sub && + getBinop() != cir::AtomicFetchKind::Max && + getBinop() != cir::AtomicFetchKind::Min && + !mlir::isa<cir::IntType>(getVal().getType())) + return emitError("only atomic add, sub, max, and min operation could " + "operate on floating-point values"); + return success(); +} + +//===----------------------------------------------------------------------===// // TypeInfoAttr //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp b/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp index 7e96ae9..66469e2 100644 --- a/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp @@ -34,8 +34,8 @@ llvm::SmallVector<MemorySlot> cir::AllocaOp::getPromotableSlots() { Value cir::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create<cir::ConstantOp>(getLoc(), - cir::UndefAttr::get(slot.elemType)); + return cir::ConstantOp::create(builder, getLoc(), + cir::UndefAttr::get(slot.elemType)); } void cir::AllocaOp::handleBlockArgument(const MemorySlot &slot, diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 26e5c05..21c96fe 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -100,8 +100,8 @@ struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> { } rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody, - elseBeforeBody); + cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody, + elseBeforeBody); if (!emptyElse) { rewriter.setInsertionPointToEnd(elseAfterBody); @@ -154,7 +154,7 @@ public: // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); assert(!cir::MissingFeatures::stackSaveOp()); - rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody); + cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody); // Replace the scopeop return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -195,26 +195,27 @@ public: cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); - cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( - op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + cir::ConstantOp rangeLength = cir::ConstantOp::create( + rewriter, op.getLoc(), + cir::IntAttr::get(sIntType, upperBound - lowerBound)); - cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( - op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::ConstantOp lowerBoundValue = cir::ConstantOp::create( + rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); cir::BinOp diffValue = - rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, - op.getCondition(), lowerBoundValue); + cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); // Use unsigned comparison to check if the condition is in the range. - cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( - op.getLoc(), uIntType, CastKind::integral, diffValue); - cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( - op.getLoc(), uIntType, CastKind::integral, rangeLength); - - cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( - op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, - uDiffValue, uRangeLength); - rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, - defaultDestination); + cir::CastOp uDiffValue = cir::CastOp::create( + rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = cir::CastOp::create( + rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = cir::CmpOp::create( + rewriter, op.getLoc(), cir::BoolType::get(op.getContext()), + cir::CmpOpKind::le, uDiffValue, uRangeLength); + cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination, + defaultDestination); return resBlock; } @@ -262,7 +263,7 @@ public: rewriteYieldOp(rewriter, switchYield, exitBlock); rewriter.setInsertionPointToEnd(originalBlock); - rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + cir::BrOp::create(rewriter, op.getLoc(), swopBlock); } // Allocate required data structures (disconsider default case in @@ -331,8 +332,8 @@ public: mlir::Block *newBlock = rewriter.splitBlock(oldBlock, nextOp->getIterator()); rewriter.setInsertionPointToEnd(oldBlock); - rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), - newBlock); + cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(), + newBlock); rewriteYieldOp(rewriter, yieldOp, newBlock); } } @@ -346,7 +347,7 @@ public: // Create a branch to the entry of the inlined region. rewriter.setInsertionPointToEnd(oldBlock); - rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); + cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock); } // Remove all cases since we've inlined the regions. @@ -427,7 +428,7 @@ public: // Setup loop entry branch. rewriter.setInsertionPointToEnd(entry); - rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front()); + cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front()); // Branch from condition region to body or exit. auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator()); @@ -499,16 +500,25 @@ public: locs.push_back(loc); Block *continueBlock = rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); - rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + cir::BrOp::create(rewriter, loc, remainingOpsBlock); Region &trueRegion = op.getTrueRegion(); Block *trueBlock = &trueRegion.front(); mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); rewriter.setInsertionPointToEnd(&trueRegion.back()); - auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); - rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), - continueBlock); + // Handle both yield and unreachable terminators (throw expressions) + if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) { + rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), + continueBlock); + } else if (isa<cir::UnreachableOp>(trueTerminator)) { + // Terminator is unreachable (e.g., from throw), just keep it + } else { + trueTerminator->emitError("unexpected terminator in ternary true region, " + "expected yield or unreachable, got: ") + << trueTerminator->getName(); + return mlir::failure(); + } rewriter.inlineRegionBefore(trueRegion, continueBlock); Block *falseBlock = continueBlock; @@ -517,13 +527,23 @@ public: falseBlock = &falseRegion.front(); mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); rewriter.setInsertionPointToEnd(&falseRegion.back()); - auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); - rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), - continueBlock); + + // Handle both yield and unreachable terminators (throw expressions) + if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) { + rewriter.replaceOpWithNewOp<cir::BrOp>( + falseYieldOp, falseYieldOp.getArgs(), continueBlock); + } else if (isa<cir::UnreachableOp>(falseTerminator)) { + // Terminator is unreachable (e.g., from throw), just keep it + } else { + falseTerminator->emitError("unexpected terminator in ternary false " + "region, expected yield or unreachable, got: ") + << falseTerminator->getName(); + return mlir::failure(); + } rewriter.inlineRegionBefore(falseRegion, continueBlock); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); + cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock); rewriter.replaceOp(op, continueBlock->getArguments()); @@ -532,10 +552,100 @@ public: } }; +class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> { +public: + using OpRewritePattern<cir::TryOp>::OpRewritePattern; + + mlir::Block *buildTryBody(cir::TryOp tryOp, + mlir::PatternRewriter &rewriter) const { + // Split the current block before the TryOp to create the inlining + // point. + mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock(); + mlir::Block *afterTry = + rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint()); + + // Inline body region. + mlir::Block *beforeBody = &tryOp.getTryRegion().front(); + rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry); + + // Branch into the body of the region. + rewriter.setInsertionPointToEnd(beforeTryScopeBlock); + cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody); + return afterTry; + } + + void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter, + mlir::Block *afterBody, mlir::Block *afterTry, + SmallVectorImpl<cir::CallOp> &callsToRewrite, + SmallVectorImpl<mlir::Block *> &landingPads) const { + // Replace the tryOp return with a branch that jumps out of the body. + rewriter.setInsertionPointToEnd(afterBody); + + mlir::Block *beforeCatch = rewriter.getInsertionBlock(); + rewriter.setInsertionPointToEnd(beforeCatch); + + // Check if the terminator is a YieldOp because there could be another + // terminator, e.g. unreachable + if (auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) + rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry); + + mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr(); + if (!handlers || handlers.empty()) + return; + + llvm_unreachable("TryOpFlattening buildHandlers with CallsOp is NYI"); + } + + mlir::LogicalResult + matchAndRewrite(cir::TryOp tryOp, + mlir::PatternRewriter &rewriter) const override { + mlir::OpBuilder::InsertionGuard guard(rewriter); + mlir::Block *afterBody = &tryOp.getTryRegion().back(); + + // Grab the collection of `cir.call exception`s to rewrite to + // `cir.try_call`. + llvm::SmallVector<cir::CallOp, 4> callsToRewrite; + tryOp.getTryRegion().walk([&](CallOp op) { + // Only grab calls within immediate closest TryOp scope. + if (op->getParentOfType<cir::TryOp>() != tryOp) + return; + assert(!cir::MissingFeatures::opCallExceptionAttr()); + callsToRewrite.push_back(op); + }); + + if (!callsToRewrite.empty()) + llvm_unreachable( + "TryOpFlattening with try block that contains CallOps is NYI"); + + // Build try body. + mlir::Block *afterTry = buildTryBody(tryOp, rewriter); + + // Build handlers. + llvm::SmallVector<mlir::Block *, 4> landingPads; + buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite, + landingPads); + + rewriter.eraseOp(tryOp); + + assert((landingPads.size() == callsToRewrite.size()) && + "expected matching number of entries"); + + // Quick block cleanup: no indirection to the post try block. + auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator()); + if (brOp && brOp.getDest()->hasNoPredecessors()) { + mlir::Block *srcBlock = brOp.getDest(); + rewriter.eraseOp(brOp); + rewriter.mergeBlocks(srcBlock, afterTry); + } + + return mlir::success(); + } +}; + void populateFlattenCFGPatterns(RewritePatternSet &patterns) { patterns .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening, - CIRSwitchOpFlattening, CIRTernaryOpFlattening>( + CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>( patterns.getContext()); } @@ -549,7 +659,7 @@ void CIRFlattenCFGPass::runOnOperation() { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op)) + if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index d99c362..cba0464 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -155,7 +155,7 @@ cir::FuncOp LoweringPreparePass::buildRuntimeFunction( cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom( mlirModule, StringAttr::get(mlirModule->getContext(), name))); if (!f) { - f = builder.create<cir::FuncOp>(loc, name, type); + f = cir::FuncOp::create(builder, loc, name, type); f.setLinkageAttr( cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); mlir::SymbolTable::setSymbolVisibility( @@ -400,12 +400,12 @@ buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, builder.createYield(loc, result); }; - auto cFabs = builder.create<cir::FAbsOp>(loc, c); - auto dFabs = builder.create<cir::FAbsOp>(loc, d); + auto cFabs = cir::FAbsOp::create(builder, loc, c); + auto dFabs = cir::FAbsOp::create(builder, loc, d); cir::CmpOp cmpResult = builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs); - auto ternary = builder.create<cir::TernaryOp>( - loc, cmpResult, trueBranchBuilder, falseBranchBuilder); + auto ternary = cir::TernaryOp::create(builder, loc, cmpResult, + trueBranchBuilder, falseBranchBuilder); return ternary.getResult(); } @@ -612,18 +612,17 @@ static mlir::Value lowerComplexMul(LoweringPreparePass &pass, mlir::Value resultRealAndImagAreNaN = builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); - return builder - .create<cir::TernaryOp>( - loc, resultRealAndImagAreNaN, - [&](mlir::OpBuilder &, mlir::Location) { - mlir::Value libCallResult = buildComplexBinOpLibCall( - pass, builder, &getComplexMulLibCallName, loc, complexTy, - lhsReal, lhsImag, rhsReal, rhsImag); - builder.createYield(loc, libCallResult); - }, - [&](mlir::OpBuilder &, mlir::Location) { - builder.createYield(loc, algebraicResult); - }) + return cir::TernaryOp::create( + builder, loc, resultRealAndImagAreNaN, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value libCallResult = buildComplexBinOpLibCall( + pass, builder, &getComplexMulLibCallName, loc, complexTy, + lhsReal, lhsImag, rhsReal, rhsImag); + builder.createYield(loc, libCallResult); + }, + [&](mlir::OpBuilder &, mlir::Location) { + builder.createYield(loc, algebraicResult); + }) .getResult(); } @@ -920,15 +919,15 @@ static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, loc, /*condBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { - auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); + auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr); mlir::Type boolTy = cir::BoolType::get(b.getContext()); - auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne, - currentElement, stop); + auto cmp = cir::CmpOp::create(builder, loc, boolTy, cir::CmpOpKind::ne, + currentElement, stop); builder.createCondition(cmp); }, /*bodyBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { - auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); + auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr); cir::CallOp ctorCall; op->walk([&](cir::CallOp c) { ctorCall = c; }); diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp index 11ce2a8..5a067f8 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp @@ -77,10 +77,11 @@ buildDynamicCastAfterNullCheck(cir::CIRBaseBuilderTy &builder, if (op.isRefCast()) { // Emit a cir.if that checks the casted value. mlir::Value castedValueIsNull = builder.createPtrIsNull(castedPtr); - builder.create<cir::IfOp>( - loc, castedValueIsNull, false, [&](mlir::OpBuilder &, mlir::Location) { - buildBadCastCall(builder, loc, castInfo.getBadCastFunc()); - }); + cir::IfOp::create(builder, loc, castedValueIsNull, false, + [&](mlir::OpBuilder &, mlir::Location) { + buildBadCastCall(builder, loc, + castInfo.getBadCastFunc()); + }); } // Note that castedPtr is a void*. Cast it to a pointer to the destination @@ -154,19 +155,19 @@ LoweringPrepareItaniumCXXABI::lowerDynamicCast(cir::CIRBaseBuilderTy &builder, return buildDynamicCastAfterNullCheck(builder, op); mlir::Value srcValueIsNotNull = builder.createPtrToBoolCast(srcValue); - return builder - .create<cir::TernaryOp>( - loc, srcValueIsNotNull, - [&](mlir::OpBuilder &, mlir::Location) { - mlir::Value castedValue = - op.isCastToVoid() - ? buildDynamicCastToVoidAfterNullCheck(builder, astCtx, op) - : buildDynamicCastAfterNullCheck(builder, op); - builder.createYield(loc, castedValue); - }, - [&](mlir::OpBuilder &, mlir::Location) { - builder.createYield( - loc, builder.getNullPtr(op.getType(), loc).getResult()); - }) + return cir::TernaryOp::create( + builder, loc, srcValueIsNotNull, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value castedValue = + op.isCastToVoid() + ? buildDynamicCastToVoidAfterNullCheck(builder, astCtx, + op) + : buildDynamicCastAfterNullCheck(builder, op); + builder.createYield(loc, castedValue); + }, + [&](mlir::OpBuilder &, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(op.getType(), loc).getResult()); + }) .getResult(); } |