aboutsummaryrefslogtreecommitdiff
path: root/clang/unittests/AST/ASTPrint.h
diff options
context:
space:
mode:
Diffstat (limited to 'clang/unittests/AST/ASTPrint.h')
-rw-r--r--clang/unittests/AST/ASTPrint.h88
1 files changed, 52 insertions, 36 deletions
diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h
index c3b6b84..0e35846 100644
--- a/clang/unittests/AST/ASTPrint.h
+++ b/clang/unittests/AST/ASTPrint.h
@@ -19,72 +19,88 @@
namespace clang {
-using PolicyAdjusterType =
- Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
-
-static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
- const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
- assert(S != nullptr && "Expected non-null Stmt");
- PrintingPolicy Policy = Context->getPrintingPolicy();
- if (PolicyAdjuster)
- (*PolicyAdjuster)(Policy);
- S->printPretty(Out, /*Helper*/ nullptr, Policy);
-}
+using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>;
+
+template <typename NodeType>
+using NodePrinter =
+ std::function<void(llvm::raw_ostream &Out, const ASTContext *Context,
+ const NodeType *Node,
+ PrintingPolicyAdjuster PolicyAdjuster)>;
+
+template <typename NodeType>
+using NodeFilter = std::function<bool(const NodeType *Node)>;
+template <typename NodeType>
class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
+ using PrinterT = NodePrinter<NodeType>;
+ using FilterT = NodeFilter<NodeType>;
+
SmallString<1024> Printed;
- unsigned NumFoundStmts;
- PolicyAdjusterType PolicyAdjuster;
+ unsigned NumFoundNodes;
+ PrinterT Printer;
+ FilterT Filter;
+ PrintingPolicyAdjuster PolicyAdjuster;
public:
- PrintMatch(PolicyAdjusterType PolicyAdjuster)
- : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
+ PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster,
+ FilterT Filter)
+ : NumFoundNodes(0), Printer(std::move(Printer)),
+ Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {}
void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
- const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
- if (!S)
+ const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id");
+ if (!N || !Filter(N))
return;
- NumFoundStmts++;
- if (NumFoundStmts > 1)
+ NumFoundNodes++;
+ if (NumFoundNodes > 1)
return;
llvm::raw_svector_ostream Out(Printed);
- PrintStmt(Out, Result.Context, S, PolicyAdjuster);
+ Printer(Out, Result.Context, N, PolicyAdjuster);
}
StringRef getPrinted() const { return Printed; }
- unsigned getNumFoundStmts() const { return NumFoundStmts; }
+ unsigned getNumFoundNodes() const { return NumFoundNodes; }
};
-template <typename T>
-::testing::AssertionResult
-PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
- const T &NodeMatch, StringRef ExpectedPrinted,
- PolicyAdjusterType PolicyAdjuster = None) {
+template <typename NodeType, typename Matcher>
+::testing::AssertionResult PrintedNodeMatches(
+ StringRef Code, const std::vector<std::string> &Args,
+ const Matcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName,
+ NodePrinter<NodeType> Printer,
+ PrintingPolicyAdjuster PolicyAdjuster = nullptr, bool AllowError = false,
+ NodeFilter<NodeType> Filter = [](const NodeType *) { return true; }) {
- PrintMatch Printer(PolicyAdjuster);
+ PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter);
ast_matchers::MatchFinder Finder;
- Finder.addMatcher(NodeMatch, &Printer);
+ Finder.addMatcher(NodeMatch, &Callback);
std::unique_ptr<tooling::FrontendActionFactory> Factory(
tooling::newFrontendActionFactory(&Finder));
- if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
+ bool ToolResult;
+ if (FileName.empty()) {
+ ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args);
+ } else {
+ ToolResult =
+ tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName);
+ }
+ if (!ToolResult && !AllowError)
return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\"";
- if (Printer.getNumFoundStmts() == 0)
- return testing::AssertionFailure() << "Matcher didn't find any statements";
+ if (Callback.getNumFoundNodes() == 0)
+ return testing::AssertionFailure() << "Matcher didn't find any nodes";
- if (Printer.getNumFoundStmts() > 1)
+ if (Callback.getNumFoundNodes() > 1)
return testing::AssertionFailure()
- << "Matcher should match only one statement (found "
- << Printer.getNumFoundStmts() << ")";
+ << "Matcher should match only one node (found "
+ << Callback.getNumFoundNodes() << ")";
- if (Printer.getPrinted() != ExpectedPrinted)
+ if (Callback.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", got \""
- << Printer.getPrinted().str() << "\"";
+ << Callback.getPrinted().str() << "\"";
return ::testing::AssertionSuccess();
}