diff options
Diffstat (limited to 'clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp')
| -rw-r--r-- | clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp | 68 |
1 files changed, 37 insertions, 31 deletions
diff --git a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp index bd51cc5..0014153 100644 --- a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp +++ b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp @@ -18,18 +18,15 @@ #include "llvm/Support/FormatVariadic.h" namespace clang::tidy::llvm_check { -namespace { using namespace ::clang::ast_matchers; using namespace ::clang::transformer; -EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, - RangeSelector CallArgs) { +static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) { // This is using an EditGenerator rather than ASTEdit as we want to warn even // if in macro. - return [Call = std::move(Call), Builder = std::move(Builder), - CallArgs = - std::move(CallArgs)](const MatchFinder::MatchResult &Result) + return [Call = std::move(Call), + Builder = std::move(Builder)](const MatchFinder::MatchResult &Result) -> Expected<SmallVector<transformer::Edit, 1>> { Expected<CharSourceRange> CallRange = Call(Result); if (!CallRange) @@ -54,7 +51,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, auto NextToken = [&](std::optional<Token> CurrentToken) { if (!CurrentToken) return CurrentToken; - if (CurrentToken->getEndLoc() >= CallRange->getEnd()) + if (CurrentToken->is(clang::tok::eof)) return std::optional<Token>(); return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM, LangOpts); @@ -68,9 +65,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, "missing '<' token"); } + std::optional<Token> EndToken = NextToken(LessToken); - for (std::optional<Token> GreaterToken = NextToken(EndToken); - GreaterToken && GreaterToken->getKind() != clang::tok::greater; + std::optional<Token> GreaterToken = NextToken(EndToken); + for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater; GreaterToken = NextToken(GreaterToken)) { EndToken = GreaterToken; } @@ -79,12 +77,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, "missing '>' token"); } + std::optional<Token> ArgStart = NextToken(GreaterToken); + if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) { + return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, + "missing '(' token"); + } + std::optional<Token> Arg = NextToken(ArgStart); + if (!Arg) { + return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, + "unexpected end of file"); + } + const bool HasArgs = Arg->getKind() != clang::tok::r_paren; + Expected<CharSourceRange> BuilderRange = Builder(Result); if (!BuilderRange) return BuilderRange.takeError(); - Expected<CharSourceRange> CallArgsRange = CallArgs(Result); - if (!CallArgsRange) - return CallArgsRange.takeError(); // Helper for concatting below. auto GetText = [&](const CharSourceRange &Range) { @@ -93,43 +100,42 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, Edit Replace; Replace.Kind = EditKind::Range; - Replace.Range = *CallRange; - std::string CallArgsStr; - // Only emit args if there are any. - if (auto CallArgsText = GetText(*CallArgsRange).ltrim(); - !CallArgsText.rtrim().empty()) { - CallArgsStr = llvm::formatv(", {}", CallArgsText); + Replace.Range.setBegin(CallRange->getBegin()); + Replace.Range.setEnd(ArgStart->getEndLoc()); + const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder"); + std::string BuilderText = GetText(*BuilderRange).str(); + if (BuilderExpr->getType()->isPointerType()) { + BuilderText = BuilderExpr->isImplicitCXXThis() + ? "*this" + : llvm::formatv("*{}", BuilderText).str(); } - Replace.Replacement = - llvm::formatv("{}::create({}{})", - GetText(CharSourceRange::getTokenRange( - LessToken->getEndLoc(), EndToken->getLastLoc())), - GetText(*BuilderRange), CallArgsStr); + const StringRef OpType = GetText(CharSourceRange::getTokenRange( + LessToken->getEndLoc(), EndToken->getLastLoc())); + Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText, + HasArgs ? ", " : ""); return SmallVector<Edit, 1>({Replace}); }; } -RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() { +static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() { Stencil Message = cat("use 'OpType::create(builder, ...)' instead of " "'builder.create<OpType>(...)'"); // Match a create call on an OpBuilder. + auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")); ast_matchers::internal::Matcher<Stmt> Base = cxxMemberCallExpr( - on(expr(hasType( - cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")))) + on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType)))) .bind("builder")), - callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))), - callee(cxxMethodDecl(hasName("create")))) + callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()), + hasName("create")))) .bind("call"); return applyFirst( // Attempt rewrite given an lvalue builder, else just warn. {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base), - rewrite(node("call"), node("builder"), callArgs("call")), - Message), + rewrite(node("call"), node("builder")), Message), makeRule(Base, noopEdit(node("call")), Message)}); } -} // namespace UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name, ClangTidyContext *Context) |
