//===--- UseNewMLIROpBuilderCheck.cpp - clang-tidy ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "UseNewMLIROpBuilderCheck.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "clang/Basic/LLVM.h" #include "clang/Lex/Lexer.h" #include "clang/Tooling/Transformer/RangeSelector.h" #include "clang/Tooling/Transformer/RewriteRule.h" #include "clang/Tooling/Transformer/SourceCode.h" #include "clang/Tooling/Transformer/Stencil.h" #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" namespace clang::tidy::llvm_check { namespace { using namespace ::clang::ast_matchers; using namespace ::clang::transformer; EditGenerator rewrite(RangeSelector Call, RangeSelector Builder, RangeSelector CallArgs) { // This is using an EditGenerator rather than ASTEdit as we want to warn even // if in macro. return [Call = std::move(Call), Builder = std::move(Builder), CallArgs = std::move(CallArgs)](const MatchFinder::MatchResult &Result) -> Expected> { Expected CallRange = Call(Result); if (!CallRange) return CallRange.takeError(); SourceManager &SM = *Result.SourceManager; const LangOptions &LangOpts = Result.Context->getLangOpts(); SourceLocation Begin = CallRange->getBegin(); // This will result in just a warning and no edit. bool InMacro = CallRange->getBegin().isMacroID(); if (InMacro) { while (SM.isMacroArgExpansion(Begin)) Begin = SM.getImmediateExpansionRange(Begin).getBegin(); Edit WarnOnly; WarnOnly.Kind = EditKind::Range; WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin); return SmallVector({WarnOnly}); } // This will try to extract the template argument as written so that the // rewritten code looks closest to original. auto NextToken = [&](std::optional CurrentToken) { if (!CurrentToken) return CurrentToken; if (CurrentToken->getEndLoc() >= CallRange->getEnd()) return std::optional(); return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM, LangOpts); }; std::optional LessToken = clang::Lexer::findNextToken(Begin, SM, LangOpts); while (LessToken && LessToken->getKind() != clang::tok::less) { LessToken = NextToken(LessToken); } if (!LessToken) { return llvm::make_error(llvm::errc::invalid_argument, "missing '<' token"); } std::optional EndToken = NextToken(LessToken); for (std::optional GreaterToken = NextToken(EndToken); GreaterToken && GreaterToken->getKind() != clang::tok::greater; GreaterToken = NextToken(GreaterToken)) { EndToken = GreaterToken; } if (!EndToken) { return llvm::make_error(llvm::errc::invalid_argument, "missing '>' token"); } Expected BuilderRange = Builder(Result); if (!BuilderRange) return BuilderRange.takeError(); Expected CallArgsRange = CallArgs(Result); if (!CallArgsRange) return CallArgsRange.takeError(); // Helper for concatting below. auto GetText = [&](const CharSourceRange &Range) { return clang::Lexer::getSourceText(Range, SM, LangOpts); }; Edit Replace; Replace.Kind = EditKind::Range; Replace.Range = *CallRange; std::string CallArgsStr; // Only emit args if there are any. if (auto CallArgsText = GetText(*CallArgsRange).ltrim(); !CallArgsText.rtrim().empty()) { CallArgsStr = llvm::formatv(", {}", CallArgsText); } Replace.Replacement = llvm::formatv("{}::create({}{})", GetText(CharSourceRange::getTokenRange( LessToken->getEndLoc(), EndToken->getLastLoc())), GetText(*BuilderRange), CallArgsStr); return SmallVector({Replace}); }; } RewriteRuleWith useNewMlirOpBuilderCheckRule() { Stencil message = cat("use 'OpType::create(builder, ...)' instead of " "'builder.create(...)'"); // Match a create call on an OpBuilder. ast_matchers::internal::Matcher base = cxxMemberCallExpr( on(expr(hasType( cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")))) .bind("builder")), callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))), callee(cxxMethodDecl(hasName("create")))) .bind("call"); return applyFirst( // Attempt rewrite given an lvalue builder, else just warn. {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base), rewrite(node("call"), node("builder"), callArgs("call")), message), makeRule(base, noopEdit(node("call")), message)}); } } // namespace UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name, ClangTidyContext *Context) : TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) { } } // namespace clang::tidy::llvm_check