aboutsummaryrefslogtreecommitdiff
path: root/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp')
-rw-r--r--clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp68
1 files changed, 37 insertions, 31 deletions
diff --git a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
index bd51cc5..0014153 100644
--- a/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
+++ b/clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
@@ -18,18 +18,15 @@
#include "llvm/Support/FormatVariadic.h"
namespace clang::tidy::llvm_check {
-namespace {
using namespace ::clang::ast_matchers;
using namespace ::clang::transformer;
-EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
- RangeSelector CallArgs) {
+static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
// This is using an EditGenerator rather than ASTEdit as we want to warn even
// if in macro.
- return [Call = std::move(Call), Builder = std::move(Builder),
- CallArgs =
- std::move(CallArgs)](const MatchFinder::MatchResult &Result)
+ return [Call = std::move(Call),
+ Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
-> Expected<SmallVector<transformer::Edit, 1>> {
Expected<CharSourceRange> CallRange = Call(Result);
if (!CallRange)
@@ -54,7 +51,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
auto NextToken = [&](std::optional<Token> CurrentToken) {
if (!CurrentToken)
return CurrentToken;
- if (CurrentToken->getEndLoc() >= CallRange->getEnd())
+ if (CurrentToken->is(clang::tok::eof))
return std::optional<Token>();
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
LangOpts);
@@ -68,9 +65,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '<' token");
}
+
std::optional<Token> EndToken = NextToken(LessToken);
- for (std::optional<Token> GreaterToken = NextToken(EndToken);
- GreaterToken && GreaterToken->getKind() != clang::tok::greater;
+ std::optional<Token> GreaterToken = NextToken(EndToken);
+ for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
GreaterToken = NextToken(GreaterToken)) {
EndToken = GreaterToken;
}
@@ -79,12 +77,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
"missing '>' token");
}
+ std::optional<Token> ArgStart = NextToken(GreaterToken);
+ if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
+ "missing '(' token");
+ }
+ std::optional<Token> Arg = NextToken(ArgStart);
+ if (!Arg) {
+ return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
+ "unexpected end of file");
+ }
+ const bool HasArgs = Arg->getKind() != clang::tok::r_paren;
+
Expected<CharSourceRange> BuilderRange = Builder(Result);
if (!BuilderRange)
return BuilderRange.takeError();
- Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
- if (!CallArgsRange)
- return CallArgsRange.takeError();
// Helper for concatting below.
auto GetText = [&](const CharSourceRange &Range) {
@@ -93,43 +100,42 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
Edit Replace;
Replace.Kind = EditKind::Range;
- Replace.Range = *CallRange;
- std::string CallArgsStr;
- // Only emit args if there are any.
- if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
- !CallArgsText.rtrim().empty()) {
- CallArgsStr = llvm::formatv(", {}", CallArgsText);
+ Replace.Range.setBegin(CallRange->getBegin());
+ Replace.Range.setEnd(ArgStart->getEndLoc());
+ const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
+ std::string BuilderText = GetText(*BuilderRange).str();
+ if (BuilderExpr->getType()->isPointerType()) {
+ BuilderText = BuilderExpr->isImplicitCXXThis()
+ ? "*this"
+ : llvm::formatv("*{}", BuilderText).str();
}
- Replace.Replacement =
- llvm::formatv("{}::create({}{})",
- GetText(CharSourceRange::getTokenRange(
- LessToken->getEndLoc(), EndToken->getLastLoc())),
- GetText(*BuilderRange), CallArgsStr);
+ const StringRef OpType = GetText(CharSourceRange::getTokenRange(
+ LessToken->getEndLoc(), EndToken->getLastLoc()));
+ Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
+ HasArgs ? ", " : "");
return SmallVector<Edit, 1>({Replace});
};
}
-RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
+static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
Stencil Message = cat("use 'OpType::create(builder, ...)' instead of "
"'builder.create<OpType>(...)'");
// Match a create call on an OpBuilder.
+ auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
ast_matchers::internal::Matcher<Stmt> Base =
cxxMemberCallExpr(
- on(expr(hasType(
- cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
+ on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
.bind("builder")),
- callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
- callee(cxxMethodDecl(hasName("create"))))
+ callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()),
+ hasName("create"))))
.bind("call");
return applyFirst(
// Attempt rewrite given an lvalue builder, else just warn.
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base),
- rewrite(node("call"), node("builder"), callArgs("call")),
- Message),
+ rewrite(node("call"), node("builder")), Message),
makeRule(Base, noopEdit(node("call")), Message)});
}
-} // namespace
UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,
ClangTidyContext *Context)