summaryrefslogtreecommitdiffstats
path: root/include/clang/Tooling/Refactoring/RecursiveSymbolVisitor.h
blob: 8b01a61256f6bbd52771de9b0d8ee835061a3c71 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//===--- RecursiveSymbolVisitor.h - Clang refactoring library -------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
///
/// \file
/// \brief A wrapper class around \c RecursiveASTVisitor that visits each
/// occurrences of a named symbol.
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H
#define LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H

#include "clang/AST/AST.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Lex/Lexer.h"

namespace clang {
namespace tooling {

/// Traverses the AST and visits the occurrence of each named symbol in the
/// given nodes.
template <typename T>
class RecursiveSymbolVisitor
    : public RecursiveASTVisitor<RecursiveSymbolVisitor<T>> {
  using BaseType = RecursiveASTVisitor<RecursiveSymbolVisitor<T>>;

public:
  RecursiveSymbolVisitor(const SourceManager &SM, const LangOptions &LangOpts)
      : SM(SM), LangOpts(LangOpts) {}

  bool visitSymbolOccurrence(const NamedDecl *ND,
                             ArrayRef<SourceRange> NameRanges) {
    return true;
  }

  // Declaration visitors:

  bool VisitNamedDecl(const NamedDecl *D) {
    return isa<CXXConversionDecl>(D) ? true : visit(D, D->getLocation());
  }

  bool VisitCXXConstructorDecl(const CXXConstructorDecl *CD) {
    for (const auto *Initializer : CD->inits()) {
      // Ignore implicit initializers.
      if (!Initializer->isWritten())
        continue;
      if (const FieldDecl *FD = Initializer->getMember()) {
        if (!visit(FD, Initializer->getSourceLocation(),
                   Lexer::getLocForEndOfToken(Initializer->getSourceLocation(),
                                              0, SM, LangOpts)))
          return false;
      }
    }
    return true;
  }

  // Expression visitors:

  bool VisitDeclRefExpr(const DeclRefExpr *Expr) {
    return visit(Expr->getFoundDecl(), Expr->getLocation());
  }

  bool VisitMemberExpr(const MemberExpr *Expr) {
    return visit(Expr->getFoundDecl().getDecl(), Expr->getMemberLoc());
  }

  // Other visitors:

  bool VisitTypeLoc(const TypeLoc Loc) {
    const SourceLocation TypeBeginLoc = Loc.getBeginLoc();
    const SourceLocation TypeEndLoc =
        Lexer::getLocForEndOfToken(TypeBeginLoc, 0, SM, LangOpts);
    if (const auto *TemplateTypeParm =
            dyn_cast<TemplateTypeParmType>(Loc.getType())) {
      if (!visit(TemplateTypeParm->getDecl(), TypeBeginLoc, TypeEndLoc))
        return false;
    }
    if (const auto *TemplateSpecType =
            dyn_cast<TemplateSpecializationType>(Loc.getType())) {
      if (!visit(TemplateSpecType->getTemplateName().getAsTemplateDecl(),
                 TypeBeginLoc, TypeEndLoc))
        return false;
    }
    return visit(Loc.getType()->getAsCXXRecordDecl(), TypeBeginLoc, TypeEndLoc);
  }

  bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
    // The base visitor will visit NNSL prefixes, so we should only look at
    // the current NNS.
    if (NNS) {
      const NamespaceDecl *ND = NNS.getNestedNameSpecifier()->getAsNamespace();
      if (!visit(ND, NNS.getLocalBeginLoc(), NNS.getLocalEndLoc()))
        return false;
    }
    return BaseType::TraverseNestedNameSpecifierLoc(NNS);
  }

private:
  const SourceManager &SM;
  const LangOptions &LangOpts;

  bool visit(const NamedDecl *ND, SourceLocation BeginLoc,
             SourceLocation EndLoc) {
    return static_cast<T *>(this)->visitSymbolOccurrence(
        ND, SourceRange(BeginLoc, EndLoc));
  }
  bool visit(const NamedDecl *ND, SourceLocation Loc) {
    return visit(ND, Loc,
                 Loc.getLocWithOffset(ND->getNameAsString().length() - 1));
  }
};

} // end namespace tooling
} // end namespace clang

#endif // LLVM_CLANG_TOOLING_REFACTOR_RECURSIVE_SYMBOL_VISITOR_H