aboutsummaryrefslogtreecommitdiff
path: root/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp')
-rw-r--r--clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp139
1 files changed, 139 insertions, 0 deletions
diff --git a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
new file mode 100644
index 0000000..4722199
--- /dev/null
+++ b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
@@ -0,0 +1,139 @@
+//===--- UseNewMLIROpBuilderCheck.cpp - clang-tidy ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "UseNewMLIROpBuilderCheck.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Tooling/Transformer/RangeSelector.h"
+#include "clang/Tooling/Transformer/RewriteRule.h"
+#include "clang/Tooling/Transformer/SourceCode.h"
+#include "clang/Tooling/Transformer/Stencil.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace clang::tidy::llvm_check {
+namespace {
+
+using namespace ::clang::ast_matchers;
+using namespace ::clang::transformer;
+
+EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
+ RangeSelector CallArgs) {
+ // This is using an EditGenerator rather than ASTEdit as we want to warn even
+ // if in macro.
+ return [Call = std::move(Call), Builder = std::move(Builder),
+ CallArgs =
+ std::move(CallArgs)](const MatchFinder::MatchResult &Result)
+ -> Expected<SmallVector<transformer::Edit, 1>> {
+ Expected<CharSourceRange> CallRange = Call(Result);
+ if (!CallRange)
+ return CallRange.takeError();
+ SourceManager &SM = *Result.SourceManager;
+ const LangOptions &LangOpts = Result.Context->getLangOpts();
+ SourceLocation Begin = CallRange->getBegin();
+
+ // This will result in just a warning and no edit.
+ bool InMacro = CallRange->getBegin().isMacroID();
+ if (InMacro) {
+ while (SM.isMacroArgExpansion(Begin))
+ Begin = SM.getImmediateExpansionRange(Begin).getBegin();
+ Edit WarnOnly;
+ WarnOnly.Kind = EditKind::Range;
+ WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin);
+ return SmallVector<Edit, 1>({WarnOnly});
+ }
+
+ // This will try to extract the template argument as written so that the
+ // rewritten code looks closest to original.
+ auto NextToken = [&](std::optional<Token> CurrentToken) {
+ if (!CurrentToken)
+ return CurrentToken;
+ if (CurrentToken->getEndLoc() >= CallRange->getEnd())
+ return std::optional<Token>();
+ return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
+ LangOpts);
+ };
+ std::optional<Token> LessToken =
+ clang::Lexer::findNextToken(Begin, SM, LangOpts);
+ while (LessToken && LessToken->getKind() != clang::tok::less) {
+ LessToken = NextToken(LessToken);
+ }
+ if (!LessToken) {
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
+ "missing '<' token");
+ }
+ std::optional<Token> EndToken = NextToken(LessToken);
+ for (std::optional<Token> GreaterToken = NextToken(EndToken);
+ GreaterToken && GreaterToken->getKind() != clang::tok::greater;
+ GreaterToken = NextToken(GreaterToken)) {
+ EndToken = GreaterToken;
+ }
+ if (!EndToken) {
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
+ "missing '>' token");
+ }
+
+ Expected<CharSourceRange> BuilderRange = Builder(Result);
+ if (!BuilderRange)
+ return BuilderRange.takeError();
+ Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
+ if (!CallArgsRange)
+ return CallArgsRange.takeError();
+
+ // Helper for concatting below.
+ auto GetText = [&](const CharSourceRange &Range) {
+ return clang::Lexer::getSourceText(Range, SM, LangOpts);
+ };
+
+ Edit Replace;
+ Replace.Kind = EditKind::Range;
+ Replace.Range = *CallRange;
+ std::string CallArgsStr;
+ // Only emit args if there are any.
+ if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
+ !CallArgsText.rtrim().empty()) {
+ CallArgsStr = llvm::formatv(", {}", CallArgsText);
+ }
+ Replace.Replacement =
+ llvm::formatv("{}::create({}{})",
+ GetText(CharSourceRange::getTokenRange(
+ LessToken->getEndLoc(), EndToken->getLastLoc())),
+ GetText(*BuilderRange), CallArgsStr);
+
+ return SmallVector<Edit, 1>({Replace});
+ };
+}
+
+RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
+ Stencil message = cat("use 'OpType::create(builder, ...)' instead of "
+ "'builder.create<OpType>(...)'");
+ // Match a create call on an OpBuilder.
+ ast_matchers::internal::Matcher<Stmt> base =
+ cxxMemberCallExpr(
+ on(expr(hasType(
+ cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
+ .bind("builder")),
+ callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
+ callee(cxxMethodDecl(hasName("create"))))
+ .bind("call");
+ return applyFirst(
+ // Attempt rewrite given an lvalue builder, else just warn.
+ {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base),
+ rewrite(node("call"), node("builder"), callArgs("call")),
+ message),
+ makeRule(base, noopEdit(node("call")), message)});
+}
+} // namespace
+
+UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,
+ ClangTidyContext *Context)
+ : TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) {
+}
+
+} // namespace clang::tidy::llvm_check