diff options
Diffstat (limited to 'clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp')
| -rw-r--r-- | clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp | 68 | 
1 files changed, 37 insertions, 31 deletions
diff --git a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp index bd51cc5..0014153 100644 --- a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp +++ b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp @@ -18,18 +18,15 @@  #include "llvm/Support/FormatVariadic.h"  namespace clang::tidy::llvm_check { -namespace {  using namespace ::clang::ast_matchers;  using namespace ::clang::transformer; -EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, -                      RangeSelector CallArgs) { +static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {    // This is using an EditGenerator rather than ASTEdit as we want to warn even    // if in macro. -  return [Call = std::move(Call), Builder = std::move(Builder), -          CallArgs = -              std::move(CallArgs)](const MatchFinder::MatchResult &Result) +  return [Call = std::move(Call), +          Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)               -> Expected<SmallVector<transformer::Edit, 1>> {      Expected<CharSourceRange> CallRange = Call(Result);      if (!CallRange) @@ -54,7 +51,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,      auto NextToken = [&](std::optional<Token> CurrentToken) {        if (!CurrentToken)          return CurrentToken; -      if (CurrentToken->getEndLoc() >= CallRange->getEnd()) +      if (CurrentToken->is(clang::tok::eof))          return std::optional<Token>();        return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,                                           LangOpts); @@ -68,9 +65,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,        return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,                                                   "missing '<' token");      } +      std::optional<Token> EndToken = NextToken(LessToken); -    for (std::optional<Token> GreaterToken = NextToken(EndToken); -         GreaterToken && GreaterToken->getKind() != clang::tok::greater; +    std::optional<Token> GreaterToken = NextToken(EndToken); +    for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;           GreaterToken = NextToken(GreaterToken)) {        EndToken = GreaterToken;      } @@ -79,12 +77,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,                                                   "missing '>' token");      } +    std::optional<Token> ArgStart = NextToken(GreaterToken); +    if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) { +      return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, +                                                 "missing '(' token"); +    } +    std::optional<Token> Arg = NextToken(ArgStart); +    if (!Arg) { +      return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, +                                                 "unexpected end of file"); +    } +    const bool HasArgs = Arg->getKind() != clang::tok::r_paren; +      Expected<CharSourceRange> BuilderRange = Builder(Result);      if (!BuilderRange)        return BuilderRange.takeError(); -    Expected<CharSourceRange> CallArgsRange = CallArgs(Result); -    if (!CallArgsRange) -      return CallArgsRange.takeError();      // Helper for concatting below.      auto GetText = [&](const CharSourceRange &Range) { @@ -93,43 +100,42 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,      Edit Replace;      Replace.Kind = EditKind::Range; -    Replace.Range = *CallRange; -    std::string CallArgsStr; -    // Only emit args if there are any. -    if (auto CallArgsText = GetText(*CallArgsRange).ltrim(); -        !CallArgsText.rtrim().empty()) { -      CallArgsStr = llvm::formatv(", {}", CallArgsText); +    Replace.Range.setBegin(CallRange->getBegin()); +    Replace.Range.setEnd(ArgStart->getEndLoc()); +    const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder"); +    std::string BuilderText = GetText(*BuilderRange).str(); +    if (BuilderExpr->getType()->isPointerType()) { +      BuilderText = BuilderExpr->isImplicitCXXThis() +                        ? "*this" +                        : llvm::formatv("*{}", BuilderText).str();      } -    Replace.Replacement = -        llvm::formatv("{}::create({}{})", -                      GetText(CharSourceRange::getTokenRange( -                          LessToken->getEndLoc(), EndToken->getLastLoc())), -                      GetText(*BuilderRange), CallArgsStr); +    const StringRef OpType = GetText(CharSourceRange::getTokenRange( +        LessToken->getEndLoc(), EndToken->getLastLoc())); +    Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText, +                                        HasArgs ? ", " : "");      return SmallVector<Edit, 1>({Replace});    };  } -RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() { +static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {    Stencil Message = cat("use 'OpType::create(builder, ...)' instead of "                          "'builder.create<OpType>(...)'");    // Match a create call on an OpBuilder. +  auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));    ast_matchers::internal::Matcher<Stmt> Base =        cxxMemberCallExpr( -          on(expr(hasType( -                      cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")))) +          on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))                   .bind("builder")), -          callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))), -          callee(cxxMethodDecl(hasName("create")))) +          callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()), +                               hasName("create"))))            .bind("call");    return applyFirst(        //  Attempt rewrite given an lvalue builder, else just warn.        {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base), -                rewrite(node("call"), node("builder"), callArgs("call")), -                Message), +                rewrite(node("call"), node("builder")), Message),         makeRule(Base, noopEdit(node("call")), Message)});  } -} // namespace  UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,                                                     ClangTidyContext *Context)  | 
