diff options
Diffstat (limited to 'clang/lib')
43 files changed, 1388 insertions, 923 deletions
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp index ff83c52..2d5ad4a 100644 --- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp +++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp @@ -3471,7 +3471,7 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case Builtin::BI_lrotl: case Builtin::BI_rotl64: return interp__builtin_elementwise_int_binop( - S, OpPC, Call, [](const APSInt &Value, const APSInt &Amount) -> APInt { + S, OpPC, Call, [](const APSInt &Value, const APSInt &Amount) { return Value.rotl(Amount); }); @@ -3485,7 +3485,7 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call, case Builtin::BI_lrotr: case Builtin::BI_rotr64: return interp__builtin_elementwise_int_binop( - S, OpPC, Call, [](const APSInt &Value, const APSInt &Amount) -> APInt { + S, OpPC, Call, [](const APSInt &Value, const APSInt &Amount) { return Value.rotr(Amount); }); diff --git a/clang/lib/AST/CMakeLists.txt b/clang/lib/AST/CMakeLists.txt index d4fd7a7..fd50e95 100644 --- a/clang/lib/AST/CMakeLists.txt +++ b/clang/lib/AST/CMakeLists.txt @@ -66,6 +66,7 @@ add_clang_library(clangAST ExternalASTMerger.cpp ExternalASTSource.cpp FormatString.cpp + InferAlloc.cpp InheritViz.cpp ByteCode/BitcastBuffer.cpp ByteCode/ByteCodeEmitter.cpp diff --git a/clang/lib/AST/InferAlloc.cpp b/clang/lib/AST/InferAlloc.cpp new file mode 100644 index 0000000..e439ed4 --- /dev/null +++ b/clang/lib/AST/InferAlloc.cpp @@ -0,0 +1,201 @@ +//===--- InferAlloc.cpp - Allocation type inference -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements allocation-related type inference. +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/InferAlloc.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" +#include "clang/AST/DeclCXX.h" +#include "clang/AST/Expr.h" +#include "clang/AST/Type.h" +#include "clang/Basic/IdentifierTable.h" +#include "llvm/ADT/SmallPtrSet.h" + +using namespace clang; +using namespace infer_alloc; + +static bool +typeContainsPointer(QualType T, + llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD, + bool &IncompleteType) { + QualType CanonicalType = T.getCanonicalType(); + if (CanonicalType->isPointerType()) + return true; // base case + + // Look through typedef chain to check for special types. + for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>(); + CurrentT = TT->getDecl()->getUnderlyingType()) { + const IdentifierInfo *II = TT->getDecl()->getIdentifier(); + // Special Case: Syntactically uintptr_t is not a pointer; semantically, + // however, very likely used as such. Therefore, classify uintptr_t as a + // pointer, too. + if (II && II->isStr("uintptr_t")) + return true; + } + + // The type is an array; check the element type. + if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType)) + return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType); + // The type is a struct, class, or union. + if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) { + if (!RD->isCompleteDefinition()) { + IncompleteType = true; + return false; + } + if (!VisitedRD.insert(RD).second) + return false; // already visited + // Check all fields. + for (const FieldDecl *Field : RD->fields()) { + if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType)) + return true; + } + // For C++ classes, also check base classes. + if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) { + // Polymorphic types require a vptr. + if (CXXRD->isDynamicClass()) + return true; + for (const CXXBaseSpecifier &Base : CXXRD->bases()) { + if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType)) + return true; + } + } + } + return false; +} + +/// Infer type from a simple sizeof expression. +static QualType inferTypeFromSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) { + if (UET->getKind() == UETT_SizeOf) { + if (UET->isArgumentType()) + return UET->getArgumentTypeInfo()->getType(); + else + return UET->getArgumentExpr()->getType(); + } + } + return QualType(); +} + +/// Infer type from an arithmetic expression involving a sizeof. For example: +/// +/// malloc(sizeof(MyType) + padding); // infers 'MyType' +/// malloc(sizeof(MyType) * 32); // infers 'MyType' +/// malloc(32 * sizeof(MyType)); // infers 'MyType' +/// malloc(sizeof(MyType) << 1); // infers 'MyType' +/// ... +/// +/// More complex arithmetic expressions are supported, but are a heuristic, e.g. +/// when considering allocations for structs with flexible array members: +/// +/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray' +/// +static QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + // The argument is a lone sizeof expression. + if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull()) + return T; + if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) { + // Argument is an arithmetic expression. Cover common arithmetic patterns + // involving sizeof. + switch (BO->getOpcode()) { + case BO_Add: + case BO_Div: + case BO_Mul: + case BO_Shl: + case BO_Shr: + case BO_Sub: + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS()); + !T.isNull()) + return T; + if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS()); + !T.isNull()) + return T; + break; + default: + break; + } + } + return QualType(); +} + +/// If the expression E is a reference to a variable, infer the type from a +/// variable's initializer if it contains a sizeof. Beware, this is a heuristic +/// and ignores if a variable is later reassigned. For example: +/// +/// size_t my_size = sizeof(MyType); +/// void *x = malloc(my_size); // infers 'MyType' +/// +static QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) { + const Expr *Arg = E->IgnoreParenImpCasts(); + if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) { + if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { + if (const Expr *Init = VD->getInit()) + return inferPossibleTypeFromArithSizeofExpr(Init); + } + } + return QualType(); +} + +/// Deduces the allocated type by checking if the allocation call's result +/// is immediately used in a cast expression. For example: +/// +/// MyType *x = (MyType *)malloc(4096); // infers 'MyType' +/// +static QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE, + const CastExpr *CastE) { + if (!CastE) + return QualType(); + QualType PtrType = CastE->getType(); + if (PtrType->isPointerType()) + return PtrType->getPointeeType(); + return QualType(); +} + +QualType infer_alloc::inferPossibleType(const CallExpr *E, + const ASTContext &Ctx, + const CastExpr *CastE) { + QualType AllocType; + // First check arguments. + for (const Expr *Arg : E->arguments()) { + AllocType = inferPossibleTypeFromArithSizeofExpr(Arg); + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg); + if (!AllocType.isNull()) + break; + } + // Then check later casts. + if (AllocType.isNull()) + AllocType = inferPossibleTypeFromCastExpr(E, CastE); + return AllocType; +} + +std::optional<llvm::AllocTokenMetadata> +infer_alloc::getAllocTokenMetadata(QualType T, const ASTContext &Ctx) { + llvm::AllocTokenMetadata ATMD; + + // Get unique type name. + PrintingPolicy Policy(Ctx.getLangOpts()); + Policy.SuppressTagKeyword = true; + Policy.FullyQualifiedName = true; + llvm::raw_svector_ostream TypeNameOS(ATMD.TypeName); + T.getCanonicalType().print(TypeNameOS, Policy); + + // Check if QualType contains a pointer. Implements a simple DFS to + // recursively check if a type contains a pointer type. + llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD; + bool IncompleteType = false; + ATMD.ContainsPointer = typeContainsPointer(T, VisitedRD, IncompleteType); + if (!ATMD.ContainsPointer && IncompleteType) + return std::nullopt; + + return ATMD; +} diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp index f068be5..598d33a 100644 --- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp +++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp @@ -137,6 +137,37 @@ static auto valueOperatorCall() { isStatusOrOperatorCallWithName("->"))); } +static clang::ast_matchers::TypeMatcher statusType() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return hasCanonicalType(qualType(hasDeclaration(statusClass()))); +} + +static auto isComparisonOperatorCall(llvm::StringRef operator_name) { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return cxxOperatorCallExpr( + hasOverloadedOperatorName(operator_name), argumentCountIs(2), + hasArgument(0, anyOf(hasType(statusType()), hasType(statusOrType()))), + hasArgument(1, anyOf(hasType(statusType()), hasType(statusOrType())))); +} + +static auto isOkStatusCall() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return callExpr(callee(functionDecl(hasName("::absl::OkStatus")))); +} + +static auto isNotOkStatusCall() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return callExpr(callee(functionDecl(hasAnyName( + "::absl::AbortedError", "::absl::AlreadyExistsError", + "::absl::CancelledError", "::absl::DataLossError", + "::absl::DeadlineExceededError", "::absl::FailedPreconditionError", + "::absl::InternalError", "::absl::InvalidArgumentError", + "::absl::NotFoundError", "::absl::OutOfRangeError", + "::absl::PermissionDeniedError", "::absl::ResourceExhaustedError", + "::absl::UnauthenticatedError", "::absl::UnavailableError", + "::absl::UnimplementedError", "::absl::UnknownError")))); +} + static auto buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) { return CFGMatchSwitchBuilder<const Environment, @@ -312,6 +343,118 @@ static void transferStatusUpdateCall(const CXXMemberCallExpr *Expr, State.Env.setValue(locForOk(*ThisLoc), NewVal); } +static BoolValue *evaluateStatusEquality(RecordStorageLocation &LhsStatusLoc, + RecordStorageLocation &RhsStatusLoc, + Environment &Env) { + auto &A = Env.arena(); + // Logically, a Status object is composed of an error code that could take one + // of multiple possible values, including the "ok" value. We track whether a + // Status object has an "ok" value and represent this as an `ok` bit. Equality + // of Status objects compares their error codes. Therefore, merely comparing + // the `ok` bits isn't sufficient: when two Status objects are assigned non-ok + // error codes the equality of their respective error codes matters. Since we + // only track the `ok` bits, we can't make any conclusions about equality when + // we know that two Status objects have non-ok values. + + auto &LhsOkVal = valForOk(LhsStatusLoc, Env); + auto &RhsOkVal = valForOk(RhsStatusLoc, Env); + + auto &Res = Env.makeAtomicBoolValue(); + + // lhs && rhs => res (a.k.a. !res => !lhs || !rhs) + Env.assume(A.makeImplies(A.makeAnd(LhsOkVal.formula(), RhsOkVal.formula()), + Res.formula())); + // res => (lhs == rhs) + Env.assume(A.makeImplies( + Res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula()))); + + return &Res; +} + +static BoolValue * +evaluateStatusOrEquality(RecordStorageLocation &LhsStatusOrLoc, + RecordStorageLocation &RhsStatusOrLoc, + Environment &Env) { + auto &A = Env.arena(); + // Logically, a StatusOr<T> object is composed of two values - a Status and a + // value of type T. Equality of StatusOr objects compares both values. + // Therefore, merely comparing the `ok` bits of the Status values isn't + // sufficient. When two StatusOr objects are engaged, the equality of their + // respective values of type T matters. Similarly, when two StatusOr objects + // have Status values that have non-ok error codes, the equality of the error + // codes matters. Since we only track the `ok` bits of the Status values, we + // can't make any conclusions about equality when we know that two StatusOr + // objects are engaged or when their Status values contain non-ok error codes. + auto &LhsOkVal = valForOk(locForStatus(LhsStatusOrLoc), Env); + auto &RhsOkVal = valForOk(locForStatus(RhsStatusOrLoc), Env); + auto &res = Env.makeAtomicBoolValue(); + + // res => (lhs == rhs) + Env.assume(A.makeImplies( + res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula()))); + return &res; +} + +static BoolValue *evaluateEquality(const Expr *LhsExpr, const Expr *RhsExpr, + Environment &Env) { + // Check the type of both sides in case an operator== is added that admits + // different types. + if (isStatusOrType(LhsExpr->getType()) && + isStatusOrType(RhsExpr->getType())) { + auto *LhsStatusOrLoc = Env.get<RecordStorageLocation>(*LhsExpr); + if (LhsStatusOrLoc == nullptr) + return nullptr; + auto *RhsStatusOrLoc = Env.get<RecordStorageLocation>(*RhsExpr); + if (RhsStatusOrLoc == nullptr) + return nullptr; + + return evaluateStatusOrEquality(*LhsStatusOrLoc, *RhsStatusOrLoc, Env); + } + if (isStatusType(LhsExpr->getType()) && isStatusType(RhsExpr->getType())) { + auto *LhsStatusLoc = Env.get<RecordStorageLocation>(*LhsExpr); + if (LhsStatusLoc == nullptr) + return nullptr; + + auto *RhsStatusLoc = Env.get<RecordStorageLocation>(*RhsExpr); + if (RhsStatusLoc == nullptr) + return nullptr; + + return evaluateStatusEquality(*LhsStatusLoc, *RhsStatusLoc, Env); + } + return nullptr; +} + +static void transferComparisonOperator(const CXXOperatorCallExpr *Expr, + LatticeTransferState &State, + bool IsNegative) { + auto *LhsAndRhsVal = + evaluateEquality(Expr->getArg(0), Expr->getArg(1), State.Env); + if (LhsAndRhsVal == nullptr) + return; + + if (IsNegative) + State.Env.setValue(*Expr, State.Env.makeNot(*LhsAndRhsVal)); + else + State.Env.setValue(*Expr, *LhsAndRhsVal); +} + +static void transferOkStatusCall(const CallExpr *Expr, + const MatchFinder::MatchResult &, + LatticeTransferState &State) { + auto &OkVal = + initializeStatus(State.Env.getResultObjectLocation(*Expr), State.Env); + State.Env.assume(OkVal.formula()); +} + +static void transferNotOkStatusCall(const CallExpr *Expr, + const MatchFinder::MatchResult &, + LatticeTransferState &State) { + auto &OkVal = + initializeStatus(State.Env.getResultObjectLocation(*Expr), State.Env); + auto &A = State.Env.arena(); + State.Env.assume(A.makeNot(OkVal.formula())); +} + CFGMatchSwitch<LatticeTransferState> buildTransferMatchSwitch(ASTContext &Ctx, CFGMatchSwitchBuilder<LatticeTransferState> Builder) { @@ -325,6 +468,22 @@ buildTransferMatchSwitch(ASTContext &Ctx, transferStatusOkCall) .CaseOfCFGStmt<CXXMemberCallExpr>(isStatusMemberCallWithName("Update"), transferStatusUpdateCall) + .CaseOfCFGStmt<CXXOperatorCallExpr>( + isComparisonOperatorCall("=="), + [](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &, + LatticeTransferState &State) { + transferComparisonOperator(Expr, State, + /*IsNegative=*/false); + }) + .CaseOfCFGStmt<CXXOperatorCallExpr>( + isComparisonOperatorCall("!="), + [](const CXXOperatorCallExpr *Expr, const MatchFinder::MatchResult &, + LatticeTransferState &State) { + transferComparisonOperator(Expr, State, + /*IsNegative=*/true); + }) + .CaseOfCFGStmt<CallExpr>(isOkStatusCall(), transferOkStatusCall) + .CaseOfCFGStmt<CallExpr>(isNotOkStatusCall(), transferNotOkStatusCall) .Build(); } diff --git a/clang/lib/CIR/CodeGen/CIRGenAsm.cpp b/clang/lib/CIR/CodeGen/CIRGenAsm.cpp index 17dffb3..88a7e85 100644 --- a/clang/lib/CIR/CodeGen/CIRGenAsm.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenAsm.cpp @@ -117,9 +117,9 @@ mlir::LogicalResult CIRGenFunction::emitAsmStmt(const AsmStmt &s) { bool hasSideEffect = s.isVolatile() || s.getNumOutputs() == 0; - cir::InlineAsmOp ia = builder.create<cir::InlineAsmOp>( - getLoc(s.getAsmLoc()), resultType, operands, asmString, constraints, - hasSideEffect, inferFlavor(cgm, s), mlir::ArrayAttr()); + cir::InlineAsmOp ia = cir::InlineAsmOp::create( + builder, getLoc(s.getAsmLoc()), resultType, operands, asmString, + constraints, hasSideEffect, inferFlavor(cgm, s), mlir::ArrayAttr()); if (isGCCAsmGoto) { assert(!cir::MissingFeatures::asmGoto()); diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.cpp b/clang/lib/CIR/CodeGen/CIRGenBuilder.cpp index 670a431..75355ee 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.cpp @@ -22,8 +22,8 @@ mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc, if (arrayTy) { const cir::PointerType flatPtrTy = getPointerTo(arrayTy.getElementType()); - return create<cir::CastOp>(loc, flatPtrTy, cir::CastKind::array_to_ptrdecay, - arrayPtr); + return cir::CastOp::create(*this, loc, flatPtrTy, + cir::CastKind::array_to_ptrdecay, arrayPtr); } assert(arrayPtrTy.getPointee() == eltTy && @@ -40,7 +40,7 @@ mlir::Value CIRGenBuilderTy::getArrayElement(mlir::Location arrayLocBegin, if (shouldDecay) basePtr = maybeBuildArrayDecay(arrayLocBegin, arrayPtr, eltTy); const mlir::Type flatPtrTy = basePtr.getType(); - return create<cir::PtrStrideOp>(arrayLocEnd, flatPtrTy, basePtr, idx); + return cir::PtrStrideOp::create(*this, arrayLocEnd, flatPtrTy, basePtr, idx); } cir::ConstantOp CIRGenBuilderTy::getConstInt(mlir::Location loc, @@ -60,14 +60,14 @@ cir::ConstantOp CIRGenBuilderTy::getConstInt(mlir::Location loc, cir::ConstantOp CIRGenBuilderTy::getConstInt(mlir::Location loc, mlir::Type t, uint64_t c) { assert(mlir::isa<cir::IntType>(t) && "expected cir::IntType"); - return create<cir::ConstantOp>(loc, cir::IntAttr::get(t, c)); + return cir::ConstantOp::create(*this, loc, cir::IntAttr::get(t, c)); } cir::ConstantOp clang::CIRGen::CIRGenBuilderTy::getConstFP(mlir::Location loc, mlir::Type t, llvm::APFloat fpVal) { assert(mlir::isa<cir::FPTypeInterface>(t) && "expected floating point type"); - return create<cir::ConstantOp>(loc, cir::FPAttr::get(t, fpVal)); + return cir::ConstantOp::create(*this, loc, cir::FPAttr::get(t, fpVal)); } void CIRGenBuilderTy::computeGlobalViewIndicesFromFlatOffset( diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index 798e9d9..27c4d11 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -46,9 +46,9 @@ static RValue emitBuiltinBitOp(CIRGenFunction &cgf, const CallExpr *e, Op op; if constexpr (std::is_same_v<Op, cir::BitClzOp> || std::is_same_v<Op, cir::BitCtzOp>) - op = builder.create<Op>(cgf.getLoc(e->getSourceRange()), arg, poisonZero); + op = Op::create(builder, cgf.getLoc(e->getSourceRange()), arg, poisonZero); else - op = builder.create<Op>(cgf.getLoc(e->getSourceRange()), arg); + op = Op::create(builder, cgf.getLoc(e->getSourceRange()), arg); mlir::Value result = op.getResult(); mlir::Type exprTy = cgf.convertType(e->getType()); @@ -67,8 +67,8 @@ RValue CIRGenFunction::emitRotate(const CallExpr *e, bool isRotateLeft) { // to the type of input when necessary. assert(!cir::MissingFeatures::msvcBuiltins()); - auto r = builder.create<cir::RotateOp>(getLoc(e->getSourceRange()), input, - amount, isRotateLeft); + auto r = cir::RotateOp::create(builder, getLoc(e->getSourceRange()), input, + amount, isRotateLeft); return RValue::get(r); } @@ -227,14 +227,14 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, return RValue::get(nullptr); mlir::Value argValue = emitCheckedArgForAssume(e->getArg(0)); - builder.create<cir::AssumeOp>(loc, argValue); + cir::AssumeOp::create(builder, loc, argValue); return RValue::get(nullptr); } case Builtin::BI__builtin_assume_separate_storage: { mlir::Value value0 = emitScalarExpr(e->getArg(0)); mlir::Value value1 = emitScalarExpr(e->getArg(1)); - builder.create<cir::AssumeSepStorageOp>(loc, value0, value1); + cir::AssumeSepStorageOp::create(builder, loc, value0, value1); return RValue::get(nullptr); } @@ -363,8 +363,8 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, probability); } - auto result = builder.create<cir::ExpectOp>( - loc, argValue.getType(), argValue, expectedValue, probAttr); + auto result = cir::ExpectOp::create(builder, loc, argValue.getType(), + argValue, expectedValue, probAttr); return RValue::get(result); } @@ -375,7 +375,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, case Builtin::BI_byteswap_ulong: case Builtin::BI_byteswap_uint64: { mlir::Value arg = emitScalarExpr(e->getArg(0)); - return RValue::get(builder.create<cir::ByteSwapOp>(loc, arg)); + return RValue::get(cir::ByteSwapOp::create(builder, loc, arg)); } case Builtin::BI__builtin_bitreverse8: @@ -383,7 +383,7 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl &gd, unsigned builtinID, case Builtin::BI__builtin_bitreverse32: case Builtin::BI__builtin_bitreverse64: { mlir::Value arg = emitScalarExpr(e->getArg(0)); - return RValue::get(builder.create<cir::BitReverseOp>(loc, arg)); + return RValue::get(cir::BitReverseOp::create(builder, loc, arg)); } case Builtin::BI__builtin_rotateleft8: diff --git a/clang/lib/CIR/CodeGen/CIRGenCXX.cpp b/clang/lib/CIR/CodeGen/CIRGenCXX.cpp index 171ce1c..a3e2081 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXX.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCXX.cpp @@ -53,7 +53,7 @@ static void emitDeclInit(CIRGenFunction &cgf, const VarDecl *varDecl, cgf.emitScalarInit(init, cgf.getLoc(varDecl->getLocation()), lv, false); break; case cir::TEK_Complex: - cgf.cgm.errorNYI(varDecl->getSourceRange(), "complex global initializer"); + cgf.emitComplexExprIntoLValue(init, lv, /*isInit=*/true); break; case cir::TEK_Aggregate: assert(!cir::MissingFeatures::aggValueSlotGC()); @@ -151,7 +151,7 @@ static void emitDeclDestroy(CIRGenFunction &cgf, const VarDecl *vd, // Don't confuse lexical cleanup. builder.clearInsertionPoint(); } else { - builder.create<cir::YieldOp>(addr.getLoc()); + cir::YieldOp::create(builder, addr.getLoc()); } } diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp index eef3739..aa0182e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCXXABI.cpp @@ -70,8 +70,8 @@ cir::GlobalLinkageKind CIRGenCXXABI::getCXXDestructorLinkage( mlir::Value CIRGenCXXABI::loadIncomingCXXThis(CIRGenFunction &cgf) { ImplicitParamDecl *vd = getThisDecl(cgf); Address addr = cgf.getAddrOfLocalVar(vd); - return cgf.getBuilder().create<cir::LoadOp>( - cgf.getLoc(vd->getLocation()), addr.getElementType(), addr.getPointer()); + return cir::LoadOp::create(cgf.getBuilder(), cgf.getLoc(vd->getLocation()), + addr.getElementType(), addr.getPointer()); } void CIRGenCXXABI::setCXXABIThisValue(CIRGenFunction &cgf, diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index 61072f0..88aef89 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr, isUsed = true; } +mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc, + cir::FuncOp callee, + ArrayRef<mlir::Value> args) { + // TODO(cir): set the calling convention to this runtime call. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cir::CallOp call = builder.createCallOp(loc, callee, args); + assert(call->getNumResults() <= 1 && + "runtime functions have at most 1 result"); + + if (call->getNumResults() == 0) + return nullptr; + + return call->getResult(0); +} + void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e, clang::QualType argType) { assert(argType->isReferenceType() == e->isGLValue() && diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index 89f4926..5046e09 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -725,8 +725,9 @@ void CIRGenFunction::emitCXXAggrConstructorCall( // Emit the constructor call that will execute for every array element. mlir::Value arrayOp = builder.createPtrBitcast(arrayBase.getPointer(), arrayTy); - builder.create<cir::ArrayCtor>( - *currSrcLoc, arrayOp, [&](mlir::OpBuilder &b, mlir::Location loc) { + cir::ArrayCtor::create( + builder, *currSrcLoc, arrayOp, + [&](mlir::OpBuilder &b, mlir::Location loc) { mlir::BlockArgument arg = b.getInsertionBlock()->addArgument(ptrToElmType, loc); Address curAddr = Address(arg, elementType, eltAlignment); @@ -738,7 +739,7 @@ void CIRGenFunction::emitCXXAggrConstructorCall( emitCXXConstructorCall(ctor, Ctor_Complete, /*ForVirtualBase=*/false, /*Delegating=*/false, currAVS, e); - builder.create<cir::YieldOp>(loc); + cir::YieldOp::create(builder, loc); }); } } diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 52021fc..9df88ad 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -251,8 +251,8 @@ void CIRGenFunction::emitStoreThroughLValue(RValue src, LValue dst, const mlir::Location loc = dst.getVectorPointer().getLoc(); const mlir::Value vector = builder.createLoad(loc, dst.getVectorAddress()); - const mlir::Value newVector = builder.create<cir::VecInsertOp>( - loc, vector, src.getValue(), dst.getVectorIdx()); + const mlir::Value newVector = cir::VecInsertOp::create( + builder, loc, vector, src.getValue(), dst.getVectorIdx()); builder.createStore(loc, newVector, dst.getVectorAddress()); return; } @@ -615,8 +615,8 @@ RValue CIRGenFunction::emitLoadOfLValue(LValue lv, SourceLocation loc) { if (lv.isVectorElt()) { const mlir::Value load = builder.createLoad(getLoc(loc), lv.getVectorAddress()); - return RValue::get(builder.create<cir::VecExtractOp>(getLoc(loc), load, - lv.getVectorIdx())); + return RValue::get(cir::VecExtractOp::create(builder, getLoc(loc), load, + lv.getVectorIdx())); } cgm.errorNYI(loc, "emitLoadOfLValue"); @@ -671,8 +671,8 @@ static LValue emitFunctionDeclLValue(CIRGenFunction &cgf, const Expr *e, mlir::Type fnTy = funcOp.getFunctionType(); mlir::Type ptrTy = cir::PointerType::get(fnTy); - mlir::Value addr = cgf.getBuilder().create<cir::GetGlobalOp>( - loc, ptrTy, funcOp.getSymName()); + mlir::Value addr = cir::GetGlobalOp::create(cgf.getBuilder(), loc, ptrTy, + funcOp.getSymName()); if (funcOp.getFunctionType() != cgf.convertType(fd->getType())) { fnTy = cgf.convertType(fd->getType()); @@ -1685,8 +1685,8 @@ CIRGenCallee CIRGenFunction::emitDirectCallee(const GlobalDecl &gd) { mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(cgm.getModule().getBody()); - clone = builder.create<cir::FuncOp>(calleeFunc.getLoc(), fdInlineName, - calleeFunc.getFunctionType()); + clone = cir::FuncOp::create(builder, calleeFunc.getLoc(), fdInlineName, + calleeFunc.getFunctionType()); clone.setLinkageAttr(cir::GlobalLinkageKindAttr::get( &cgm.getMLIRContext(), cir::GlobalLinkageKind::InternalLinkage)); clone.setSymVisibility("private"); @@ -1778,8 +1778,8 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy, mlir::Operation *fn = callee.getFunctionPointer(); mlir::Value addr; if (auto funcOp = mlir::dyn_cast<cir::FuncOp>(fn)) { - addr = builder.create<cir::GetGlobalOp>( - getLoc(e->getSourceRange()), + addr = cir::GetGlobalOp::create( + builder, getLoc(e->getSourceRange()), cir::PointerType::get(funcOp.getFunctionType()), funcOp.getSymName()); } else { addr = fn->getResult(0); @@ -1820,10 +1820,12 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) { // Resolve direct calls. const auto *funcDecl = cast<FunctionDecl>(declRef->getDecl()); return emitDirectCallee(funcDecl); - } else if (isa<MemberExpr>(e)) { - cgm.errorNYI(e->getSourceRange(), - "emitCallee: call to member function is NYI"); - return {}; + } else if (auto me = dyn_cast<MemberExpr>(e)) { + if (const auto *fd = dyn_cast<FunctionDecl>(me->getMemberDecl())) { + emitIgnoredExpr(me->getBase()); + return emitDirectCallee(fd); + } + // Else fall through to the indirect reference handling below. } else if (auto *pde = dyn_cast<CXXPseudoDestructorExpr>(e)) { return CIRGenCallee::forPseudoDestructor(pde); } @@ -1996,9 +1998,9 @@ cir::IfOp CIRGenFunction::emitIfOnBoolExpr( // Emit the code with the fully general case. mlir::Value condV = emitOpOnBoolExpr(loc, cond); - return builder.create<cir::IfOp>(loc, condV, elseLoc.has_value(), - /*thenBuilder=*/thenBuilder, - /*elseBuilder=*/elseBuilder); + return cir::IfOp::create(builder, loc, condV, elseLoc.has_value(), + /*thenBuilder=*/thenBuilder, + /*elseBuilder=*/elseBuilder); } /// TODO(cir): see EmitBranchOnBoolExpr for extra ideas). @@ -2020,18 +2022,17 @@ mlir::Value CIRGenFunction::emitOpOnBoolExpr(mlir::Location loc, mlir::Value condV = emitOpOnBoolExpr(loc, condOp->getCond()); mlir::Value ternaryOpRes = - builder - .create<cir::TernaryOp>( - loc, condV, /*thenBuilder=*/ - [this, trueExpr](mlir::OpBuilder &b, mlir::Location loc) { - mlir::Value lhs = emitScalarExpr(trueExpr); - b.create<cir::YieldOp>(loc, lhs); - }, - /*elseBuilder=*/ - [this, falseExpr](mlir::OpBuilder &b, mlir::Location loc) { - mlir::Value rhs = emitScalarExpr(falseExpr); - b.create<cir::YieldOp>(loc, rhs); - }) + cir::TernaryOp::create( + builder, loc, condV, /*thenBuilder=*/ + [this, trueExpr](mlir::OpBuilder &b, mlir::Location loc) { + mlir::Value lhs = emitScalarExpr(trueExpr); + cir::YieldOp::create(b, loc, lhs); + }, + /*elseBuilder=*/ + [this, falseExpr](mlir::OpBuilder &b, mlir::Location loc) { + mlir::Value rhs = emitScalarExpr(falseExpr); + cir::YieldOp::create(b, loc, rhs); + }) .getResult(); return emitScalarConversion(ternaryOpRes, condOp->getType(), @@ -2211,8 +2212,8 @@ Address CIRGenFunction::emitLoadOfReference(LValue refLVal, mlir::Location loc, cgm.errorNYI(loc, "load of volatile reference"); cir::LoadOp load = - builder.create<cir::LoadOp>(loc, refLVal.getAddress().getElementType(), - refLVal.getAddress().getPointer()); + cir::LoadOp::create(builder, loc, refLVal.getAddress().getElementType(), + refLVal.getAddress().getPointer()); assert(!cir::MissingFeatures::opTBAA()); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp index d8f4943..047f359 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp @@ -390,7 +390,7 @@ ComplexExprEmitter::VisitImaginaryLiteral(const ImaginaryLiteral *il) { } auto complexAttr = cir::ConstComplexAttr::get(realValueAttr, imagValueAttr); - return builder.create<cir::ConstantOp>(loc, complexAttr); + return cir::ConstantOp::create(builder, loc, complexAttr); } mlir::Value ComplexExprEmitter::VisitCallExpr(const CallExpr *e) { @@ -601,7 +601,7 @@ mlir::Value ComplexExprEmitter::emitBinAdd(const BinOpInfo &op) { if (mlir::isa<cir::ComplexType>(op.lhs.getType()) && mlir::isa<cir::ComplexType>(op.rhs.getType())) - return builder.create<cir::ComplexAddOp>(op.loc, op.lhs, op.rhs); + return cir::ComplexAddOp::create(builder, op.loc, op.lhs, op.rhs); if (mlir::isa<cir::ComplexType>(op.lhs.getType())) { mlir::Value real = builder.createComplexReal(op.loc, op.lhs); @@ -623,7 +623,7 @@ mlir::Value ComplexExprEmitter::emitBinSub(const BinOpInfo &op) { if (mlir::isa<cir::ComplexType>(op.lhs.getType()) && mlir::isa<cir::ComplexType>(op.rhs.getType())) - return builder.create<cir::ComplexSubOp>(op.loc, op.lhs, op.rhs); + return cir::ComplexSubOp::create(builder, op.loc, op.lhs, op.rhs); if (mlir::isa<cir::ComplexType>(op.lhs.getType())) { mlir::Value real = builder.createComplexReal(op.loc, op.lhs); @@ -664,7 +664,8 @@ mlir::Value ComplexExprEmitter::emitBinMul(const BinOpInfo &op) { mlir::isa<cir::ComplexType>(op.rhs.getType())) { cir::ComplexRangeKind rangeKind = getComplexRangeAttr(op.fpFeatures.getComplexRange()); - return builder.create<cir::ComplexMulOp>(op.loc, op.lhs, op.rhs, rangeKind); + return cir::ComplexMulOp::create(builder, op.loc, op.lhs, op.rhs, + rangeKind); } if (mlir::isa<cir::ComplexType>(op.lhs.getType())) { @@ -968,23 +969,22 @@ mlir::Value ComplexExprEmitter::VisitAbstractConditionalOperator( Expr *cond = e->getCond()->IgnoreParens(); mlir::Value condValue = cgf.evaluateExprAsBool(cond); - return builder - .create<cir::TernaryOp>( - loc, condValue, - /*thenBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - eval.beginEvaluation(); - mlir::Value trueValue = Visit(e->getTrueExpr()); - b.create<cir::YieldOp>(loc, trueValue); - eval.endEvaluation(); - }, - /*elseBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - eval.beginEvaluation(); - mlir::Value falseValue = Visit(e->getFalseExpr()); - b.create<cir::YieldOp>(loc, falseValue); - eval.endEvaluation(); - }) + return cir::TernaryOp::create( + builder, loc, condValue, + /*thenBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + eval.beginEvaluation(); + mlir::Value trueValue = Visit(e->getTrueExpr()); + cir::YieldOp::create(b, loc, trueValue); + eval.endEvaluation(); + }, + /*elseBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + eval.beginEvaluation(); + mlir::Value falseValue = Visit(e->getFalseExpr()); + cir::YieldOp::create(b, loc, falseValue); + eval.endEvaluation(); + }) .getResult(); } diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp index 800262a..7de3dd0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp @@ -179,8 +179,23 @@ bool ConstantAggregateBuilder::add(mlir::TypedAttr typedAttr, CharUnits offset, } // Uncommon case: constant overlaps what we've already created. - cgm.errorNYI("overlapping constants"); - return false; + std::optional<size_t> firstElemToReplace = splitAt(offset); + if (!firstElemToReplace) + return false; + + CharUnits cSize = getSize(typedAttr); + std::optional<size_t> lastElemToReplace = splitAt(offset + cSize); + if (!lastElemToReplace) + return false; + + assert((firstElemToReplace == lastElemToReplace || allowOverwrite) && + "unexpectedly overwriting field"); + + Element newElt(typedAttr, offset); + replace(elements, *firstElemToReplace, *lastElemToReplace, {newElt}); + size = std::max(size, offset + cSize); + naturalLayout = false; + return true; } bool ConstantAggregateBuilder::addBits(llvm::APInt bits, uint64_t offsetInBits, @@ -612,10 +627,7 @@ bool ConstRecordBuilder::applyZeroInitPadding(const ASTRecordLayout &layout, } bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { - RecordDecl *rd = ile->getType() - ->castAs<clang::RecordType>() - ->getDecl() - ->getDefinitionOrSelf(); + RecordDecl *rd = ile->getType()->castAsRecordDecl(); const ASTRecordLayout &layout = cgm.getASTContext().getASTRecordLayout(rd); // Bail out if we have base classes. We could support these, but they only @@ -671,17 +683,14 @@ bool ConstRecordBuilder::build(InitListExpr *ile, bool allowOverwrite) { return false; } - mlir::TypedAttr eltInit; - if (init) - eltInit = mlir::cast<mlir::TypedAttr>( - emitter.tryEmitPrivateForMemory(init, field->getType())); - else - eltInit = mlir::cast<mlir::TypedAttr>(emitter.emitNullForMemory( - cgm.getLoc(ile->getSourceRange()), field->getType())); - - if (!eltInit) + mlir::Attribute eltInitAttr = + init ? emitter.tryEmitPrivateForMemory(init, field->getType()) + : emitter.emitNullForMemory(cgm.getLoc(ile->getSourceRange()), + field->getType()); + if (!eltInitAttr) return false; + mlir::TypedAttr eltInit = mlir::cast<mlir::TypedAttr>(eltInitAttr); if (!field->isBitField()) { // Handle non-bitfield members. if (!appendField(field, layout.getFieldOffset(index), eltInit, diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 33eb748..db6878d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -164,22 +164,22 @@ public: mlir::Value VisitIntegerLiteral(const IntegerLiteral *e) { mlir::Type type = cgf.convertType(e->getType()); - return builder.create<cir::ConstantOp>( - cgf.getLoc(e->getExprLoc()), cir::IntAttr::get(type, e->getValue())); + return cir::ConstantOp::create(builder, cgf.getLoc(e->getExprLoc()), + cir::IntAttr::get(type, e->getValue())); } mlir::Value VisitFloatingLiteral(const FloatingLiteral *e) { mlir::Type type = cgf.convertType(e->getType()); assert(mlir::isa<cir::FPTypeInterface>(type) && "expect floating-point type"); - return builder.create<cir::ConstantOp>( - cgf.getLoc(e->getExprLoc()), cir::FPAttr::get(type, e->getValue())); + return cir::ConstantOp::create(builder, cgf.getLoc(e->getExprLoc()), + cir::FPAttr::get(type, e->getValue())); } mlir::Value VisitCharacterLiteral(const CharacterLiteral *e) { mlir::Type ty = cgf.convertType(e->getType()); auto init = cir::IntAttr::get(ty, e->getValue()); - return builder.create<cir::ConstantOp>(cgf.getLoc(e->getExprLoc()), init); + return cir::ConstantOp::create(builder, cgf.getLoc(e->getExprLoc()), init); } mlir::Value VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *e) { @@ -227,7 +227,7 @@ public: const mlir::Location loc = cgf.getLoc(e->getSourceRange()); const mlir::Value vecValue = Visit(e->getBase()); const mlir::Value indexValue = Visit(e->getIdx()); - return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue); + return cir::VecExtractOp::create(cgf.builder, loc, vecValue, indexValue); } // Just load the lvalue formed by the subscript expression. return emitLoadOfLValue(e); @@ -238,8 +238,8 @@ public: // The undocumented form of __builtin_shufflevector. mlir::Value inputVec = Visit(e->getExpr(0)); mlir::Value indexVec = Visit(e->getExpr(1)); - return cgf.builder.create<cir::VecShuffleDynamicOp>( - cgf.getLoc(e->getSourceRange()), inputVec, indexVec); + return cir::VecShuffleDynamicOp::create( + cgf.builder, cgf.getLoc(e->getSourceRange()), inputVec, indexVec); } mlir::Value vec1 = Visit(e->getExpr(0)); @@ -257,9 +257,10 @@ public: .getSExtValue())); } - return cgf.builder.create<cir::VecShuffleOp>( - cgf.getLoc(e->getSourceRange()), cgf.convertType(e->getType()), vec1, - vec2, cgf.builder.getArrayAttr(indices)); + return cir::VecShuffleOp::create(cgf.builder, + cgf.getLoc(e->getSourceRange()), + cgf.convertType(e->getType()), vec1, vec2, + cgf.builder.getArrayAttr(indices)); } mlir::Value VisitConvertVectorExpr(ConvertVectorExpr *e) { @@ -296,8 +297,8 @@ public: mlir::Value emitFloatToBoolConversion(mlir::Value src, mlir::Location loc) { cir::BoolType boolTy = builder.getBoolTy(); - return builder.create<cir::CastOp>(loc, boolTy, - cir::CastKind::float_to_bool, src); + return cir::CastOp::create(builder, loc, boolTy, + cir::CastKind::float_to_bool, src); } mlir::Value emitIntToBoolConversion(mlir::Value srcVal, mlir::Location loc) { @@ -307,8 +308,8 @@ public: // TODO: optimize this common case here or leave it for later // CIR passes? cir::BoolType boolTy = builder.getBoolTy(); - return builder.create<cir::CastOp>(loc, boolTy, cir::CastKind::int_to_bool, - srcVal); + return cir::CastOp::create(builder, loc, boolTy, cir::CastKind::int_to_bool, + srcVal); } /// Convert the specified expression value to a boolean (!cir.bool) truth @@ -411,7 +412,8 @@ public: } assert(castKind.has_value() && "Internal error: CastKind not set."); - return builder.create<cir::CastOp>(src.getLoc(), fullDstTy, *castKind, src); + return cir::CastOp::create(builder, src.getLoc(), fullDstTy, *castKind, + src); } mlir::Value @@ -658,9 +660,9 @@ public: mlir::Value emitUnaryOp(const UnaryOperator *e, cir::UnaryOpKind kind, mlir::Value input, bool nsw = false) { - return builder.create<cir::UnaryOp>( - cgf.getLoc(e->getSourceRange().getBegin()), input.getType(), kind, - input, nsw); + return cir::UnaryOp::create(builder, + cgf.getLoc(e->getSourceRange().getBegin()), + input.getType(), kind, input, nsw); } mlir::Value VisitUnaryNot(const UnaryOperator *e) { @@ -967,9 +969,9 @@ public: } else { // Other kinds of vectors. Element-wise comparison returning // a vector. - result = builder.create<cir::VecCmpOp>( - cgf.getLoc(boInfo.loc), cgf.convertType(boInfo.fullType), kind, - boInfo.lhs, boInfo.rhs); + result = cir::VecCmpOp::create(builder, cgf.getLoc(boInfo.loc), + cgf.convertType(boInfo.fullType), kind, + boInfo.lhs, boInfo.rhs); } } else if (boInfo.isFixedPointOp()) { assert(!cir::MissingFeatures::fixedPointType()); @@ -991,7 +993,7 @@ public: assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE); BinOpInfo boInfo = emitBinOps(e); - result = builder.create<cir::CmpOp>(loc, kind, boInfo.lhs, boInfo.rhs); + result = cir::CmpOp::create(builder, loc, kind, boInfo.lhs, boInfo.rhs); } return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(), @@ -1093,8 +1095,8 @@ public: CIRGenFunction::ConditionalEvaluation eval(cgf); mlir::Value lhsCondV = cgf.evaluateExprAsBool(e->getLHS()); - auto resOp = builder.create<cir::TernaryOp>( - loc, lhsCondV, /*trueBuilder=*/ + auto resOp = cir::TernaryOp::create( + builder, loc, lhsCondV, /*trueBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { CIRGenFunction::LexicalScope lexScope{cgf, loc, b.getInsertionBlock()}; @@ -1139,8 +1141,8 @@ public: CIRGenFunction::ConditionalEvaluation eval(cgf); mlir::Value lhsCondV = cgf.evaluateExprAsBool(e->getLHS()); - auto resOp = builder.create<cir::TernaryOp>( - loc, lhsCondV, /*trueBuilder=*/ + auto resOp = cir::TernaryOp::create( + builder, loc, lhsCondV, /*trueBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { CIRGenFunction::LexicalScope lexScope{cgf, loc, b.getInsertionBlock()}; @@ -1566,8 +1568,9 @@ static mlir::Value emitPointerArithmetic(CIRGenFunction &cgf, } assert(!cir::MissingFeatures::sanitizers()); - return cgf.getBuilder().create<cir::PtrStrideOp>( - cgf.getLoc(op.e->getExprLoc()), pointer.getType(), pointer, index); + return cir::PtrStrideOp::create(cgf.getBuilder(), + cgf.getLoc(op.e->getExprLoc()), + pointer.getType(), pointer, index); } mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &ops) { @@ -1609,19 +1612,19 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &ops) { return nullptr; } - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Mul, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::Mul, + ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitDiv(const BinOpInfo &ops) { - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Div, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::Div, + ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitRem(const BinOpInfo &ops) { - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Rem, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::Rem, + ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &ops) { @@ -1668,8 +1671,8 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &ops) { return {}; } - return builder.create<cir::BinOp>(loc, cgf.convertType(ops.fullType), - cir::BinOpKind::Add, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, loc, cgf.convertType(ops.fullType), + cir::BinOpKind::Add, ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &ops) { @@ -1716,9 +1719,9 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &ops) { return {}; } - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Sub, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), + cir::BinOpKind::Sub, ops.lhs, ops.rhs); } // If the RHS is not a pointer, then we have normal pointer @@ -1796,19 +1799,19 @@ mlir::Value ScalarExprEmitter::emitShr(const BinOpInfo &ops) { } mlir::Value ScalarExprEmitter::emitAnd(const BinOpInfo &ops) { - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::And, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::And, + ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitXor(const BinOpInfo &ops) { - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Xor, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::Xor, + ops.lhs, ops.rhs); } mlir::Value ScalarExprEmitter::emitOr(const BinOpInfo &ops) { - return builder.create<cir::BinOp>(cgf.getLoc(ops.loc), - cgf.convertType(ops.fullType), - cir::BinOpKind::Or, ops.lhs, ops.rhs); + return cir::BinOp::create(builder, cgf.getLoc(ops.loc), + cgf.convertType(ops.fullType), cir::BinOpKind::Or, + ops.lhs, ops.rhs); } // Emit code for an explicit or implicit cast. Implicit @@ -2011,9 +2014,9 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) { case CK_VectorSplat: { // Create a vector object and fill all elements with the same scalar value. assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type"); - return builder.create<cir::VecSplatOp>( - cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy), - Visit(subExpr)); + return cir::VecSplatOp::create(builder, + cgf.getLoc(subExpr->getSourceRange()), + cgf.convertType(destTy), Visit(subExpr)); } case CK_FunctionToPointerDecay: return cgf.emitLValue(subExpr).getPointer(); @@ -2073,8 +2076,9 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *e) { vectorType.getSize() - numInitElements, zeroValue); } - return cgf.getBuilder().create<cir::VecCreateOp>( - cgf.getLoc(e->getSourceRange()), vectorType, elements); + return cir::VecCreateOp::create(cgf.getBuilder(), + cgf.getLoc(e->getSourceRange()), vectorType, + elements); } // C++11 value-initialization for the scalar. @@ -2310,8 +2314,8 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator( mlir::Value condValue = Visit(condExpr); mlir::Value lhsValue = Visit(lhsExpr); mlir::Value rhsValue = Visit(rhsExpr); - return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue, - rhsValue); + return cir::VecTernaryOp::create(builder, loc, condValue, lhsValue, + rhsValue); } // If this is a really simple expression (like x ? 4 : 5), emit this as a @@ -2354,7 +2358,7 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator( if (branch) { yieldTy = branch.getType(); - b.create<cir::YieldOp>(loc, branch); + cir::YieldOp::create(b, loc, branch); } else { // If LHS or RHS is a throw or void expression we need to patch // arms as to properly match yield types. @@ -2362,17 +2366,16 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator( } }; - mlir::Value result = builder - .create<cir::TernaryOp>( - loc, condV, - /*trueBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - emitBranch(b, loc, lhsExpr); - }, - /*falseBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - emitBranch(b, loc, rhsExpr); - }) + mlir::Value result = cir::TernaryOp::create( + builder, loc, condV, + /*trueBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + emitBranch(b, loc, lhsExpr); + }, + /*falseBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + emitBranch(b, loc, rhsExpr); + }) .getResult(); if (!insertPoints.empty()) { @@ -2387,10 +2390,10 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator( // Block does not return: build empty yield. if (mlir::isa<cir::VoidType>(yieldTy)) { - builder.create<cir::YieldOp>(loc); + cir::YieldOp::create(builder, loc); } else { // Block returns: set null yield value. mlir::Value op0 = builder.getNullValue(yieldTy, loc); - builder.create<cir::YieldOp>(loc, op0); + cir::YieldOp::create(builder, loc, op0); } } } diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index d3c0d9f..58feb36 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -264,11 +264,11 @@ void CIRGenFunction::LexicalScope::cleanup() { // If we now have one after `applyCleanup`, hook it up properly. if (!cleanupBlock && localScope->getCleanupBlock(builder)) { cleanupBlock = localScope->getCleanupBlock(builder); - builder.create<cir::BrOp>(insPt->back().getLoc(), cleanupBlock); + cir::BrOp::create(builder, insPt->back().getLoc(), cleanupBlock); if (!cleanupBlock->mightHaveTerminator()) { mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToEnd(cleanupBlock); - builder.create<cir::YieldOp>(localScope->endLoc); + cir::YieldOp::create(builder, localScope->endLoc); } } @@ -286,7 +286,7 @@ void CIRGenFunction::LexicalScope::cleanup() { } } - builder.create<cir::BrOp>(*returnLoc, returnBlock); + cir::BrOp::create(builder, *returnLoc, returnBlock); return; } } @@ -298,8 +298,8 @@ void CIRGenFunction::LexicalScope::cleanup() { // Ternary ops have to deal with matching arms for yielding types // and do return a value, it must do its own cir.yield insertion. if (!localScope->isTernary() && !insPt->mightHaveTerminator()) { - !retVal ? builder.create<cir::YieldOp>(localScope->endLoc) - : builder.create<cir::YieldOp>(localScope->endLoc, retVal); + !retVal ? cir::YieldOp::create(builder, localScope->endLoc) + : cir::YieldOp::create(builder, localScope->endLoc, retVal); } }; @@ -331,7 +331,7 @@ void CIRGenFunction::LexicalScope::cleanup() { // If there's a cleanup block, branch to it, nothing else to do. if (cleanupBlock) { - builder.create<cir::BrOp>(curBlock->back().getLoc(), cleanupBlock); + cir::BrOp::create(builder, curBlock->back().getLoc(), cleanupBlock); return; } @@ -349,12 +349,12 @@ cir::ReturnOp CIRGenFunction::LexicalScope::emitReturn(mlir::Location loc) { assert(fn && "emitReturn from non-function"); if (!fn.getFunctionType().hasVoidReturn()) { // Load the value from `__retval` and return it via the `cir.return` op. - auto value = builder.create<cir::LoadOp>( - loc, fn.getFunctionType().getReturnType(), *cgf.fnRetAlloca); - return builder.create<cir::ReturnOp>(loc, - llvm::ArrayRef(value.getResult())); + auto value = cir::LoadOp::create( + builder, loc, fn.getFunctionType().getReturnType(), *cgf.fnRetAlloca); + return cir::ReturnOp::create(builder, loc, + llvm::ArrayRef(value.getResult())); } - return builder.create<cir::ReturnOp>(loc); + return cir::ReturnOp::create(builder, loc); } // This is copied from CodeGenModule::MayDropFunctionReturn. This is a @@ -389,9 +389,9 @@ void CIRGenFunction::LexicalScope::emitImplicitReturn() { if (shouldEmitUnreachable) { assert(!cir::MissingFeatures::sanitizers()); if (cgf.cgm.getCodeGenOpts().OptimizationLevel == 0) - builder.create<cir::TrapOp>(localScope->endLoc); + cir::TrapOp::create(builder, localScope->endLoc); else - builder.create<cir::UnreachableOp>(localScope->endLoc); + cir::UnreachableOp::create(builder, localScope->endLoc); builder.clearInsertionPoint(); return; } @@ -561,8 +561,8 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl gd, cir::FuncOp fn, if (!clone) { mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(fn); - clone = builder.create<cir::FuncOp>(fn.getLoc(), fdInlineName, - fn.getFunctionType()); + clone = cir::FuncOp::create(builder, fn.getLoc(), fdInlineName, + fn.getFunctionType()); clone.setLinkage(cir::GlobalLinkageKind::InternalLinkage); clone.setSymVisibility("private"); clone.setInlineKind(cir::InlineKind::AlwaysInline); diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index e3b9b6a..5f9dbdc 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1461,6 +1461,9 @@ public: void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty); + mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee, + llvm::ArrayRef<mlir::Value> args = {}); + /// Emit the computation of the specified expression of scalar type. mlir::Value emitScalarExpr(const clang::Expr *e); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index e620310..f7c4d18 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -1809,8 +1809,8 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject base, mlir::OpBuilder &builder = cgm.getBuilder(); auto vtablePtrTy = cir::VPtrType::get(builder.getContext()); - return builder.create<cir::VTableAddrPointOp>( - cgm.getLoc(vtableClass->getSourceRange()), vtablePtrTy, + return cir::VTableAddrPointOp::create( + builder, cgm.getLoc(vtableClass->getSourceRange()), vtablePtrTy, mlir::FlatSymbolRefAttr::get(vtable.getSymNameAttr()), cir::AddressPointAttr::get(cgm.getBuilder().getContext(), addressPoint.VTableIndex, @@ -1874,6 +1874,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) { return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast"); } +static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) { + // TODO(cir): set the calling convention to the runtime function. + assert(!cir::MissingFeatures::opFuncCallingConv()); + + cgf.emitRuntimeCall(loc, getBadCastFn(cgf)); + cir::UnreachableOp::create(cgf.getBuilder(), loc); + cgf.getBuilder().clearInsertionPoint(); +} + // TODO(cir): This could be shared with classic codegen. static CharUnits computeOffsetHint(ASTContext &astContext, const CXXRecordDecl *src, @@ -1959,6 +1968,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc, return Address{ptr, src.getAlignment()}; } +static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi, + CIRGenFunction &cgf, mlir::Location loc, + QualType srcRecordTy, + QualType destRecordTy, + cir::PointerType destCIRTy, + bool isRefCast, Address src) { + // Find all the inheritance paths from SrcRecordTy to DestRecordTy. + const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl(); + const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl(); + CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true, + /*DetectVirtual=*/false); + (void)destDecl->isDerivedFrom(srcDecl, paths); + + // Find an offset within `destDecl` where a `srcDecl` instance and its vptr + // might appear. + std::optional<CharUnits> offset; + for (const CXXBasePath &path : paths) { + // dynamic_cast only finds public inheritance paths. + if (path.Access != AS_public) + continue; + + CharUnits pathOffset; + for (const CXXBasePathElement &pathElement : path) { + // Find the offset along this inheritance step. + const CXXRecordDecl *base = + pathElement.Base->getType()->getAsCXXRecordDecl(); + if (pathElement.Base->isVirtual()) { + // For a virtual base class, we know that the derived class is exactly + // destDecl, so we can use the vbase offset from its layout. + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(destDecl); + pathOffset = layout.getVBaseClassOffset(base); + } else { + const ASTRecordLayout &layout = + cgf.getContext().getASTRecordLayout(pathElement.Class); + pathOffset += layout.getBaseClassOffset(base); + } + } + + if (!offset) { + offset = pathOffset; + } else if (offset != pathOffset) { + // base appears in at least two different places. Find the most-derived + // object and see if it's a DestDecl. Note that the most-derived object + // must be at least as aligned as this base class subobject, and must + // have a vptr at offset 0. + src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src); + srcDecl = destDecl; + offset = CharUnits::Zero(); + break; + } + } + + CIRGenBuilderTy &builder = cgf.getBuilder(); + + if (!offset) { + // If there are no public inheritance paths, the cast always fails. + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + if (isRefCast) { + mlir::Region *currentRegion = builder.getBlock()->getParent(); + emitCallToBadCast(cgf, loc); + + // The call to bad_cast will terminate the block. Create a new block to + // hold any follow up code. + builder.createBlock(currentRegion, currentRegion->end()); + } + + return nullPtrValue; + } + + // Compare the vptr against the expected vptr for the destination type at + // this offset. Note that we do not know what type src points to in the case + // where the derived class multiply inherits from the base class so we can't + // use getVTablePtr, so we load the vptr directly instead. + + mlir::Value expectedVPtr = + abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type vptrTy = expectedVPtr.getType(); + mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy); + Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy), + src.getAlignment()); + mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr); + + // TODO(cir): decorate SrcVPtr with TBAA info. + assert(!cir::MissingFeatures::opTBAA()); + + mlir::Value success = + builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr); + + auto emitCastResult = [&] { + if (offset->isZero()) + return builder.createBitcast(src.getPointer(), destCIRTy); + + // TODO(cir): handle address space here. + assert(!cir::MissingFeatures::addressSpace()); + mlir::Type u8PtrTy = builder.getUInt8PtrTy(); + + mlir::Value strideToApply = + builder.getConstInt(loc, builder.getUInt64Ty(), -offset->getQuantity()); + mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy); + mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy, + srcU8Ptr, strideToApply); + return builder.createBitcast(resultU8Ptr, destCIRTy); + }; + + if (isRefCast) { + mlir::Value failed = builder.createNot(success); + cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false, + [&](mlir::OpBuilder &, mlir::Location) { + emitCallToBadCast(cgf, loc); + }); + return emitCastResult(); + } + + return cir::TernaryOp::create( + builder, loc, success, + [&](mlir::OpBuilder &, mlir::Location) { + auto result = emitCastResult(); + builder.createYield(loc, result); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc); + builder.createYield(loc, nullPtrValue); + }) + .getResult(); +} + static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf, mlir::Location loc, QualType srcRecordTy, @@ -2000,8 +2139,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf, // if the dynamic type of the pointer is exactly the destination type. if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() && cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) { - cgm.errorNYI(loc, "emitExactDynamicCast"); - return {}; + CIRGenBuilderTy &builder = cgf.getBuilder(); + // If this isn't a reference cast, check the pointer to see if it's null. + if (!isRefCast) { + mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer()); + return cir::TernaryOp::create( + builder, loc, srcPtrIsNull, + [&](mlir::OpBuilder, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(destCIRTy, loc).getResult()); + }, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value exactCast = emitExactDynamicCast( + *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy, + isRefCast, src); + builder.createYield(loc, exactCast); + }) + .getResult(); + } + + return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy, + destCIRTy, isRefCast, src); } cir::DynamicCastInfoAttr castInfo = diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp index 6b29373..46adfe2 100644 --- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp @@ -535,7 +535,7 @@ cir::GlobalOp CIRGenModule::createGlobalOp(CIRGenModule &cgm, builder.setInsertionPointToStart(cgm.getModule().getBody()); } - g = builder.create<cir::GlobalOp>(loc, name, t, isConstant); + g = cir::GlobalOp::create(builder, loc, name, t, isConstant); if (!insertPoint) cgm.lastGlobalOp = g; @@ -739,8 +739,8 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *d, mlir::Type ty, cir::GlobalOp g = getOrCreateCIRGlobal(d, ty, isForDefinition); mlir::Type ptrTy = builder.getPointerTo(g.getSymType()); - return builder.create<cir::GetGlobalOp>(getLoc(d->getSourceRange()), ptrTy, - g.getSymName()); + return cir::GetGlobalOp::create(builder, getLoc(d->getSourceRange()), ptrTy, + g.getSymName()); } cir::GlobalViewAttr CIRGenModule::getAddrOfGlobalVarAttr(const VarDecl *d) { @@ -2176,7 +2176,7 @@ CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name, if (cgf) builder.setInsertionPoint(cgf->curFn); - func = builder.create<cir::FuncOp>(loc, name, funcType); + func = cir::FuncOp::create(builder, loc, name, funcType); assert(!cir::MissingFeatures::opFuncAstDeclAttr()); diff --git a/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp b/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp index 5ba6bcb..e7bf3bc 100644 --- a/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenOpenACC.cpp @@ -27,8 +27,8 @@ mlir::Value createBound(CIRGenFunction &cgf, CIRGen::CIRGenBuilderTy &builder, // Stride is always 1 in C/C++. mlir::Value stride = cgf.createOpenACCConstantInt(boundLoc, 64, 1); - auto bound = - builder.create<mlir::acc::DataBoundsOp>(boundLoc, lowerBound, upperBound); + auto bound = mlir::acc::DataBoundsOp::create(builder, boundLoc, lowerBound, + upperBound); bound.getStartIdxMutable().assign(startIdx); if (extent) bound.getExtentMutable().assign(extent); @@ -48,8 +48,8 @@ mlir::Value CIRGenFunction::emitOpenACCIntExpr(const Expr *intExpr) { ? mlir::IntegerType::SignednessSemantics::Signed : mlir::IntegerType::SignednessSemantics::Unsigned); - auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>( - exprLoc, targetType, expr); + auto conversionOp = mlir::UnrealizedConversionCastOp::create( + builder, exprLoc, targetType, expr); return conversionOp.getResult(0); } @@ -59,8 +59,8 @@ mlir::Value CIRGenFunction::createOpenACCConstantInt(mlir::Location loc, mlir::IntegerType ty = mlir::IntegerType::get(&getMLIRContext(), width, mlir::IntegerType::SignednessSemantics::Signless); - auto constOp = builder.create<mlir::arith::ConstantOp>( - loc, builder.getIntegerAttr(ty, value)); + auto constOp = mlir::arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(ty, value)); return constOp; } diff --git a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp index 385f89c..5010137 100644 --- a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp @@ -96,8 +96,8 @@ class OpenACCClauseCIREmitter final mlir::IntegerType targetType = mlir::IntegerType::get( &cgf.getMLIRContext(), /*width=*/1, mlir::IntegerType::SignednessSemantics::Signless); - auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>( - exprLoc, targetType, condition); + auto conversionOp = mlir::UnrealizedConversionCastOp::create( + builder, exprLoc, targetType, condition); return conversionOp.getResult(0); } @@ -107,8 +107,8 @@ class OpenACCClauseCIREmitter final mlir::IntegerType ty = mlir::IntegerType::get( &cgf.getMLIRContext(), width, mlir::IntegerType::SignednessSemantics::Signless); - auto constOp = builder.create<mlir::arith::ConstantOp>( - loc, builder.getIntegerAttr(ty, value)); + auto constOp = mlir::arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(ty, value)); return constOp; } @@ -217,8 +217,8 @@ class OpenACCClauseCIREmitter final cgf.getOpenACCDataOperandInfo(varOperand); auto beforeOp = - builder.create<BeforeOpTy>(opInfo.beginLoc, opInfo.varValue, structured, - implicit, opInfo.name, opInfo.bounds); + BeforeOpTy::create(builder, opInfo.beginLoc, opInfo.varValue, + structured, implicit, opInfo.name, opInfo.bounds); operation.getDataClauseOperandsMutable().append(beforeOp.getResult()); AfterOpTy afterOp; @@ -231,12 +231,12 @@ class OpenACCClauseCIREmitter final // Detach/Delete ops don't have the variable reference here, so they // take 1 fewer argument to their build function. afterOp = - builder.create<AfterOpTy>(opInfo.beginLoc, beforeOp, structured, - implicit, opInfo.name, opInfo.bounds); + AfterOpTy::create(builder, opInfo.beginLoc, beforeOp, structured, + implicit, opInfo.name, opInfo.bounds); } else { - afterOp = builder.create<AfterOpTy>( - opInfo.beginLoc, beforeOp, opInfo.varValue, structured, implicit, - opInfo.name, opInfo.bounds); + afterOp = AfterOpTy::create(builder, opInfo.beginLoc, beforeOp, + opInfo.varValue, structured, implicit, + opInfo.name, opInfo.bounds); } } @@ -258,8 +258,8 @@ class OpenACCClauseCIREmitter final CIRGenFunction::OpenACCDataOperandInfo opInfo = cgf.getOpenACCDataOperandInfo(varOperand); auto beforeOp = - builder.create<BeforeOpTy>(opInfo.beginLoc, opInfo.varValue, structured, - implicit, opInfo.name, opInfo.bounds); + BeforeOpTy::create(builder, opInfo.beginLoc, opInfo.varValue, + structured, implicit, opInfo.name, opInfo.bounds); operation.getDataClauseOperandsMutable().append(beforeOp.getResult()); // Set the 'rest' of the info for the operation. diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index f486c46..1eb7199 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -91,8 +91,9 @@ mlir::LogicalResult CIRGenFunction::emitCompoundStmt(const CompoundStmt &s, SymTableScopeTy varScope(symbolTable); mlir::Location scopeLoc = getLoc(s.getSourceRange()); mlir::OpBuilder::InsertPoint scopeInsPt; - builder.create<cir::ScopeOp>( - scopeLoc, [&](mlir::OpBuilder &b, mlir::Type &type, mlir::Location loc) { + cir::ScopeOp::create( + builder, scopeLoc, + [&](mlir::OpBuilder &b, mlir::Type &type, mlir::Location loc) { scopeInsPt = b.saveInsertionPoint(); }); mlir::OpBuilder::InsertionGuard guard(builder); @@ -423,12 +424,12 @@ mlir::LogicalResult CIRGenFunction::emitIfStmt(const IfStmt &s) { // LexicalScope ConditionScope(*this, S.getCond()->getSourceRange()); // The if scope contains the full source range for IfStmt. mlir::Location scopeLoc = getLoc(s.getSourceRange()); - builder.create<cir::ScopeOp>( - scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - LexicalScope lexScope{*this, scopeLoc, builder.getInsertionBlock()}; - res = ifStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, scopeLoc, + builder.getInsertionBlock()}; + res = ifStmtBuilder(); + }); return res; } @@ -576,11 +577,11 @@ mlir::LogicalResult CIRGenFunction::emitLabel(const clang::LabelDecl &d) { mlir::OpBuilder::InsertionGuard guard(builder); labelBlock = builder.createBlock(builder.getBlock()->getParent()); } - builder.create<cir::BrOp>(getLoc(d.getSourceRange()), labelBlock); + cir::BrOp::create(builder, getLoc(d.getSourceRange()), labelBlock); } builder.setInsertionPointToEnd(labelBlock); - builder.create<cir::LabelOp>(getLoc(d.getSourceRange()), d.getName()); + cir::LabelOp::create(builder, getLoc(d.getSourceRange()), d.getName()); builder.setInsertionPointToEnd(labelBlock); // FIXME: emit debug info for labels, incrementProfileCounter @@ -617,7 +618,7 @@ CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType, const Stmt *sub = stmt->getSubStmt(); mlir::OpBuilder::InsertPoint insertPoint; - builder.create<CaseOp>(loc, value, kind, insertPoint); + CaseOp::create(builder, loc, value, kind, insertPoint); { mlir::OpBuilder::InsertionGuard guardSwitch(builder); @@ -789,16 +790,16 @@ CIRGenFunction::emitCXXForRangeStmt(const CXXForRangeStmt &s, mlir::LogicalResult res = mlir::success(); mlir::Location scopeLoc = getLoc(s.getSourceRange()); - builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - // Create a cleanup scope for the condition - // variable cleanups. Logical equivalent from - // LLVM codegn for LexicalScope - // ConditionScope(*this, S.getSourceRange())... - LexicalScope lexScope{ - *this, loc, builder.getInsertionBlock()}; - res = forStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + // Create a cleanup scope for the condition + // variable cleanups. Logical equivalent from + // LLVM codegn for LexicalScope + // ConditionScope(*this, S.getSourceRange())... + LexicalScope lexScope{*this, loc, + builder.getInsertionBlock()}; + res = forStmtBuilder(); + }); if (res.failed()) return res; @@ -841,7 +842,7 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) { // scalar type. condVal = evaluateExprAsBool(s.getCond()); } else { - condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr()); + condVal = cir::ConstantOp::create(b, loc, builder.getTrueAttr()); } builder.createCondition(condVal); }, @@ -865,12 +866,12 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) { auto res = mlir::success(); auto scopeLoc = getLoc(s.getSourceRange()); - builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - LexicalScope lexScope{ - *this, loc, builder.getInsertionBlock()}; - res = forStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, loc, + builder.getInsertionBlock()}; + res = forStmtBuilder(); + }); if (res.failed()) return res; @@ -916,12 +917,12 @@ mlir::LogicalResult CIRGenFunction::emitDoStmt(const DoStmt &s) { mlir::LogicalResult res = mlir::success(); mlir::Location scopeLoc = getLoc(s.getSourceRange()); - builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - LexicalScope lexScope{ - *this, loc, builder.getInsertionBlock()}; - res = doStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, loc, + builder.getInsertionBlock()}; + res = doStmtBuilder(); + }); if (res.failed()) return res; @@ -972,12 +973,12 @@ mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) { mlir::LogicalResult res = mlir::success(); mlir::Location scopeLoc = getLoc(s.getSourceRange()); - builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - LexicalScope lexScope{ - *this, loc, builder.getInsertionBlock()}; - res = whileStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, loc, + builder.getInsertionBlock()}; + res = whileStmtBuilder(); + }); if (res.failed()) return res; @@ -1048,8 +1049,8 @@ mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) { assert(!cir::MissingFeatures::insertBuiltinUnpredictable()); mlir::LogicalResult res = mlir::success(); - swop = builder.create<SwitchOp>( - getLoc(s.getBeginLoc()), condV, + swop = SwitchOp::create( + builder, getLoc(s.getBeginLoc()), condV, /*switchBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) { curLexScope->setAsSwitch(); @@ -1067,12 +1068,12 @@ mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) { // The switch scope contains the full source range for SwitchStmt. mlir::Location scopeLoc = getLoc(s.getSourceRange()); mlir::LogicalResult res = mlir::success(); - builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/ - [&](mlir::OpBuilder &b, mlir::Location loc) { - LexicalScope lexScope{ - *this, loc, builder.getInsertionBlock()}; - res = switchStmtBuilder(); - }); + cir::ScopeOp::create(builder, scopeLoc, /*scopeBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + LexicalScope lexScope{*this, loc, + builder.getInsertionBlock()}; + res = switchStmtBuilder(); + }); llvm::SmallVector<CaseOp> cases; swop.collectCases(cases); @@ -1096,7 +1097,7 @@ void CIRGenFunction::emitReturnOfRValue(mlir::Location loc, RValue rv, } mlir::Block *retBlock = curLexScope->getOrCreateRetBlock(*this, loc); assert(!cir::MissingFeatures::emitBranchThroughCleanup()); - builder.create<cir::BrOp>(loc, retBlock); + cir::BrOp::create(builder, loc, retBlock); if (ehStack.stable_begin() != currentCleanupStackDepth) cgm.errorNYI(loc, "return of r-value with cleanup stack"); } diff --git a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp index 02bb46d..77e6f83 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp @@ -30,7 +30,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( llvm::SmallVector<mlir::Type> retTy; llvm::SmallVector<mlir::Value> operands; - auto op = builder.create<Op>(start, retTy, operands); + auto op = Op::create(builder, start, retTy, operands); emitOpenACCClauses(op, dirKind, dirLoc, clauses); @@ -42,7 +42,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( LexicalScope ls{*this, start, builder.getInsertionBlock()}; res = emitStmt(associatedStmt, /*useCurrentScope=*/true); - builder.create<TermOp>(end); + TermOp::create(builder, end); } return res; } @@ -73,7 +73,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct( llvm::SmallVector<mlir::Type> retTy; llvm::SmallVector<mlir::Value> operands; - auto computeOp = builder.create<Op>(start, retTy, operands); + auto computeOp = Op::create(builder, start, retTy, operands); computeOp.setCombinedAttr(builder.getUnitAttr()); mlir::acc::LoopOp loopOp; @@ -85,7 +85,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct( builder.setInsertionPointToEnd(&block); LexicalScope ls{*this, start, builder.getInsertionBlock()}; - auto loopOp = builder.create<LoopOp>(start, retTy, operands); + auto loopOp = LoopOp::create(builder, start, retTy, operands); loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get( builder.getContext(), CombinedType<Op>::value)); @@ -99,14 +99,14 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct( res = emitStmt(loopStmt, /*useCurrentScope=*/true); - builder.create<mlir::acc::YieldOp>(end); + mlir::acc::YieldOp::create(builder, end); } emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses); updateLoopOpParallelism(loopOp, /*isOrphan=*/false, dirKind); - builder.create<TermOp>(end); + TermOp::create(builder, end); } return res; @@ -118,7 +118,7 @@ Op CIRGenFunction::emitOpenACCOp( llvm::ArrayRef<const OpenACCClause *> clauses) { llvm::SmallVector<mlir::Type> retTy; llvm::SmallVector<mlir::Value> operands; - auto op = builder.create<Op>(start, retTy, operands); + auto op = Op::create(builder, start, retTy, operands); emitOpenACCClauses(op, dirKind, dirLoc, clauses); return op; @@ -197,8 +197,8 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) { ? mlir::IntegerType::SignednessSemantics::Signed : mlir::IntegerType::SignednessSemantics::Unsigned); - auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>( - exprLoc, targetType, expr); + auto conversionOp = mlir::UnrealizedConversionCastOp::create( + builder, exprLoc, targetType, expr); return conversionOp.getResult(0); }; @@ -294,9 +294,9 @@ CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) { CIRGenFunction::OpenACCDataOperandInfo opInfo = getOpenACCDataOperandInfo(var); - auto cacheOp = builder.create<CacheOp>( - opInfo.beginLoc, opInfo.varValue, - /*structured=*/false, /*implicit=*/false, opInfo.name, opInfo.bounds); + auto cacheOp = CacheOp::create(builder, opInfo.beginLoc, opInfo.varValue, + /*structured=*/false, /*implicit=*/false, + opInfo.name, opInfo.bounds); loopOp.getCacheOperandsMutable().append(cacheOp.getResult()); } diff --git a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACCLoop.cpp b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACCLoop.cpp index f3911ae..c5b89bd 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACCLoop.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACCLoop.cpp @@ -58,7 +58,7 @@ CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) { mlir::Location end = getLoc(s.getSourceRange().getEnd()); llvm::SmallVector<mlir::Type> retTy; llvm::SmallVector<mlir::Value> operands; - auto op = builder.create<LoopOp>(start, retTy, operands); + auto op = LoopOp::create(builder, start, retTy, operands); // TODO(OpenACC): In the future we are going to need to come up with a // transformation here that can teach the acc.loop how to figure out the @@ -133,7 +133,7 @@ CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) { ActiveOpenACCLoopRAII activeLoop{*this, &op}; stmtRes = emitStmt(s.getLoop(), /*useCurrentScope=*/true); - builder.create<mlir::acc::YieldOp>(end); + mlir::acc::YieldOp::create(builder, end); } return stmtRes; diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index fa180f5..2d2ef42 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -95,8 +95,8 @@ Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Attribute value, mlir::Type type, mlir::Location loc) { - return builder.create<cir::ConstantOp>(loc, type, - mlir::cast<mlir::TypedAttr>(value)); + return cir::ConstantOp::create(builder, loc, type, + mlir::cast<mlir::TypedAttr>(value)); } //===----------------------------------------------------------------------===// @@ -184,7 +184,7 @@ static LogicalResult ensureRegionTerm(OpAsmParser &parser, Region ®ion, // Terminator was omitted correctly: recreate it. builder.setInsertionPointToEnd(&block); - builder.create<cir::YieldOp>(eLoc); + cir::YieldOp::create(builder, eLoc); return success(); } @@ -977,7 +977,7 @@ void cir::IfOp::print(OpAsmPrinter &p) { /// Default callback for IfOp builders. void cir::buildTerminatedBody(OpBuilder &builder, Location loc) { // add cir.yield to end of the block - builder.create<cir::YieldOp>(loc); + cir::YieldOp::create(builder, loc); } /// Given the region at `index`, or the parent operation if `index` is None, diff --git a/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp b/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp index 7e96ae9..66469e2 100644 --- a/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRMemorySlot.cpp @@ -34,8 +34,8 @@ llvm::SmallVector<MemorySlot> cir::AllocaOp::getPromotableSlots() { Value cir::AllocaOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) { - return builder.create<cir::ConstantOp>(getLoc(), - cir::UndefAttr::get(slot.elemType)); + return cir::ConstantOp::create(builder, getLoc(), + cir::UndefAttr::get(slot.elemType)); } void cir::AllocaOp::handleBlockArgument(const MemorySlot &slot, diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 46bd186..21c96fe 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -100,8 +100,8 @@ struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> { } rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody, - elseBeforeBody); + cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody, + elseBeforeBody); if (!emptyElse) { rewriter.setInsertionPointToEnd(elseAfterBody); @@ -154,7 +154,7 @@ public: // Save stack and then branch into the body of the region. rewriter.setInsertionPointToEnd(currentBlock); assert(!cir::MissingFeatures::stackSaveOp()); - rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody); + cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody); // Replace the scopeop return with a branch that jumps out of the body. // Stack restore before leaving the body region. @@ -195,26 +195,27 @@ public: cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); - cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( - op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + cir::ConstantOp rangeLength = cir::ConstantOp::create( + rewriter, op.getLoc(), + cir::IntAttr::get(sIntType, upperBound - lowerBound)); - cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( - op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::ConstantOp lowerBoundValue = cir::ConstantOp::create( + rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); cir::BinOp diffValue = - rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, - op.getCondition(), lowerBoundValue); + cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); // Use unsigned comparison to check if the condition is in the range. - cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( - op.getLoc(), uIntType, CastKind::integral, diffValue); - cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( - op.getLoc(), uIntType, CastKind::integral, rangeLength); - - cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( - op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, - uDiffValue, uRangeLength); - rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, - defaultDestination); + cir::CastOp uDiffValue = cir::CastOp::create( + rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = cir::CastOp::create( + rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = cir::CmpOp::create( + rewriter, op.getLoc(), cir::BoolType::get(op.getContext()), + cir::CmpOpKind::le, uDiffValue, uRangeLength); + cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination, + defaultDestination); return resBlock; } @@ -262,7 +263,7 @@ public: rewriteYieldOp(rewriter, switchYield, exitBlock); rewriter.setInsertionPointToEnd(originalBlock); - rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + cir::BrOp::create(rewriter, op.getLoc(), swopBlock); } // Allocate required data structures (disconsider default case in @@ -331,8 +332,8 @@ public: mlir::Block *newBlock = rewriter.splitBlock(oldBlock, nextOp->getIterator()); rewriter.setInsertionPointToEnd(oldBlock); - rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), - newBlock); + cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(), + newBlock); rewriteYieldOp(rewriter, yieldOp, newBlock); } } @@ -346,7 +347,7 @@ public: // Create a branch to the entry of the inlined region. rewriter.setInsertionPointToEnd(oldBlock); - rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); + cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock); } // Remove all cases since we've inlined the regions. @@ -427,7 +428,7 @@ public: // Setup loop entry branch. rewriter.setInsertionPointToEnd(entry); - rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front()); + cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front()); // Branch from condition region to body or exit. auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator()); @@ -499,7 +500,7 @@ public: locs.push_back(loc); Block *continueBlock = rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); - rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + cir::BrOp::create(rewriter, loc, remainingOpsBlock); Region &trueRegion = op.getTrueRegion(); Block *trueBlock = &trueRegion.front(); @@ -542,7 +543,7 @@ public: rewriter.inlineRegionBefore(falseRegion, continueBlock); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); + cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock); rewriter.replaceOp(op, continueBlock->getArguments()); diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index d99c362..cba0464 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -155,7 +155,7 @@ cir::FuncOp LoweringPreparePass::buildRuntimeFunction( cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom( mlirModule, StringAttr::get(mlirModule->getContext(), name))); if (!f) { - f = builder.create<cir::FuncOp>(loc, name, type); + f = cir::FuncOp::create(builder, loc, name, type); f.setLinkageAttr( cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); mlir::SymbolTable::setSymbolVisibility( @@ -400,12 +400,12 @@ buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, builder.createYield(loc, result); }; - auto cFabs = builder.create<cir::FAbsOp>(loc, c); - auto dFabs = builder.create<cir::FAbsOp>(loc, d); + auto cFabs = cir::FAbsOp::create(builder, loc, c); + auto dFabs = cir::FAbsOp::create(builder, loc, d); cir::CmpOp cmpResult = builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs); - auto ternary = builder.create<cir::TernaryOp>( - loc, cmpResult, trueBranchBuilder, falseBranchBuilder); + auto ternary = cir::TernaryOp::create(builder, loc, cmpResult, + trueBranchBuilder, falseBranchBuilder); return ternary.getResult(); } @@ -612,18 +612,17 @@ static mlir::Value lowerComplexMul(LoweringPreparePass &pass, mlir::Value resultRealAndImagAreNaN = builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); - return builder - .create<cir::TernaryOp>( - loc, resultRealAndImagAreNaN, - [&](mlir::OpBuilder &, mlir::Location) { - mlir::Value libCallResult = buildComplexBinOpLibCall( - pass, builder, &getComplexMulLibCallName, loc, complexTy, - lhsReal, lhsImag, rhsReal, rhsImag); - builder.createYield(loc, libCallResult); - }, - [&](mlir::OpBuilder &, mlir::Location) { - builder.createYield(loc, algebraicResult); - }) + return cir::TernaryOp::create( + builder, loc, resultRealAndImagAreNaN, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value libCallResult = buildComplexBinOpLibCall( + pass, builder, &getComplexMulLibCallName, loc, complexTy, + lhsReal, lhsImag, rhsReal, rhsImag); + builder.createYield(loc, libCallResult); + }, + [&](mlir::OpBuilder &, mlir::Location) { + builder.createYield(loc, algebraicResult); + }) .getResult(); } @@ -920,15 +919,15 @@ static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, loc, /*condBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { - auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); + auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr); mlir::Type boolTy = cir::BoolType::get(b.getContext()); - auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne, - currentElement, stop); + auto cmp = cir::CmpOp::create(builder, loc, boolTy, cir::CmpOpKind::ne, + currentElement, stop); builder.createCondition(cmp); }, /*bodyBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc) { - auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); + auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr); cir::CallOp ctorCall; op->walk([&](cir::CallOp c) { ctorCall = c; }); diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp index 11ce2a8..5a067f8 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepareItaniumCXXABI.cpp @@ -77,10 +77,11 @@ buildDynamicCastAfterNullCheck(cir::CIRBaseBuilderTy &builder, if (op.isRefCast()) { // Emit a cir.if that checks the casted value. mlir::Value castedValueIsNull = builder.createPtrIsNull(castedPtr); - builder.create<cir::IfOp>( - loc, castedValueIsNull, false, [&](mlir::OpBuilder &, mlir::Location) { - buildBadCastCall(builder, loc, castInfo.getBadCastFunc()); - }); + cir::IfOp::create(builder, loc, castedValueIsNull, false, + [&](mlir::OpBuilder &, mlir::Location) { + buildBadCastCall(builder, loc, + castInfo.getBadCastFunc()); + }); } // Note that castedPtr is a void*. Cast it to a pointer to the destination @@ -154,19 +155,19 @@ LoweringPrepareItaniumCXXABI::lowerDynamicCast(cir::CIRBaseBuilderTy &builder, return buildDynamicCastAfterNullCheck(builder, op); mlir::Value srcValueIsNotNull = builder.createPtrToBoolCast(srcValue); - return builder - .create<cir::TernaryOp>( - loc, srcValueIsNotNull, - [&](mlir::OpBuilder &, mlir::Location) { - mlir::Value castedValue = - op.isCastToVoid() - ? buildDynamicCastToVoidAfterNullCheck(builder, astCtx, op) - : buildDynamicCastAfterNullCheck(builder, op); - builder.createYield(loc, castedValue); - }, - [&](mlir::OpBuilder &, mlir::Location) { - builder.createYield( - loc, builder.getNullPtr(op.getType(), loc).getResult()); - }) + return cir::TernaryOp::create( + builder, loc, srcValueIsNotNull, + [&](mlir::OpBuilder &, mlir::Location) { + mlir::Value castedValue = + op.isCastToVoid() + ? buildDynamicCastToVoidAfterNullCheck(builder, astCtx, + op) + : buildDynamicCastAfterNullCheck(builder, op); + builder.createYield(loc, castedValue); + }, + [&](mlir::OpBuilder &, mlir::Location) { + builder.createYield( + loc, builder.getNullPtr(op.getType(), loc).getResult()); + }) .getResult(); } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index bb75f2d..a30ae02 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -90,12 +90,12 @@ static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src, mlir::Location loc = src.getLoc(); if (dstWidth > srcWidth && isSigned) - return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src); + return mlir::LLVM::SExtOp::create(bld, loc, dstTy, src); if (dstWidth > srcWidth) - return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src); + return mlir::LLVM::ZExtOp::create(bld, loc, dstTy, src); if (dstWidth < srcWidth) - return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src); - return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src); + return mlir::LLVM::TruncOp::create(bld, loc, dstTy, src); + return mlir::LLVM::BitcastOp::create(bld, loc, dstTy, src); } static mlir::LLVM::Visibility @@ -204,12 +204,12 @@ static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, auto loc = llvmSrc.getLoc(); if (cirSrcWidth < cirDstIntWidth) { if (isUnsigned) - return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc); - return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc); + return mlir::LLVM::ZExtOp::create(rewriter, loc, llvmDstIntTy, llvmSrc); + return mlir::LLVM::SExtOp::create(rewriter, loc, llvmDstIntTy, llvmSrc); } // Otherwise truncate - return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc); + return mlir::LLVM::TruncOp::create(rewriter, loc, llvmDstIntTy, llvmSrc); } class CIRAttrToValue { @@ -315,15 +315,17 @@ static mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( /// IntAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(intAttr.getType()), intAttr.getValue()); + return mlir::LLVM::ConstantOp::create( + rewriter, loc, converter->convertType(intAttr.getType()), + intAttr.getValue()); } /// FPAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); + return mlir::LLVM::ConstantOp::create( + rewriter, loc, converter->convertType(fltAttr.getType()), + fltAttr.getValue()); } /// ConstComplexAttr visitor. @@ -350,8 +352,8 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) { } mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(complexAttr.getType()), + return mlir::LLVM::ConstantOp::create( + rewriter, loc, converter->convertType(complexAttr.getType()), rewriter.getArrayAttr(components)); } @@ -359,15 +361,16 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) { mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { mlir::Location loc = parentOp->getLoc(); if (ptrAttr.isNullValue()) { - return rewriter.create<mlir::LLVM::ZeroOp>( - loc, converter->convertType(ptrAttr.getType())); + return mlir::LLVM::ZeroOp::create( + rewriter, loc, converter->convertType(ptrAttr.getType())); } mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); - mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( - loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), + mlir::Value ptrVal = mlir::LLVM::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), ptrAttr.getValue().getInt()); - return rewriter.create<mlir::LLVM::IntToPtrOp>( - loc, converter->convertType(ptrAttr.getType()), ptrVal); + return mlir::LLVM::IntToPtrOp::create( + rewriter, loc, converter->convertType(ptrAttr.getType()), ptrVal); } // ConstArrayAttr visitor @@ -378,10 +381,10 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { if (attr.hasTrailingZeros()) { mlir::Type arrayTy = attr.getType(); - result = rewriter.create<mlir::LLVM::ZeroOp>( - loc, converter->convertType(arrayTy)); + result = mlir::LLVM::ZeroOp::create(rewriter, loc, + converter->convertType(arrayTy)); } else { - result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); + result = mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy); } // Iteratively lower each constant element of the array. @@ -390,7 +393,7 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>()); mlir::Value init = visit(elt); result = - rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); + mlir::LLVM::InsertValueOp::create(rewriter, loc, result, init, idx); } } else if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) { // TODO(cir): this diverges from traditional lowering. Normally the string @@ -399,10 +402,10 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { assert(arrayTy && "String attribute must have an array type"); mlir::Type eltTy = arrayTy.getElementType(); for (auto [idx, elt] : llvm::enumerate(strAttr)) { - auto init = rewriter.create<mlir::LLVM::ConstantOp>( - loc, converter->convertType(eltTy), elt); + auto init = mlir::LLVM::ConstantOp::create( + rewriter, loc, converter->convertType(eltTy), elt); result = - rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); + mlir::LLVM::InsertValueOp::create(rewriter, loc, result, init, idx); } } else { llvm_unreachable("unexpected ConstArrayAttr elements"); @@ -415,12 +418,13 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstRecordAttr constRecord) { const mlir::Type llvmTy = converter->convertType(constRecord.getType()); const mlir::Location loc = parentOp->getLoc(); - mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); + mlir::Value result = mlir::LLVM::UndefOp::create(rewriter, loc, llvmTy); // Iteratively lower each constant element of the record. for (auto [idx, elt] : llvm::enumerate(constRecord.getMembers())) { mlir::Value init = visit(elt); - result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); + result = + mlir::LLVM::InsertValueOp::create(rewriter, loc, result, init, idx); } return result; @@ -447,8 +451,8 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { mlirValues.push_back(mlirAttr); } - return rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmTy, + return mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmTy, mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), mlirValues)); } @@ -483,8 +487,9 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) { } mlir::Location loc = parentOp->getLoc(); - mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>( - loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), symName); + mlir::Value addrOp = mlir::LLVM::AddressOfOp::create( + rewriter, loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), + symName); if (globalAttr.getIndices()) { llvm::SmallVector<mlir::LLVM::GEPArg> indices; @@ -499,8 +504,9 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) { } mlir::Type resTy = addrOp.getType(); mlir::Type eltTy = converter->convertType(sourceType); - addrOp = rewriter.create<mlir::LLVM::GEPOp>( - loc, resTy, eltTy, addrOp, indices, mlir::LLVM::GEPNoWrapFlags::none); + addrOp = + mlir::LLVM::GEPOp::create(rewriter, loc, resTy, eltTy, addrOp, indices, + mlir::LLVM::GEPNoWrapFlags::none); } // The incubator has handling here for the attribute having integer type, but @@ -517,8 +523,8 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) { return addrOp; mlir::Type llvmDstTy = converter->convertType(globalAttr.getType()); - return rewriter.create<mlir::LLVM::BitcastOp>(parentOp->getLoc(), llvmDstTy, - addrOp); + return mlir::LLVM::BitcastOp::create(rewriter, parentOp->getLoc(), + llvmDstTy, addrOp); } llvm_unreachable("Expecting pointer or integer type for GlobalViewAttr"); @@ -557,8 +563,8 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::VTableAttr vtableArr) { /// ZeroAttr visitor. mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) { mlir::Location loc = parentOp->getLoc(); - return rewriter.create<mlir::LLVM::ZeroOp>( - loc, converter->convertType(attr.getType())); + return mlir::LLVM::ZeroOp::create(rewriter, loc, + converter->convertType(attr.getType())); } // This class handles rewriting initializer attributes for types that do not @@ -666,8 +672,8 @@ mlir::LogicalResult CIRToLLVMAssumeAlignedOpLowering::matchAndRewrite( mlir::LogicalResult CIRToLLVMAssumeSepStorageOpLowering::matchAndRewrite( cir::AssumeSepStorageOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - auto cond = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), - rewriter.getI1Type(), 1); + auto cond = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI1Type(), 1); rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>( op, cond, mlir::LLVM::AssumeSeparateStorageTag{}, adaptor.getPtr1(), adaptor.getPtr2()); @@ -914,28 +920,28 @@ mlir::LogicalResult CIRToLLVMAtomicFetchOpLowering::matchAndRewrite( mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite( cir::BitClrsbOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - auto zero = rewriter.create<mlir::LLVM::ConstantOp>( - op.getLoc(), adaptor.getInput().getType(), 0); - auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>( - op.getLoc(), + auto zero = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), + adaptor.getInput().getType(), 0); + auto isNeg = mlir::LLVM::ICmpOp::create( + rewriter, op.getLoc(), mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), mlir::LLVM::ICmpPredicate::slt), adaptor.getInput(), zero); - auto negOne = rewriter.create<mlir::LLVM::ConstantOp>( - op.getLoc(), adaptor.getInput().getType(), -1); - auto flipped = rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(), - adaptor.getInput(), negOne); + auto negOne = mlir::LLVM::ConstantOp::create( + rewriter, op.getLoc(), adaptor.getInput().getType(), -1); + auto flipped = mlir::LLVM::XOrOp::create(rewriter, op.getLoc(), + adaptor.getInput(), negOne); - auto select = rewriter.create<mlir::LLVM::SelectOp>( - op.getLoc(), isNeg, flipped, adaptor.getInput()); + auto select = mlir::LLVM::SelectOp::create(rewriter, op.getLoc(), isNeg, + flipped, adaptor.getInput()); auto resTy = getTypeConverter()->convertType(op.getType()); - auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>( - op.getLoc(), resTy, select, /*is_zero_poison=*/false); + auto clz = mlir::LLVM::CountLeadingZerosOp::create( + rewriter, op.getLoc(), resTy, select, /*is_zero_poison=*/false); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); - auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one); + auto one = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), resTy, 1); + auto res = mlir::LLVM::SubOp::create(rewriter, op.getLoc(), clz, one); rewriter.replaceOp(op, res); return mlir::LogicalResult::success(); @@ -945,8 +951,8 @@ mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite( cir::BitClzOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>( - op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); + auto llvmOp = mlir::LLVM::CountLeadingZerosOp::create( + rewriter, op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); rewriter.replaceOp(op, llvmOp); return mlir::LogicalResult::success(); } @@ -955,8 +961,8 @@ mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite( cir::BitCtzOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>( - op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); + auto llvmOp = mlir::LLVM::CountTrailingZerosOp::create( + rewriter, op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero()); rewriter.replaceOp(op, llvmOp); return mlir::LogicalResult::success(); } @@ -965,23 +971,24 @@ mlir::LogicalResult CIRToLLVMBitFfsOpLowering::matchAndRewrite( cir::BitFfsOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto resTy = getTypeConverter()->convertType(op.getType()); - auto ctz = rewriter.create<mlir::LLVM::CountTrailingZerosOp>( - op.getLoc(), resTy, adaptor.getInput(), /*is_zero_poison=*/true); + auto ctz = mlir::LLVM::CountTrailingZerosOp::create(rewriter, op.getLoc(), + resTy, adaptor.getInput(), + /*is_zero_poison=*/true); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); - auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one); + auto one = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), resTy, 1); + auto ctzAddOne = mlir::LLVM::AddOp::create(rewriter, op.getLoc(), ctz, one); - auto zeroInputTy = rewriter.create<mlir::LLVM::ConstantOp>( - op.getLoc(), adaptor.getInput().getType(), 0); - auto isZero = rewriter.create<mlir::LLVM::ICmpOp>( - op.getLoc(), + auto zeroInputTy = mlir::LLVM::ConstantOp::create( + rewriter, op.getLoc(), adaptor.getInput().getType(), 0); + auto isZero = mlir::LLVM::ICmpOp::create( + rewriter, op.getLoc(), mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), mlir::LLVM::ICmpPredicate::eq), adaptor.getInput(), zeroInputTy); - auto zero = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 0); - auto res = rewriter.create<mlir::LLVM::SelectOp>(op.getLoc(), isZero, zero, - ctzAddOne); + auto zero = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), resTy, 0); + auto res = mlir::LLVM::SelectOp::create(rewriter, op.getLoc(), isZero, zero, + ctzAddOne); rewriter.replaceOp(op, res); return mlir::LogicalResult::success(); @@ -991,12 +998,12 @@ mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite( cir::BitParityOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto resTy = getTypeConverter()->convertType(op.getType()); - auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy, - adaptor.getInput()); + auto popcnt = mlir::LLVM::CtPopOp::create(rewriter, op.getLoc(), resTy, + adaptor.getInput()); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); + auto one = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), resTy, 1); auto popcntMod2 = - rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one); + mlir::LLVM::AndOp::create(rewriter, op.getLoc(), popcnt, one); rewriter.replaceOp(op, popcntMod2); return mlir::LogicalResult::success(); @@ -1006,8 +1013,8 @@ mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite( cir::BitPopcountOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { auto resTy = getTypeConverter()->convertType(op.getType()); - auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy, - adaptor.getInput()); + auto llvmOp = mlir::LLVM::CtPopOp::create(rewriter, op.getLoc(), resTy, + adaptor.getInput()); rewriter.replaceOp(op, llvmOp); return mlir::LogicalResult::success(); } @@ -1067,8 +1074,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( } case cir::CastKind::int_to_bool: { mlir::Value llvmSrcVal = adaptor.getSrc(); - mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>( - castOp.getLoc(), llvmSrcVal.getType(), 0); + mlir::Value zeroInt = mlir::LLVM::ConstantOp::create( + rewriter, castOp.getLoc(), llvmSrcVal.getType(), 0); rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt); break; @@ -1132,8 +1139,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( auto kind = mlir::LLVM::FCmpPredicate::une; // Check if float is not equal to zero. - auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>( - castOp.getLoc(), llvmSrcVal.getType(), + auto zeroFloat = mlir::LLVM::ConstantOp::create( + rewriter, castOp.getLoc(), llvmSrcVal.getType(), mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); // Extend comparison result to either bool (C++) or int (C). @@ -1204,8 +1211,8 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( } case cir::CastKind::ptr_to_bool: { mlir::Value llvmSrcVal = adaptor.getSrc(); - mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>( - castOp.getLoc(), llvmSrcVal.getType()); + mlir::Value zeroPtr = mlir::LLVM::ZeroOp::create(rewriter, castOp.getLoc(), + llvmSrcVal.getType()); rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr); break; @@ -1275,10 +1282,10 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( // Rewrite the sub in front of extensions/trunc if (rewriteSub) { - index = rewriter.create<mlir::LLVM::SubOp>( - index.getLoc(), index.getType(), - rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(), - index.getType(), 0), + index = mlir::LLVM::SubOp::create( + rewriter, index.getLoc(), index.getType(), + mlir::LLVM::ConstantOp::create(rewriter, index.getLoc(), + index.getType(), 0), index); rewriter.eraseOp(sub); } @@ -1310,11 +1317,11 @@ mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite( baseClassOp, resultType, byteType, derivedAddr, offset); } else { auto loc = baseClassOp.getLoc(); - mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>( - loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr, - rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType())); - mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>( - loc, resultType, byteType, derivedAddr, offset); + mlir::Value isNull = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr, + mlir::LLVM::ZeroOp::create(rewriter, loc, derivedAddr.getType())); + mlir::Value adjusted = mlir::LLVM::GEPOp::create( + rewriter, loc, resultType, byteType, derivedAddr, offset); rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull, derivedAddr, adjusted); } @@ -1335,8 +1342,8 @@ mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite( mlir::Value size = op.isDynamic() ? adaptor.getDynAllocSize() - : rewriter.create<mlir::LLVM::ConstantOp>( - op.getLoc(), + : mlir::LLVM::ConstantOp::create( + rewriter, op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1); mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType()); @@ -1694,13 +1701,13 @@ mlir::LogicalResult CIRToLLVMPtrDiffOpLowering::matchAndRewrite( auto dstTy = mlir::cast<cir::IntType>(op.getType()); mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); - auto lhs = rewriter.create<mlir::LLVM::PtrToIntOp>(op.getLoc(), llvmDstTy, - adaptor.getLhs()); - auto rhs = rewriter.create<mlir::LLVM::PtrToIntOp>(op.getLoc(), llvmDstTy, - adaptor.getRhs()); + auto lhs = mlir::LLVM::PtrToIntOp::create(rewriter, op.getLoc(), llvmDstTy, + adaptor.getLhs()); + auto rhs = mlir::LLVM::PtrToIntOp::create(rewriter, op.getLoc(), llvmDstTy, + adaptor.getRhs()); auto diff = - rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), llvmDstTy, lhs, rhs); + mlir::LLVM::SubOp::create(rewriter, op.getLoc(), llvmDstTy, lhs, rhs); cir::PointerType ptrTy = op.getLhs().getType(); assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee()); @@ -1709,17 +1716,17 @@ mlir::LogicalResult CIRToLLVMPtrDiffOpLowering::matchAndRewrite( // Avoid silly division by 1. mlir::Value resultVal = diff.getResult(); if (typeSize != 1) { - auto typeSizeVal = rewriter.create<mlir::LLVM::ConstantOp>( - op.getLoc(), llvmDstTy, typeSize); + auto typeSizeVal = mlir::LLVM::ConstantOp::create(rewriter, op.getLoc(), + llvmDstTy, typeSize); if (dstTy.isUnsigned()) { auto uDiv = - rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), diff, typeSizeVal); + mlir::LLVM::UDivOp::create(rewriter, op.getLoc(), diff, typeSizeVal); uDiv.setIsExact(true); resultVal = uDiv.getResult(); } else { auto sDiv = - rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), diff, typeSizeVal); + mlir::LLVM::SDivOp::create(rewriter, op.getLoc(), diff, typeSizeVal); sDiv.setIsExact(true); resultVal = sDiv.getResult(); } @@ -1847,8 +1854,8 @@ mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( SmallVector<mlir::NamedAttribute, 4> attributes; lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); - mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>( - loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, + mlir::LLVM::LLVMFuncOp fn = mlir::LLVM::LLVMFuncOp::create( + rewriter, loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, mlir::SymbolRefAttr(), attributes); assert(!cir::MissingFeatures::opFuncMultipleReturnVals()); @@ -1884,8 +1891,8 @@ mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite( } mlir::Type type = getTypeConverter()->convertType(op.getType()); - mlir::Operation *newop = - rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName()); + mlir::Operation *newop = mlir::LLVM::AddressOfOp::create( + rewriter, op.getLoc(), type, op.getName()); assert(!cir::MissingFeatures::opGlobalThreadLocal()); @@ -1941,7 +1948,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( setupRegionInitializedLLVMGlobalOp(op, rewriter); CIRAttrToValue valueConverter(op, rewriter, typeConverter); mlir::Value value = valueConverter.visit(init); - rewriter.create<mlir::LLVM::ReturnOp>(loc, value); + mlir::LLVM::ReturnOp::create(rewriter, loc, value); return mlir::success(); } @@ -2094,14 +2101,14 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( switch (op.getKind()) { case cir::UnaryOpKind::Inc: { assert(!isVector && "++ not allowed on vector types"); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); + auto one = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, 1); rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>( op, llvmType, adaptor.getInput(), one, maybeNSW); return mlir::success(); } case cir::UnaryOpKind::Dec: { assert(!isVector && "-- not allowed on vector types"); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); + auto one = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, 1); rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(), one, maybeNSW); return mlir::success(); @@ -2112,9 +2119,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( case cir::UnaryOpKind::Minus: { mlir::Value zero; if (isVector) - zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType); + zero = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmType); else - zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0); + zero = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, 0); rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>( op, zero, adaptor.getInput(), maybeNSW); return mlir::success(); @@ -2128,9 +2135,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( std::vector<int32_t> values(numElements, -1); mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values); minusOne = - rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec); + mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, denseVec); } else { - minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1); + minusOne = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, -1); } rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(), minusOne); @@ -2145,16 +2152,16 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( switch (op.getKind()) { case cir::UnaryOpKind::Inc: { assert(!isVector && "++ not allowed on vector types"); - mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); + mlir::LLVM::ConstantOp one = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one, adaptor.getInput()); return mlir::success(); } case cir::UnaryOpKind::Dec: { assert(!isVector && "-- not allowed on vector types"); - mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0)); + mlir::LLVM::ConstantOp minusOne = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0)); rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne, adaptor.getInput()); return mlir::success(); @@ -2185,7 +2192,7 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( return op.emitError() << "Unsupported unary operation on boolean type"; case cir::UnaryOpKind::Not: { assert(!isVector && "NYI: op! on vector mask"); - auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1); + auto one = mlir::LLVM::ConstantOp::create(rewriter, loc, llvmType, 1); rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(), one); return mlir::success(); @@ -2404,6 +2411,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return mlir::success(); } + if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) { + // !cir.vptr is a special case, but it's just a pointer to LLVM. + auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), + /* isSigned=*/false); + rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( + cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); + return mlir::success(); + } + if (mlir::isa<cir::FPTypeInterface>(type)) { mlir::LLVM::FCmpPredicate kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); @@ -2421,47 +2437,47 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( mlir::Type complexElemTy = getTypeConverter()->convertType(complexType.getElementType()); - auto lhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); - auto lhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); - auto rhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); - auto rhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); + auto lhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{0})); + auto lhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{1})); + auto rhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{0})); + auto rhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{1})); if (cmpOp.getKind() == cir::CmpOpKind::eq) { if (complexElemTy.isInteger()) { - auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>( - loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); - auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>( - loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); + auto realCmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal); + auto imagCmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag); rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp); return mlir::success(); } - auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>( - loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); - auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>( - loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); + auto realCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal); + auto imagCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag); rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp); return mlir::success(); } if (cmpOp.getKind() == cir::CmpOpKind::ne) { if (complexElemTy.isInteger()) { - auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>( - loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal); - auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>( - loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag); + auto realCmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal); + auto imagCmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag); rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp); return mlir::success(); } - auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>( - loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal); - auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>( - loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag); + auto realCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal); + auto imagCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag); rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp); return mlir::success(); } @@ -2725,7 +2741,7 @@ static void buildCtorDtorList( index); } - builder.create<mlir::LLVM::ReturnOp>(loc, result); + mlir::LLVM::ReturnOp::create(builder, loc, result); } // The applyPartialConversion function traverses blocks in the dominance order, @@ -2904,7 +2920,7 @@ void createLLVMFuncOpIfNotExist(mlir::ConversionPatternRewriter &rewriter, if (!sourceSymbol) { mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(enclosingFnOp); - rewriter.create<mlir::LLVM::LLVMFuncOp>(srcOp->getLoc(), fnName, fnTy); + mlir::LLVM::LLVMFuncOp::create(rewriter, srcOp->getLoc(), fnName, fnTy); } } @@ -2983,12 +2999,12 @@ mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite( mlir::Location loc = op->getLoc(); rewriter.eraseOp(op); - rewriter.create<mlir::LLVM::Trap>(loc); + mlir::LLVM::Trap::create(rewriter, loc); // Note that the call to llvm.trap is not a terminator in LLVM dialect. // So we must emit an additional llvm.unreachable to terminate the current // block. - rewriter.create<mlir::LLVM::UnreachableOp>(loc); + mlir::LLVM::UnreachableOp::create(rewriter, loc); return mlir::success(); } @@ -3114,15 +3130,15 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite( const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); const mlir::Type llvmTy = typeConverter->convertType(vecTy); const mlir::Location loc = op.getLoc(); - mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); + mlir::Value result = mlir::LLVM::PoisonOp::create(rewriter, loc, llvmTy); assert(vecTy.getSize() == op.getElements().size() && "cir.vec.create op count doesn't match vector type elements count"); for (uint64_t i = 0; i < vecTy.getSize(); ++i) { const mlir::Value indexValue = - rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); - result = rewriter.create<mlir::LLVM::InsertElementOp>( - loc, result, adaptor.getElements()[i], indexValue); + mlir::LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), i); + result = mlir::LLVM::InsertElementOp::create( + rewriter, loc, result, adaptor.getElements()[i], indexValue); } rewriter.replaceOp(op, result); @@ -3151,13 +3167,13 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite( mlir::Type elementType = elementTypeIfVector(op.getLhs().getType()); mlir::Value bitResult; if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) { - bitResult = rewriter.create<mlir::LLVM::ICmpOp>( - op.getLoc(), + bitResult = mlir::LLVM::ICmpOp::create( + rewriter, op.getLoc(), convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), adaptor.getLhs(), adaptor.getRhs()); } else if (mlir::isa<cir::FPTypeInterface>(elementType)) { - bitResult = rewriter.create<mlir::LLVM::FCmpOp>( - op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), + bitResult = mlir::LLVM::FCmpOp::create( + rewriter, op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), adaptor.getLhs(), adaptor.getRhs()); } else { return op.emitError() << "unsupported type for VecCmpOp: " << elementType; @@ -3181,7 +3197,7 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite( cir::VectorType vecTy = op.getType(); mlir::Type llvmTy = typeConverter->convertType(vecTy); mlir::Location loc = op.getLoc(); - mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); + mlir::Value poison = mlir::LLVM::PoisonOp::create(rewriter, loc, llvmTy); mlir::Value elementValue = adaptor.getValue(); if (elementValue.getDefiningOp<mlir::LLVM::PoisonOp>()) { @@ -3210,9 +3226,9 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite( } mlir::Value indexValue = - rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0); - mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>( - loc, poison, elementValue, indexValue); + mlir::LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 0); + mlir::Value oneElement = mlir::LLVM::InsertElementOp::create( + rewriter, loc, poison, elementValue, indexValue); SmallVector<int32_t> zeroValues(vecTy.getSize(), 0); rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement, poison, zeroValues); @@ -3260,31 +3276,32 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite( mlir::cast<cir::VectorType>(op.getVec().getType()).getSize(); uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1; - mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, maskBits)); + mlir::Value maskValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmIndexType, + rewriter.getIntegerAttr(llvmIndexType, maskBits)); mlir::Value maskVector = - rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType); + mlir::LLVM::UndefOp::create(rewriter, loc, llvmIndexVecType); for (uint64_t i = 0; i < numElements; ++i) { mlir::Value idxValue = - rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); - maskVector = rewriter.create<mlir::LLVM::InsertElementOp>( - loc, maskVector, maskValue, idxValue); + mlir::LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), i); + maskVector = mlir::LLVM::InsertElementOp::create(rewriter, loc, maskVector, + maskValue, idxValue); } - mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>( - loc, llvmIndexVecType, adaptor.getIndices(), maskVector); - mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>( - loc, getTypeConverter()->convertType(op.getVec().getType())); + mlir::Value maskedIndices = mlir::LLVM::AndOp::create( + rewriter, loc, llvmIndexVecType, adaptor.getIndices(), maskVector); + mlir::Value result = mlir::LLVM::UndefOp::create( + rewriter, loc, getTypeConverter()->convertType(op.getVec().getType())); for (uint64_t i = 0; i < numElements; ++i) { mlir::Value iValue = - rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); - mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>( - loc, maskedIndices, iValue); + mlir::LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), i); + mlir::Value indexValue = mlir::LLVM::ExtractElementOp::create( + rewriter, loc, maskedIndices, iValue); mlir::Value valueAtIndex = - rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue); - result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result, - valueAtIndex, iValue); + mlir::LLVM::ExtractElementOp::create(rewriter, loc, input, indexValue); + result = mlir::LLVM::InsertElementOp::create(rewriter, loc, result, + valueAtIndex, iValue); } rewriter.replaceOp(op, result); return mlir::success(); @@ -3294,10 +3311,10 @@ mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite( cir::VecTernaryOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { // Convert `cond` into a vector of i1, then use that in a `select` op. - mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>( - op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), - rewriter.create<mlir::LLVM::ZeroOp>( - op.getCond().getLoc(), + mlir::Value bitVec = mlir::LLVM::ICmpOp::create( + rewriter, op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), + mlir::LLVM::ZeroOp::create( + rewriter, op.getCond().getLoc(), typeConverter->convertType(op.getCond().getType()))); rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( op, bitVec, adaptor.getLhs(), adaptor.getRhs()); @@ -3314,41 +3331,41 @@ mlir::LogicalResult CIRToLLVMComplexAddOpLowering::matchAndRewrite( auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType()); mlir::Type complexElemTy = getTypeConverter()->convertType(complexType.getElementType()); - auto lhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); - auto lhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); - auto rhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); - auto rhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); + auto lhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{0})); + auto lhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{1})); + auto rhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{0})); + auto rhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{1})); mlir::Value newReal; mlir::Value newImag; if (complexElemTy.isInteger()) { - newReal = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsReal, - rhsReal); - newImag = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsImag, - rhsImag); + newReal = mlir::LLVM::AddOp::create(rewriter, loc, complexElemTy, lhsReal, + rhsReal); + newImag = mlir::LLVM::AddOp::create(rewriter, loc, complexElemTy, lhsImag, + rhsImag); } else { assert(!cir::MissingFeatures::fastMathFlags()); assert(!cir::MissingFeatures::fpConstraints()); - newReal = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsReal, - rhsReal); - newImag = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsImag, - rhsImag); + newReal = mlir::LLVM::FAddOp::create(rewriter, loc, complexElemTy, lhsReal, + rhsReal); + newImag = mlir::LLVM::FAddOp::create(rewriter, loc, complexElemTy, lhsImag, + rhsImag); } mlir::Type complexLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); auto initialComplex = - rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy); + mlir::LLVM::PoisonOp::create(rewriter, op->getLoc(), complexLLVMTy); - auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( - op->getLoc(), initialComplex, newReal, 0); + auto realComplex = mlir::LLVM::InsertValueOp::create( + rewriter, op->getLoc(), initialComplex, newReal, ArrayRef(int64_t{0})); - rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex, - newImag, 1); + rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( + op, realComplex, newImag, ArrayRef(int64_t{1})); return mlir::success(); } @@ -3359,13 +3376,15 @@ mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite( mlir::Type complexLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); auto initialComplex = - rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy); + mlir::LLVM::UndefOp::create(rewriter, op->getLoc(), complexLLVMTy); - auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( - op->getLoc(), initialComplex, adaptor.getReal(), 0); + auto realComplex = mlir::LLVM::InsertValueOp::create( + rewriter, op->getLoc(), initialComplex, adaptor.getReal(), + ArrayRef(int64_t{0})); - auto complex = rewriter.create<mlir::LLVM::InsertValueOp>( - op->getLoc(), realComplex, adaptor.getImag(), 1); + auto complex = mlir::LLVM::InsertValueOp::create( + rewriter, op->getLoc(), realComplex, adaptor.getImag(), + ArrayRef(int64_t{1})); rewriter.replaceOp(op, complex); return mlir::success(); @@ -3395,41 +3414,41 @@ mlir::LogicalResult CIRToLLVMComplexSubOpLowering::matchAndRewrite( auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType()); mlir::Type complexElemTy = getTypeConverter()->convertType(complexType.getElementType()); - auto lhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0); - auto lhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1); - auto rhsReal = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0); - auto rhsImag = - rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1); + auto lhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{0})); + auto lhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, lhs, ArrayRef(int64_t{1})); + auto rhsReal = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{0})); + auto rhsImag = mlir::LLVM::ExtractValueOp::create( + rewriter, loc, complexElemTy, rhs, ArrayRef(int64_t{1})); mlir::Value newReal; mlir::Value newImag; if (complexElemTy.isInteger()) { - newReal = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsReal, - rhsReal); - newImag = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsImag, - rhsImag); + newReal = mlir::LLVM::SubOp::create(rewriter, loc, complexElemTy, lhsReal, + rhsReal); + newImag = mlir::LLVM::SubOp::create(rewriter, loc, complexElemTy, lhsImag, + rhsImag); } else { assert(!cir::MissingFeatures::fastMathFlags()); assert(!cir::MissingFeatures::fpConstraints()); - newReal = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsReal, - rhsReal); - newImag = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsImag, - rhsImag); + newReal = mlir::LLVM::FSubOp::create(rewriter, loc, complexElemTy, lhsReal, + rhsReal); + newImag = mlir::LLVM::FSubOp::create(rewriter, loc, complexElemTy, lhsImag, + rhsImag); } mlir::Type complexLLVMTy = getTypeConverter()->convertType(op.getResult().getType()); auto initialComplex = - rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy); + mlir::LLVM::PoisonOp::create(rewriter, op->getLoc(), complexLLVMTy); - auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( - op->getLoc(), initialComplex, newReal, 0); + auto realComplex = mlir::LLVM::InsertValueOp::create( + rewriter, op->getLoc(), initialComplex, newReal, ArrayRef(int64_t{0})); - rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex, - newImag, 1); + rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>( + op, realComplex, newImag, ArrayRef(int64_t{1})); return mlir::success(); } @@ -3496,8 +3515,8 @@ mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite( if (storageSize != size) { assert(storageSize > size && "Invalid bitfield size."); - mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( - op.getLoc(), intType, adaptor.getAddr(), op.getAlignment(), + mlir::Value val = mlir::LLVM::LoadOp::create( + rewriter, op.getLoc(), intType, adaptor.getAddr(), op.getAlignment(), op.getIsVolatile()); srcVal = @@ -3510,11 +3529,11 @@ mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite( ~llvm::APInt::getBitsSet(srcWidth, offset, offset + size)); // Or together the unchanged values and the source value. - srcVal = rewriter.create<mlir::LLVM::OrOp>(op.getLoc(), val, srcVal); + srcVal = mlir::LLVM::OrOp::create(rewriter, op.getLoc(), val, srcVal); } - rewriter.create<mlir::LLVM::StoreOp>(op.getLoc(), srcVal, adaptor.getAddr(), - op.getAlignment(), op.getIsVolatile()); + mlir::LLVM::StoreOp::create(rewriter, op.getLoc(), srcVal, adaptor.getAddr(), + op.getAlignment(), op.getIsVolatile()); mlir::Type resultTy = getTypeConverter()->convertType(op.getType()); @@ -3587,10 +3606,10 @@ mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite( mlir::IntegerType intType = computeBitfieldIntType(storageType, context, storageSize); - mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( - op.getLoc(), intType, adaptor.getAddr(), op.getAlignment(), + mlir::Value val = mlir::LLVM::LoadOp::create( + rewriter, op.getLoc(), intType, adaptor.getAddr(), op.getAlignment(), op.getIsVolatile()); - val = rewriter.create<mlir::LLVM::BitcastOp>(op.getLoc(), intType, val); + val = mlir::LLVM::BitcastOp::create(rewriter, op.getLoc(), intType, val); if (info.getIsSigned()) { assert(static_cast<unsigned>(offset + size) <= storageSize); diff --git a/clang/lib/CIR/Lowering/LoweringHelpers.cpp b/clang/lib/CIR/Lowering/LoweringHelpers.cpp index d5f1324..0786579 100644 --- a/clang/lib/CIR/Lowering/LoweringHelpers.cpp +++ b/clang/lib/CIR/Lowering/LoweringHelpers.cpp @@ -148,37 +148,37 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr, mlir::Value getConstAPInt(mlir::OpBuilder &bld, mlir::Location loc, mlir::Type typ, const llvm::APInt &val) { - return bld.create<mlir::LLVM::ConstantOp>(loc, typ, val); + return mlir::LLVM::ConstantOp::create(bld, loc, typ, val); } mlir::Value getConst(mlir::OpBuilder &bld, mlir::Location loc, mlir::Type typ, unsigned val) { - return bld.create<mlir::LLVM::ConstantOp>(loc, typ, val); + return mlir::LLVM::ConstantOp::create(bld, loc, typ, val); } mlir::Value createShL(mlir::OpBuilder &bld, mlir::Value lhs, unsigned rhs) { if (!rhs) return lhs; mlir::Value rhsVal = getConst(bld, lhs.getLoc(), lhs.getType(), rhs); - return bld.create<mlir::LLVM::ShlOp>(lhs.getLoc(), lhs, rhsVal); + return mlir::LLVM::ShlOp::create(bld, lhs.getLoc(), lhs, rhsVal); } mlir::Value createAShR(mlir::OpBuilder &bld, mlir::Value lhs, unsigned rhs) { if (!rhs) return lhs; mlir::Value rhsVal = getConst(bld, lhs.getLoc(), lhs.getType(), rhs); - return bld.create<mlir::LLVM::AShrOp>(lhs.getLoc(), lhs, rhsVal); + return mlir::LLVM::AShrOp::create(bld, lhs.getLoc(), lhs, rhsVal); } mlir::Value createAnd(mlir::OpBuilder &bld, mlir::Value lhs, const llvm::APInt &rhs) { mlir::Value rhsVal = getConstAPInt(bld, lhs.getLoc(), lhs.getType(), rhs); - return bld.create<mlir::LLVM::AndOp>(lhs.getLoc(), lhs, rhsVal); + return mlir::LLVM::AndOp::create(bld, lhs.getLoc(), lhs, rhsVal); } mlir::Value createLShR(mlir::OpBuilder &bld, mlir::Value lhs, unsigned rhs) { if (!rhs) return lhs; mlir::Value rhsVal = getConst(bld, lhs.getLoc(), lhs.getType(), rhs); - return bld.create<mlir::LLVM::LShrOp>(lhs.getLoc(), lhs, rhsVal); + return mlir::LLVM::LShrOp::create(bld, lhs.getLoc(), lhs, rhsVal); } diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index fd73314..301d577 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -29,6 +29,7 @@ #include "clang/AST/ASTLambda.h" #include "clang/AST/Attr.h" #include "clang/AST/DeclObjC.h" +#include "clang/AST/InferAlloc.h" #include "clang/AST/NSAPI.h" #include "clang/AST/ParentMapContext.h" #include "clang/AST/StmtVisitor.h" @@ -1273,194 +1274,39 @@ void CodeGenFunction::EmitBoundsCheckImpl(const Expr *E, llvm::Value *Bound, EmitCheck(std::make_pair(Check, CheckKind), CheckHandler, StaticData, Index); } -static bool -typeContainsPointer(QualType T, - llvm::SmallPtrSet<const RecordDecl *, 4> &VisitedRD, - bool &IncompleteType) { - QualType CanonicalType = T.getCanonicalType(); - if (CanonicalType->isPointerType()) - return true; // base case - - // Look through typedef chain to check for special types. - for (QualType CurrentT = T; const auto *TT = CurrentT->getAs<TypedefType>(); - CurrentT = TT->getDecl()->getUnderlyingType()) { - const IdentifierInfo *II = TT->getDecl()->getIdentifier(); - // Special Case: Syntactically uintptr_t is not a pointer; semantically, - // however, very likely used as such. Therefore, classify uintptr_t as a - // pointer, too. - if (II && II->isStr("uintptr_t")) - return true; - } - - // The type is an array; check the element type. - if (const ArrayType *AT = dyn_cast<ArrayType>(CanonicalType)) - return typeContainsPointer(AT->getElementType(), VisitedRD, IncompleteType); - // The type is a struct, class, or union. - if (const RecordDecl *RD = CanonicalType->getAsRecordDecl()) { - if (!RD->isCompleteDefinition()) { - IncompleteType = true; - return false; - } - if (!VisitedRD.insert(RD).second) - return false; // already visited - // Check all fields. - for (const FieldDecl *Field : RD->fields()) { - if (typeContainsPointer(Field->getType(), VisitedRD, IncompleteType)) - return true; - } - // For C++ classes, also check base classes. - if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) { - // Polymorphic types require a vptr. - if (CXXRD->isDynamicClass()) - return true; - for (const CXXBaseSpecifier &Base : CXXRD->bases()) { - if (typeContainsPointer(Base.getType(), VisitedRD, IncompleteType)) - return true; - } - } - } - return false; -} - -void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) { - assert(SanOpts.has(SanitizerKind::AllocToken) && - "Only needed with -fsanitize=alloc-token"); +llvm::MDNode *CodeGenFunction::buildAllocToken(QualType AllocType) { + auto ATMD = infer_alloc::getAllocTokenMetadata(AllocType, getContext()); + if (!ATMD) + return nullptr; llvm::MDBuilder MDB(getLLVMContext()); - - // Get unique type name. - PrintingPolicy Policy(CGM.getContext().getLangOpts()); - Policy.SuppressTagKeyword = true; - Policy.FullyQualifiedName = true; - SmallString<64> TypeName; - llvm::raw_svector_ostream TypeNameOS(TypeName); - AllocType.getCanonicalType().print(TypeNameOS, Policy); - auto *TypeNameMD = MDB.createString(TypeNameOS.str()); - - // Check if QualType contains a pointer. Implements a simple DFS to - // recursively check if a type contains a pointer type. - llvm::SmallPtrSet<const RecordDecl *, 4> VisitedRD; - bool IncompleteType = false; - const bool ContainsPtr = - typeContainsPointer(AllocType, VisitedRD, IncompleteType); - if (!ContainsPtr && IncompleteType) - return; - auto *ContainsPtrC = Builder.getInt1(ContainsPtr); + auto *TypeNameMD = MDB.createString(ATMD->TypeName); + auto *ContainsPtrC = Builder.getInt1(ATMD->ContainsPointer); auto *ContainsPtrMD = MDB.createConstant(ContainsPtrC); // Format: !{<type-name>, <contains-pointer>} - auto *MDN = - llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD}); - CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN); -} - -namespace { -/// Infer type from a simple sizeof expression. -QualType inferTypeFromSizeofExpr(const Expr *E) { - const Expr *Arg = E->IgnoreParenImpCasts(); - if (const auto *UET = dyn_cast<UnaryExprOrTypeTraitExpr>(Arg)) { - if (UET->getKind() == UETT_SizeOf) { - if (UET->isArgumentType()) - return UET->getArgumentTypeInfo()->getType(); - else - return UET->getArgumentExpr()->getType(); - } - } - return QualType(); -} - -/// Infer type from an arithmetic expression involving a sizeof. For example: -/// -/// malloc(sizeof(MyType) + padding); // infers 'MyType' -/// malloc(sizeof(MyType) * 32); // infers 'MyType' -/// malloc(32 * sizeof(MyType)); // infers 'MyType' -/// malloc(sizeof(MyType) << 1); // infers 'MyType' -/// ... -/// -/// More complex arithmetic expressions are supported, but are a heuristic, e.g. -/// when considering allocations for structs with flexible array members: -/// -/// malloc(sizeof(HasFlexArray) + sizeof(int) * 32); // infers 'HasFlexArray' -/// -QualType inferPossibleTypeFromArithSizeofExpr(const Expr *E) { - const Expr *Arg = E->IgnoreParenImpCasts(); - // The argument is a lone sizeof expression. - if (QualType T = inferTypeFromSizeofExpr(Arg); !T.isNull()) - return T; - if (const auto *BO = dyn_cast<BinaryOperator>(Arg)) { - // Argument is an arithmetic expression. Cover common arithmetic patterns - // involving sizeof. - switch (BO->getOpcode()) { - case BO_Add: - case BO_Div: - case BO_Mul: - case BO_Shl: - case BO_Shr: - case BO_Sub: - if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getLHS()); - !T.isNull()) - return T; - if (QualType T = inferPossibleTypeFromArithSizeofExpr(BO->getRHS()); - !T.isNull()) - return T; - break; - default: - break; - } - } - return QualType(); + return llvm::MDNode::get(CGM.getLLVMContext(), {TypeNameMD, ContainsPtrMD}); } -/// If the expression E is a reference to a variable, infer the type from a -/// variable's initializer if it contains a sizeof. Beware, this is a heuristic -/// and ignores if a variable is later reassigned. For example: -/// -/// size_t my_size = sizeof(MyType); -/// void *x = malloc(my_size); // infers 'MyType' -/// -QualType inferPossibleTypeFromVarInitSizeofExpr(const Expr *E) { - const Expr *Arg = E->IgnoreParenImpCasts(); - if (const auto *DRE = dyn_cast<DeclRefExpr>(Arg)) { - if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) { - if (const Expr *Init = VD->getInit()) - return inferPossibleTypeFromArithSizeofExpr(Init); - } - } - return QualType(); +void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, QualType AllocType) { + assert(SanOpts.has(SanitizerKind::AllocToken) && + "Only needed with -fsanitize=alloc-token"); + CB->setMetadata(llvm::LLVMContext::MD_alloc_token, + buildAllocToken(AllocType)); } -/// Deduces the allocated type by checking if the allocation call's result -/// is immediately used in a cast expression. For example: -/// -/// MyType *x = (MyType *)malloc(4096); // infers 'MyType' -/// -QualType inferPossibleTypeFromCastExpr(const CallExpr *CallE, - const CastExpr *CastE) { - if (!CastE) - return QualType(); - QualType PtrType = CastE->getType(); - if (PtrType->isPointerType()) - return PtrType->getPointeeType(); - return QualType(); +llvm::MDNode *CodeGenFunction::buildAllocToken(const CallExpr *E) { + QualType AllocType = infer_alloc::inferPossibleType(E, getContext(), CurCast); + if (!AllocType.isNull()) + return buildAllocToken(AllocType); + return nullptr; } -} // end anonymous namespace void CodeGenFunction::EmitAllocToken(llvm::CallBase *CB, const CallExpr *E) { - QualType AllocType; - // First check arguments. - for (const Expr *Arg : E->arguments()) { - AllocType = inferPossibleTypeFromArithSizeofExpr(Arg); - if (AllocType.isNull()) - AllocType = inferPossibleTypeFromVarInitSizeofExpr(Arg); - if (!AllocType.isNull()) - break; - } - // Then check later casts. - if (AllocType.isNull()) - AllocType = inferPossibleTypeFromCastExpr(E, CurCast); - // Emit if we were able to infer the type. - if (!AllocType.isNull()) - EmitAllocToken(CB, AllocType); + assert(SanOpts.has(SanitizerKind::AllocToken) && + "Only needed with -fsanitize=alloc-token"); + if (llvm::MDNode *MDN = buildAllocToken(E)) + CB->setMetadata(llvm::LLVMContext::MD_alloc_token, MDN); } CodeGenFunction::ComplexPairTy CodeGenFunction:: diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 1f0be2d..8c4c1c8 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3352,9 +3352,14 @@ public: SanitizerAnnotateDebugInfo(ArrayRef<SanitizerKind::SanitizerOrdinal> Ordinals, SanitizerHandler Handler); - /// Emit additional metadata used by the AllocToken instrumentation. + /// Build metadata used by the AllocToken instrumentation. + llvm::MDNode *buildAllocToken(QualType AllocType); + /// Emit and set additional metadata used by the AllocToken instrumentation. void EmitAllocToken(llvm::CallBase *CB, QualType AllocType); - /// Emit additional metadata used by the AllocToken instrumentation, + /// Build additional metadata used by the AllocToken instrumentation, + /// inferring the type from an allocation call expression. + llvm::MDNode *buildAllocToken(const CallExpr *E); + /// Emit and set additional metadata used by the AllocToken instrumentation, /// inferring the type from an allocation call expression. void EmitAllocToken(llvm::CallBase *CB, const CallExpr *E); diff --git a/clang/lib/Format/TokenAnnotator.cpp b/clang/lib/Format/TokenAnnotator.cpp index 25971d2..c97a9e8 100644 --- a/clang/lib/Format/TokenAnnotator.cpp +++ b/clang/lib/Format/TokenAnnotator.cpp @@ -3791,18 +3791,12 @@ static bool isFunctionDeclarationName(const LangOptions &LangOpts, if (Current.is(TT_FunctionDeclarationName)) return true; - if (!Current.Tok.getIdentifierInfo()) + if (Current.isNoneOf(tok::identifier, tok::kw_operator)) return false; const auto *Prev = Current.getPreviousNonComment(); assert(Prev); - if (Prev->is(tok::coloncolon)) - Prev = Prev->Previous; - - if (!Prev) - return false; - const auto &Previous = *Prev; if (const auto *PrevPrev = Previous.getPreviousNonComment(); @@ -3851,6 +3845,8 @@ static bool isFunctionDeclarationName(const LangOptions &LangOpts, // Find parentheses of parameter list. if (Current.is(tok::kw_operator)) { + if (Line.startsWith(tok::kw_friend)) + return true; if (Previous.Tok.getIdentifierInfo() && Previous.isNoneOf(tok::kw_return, tok::kw_co_return)) { return true; diff --git a/clang/lib/Frontend/ASTUnit.cpp b/clang/lib/Frontend/ASTUnit.cpp index d53b64a..6cc7094 100644 --- a/clang/lib/Frontend/ASTUnit.cpp +++ b/clang/lib/Frontend/ASTUnit.cpp @@ -512,152 +512,73 @@ namespace { /// Gathers information from ASTReader that will be used to initialize /// a Preprocessor. class ASTInfoCollector : public ASTReaderListener { - Preprocessor &PP; - ASTContext *Context; HeaderSearchOptions &HSOpts; + std::string &SpecificModuleCachePath; PreprocessorOptions &PPOpts; - LangOptions &LangOpt; + LangOptions &LangOpts; CodeGenOptions &CodeGenOpts; - std::shared_ptr<TargetOptions> &TargetOpts; - IntrusiveRefCntPtr<TargetInfo> &Target; + TargetOptions &TargetOpts; unsigned &Counter; - bool InitializedLanguage = false; - bool InitializedHeaderSearchPaths = false; public: - ASTInfoCollector(Preprocessor &PP, ASTContext *Context, - HeaderSearchOptions &HSOpts, PreprocessorOptions &PPOpts, - LangOptions &LangOpt, CodeGenOptions &CodeGenOpts, - std::shared_ptr<TargetOptions> &TargetOpts, - IntrusiveRefCntPtr<TargetInfo> &Target, unsigned &Counter) - : PP(PP), Context(Context), HSOpts(HSOpts), PPOpts(PPOpts), - LangOpt(LangOpt), CodeGenOpts(CodeGenOpts), TargetOpts(TargetOpts), - Target(Target), Counter(Counter) {} - - bool ReadLanguageOptions(const LangOptions &LangOpts, + ASTInfoCollector(HeaderSearchOptions &HSOpts, + std::string &SpecificModuleCachePath, + PreprocessorOptions &PPOpts, LangOptions &LangOpts, + CodeGenOptions &CodeGenOpts, TargetOptions &TargetOpts, + unsigned &Counter) + : HSOpts(HSOpts), SpecificModuleCachePath(SpecificModuleCachePath), + PPOpts(PPOpts), LangOpts(LangOpts), CodeGenOpts(CodeGenOpts), + TargetOpts(TargetOpts), Counter(Counter) {} + + bool ReadLanguageOptions(const LangOptions &NewLangOpts, StringRef ModuleFilename, bool Complain, bool AllowCompatibleDifferences) override { - if (InitializedLanguage) - return false; - - // FIXME: We did similar things in ReadHeaderSearchOptions too. But such - // style is not scaling. Probably we need to invite some mechanism to - // handle such patterns generally. - auto PICLevel = LangOpt.PICLevel; - auto PIE = LangOpt.PIE; - - LangOpt = LangOpts; - - LangOpt.PICLevel = PICLevel; - LangOpt.PIE = PIE; - - InitializedLanguage = true; - - updated(); + LangOpts = NewLangOpts; return false; } - bool ReadCodeGenOptions(const CodeGenOptions &CGOpts, + bool ReadCodeGenOptions(const CodeGenOptions &NewCodeGenOpts, StringRef ModuleFilename, bool Complain, bool AllowCompatibleDifferences) override { - this->CodeGenOpts = CGOpts; + CodeGenOpts = NewCodeGenOpts; return false; } - bool ReadHeaderSearchOptions(const HeaderSearchOptions &HSOpts, + bool ReadHeaderSearchOptions(const HeaderSearchOptions &NewHSOpts, StringRef ModuleFilename, - StringRef SpecificModuleCachePath, + StringRef NewSpecificModuleCachePath, bool Complain) override { - // llvm::SaveAndRestore doesn't support bit field. - auto ForceCheckCXX20ModulesInputFiles = - this->HSOpts.ForceCheckCXX20ModulesInputFiles; - llvm::SaveAndRestore X(this->HSOpts.UserEntries); - llvm::SaveAndRestore Y(this->HSOpts.SystemHeaderPrefixes); - llvm::SaveAndRestore Z(this->HSOpts.VFSOverlayFiles); - - this->HSOpts = HSOpts; - this->HSOpts.ForceCheckCXX20ModulesInputFiles = - ForceCheckCXX20ModulesInputFiles; - + HSOpts = NewHSOpts; + SpecificModuleCachePath = NewSpecificModuleCachePath; return false; } - bool ReadHeaderSearchPaths(const HeaderSearchOptions &HSOpts, + bool ReadHeaderSearchPaths(const HeaderSearchOptions &NewHSOpts, bool Complain) override { - if (InitializedHeaderSearchPaths) - return false; - - this->HSOpts.UserEntries = HSOpts.UserEntries; - this->HSOpts.SystemHeaderPrefixes = HSOpts.SystemHeaderPrefixes; - this->HSOpts.VFSOverlayFiles = HSOpts.VFSOverlayFiles; - - // Initialize the FileManager. We can't do this in update(), since that - // performs the initialization too late (once both target and language - // options are read). - PP.getFileManager().setVirtualFileSystem(createVFSFromOverlayFiles( - HSOpts.VFSOverlayFiles, PP.getDiagnostics(), - PP.getFileManager().getVirtualFileSystemPtr())); - - InitializedHeaderSearchPaths = true; - + HSOpts.UserEntries = NewHSOpts.UserEntries; + HSOpts.SystemHeaderPrefixes = NewHSOpts.SystemHeaderPrefixes; + HSOpts.VFSOverlayFiles = NewHSOpts.VFSOverlayFiles; return false; } - bool ReadPreprocessorOptions(const PreprocessorOptions &PPOpts, + bool ReadPreprocessorOptions(const PreprocessorOptions &NewPPOpts, StringRef ModuleFilename, bool ReadMacros, bool Complain, std::string &SuggestedPredefines) override { - this->PPOpts = PPOpts; + PPOpts = NewPPOpts; return false; } - bool ReadTargetOptions(const TargetOptions &TargetOpts, + bool ReadTargetOptions(const TargetOptions &NewTargetOpts, StringRef ModuleFilename, bool Complain, bool AllowCompatibleDifferences) override { - // If we've already initialized the target, don't do it again. - if (Target) - return false; - - this->TargetOpts = std::make_shared<TargetOptions>(TargetOpts); - Target = - TargetInfo::CreateTargetInfo(PP.getDiagnostics(), *this->TargetOpts); - - updated(); + TargetOpts = NewTargetOpts; return false; } void ReadCounter(const serialization::ModuleFile &M, - unsigned Value) override { - Counter = Value; - } - -private: - void updated() { - if (!Target || !InitializedLanguage) - return; - - // Inform the target of the language options. - // - // FIXME: We shouldn't need to do this, the target should be immutable once - // created. This complexity should be lifted elsewhere. - Target->adjust(PP.getDiagnostics(), LangOpt, /*AuxTarget=*/nullptr); - - // Initialize the preprocessor. - PP.Initialize(*Target); - - if (!Context) - return; - - // Initialize the ASTContext - Context->InitBuiltinTypes(*Target); - - // Adjust printing policy based on language options. - Context->setPrintingPolicy(PrintingPolicy(LangOpt)); - - // We didn't have access to the comment options when the ASTContext was - // constructed, so register them now. - Context->getCommentCommandTraits().registerCommentOptions( - LangOpt.CommentOpts); + unsigned NewCounter) override { + Counter = NewCounter; } }; @@ -812,7 +733,7 @@ std::unique_ptr<ASTUnit> ASTUnit::LoadFromASTFile( std::shared_ptr<DiagnosticOptions> DiagOpts, IntrusiveRefCntPtr<DiagnosticsEngine> Diags, const FileSystemOptions &FileSystemOpts, const HeaderSearchOptions &HSOpts, - const LangOptions *LangOpts, bool OnlyLocalDecls, + const LangOptions *ProvidedLangOpts, bool OnlyLocalDecls, CaptureDiagsKind CaptureDiagnostics, bool AllowASTWithCompilerErrors, bool UserFilesAreVolatile) { std::unique_ptr<ASTUnit> AST(new ASTUnit(true)); @@ -826,41 +747,71 @@ std::unique_ptr<ASTUnit> ASTUnit::LoadFromASTFile( ConfigureDiags(Diags, *AST, CaptureDiagnostics); - AST->LangOpts = LangOpts ? std::make_unique<LangOptions>(*LangOpts) - : std::make_unique<LangOptions>(); + std::unique_ptr<LangOptions> LocalLangOpts; + const LangOptions &LangOpts = [&]() -> const LangOptions & { + if (ProvidedLangOpts) + return *ProvidedLangOpts; + LocalLangOpts = std::make_unique<LangOptions>(); + return *LocalLangOpts; + }(); + + AST->LangOpts = std::make_unique<LangOptions>(LangOpts); AST->OnlyLocalDecls = OnlyLocalDecls; AST->CaptureDiagnostics = CaptureDiagnostics; AST->DiagOpts = DiagOpts; AST->Diagnostics = Diags; - AST->FileMgr = llvm::makeIntrusiveRefCnt<FileManager>(FileSystemOpts, VFS); AST->UserFilesAreVolatile = UserFilesAreVolatile; - AST->SourceMgr = llvm::makeIntrusiveRefCnt<SourceManager>( - AST->getDiagnostics(), AST->getFileManager(), UserFilesAreVolatile); - AST->ModCache = createCrossProcessModuleCache(); AST->HSOpts = std::make_unique<HeaderSearchOptions>(HSOpts); AST->HSOpts->ModuleFormat = std::string(PCHContainerRdr.getFormats().front()); - AST->HeaderInfo.reset(new HeaderSearch(AST->getHeaderSearchOpts(), - AST->getSourceManager(), - AST->getDiagnostics(), - AST->getLangOpts(), - /*Target=*/nullptr)); AST->PPOpts = std::make_shared<PreprocessorOptions>(); + AST->CodeGenOpts = std::make_unique<CodeGenOptions>(); + AST->TargetOpts = std::make_shared<TargetOptions>(); + + AST->ModCache = createCrossProcessModuleCache(); + + // Gather info for preprocessor construction later on. + std::string SpecificModuleCachePath; + unsigned Counter = 0; + // Using a temporary FileManager since the AST file might specify custom + // HeaderSearchOptions::VFSOverlayFiles that affect the underlying VFS. + FileManager TmpFileMgr(FileSystemOpts, VFS); + ASTInfoCollector Collector(*AST->HSOpts, SpecificModuleCachePath, + *AST->PPOpts, *AST->LangOpts, *AST->CodeGenOpts, + *AST->TargetOpts, Counter); + if (ASTReader::readASTFileControlBlock( + Filename, TmpFileMgr, *AST->ModCache, PCHContainerRdr, + /*FindModuleFileExtensions=*/true, Collector, + /*ValidateDiagnosticOptions=*/true, ASTReader::ARR_None)) { + AST->getDiagnostics().Report(diag::err_fe_unable_to_load_pch); + return nullptr; + } + + VFS = createVFSFromOverlayFiles(AST->HSOpts->VFSOverlayFiles, + *AST->Diagnostics, std::move(VFS)); - // Gather Info for preprocessor construction later on. + AST->FileMgr = llvm::makeIntrusiveRefCnt<FileManager>(FileSystemOpts, VFS); + + AST->SourceMgr = llvm::makeIntrusiveRefCnt<SourceManager>( + AST->getDiagnostics(), AST->getFileManager(), UserFilesAreVolatile); - HeaderSearch &HeaderInfo = *AST->HeaderInfo; + AST->HSOpts->PrebuiltModuleFiles = HSOpts.PrebuiltModuleFiles; + AST->HSOpts->PrebuiltModulePaths = HSOpts.PrebuiltModulePaths; + AST->HeaderInfo = std::make_unique<HeaderSearch>( + AST->getHeaderSearchOpts(), AST->getSourceManager(), + AST->getDiagnostics(), AST->getLangOpts(), + /*Target=*/nullptr); + AST->HeaderInfo->setModuleCachePath(SpecificModuleCachePath); AST->PP = std::make_shared<Preprocessor>( *AST->PPOpts, AST->getDiagnostics(), *AST->LangOpts, - AST->getSourceManager(), HeaderInfo, AST->ModuleLoader, + AST->getSourceManager(), *AST->HeaderInfo, AST->ModuleLoader, /*IILookup=*/nullptr, /*OwnsHeaderSearch=*/false); - Preprocessor &PP = *AST->PP; if (ToLoad >= LoadASTOnly) AST->Ctx = llvm::makeIntrusiveRefCnt<ASTContext>( - *AST->LangOpts, AST->getSourceManager(), PP.getIdentifierTable(), - PP.getSelectorTable(), PP.getBuiltinInfo(), + *AST->LangOpts, AST->getSourceManager(), AST->PP->getIdentifierTable(), + AST->PP->getSelectorTable(), AST->PP->getBuiltinInfo(), AST->getTranslationUnitKind()); DisableValidationForModuleKind disableValid = @@ -868,24 +819,60 @@ std::unique_ptr<ASTUnit> ASTUnit::LoadFromASTFile( if (::getenv("LIBCLANG_DISABLE_PCH_VALIDATION")) disableValid = DisableValidationForModuleKind::All; AST->Reader = llvm::makeIntrusiveRefCnt<ASTReader>( - PP, *AST->ModCache, AST->Ctx.get(), PCHContainerRdr, *AST->CodeGenOpts, - ArrayRef<std::shared_ptr<ModuleFileExtension>>(), + *AST->PP, *AST->ModCache, AST->Ctx.get(), PCHContainerRdr, + *AST->CodeGenOpts, ArrayRef<std::shared_ptr<ModuleFileExtension>>(), /*isysroot=*/"", /*DisableValidationKind=*/disableValid, AllowASTWithCompilerErrors); - unsigned Counter = 0; - AST->Reader->setListener(std::make_unique<ASTInfoCollector>( - *AST->PP, AST->Ctx.get(), *AST->HSOpts, *AST->PPOpts, *AST->LangOpts, - *AST->CodeGenOpts, AST->TargetOpts, AST->Target, Counter)); - - // Attach the AST reader to the AST context as an external AST - // source, so that declarations will be deserialized from the - // AST file as needed. + // Attach the AST reader to the AST context as an external AST source, so that + // declarations will be deserialized from the AST file as needed. // We need the external source to be set up before we read the AST, because // eagerly-deserialized declarations may use it. if (AST->Ctx) AST->Ctx->setExternalSource(AST->Reader); + AST->Target = + TargetInfo::CreateTargetInfo(AST->PP->getDiagnostics(), *AST->TargetOpts); + // Inform the target of the language options. + // + // FIXME: We shouldn't need to do this, the target should be immutable once + // created. This complexity should be lifted elsewhere. + AST->Target->adjust(AST->PP->getDiagnostics(), *AST->LangOpts, + /*AuxTarget=*/nullptr); + + // Initialize the preprocessor. + AST->PP->Initialize(*AST->Target); + + AST->PP->setCounterValue(Counter); + + if (AST->Ctx) { + // Initialize the ASTContext + AST->Ctx->InitBuiltinTypes(*AST->Target); + + // Adjust printing policy based on language options. + AST->Ctx->setPrintingPolicy(PrintingPolicy(*AST->LangOpts)); + + // We didn't have access to the comment options when the ASTContext was + // constructed, so register them now. + AST->Ctx->getCommentCommandTraits().registerCommentOptions( + AST->LangOpts->CommentOpts); + } + + // The temporary FileManager we used for ASTReader::readASTFileControlBlock() + // might have already read stdin, and reading it again will fail. Let's + // explicitly forward the buffer. + if (Filename == "-") + if (auto FE = llvm::expectedToOptional(TmpFileMgr.getSTDIN())) + if (auto BufRef = TmpFileMgr.getBufferForFile(*FE)) { + auto Buf = llvm::MemoryBuffer::getMemBufferCopy( + (*BufRef)->getBuffer(), (*BufRef)->getBufferIdentifier()); + AST->Reader->getModuleManager().addInMemoryBuffer("-", std::move(Buf)); + } + + // Reinstate the provided options that are relevant for reading AST files. + AST->HSOpts->ForceCheckCXX20ModulesInputFiles = + HSOpts.ForceCheckCXX20ModulesInputFiles; + switch (AST->Reader->ReadAST(Filename, serialization::MK_MainFile, SourceLocation(), ASTReader::ARR_None)) { case ASTReader::Success: @@ -901,11 +888,18 @@ std::unique_ptr<ASTUnit> ASTUnit::LoadFromASTFile( return nullptr; } - AST->OriginalSourceFile = std::string(AST->Reader->getOriginalSourceFile()); + // Now that we have successfully loaded the AST file, we can reinstate some + // options that the clients expect us to preserve (but would trip AST file + // validation, so we couldn't set them earlier). + AST->HSOpts->UserEntries = HSOpts.UserEntries; + AST->HSOpts->SystemHeaderPrefixes = HSOpts.SystemHeaderPrefixes; + AST->HSOpts->VFSOverlayFiles = HSOpts.VFSOverlayFiles; + AST->LangOpts->PICLevel = LangOpts.PICLevel; + AST->LangOpts->PIE = LangOpts.PIE; - PP.setCounterValue(Counter); + AST->OriginalSourceFile = std::string(AST->Reader->getOriginalSourceFile()); - Module *M = HeaderInfo.lookupModule(AST->getLangOpts().CurrentModule); + Module *M = AST->HeaderInfo->lookupModule(AST->getLangOpts().CurrentModule); if (M && AST->getLangOpts().isCompilingModule() && M->isNamedModule()) AST->Ctx->setCurrentNamedModule(M); @@ -915,13 +909,14 @@ std::unique_ptr<ASTUnit> ASTUnit::LoadFromASTFile( // Create a semantic analysis object and tell the AST reader about it. if (ToLoad >= LoadEverything) { - AST->TheSema.reset(new Sema(PP, *AST->Ctx, *AST->Consumer)); + AST->TheSema = std::make_unique<Sema>(*AST->PP, *AST->Ctx, *AST->Consumer); AST->TheSema->Initialize(); AST->Reader->InitializeSema(*AST->TheSema); } // Tell the diagnostic client that we have started a source file. - AST->getDiagnostics().getClient()->BeginSourceFile(PP.getLangOpts(), &PP); + AST->getDiagnostics().getClient()->BeginSourceFile(AST->PP->getLangOpts(), + AST->PP.get()); return AST; } diff --git a/clang/lib/Parse/Parser.cpp b/clang/lib/Parse/Parser.cpp index ec01faf..a6fc676 100644 --- a/clang/lib/Parse/Parser.cpp +++ b/clang/lib/Parse/Parser.cpp @@ -708,7 +708,7 @@ bool Parser::ParseTopLevelDecl(DeclGroupPtrTy &Result, } // Late template parsing can begin. - Actions.SetLateTemplateParser(LateTemplateParserCallback, nullptr, this); + Actions.SetLateTemplateParser(LateTemplateParserCallback, this); Actions.ActOnEndOfTranslationUnit(); //else don't tell Sema that we ended parsing: more input might come. return true; diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp index 8ed3df7..23bf7f2 100644 --- a/clang/lib/Sema/Sema.cpp +++ b/clang/lib/Sema/Sema.cpp @@ -276,10 +276,9 @@ Sema::Sema(Preprocessor &pp, ASTContext &ctxt, ASTConsumer &consumer, Context(ctxt), Consumer(consumer), Diags(PP.getDiagnostics()), SourceMgr(PP.getSourceManager()), APINotes(SourceMgr, LangOpts), AnalysisWarnings(*this), ThreadSafetyDeclCache(nullptr), - LateTemplateParser(nullptr), LateTemplateParserCleanup(nullptr), - OpaqueParser(nullptr), CurContext(nullptr), ExternalSource(nullptr), - StackHandler(Diags), CurScope(nullptr), Ident_super(nullptr), - AMDGPUPtr(std::make_unique<SemaAMDGPU>(*this)), + LateTemplateParser(nullptr), OpaqueParser(nullptr), CurContext(nullptr), + ExternalSource(nullptr), StackHandler(Diags), CurScope(nullptr), + Ident_super(nullptr), AMDGPUPtr(std::make_unique<SemaAMDGPU>(*this)), ARMPtr(std::make_unique<SemaARM>(*this)), AVRPtr(std::make_unique<SemaAVR>(*this)), BPFPtr(std::make_unique<SemaBPF>(*this)), @@ -1248,9 +1247,6 @@ void Sema::ActOnEndOfTranslationUnit() { ? TUFragmentKind::Private : TUFragmentKind::Normal); - if (LateTemplateParserCleanup) - LateTemplateParserCleanup(OpaqueParser); - CheckDelayedMemberExceptionSpecs(); } else { // If we are building a TU prefix for serialization, it is safe to transfer diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 2990fd6..f99c01e 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1498,6 +1498,24 @@ static void builtinAllocaAddrSpace(Sema &S, CallExpr *TheCall) { TheCall->setType(S.Context.getPointerType(RT)); } +static bool checkBuiltinInferAllocToken(Sema &S, CallExpr *TheCall) { + if (S.checkArgCountAtLeast(TheCall, 1)) + return true; + + for (Expr *Arg : TheCall->arguments()) { + // If argument is dependent on a template parameter, we can't resolve now. + if (Arg->isTypeDependent() || Arg->isValueDependent()) + continue; + // Reject void types. + QualType ArgTy = Arg->IgnoreParenImpCasts()->getType(); + if (ArgTy->isVoidType()) + return S.Diag(Arg->getBeginLoc(), diag::err_param_with_void_type); + } + + TheCall->setType(S.Context.UnsignedLongLongTy); + return false; +} + namespace { enum PointerAuthOpKind { PAO_Strip, @@ -2779,6 +2797,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, builtinAllocaAddrSpace(*this, TheCall); } break; + case Builtin::BI__builtin_infer_alloc_token: + if (checkBuiltinInferAllocToken(*this, TheCall)) + return ExprError(); + break; case Builtin::BI__arithmetic_fence: if (BuiltinArithmeticFence(TheCall)) return ExprError(); diff --git a/clang/lib/Sema/SemaConcept.cpp b/clang/lib/Sema/SemaConcept.cpp index a1163e9..f04cc45 100644 --- a/clang/lib/Sema/SemaConcept.cpp +++ b/clang/lib/Sema/SemaConcept.cpp @@ -385,6 +385,28 @@ public: return inherited::TraverseStmt(E->getReplacement()); } + bool TraverseTemplateName(TemplateName Template) { + if (auto *TTP = dyn_cast_if_present<TemplateTemplateParmDecl>( + Template.getAsTemplateDecl()); + TTP && TTP->getDepth() < TemplateArgs.getNumLevels()) { + if (!TemplateArgs.hasTemplateArgument(TTP->getDepth(), + TTP->getPosition())) + return true; + + TemplateArgument Arg = TemplateArgs(TTP->getDepth(), TTP->getPosition()); + if (TTP->isParameterPack() && SemaRef.ArgPackSubstIndex) { + assert(Arg.getKind() == TemplateArgument::Pack && + "Missing argument pack"); + Arg = SemaRef.getPackSubstitutedTemplateArgument(Arg); + } + assert(!Arg.getAsTemplate().isNull() && + "Null template template argument"); + UsedTemplateArgs.push_back( + SemaRef.Context.getCanonicalTemplateArgument(Arg)); + } + return inherited::TraverseTemplateName(Template); + } + void VisitConstraint(const NormalizedConstraintWithParamMapping &Constraint) { if (!Constraint.hasParameterMapping()) { for (const auto &List : TemplateArgs) @@ -2678,8 +2700,9 @@ FormulaType SubsumptionChecker::Normalize(const NormalizedConstraint &NC) { }); if (Compound.getCompoundKind() == FormulaType::Kind) { + unsigned SizeLeft = Left.size(); Res = std::move(Left); - Res.reserve(Left.size() + Right.size()); + Res.reserve(SizeLeft + Right.size()); std::for_each(std::make_move_iterator(Right.begin()), std::make_move_iterator(Right.end()), Add); return Res; diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 8b3fd41..c1b5cb7 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -5811,7 +5811,13 @@ bool ASTReader::readASTFileControlBlock( // FIXME: This allows use of the VFS; we do not allow use of the // VFS when actually loading a module. - auto BufferOrErr = FileMgr.getBufferForFile(Filename); + auto Entry = + Filename == "-" ? FileMgr.getSTDIN() : FileMgr.getFileRef(Filename); + if (!Entry) { + llvm::consumeError(Entry.takeError()); + return true; + } + auto BufferOrErr = FileMgr.getBufferForFile(*Entry); if (!BufferOrErr) return true; OwnedBuffer = std::move(*BufferOrErr); diff --git a/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.cpp b/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.cpp index b0096d8..05d5669 100644 --- a/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.cpp +++ b/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.cpp @@ -382,7 +382,8 @@ DignosticsEngineWithDiagOpts::DignosticsEngineWithDiagOpts( std::pair<std::unique_ptr<driver::Driver>, std::unique_ptr<driver::Compilation>> buildCompilation(ArrayRef<std::string> ArgStrs, DiagnosticsEngine &Diags, - IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS) { + IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS, + llvm::BumpPtrAllocator &Alloc) { SmallVector<const char *, 256> Argv; Argv.reserve(ArgStrs.size()); for (const std::string &Arg : ArgStrs) @@ -393,7 +394,6 @@ buildCompilation(ArrayRef<std::string> ArgStrs, DiagnosticsEngine &Diags, "clang LLVM compiler", FS); Driver->setTitle("clang_based_tool"); - llvm::BumpPtrAllocator Alloc; bool CLMode = driver::IsClangCL( driver::getDriverMode(Argv[0], ArrayRef(Argv).slice(1))); diff --git a/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.h b/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.h index 71c6731..5657317 100644 --- a/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.h +++ b/clang/lib/Tooling/DependencyScanning/DependencyScannerImpl.h @@ -105,7 +105,8 @@ struct TextDiagnosticsPrinterWithOutput { std::pair<std::unique_ptr<driver::Driver>, std::unique_ptr<driver::Compilation>> buildCompilation(ArrayRef<std::string> ArgStrs, DiagnosticsEngine &Diags, - IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS); + IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS, + llvm::BumpPtrAllocator &Alloc); std::unique_ptr<CompilerInvocation> createCompilerInvocation(ArrayRef<std::string> CommandLine, diff --git a/clang/lib/Tooling/DependencyScanning/DependencyScanningWorker.cpp b/clang/lib/Tooling/DependencyScanning/DependencyScanningWorker.cpp index 9515421..0a1cf6b 100644 --- a/clang/lib/Tooling/DependencyScanning/DependencyScanningWorker.cpp +++ b/clang/lib/Tooling/DependencyScanning/DependencyScanningWorker.cpp @@ -78,8 +78,10 @@ static bool forEachDriverJob( IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS, llvm::function_ref<bool(const driver::Command &Cmd)> Callback) { // Compilation holds a non-owning a reference to the Driver, hence we need to - // keep the Driver alive when we use Compilation. - auto [Driver, Compilation] = buildCompilation(ArgStrs, Diags, FS); + // keep the Driver alive when we use Compilation. Arguments to commands may be + // owned by Alloc when expanded from response files. + llvm::BumpPtrAllocator Alloc; + auto [Driver, Compilation] = buildCompilation(ArgStrs, Diags, FS, Alloc); if (!Compilation) return false; for (const driver::Command &Job : Compilation->getJobs()) { |