diff options
Diffstat (limited to 'clang/unittests/AST/ASTPrint.h')
-rw-r--r-- | clang/unittests/AST/ASTPrint.h | 88 |
1 files changed, 52 insertions, 36 deletions
diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h index c3b6b84..0e35846 100644 --- a/clang/unittests/AST/ASTPrint.h +++ b/clang/unittests/AST/ASTPrint.h @@ -19,72 +19,88 @@ namespace clang { -using PolicyAdjusterType = - Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>; - -static void PrintStmt(raw_ostream &Out, const ASTContext *Context, - const Stmt *S, PolicyAdjusterType PolicyAdjuster) { - assert(S != nullptr && "Expected non-null Stmt"); - PrintingPolicy Policy = Context->getPrintingPolicy(); - if (PolicyAdjuster) - (*PolicyAdjuster)(Policy); - S->printPretty(Out, /*Helper*/ nullptr, Policy); -} +using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>; + +template <typename NodeType> +using NodePrinter = + std::function<void(llvm::raw_ostream &Out, const ASTContext *Context, + const NodeType *Node, + PrintingPolicyAdjuster PolicyAdjuster)>; + +template <typename NodeType> +using NodeFilter = std::function<bool(const NodeType *Node)>; +template <typename NodeType> class PrintMatch : public ast_matchers::MatchFinder::MatchCallback { + using PrinterT = NodePrinter<NodeType>; + using FilterT = NodeFilter<NodeType>; + SmallString<1024> Printed; - unsigned NumFoundStmts; - PolicyAdjusterType PolicyAdjuster; + unsigned NumFoundNodes; + PrinterT Printer; + FilterT Filter; + PrintingPolicyAdjuster PolicyAdjuster; public: - PrintMatch(PolicyAdjusterType PolicyAdjuster) - : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {} + PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster, + FilterT Filter) + : NumFoundNodes(0), Printer(std::move(Printer)), + Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {} void run(const ast_matchers::MatchFinder::MatchResult &Result) override { - const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id"); - if (!S) + const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id"); + if (!N || !Filter(N)) return; - NumFoundStmts++; - if (NumFoundStmts > 1) + NumFoundNodes++; + if (NumFoundNodes > 1) return; llvm::raw_svector_ostream Out(Printed); - PrintStmt(Out, Result.Context, S, PolicyAdjuster); + Printer(Out, Result.Context, N, PolicyAdjuster); } StringRef getPrinted() const { return Printed; } - unsigned getNumFoundStmts() const { return NumFoundStmts; } + unsigned getNumFoundNodes() const { return NumFoundNodes; } }; -template <typename T> -::testing::AssertionResult -PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args, - const T &NodeMatch, StringRef ExpectedPrinted, - PolicyAdjusterType PolicyAdjuster = None) { +template <typename NodeType, typename Matcher> +::testing::AssertionResult PrintedNodeMatches( + StringRef Code, const std::vector<std::string> &Args, + const Matcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, + NodePrinter<NodeType> Printer, + PrintingPolicyAdjuster PolicyAdjuster = nullptr, bool AllowError = false, + NodeFilter<NodeType> Filter = [](const NodeType *) { return true; }) { - PrintMatch Printer(PolicyAdjuster); + PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter); ast_matchers::MatchFinder Finder; - Finder.addMatcher(NodeMatch, &Printer); + Finder.addMatcher(NodeMatch, &Callback); std::unique_ptr<tooling::FrontendActionFactory> Factory( tooling::newFrontendActionFactory(&Finder)); - if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args)) + bool ToolResult; + if (FileName.empty()) { + ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args); + } else { + ToolResult = + tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName); + } + if (!ToolResult && !AllowError) return testing::AssertionFailure() << "Parsing error in \"" << Code.str() << "\""; - if (Printer.getNumFoundStmts() == 0) - return testing::AssertionFailure() << "Matcher didn't find any statements"; + if (Callback.getNumFoundNodes() == 0) + return testing::AssertionFailure() << "Matcher didn't find any nodes"; - if (Printer.getNumFoundStmts() > 1) + if (Callback.getNumFoundNodes() > 1) return testing::AssertionFailure() - << "Matcher should match only one statement (found " - << Printer.getNumFoundStmts() << ")"; + << "Matcher should match only one node (found " + << Callback.getNumFoundNodes() << ")"; - if (Printer.getPrinted() != ExpectedPrinted) + if (Callback.getPrinted() != ExpectedPrinted) return ::testing::AssertionFailure() << "Expected \"" << ExpectedPrinted.str() << "\", got \"" - << Printer.getPrinted().str() << "\""; + << Callback.getPrinted().str() << "\""; return ::testing::AssertionSuccess(); } |