aboutsummaryrefslogtreecommitdiff
path: root/clang-tools-extra/clang-tidy/misc/CoroutineHostileRAIICheck.cpp
blob: 360335b86c6418edd8e6ee78cbd22bc4d23e8d48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
//===--- CoroutineHostileRAII.cpp - clang-tidy ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CoroutineHostileRAIICheck.h"
#include "../utils/OptionsUtils.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/Type.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersInternal.h"
#include "clang/Basic/AttrKinds.h"
#include "clang/Basic/DiagnosticIDs.h"

using namespace clang::ast_matchers;
namespace clang::tidy::misc {
namespace {
using clang::ast_matchers::internal::BoundNodesTreeBuilder;

AST_MATCHER_P(Stmt, forEachPrevStmt, ast_matchers::internal::Matcher<Stmt>,
              InnerMatcher) {
  DynTypedNode P;
  bool IsHostile = false;
  for (const Stmt *Child = &Node; Child; Child = P.get<Stmt>()) {
    auto Parents = Finder->getASTContext().getParents(*Child);
    if (Parents.empty())
      break;
    P = *Parents.begin();
    auto *PCS = P.get<CompoundStmt>();
    if (!PCS)
      continue;
    for (const auto &Sibling : PCS->children()) {
      // Child contains suspension. Siblings after Child do not persist across
      // this suspension.
      if (Sibling == Child)
        break;
      // In case of a match, add the bindings as a separate match. Also don't
      // clear the bindings if a match is not found (unlike Matcher::matches).
      BoundNodesTreeBuilder SiblingBuilder;
      if (InnerMatcher.matches(*Sibling, Finder, &SiblingBuilder)) {
        Builder->addMatch(SiblingBuilder);
        IsHostile = true;
      }
    }
  }
  return IsHostile;
}

// Matches the expression awaited by the `co_await`.
AST_MATCHER_P(CoawaitExpr, awaitable, ast_matchers::internal::Matcher<Expr>,
              InnerMatcher) {
  if (Expr *E = Node.getOperand())
    return InnerMatcher.matches(*E, Finder, Builder);
  return false;
}

auto typeWithNameIn(const std::vector<StringRef> &Names) {
  return hasType(
      hasCanonicalType(hasDeclaration(namedDecl(hasAnyName(Names)))));
}
} // namespace

CoroutineHostileRAIICheck::CoroutineHostileRAIICheck(StringRef Name,
                                                     ClangTidyContext *Context)
    : ClangTidyCheck(Name, Context),
      RAIITypesList(utils::options::parseStringList(
          Options.get("RAIITypesList", "std::lock_guard;std::scoped_lock"))),
      AllowedAwaitablesList(utils::options::parseStringList(
          Options.get("AllowedAwaitablesList", ""))) {}

void CoroutineHostileRAIICheck::registerMatchers(MatchFinder *Finder) {
  // A suspension happens with co_await or co_yield.
  auto ScopedLockable = varDecl(hasType(hasCanonicalType(hasDeclaration(
                                    hasAttr(attr::Kind::ScopedLockable)))))
                            .bind("scoped-lockable");
  auto OtherRAII = varDecl(typeWithNameIn(RAIITypesList)).bind("raii");
  auto AllowedSuspend = awaitable(typeWithNameIn(AllowedAwaitablesList));
  Finder->addMatcher(
      expr(anyOf(coawaitExpr(unless(AllowedSuspend)), coyieldExpr()),
           forEachPrevStmt(
               declStmt(forEach(varDecl(anyOf(ScopedLockable, OtherRAII))))))
          .bind("suspension"),
      this);
}

void CoroutineHostileRAIICheck::check(const MatchFinder::MatchResult &Result) {
  if (const auto *VD = Result.Nodes.getNodeAs<VarDecl>("scoped-lockable"))
    diag(VD->getLocation(),
         "%0 holds a lock across a suspension point of coroutine and could be "
         "unlocked by a different thread")
        << VD;
  if (const auto *VD = Result.Nodes.getNodeAs<VarDecl>("raii"))
    diag(VD->getLocation(),
         "%0 persists across a suspension point of coroutine")
        << VD;
  if (const auto *Suspension = Result.Nodes.getNodeAs<Expr>("suspension"))
    diag(Suspension->getBeginLoc(), "suspension point is here",
         DiagnosticIDs::Note);
}

void CoroutineHostileRAIICheck::storeOptions(
    ClangTidyOptions::OptionMap &Opts) {
  Options.store(Opts, "RAIITypesList",
                utils::options::serializeStringList(RAIITypesList));
  Options.store(Opts, "SafeAwaitableList",
                utils::options::serializeStringList(AllowedAwaitablesList));
}
} // namespace clang::tidy::misc