diff options
Diffstat (limited to 'clang-include-fixer/IncludeFixer.cpp')
-rw-r--r-- | clang-include-fixer/IncludeFixer.cpp | 444 |
1 files changed, 444 insertions, 0 deletions
diff --git a/clang-include-fixer/IncludeFixer.cpp b/clang-include-fixer/IncludeFixer.cpp new file mode 100644 index 00000000..d364021f --- /dev/null +++ b/clang-include-fixer/IncludeFixer.cpp @@ -0,0 +1,444 @@ +//===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IncludeFixer.h" +#include "clang/Format/Format.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Lex/HeaderSearch.h" +#include "clang/Lex/Preprocessor.h" +#include "clang/Parse/ParseAST.h" +#include "clang/Sema/Sema.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "clang-include-fixer" + +using namespace clang; + +namespace clang { +namespace include_fixer { +namespace { +/// Manages the parse, gathers include suggestions. +class Action : public clang::ASTFrontendAction { +public: + explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths) + : SemaSource(SymbolIndexMgr, MinimizeIncludePaths, + /*GenerateDiagnostics=*/false) {} + + std::unique_ptr<clang::ASTConsumer> + CreateASTConsumer(clang::CompilerInstance &Compiler, + StringRef InFile) override { + SemaSource.setFilePath(InFile); + return llvm::make_unique<clang::ASTConsumer>(); + } + + void ExecuteAction() override { + clang::CompilerInstance *Compiler = &getCompilerInstance(); + assert(!Compiler->hasSema() && "CI already has Sema"); + + // Set up our hooks into sema and parse the AST. + if (hasCodeCompletionSupport() && + !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty()) + Compiler->createCodeCompletionConsumer(); + + clang::CodeCompleteConsumer *CompletionConsumer = nullptr; + if (Compiler->hasCodeCompletionConsumer()) + CompletionConsumer = &Compiler->getCodeCompletionConsumer(); + + Compiler->createSema(getTranslationUnitKind(), CompletionConsumer); + SemaSource.setCompilerInstance(Compiler); + Compiler->getSema().addExternalSource(&SemaSource); + + clang::ParseAST(Compiler->getSema(), Compiler->getFrontendOpts().ShowStats, + Compiler->getFrontendOpts().SkipFunctionBodies); + } + + IncludeFixerContext + getIncludeFixerContext(const clang::SourceManager &SourceManager, + clang::HeaderSearch &HeaderSearch) const { + return SemaSource.getIncludeFixerContext(SourceManager, HeaderSearch, + SemaSource.getMatchedSymbols()); + } + +private: + IncludeFixerSemaSource SemaSource; +}; + +} // namespace + +IncludeFixerActionFactory::IncludeFixerActionFactory( + SymbolIndexManager &SymbolIndexMgr, + std::vector<IncludeFixerContext> &Contexts, StringRef StyleName, + bool MinimizeIncludePaths) + : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts), + MinimizeIncludePaths(MinimizeIncludePaths) {} + +IncludeFixerActionFactory::~IncludeFixerActionFactory() = default; + +bool IncludeFixerActionFactory::runInvocation( + std::shared_ptr<clang::CompilerInvocation> Invocation, + clang::FileManager *Files, + std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps, + clang::DiagnosticConsumer *Diagnostics) { + assert(Invocation->getFrontendOpts().Inputs.size() == 1); + + // Set up Clang. + clang::CompilerInstance Compiler(PCHContainerOps); + Compiler.setInvocation(std::move(Invocation)); + Compiler.setFileManager(Files); + + // Create the compiler's actual diagnostics engine. We want to drop all + // diagnostics here. + Compiler.createDiagnostics(new clang::IgnoringDiagConsumer, + /*ShouldOwnClient=*/true); + Compiler.createSourceManager(*Files); + + // We abort on fatal errors so don't let a large number of errors become + // fatal. A missing #include can cause thousands of errors. + Compiler.getDiagnostics().setErrorLimit(0); + + // Run the parser, gather missing includes. + auto ScopedToolAction = + llvm::make_unique<Action>(SymbolIndexMgr, MinimizeIncludePaths); + Compiler.ExecuteAction(*ScopedToolAction); + + Contexts.push_back(ScopedToolAction->getIncludeFixerContext( + Compiler.getSourceManager(), + Compiler.getPreprocessor().getHeaderSearchInfo())); + + // Technically this should only return true if we're sure that we have a + // parseable file. We don't know that though. Only inform users of fatal + // errors. + return !Compiler.getDiagnostics().hasFatalErrorOccurred(); +} + +static bool addDiagnosticsForContext(TypoCorrection &Correction, + const IncludeFixerContext &Context, + StringRef Code, SourceLocation StartOfFile, + ASTContext &Ctx) { + auto Reps = createIncludeFixerReplacements( + Code, Context, format::getLLVMStyle(), /*AddQualifiers=*/false); + if (!Reps || Reps->size() != 1) + return false; + + unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID( + DiagnosticsEngine::Note, "Add '#include %0' to provide the missing " + "declaration [clang-include-fixer]"); + + // FIXME: Currently we only generate a diagnostic for the first header. Give + // the user choices. + const tooling::Replacement &Placed = *Reps->begin(); + + auto Begin = StartOfFile.getLocWithOffset(Placed.getOffset()); + auto End = Begin.getLocWithOffset(std::max(0, (int)Placed.getLength() - 1)); + PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator()); + PD << Context.getHeaderInfos().front().Header + << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin, End), + Placed.getReplacementText()); + Correction.addExtraDiagnostic(std::move(PD)); + return true; +} + +/// Callback for incomplete types. If we encounter a forward declaration we +/// have the fully qualified name ready. Just query that. +bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType( + clang::SourceLocation Loc, clang::QualType T) { + // Ignore spurious callbacks from SFINAE contexts. + if (CI->getSema().isSFINAEContext()) + return false; + + clang::ASTContext &context = CI->getASTContext(); + std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0) + .getAsString(context.getPrintingPolicy()); + LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString + << "'"); + // Pass an empty range here since we don't add qualifier in this case. + std::vector<find_all_symbols::SymbolInfo> MatchedSymbols = + query(QueryString, "", tooling::Range()); + + if (!MatchedSymbols.empty() && GenerateDiagnostics) { + TypoCorrection Correction; + FileID FID = CI->getSourceManager().getFileID(Loc); + StringRef Code = CI->getSourceManager().getBufferData(FID); + SourceLocation StartOfFile = + CI->getSourceManager().getLocForStartOfFile(FID); + addDiagnosticsForContext( + Correction, + getIncludeFixerContext(CI->getSourceManager(), + CI->getPreprocessor().getHeaderSearchInfo(), + MatchedSymbols), + Code, StartOfFile, CI->getASTContext()); + for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics()) + CI->getSema().Diag(Loc, PD); + } + return true; +} + +/// Callback for unknown identifiers. Try to piece together as much +/// qualification as we can get and do a query. +clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo( + const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS, + CorrectionCandidateCallback &CCC, DeclContext *MemberContext, + bool EnteringContext, const ObjCObjectPointerType *OPT) { + // Ignore spurious callbacks from SFINAE contexts. + if (CI->getSema().isSFINAEContext()) + return clang::TypoCorrection(); + + // We currently ignore the unidentified symbol which is not from the + // main file. + // + // However, this is not always true due to templates in a non-self contained + // header, consider the case: + // + // // header.h + // template <typename T> + // class Foo { + // T t; + // }; + // + // // test.cc + // // We need to add <bar.h> in test.cc instead of header.h. + // class Bar; + // Foo<Bar> foo; + // + // FIXME: Add the missing header to the header file where the symbol comes + // from. + if (!CI->getSourceManager().isWrittenInMainFile(Typo.getLoc())) + return clang::TypoCorrection(); + + std::string TypoScopeString; + if (S) { + // FIXME: Currently we only use namespace contexts. Use other context + // types for query. + for (const auto *Context = S->getEntity(); Context; + Context = Context->getParent()) { + if (const auto *ND = dyn_cast<NamespaceDecl>(Context)) { + if (!ND->getName().empty()) + TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString; + } + } + } + + auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) { + StringRef Source = + Lexer::getSourceText(Range, CI->getSourceManager(), CI->getLangOpts()); + + // Skip forward until we find a character that's neither identifier nor + // colon. This is a bit of a hack around the fact that we will only get a + // single callback for a long nested name if a part of the beginning is + // unknown. For example: + // + // llvm::sys::path::parent_path(...) + // ^~~~ ^~~ + // known + // ^~~~ + // unknown, last callback + // ^~~~~~~~~~~ + // no callback + // + // With the extension we get the full nested name specifier including + // parent_path. + // FIXME: Don't rely on source text. + const char *End = Source.end(); + while (isIdentifierBody(*End) || *End == ':') + ++End; + + return std::string(Source.begin(), End); + }; + + /// If we have a scope specification, use that to get more precise results. + std::string QueryString; + tooling::Range SymbolRange; + const auto &SM = CI->getSourceManager(); + auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) { + return tooling::Range(SM.getDecomposedLoc(BeginLoc).second, + QueryString.size()); + }; + if (SS && SS->getRange().isValid()) { + auto Range = CharSourceRange::getTokenRange(SS->getRange().getBegin(), + Typo.getLoc()); + + QueryString = ExtendNestedNameSpecifier(Range); + SymbolRange = CreateToolingRange(Range.getBegin()); + } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) { + auto Range = + CharSourceRange::getTokenRange(Typo.getBeginLoc(), Typo.getEndLoc()); + + QueryString = ExtendNestedNameSpecifier(Range); + SymbolRange = CreateToolingRange(Range.getBegin()); + } else { + QueryString = Typo.getAsString(); + SymbolRange = CreateToolingRange(Typo.getLoc()); + } + + LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString + << "\n"); + std::vector<find_all_symbols::SymbolInfo> MatchedSymbols = + query(QueryString, TypoScopeString, SymbolRange); + + if (!MatchedSymbols.empty() && GenerateDiagnostics) { + TypoCorrection Correction(Typo.getName()); + Correction.setCorrectionRange(SS, Typo); + FileID FID = SM.getFileID(Typo.getLoc()); + StringRef Code = SM.getBufferData(FID); + SourceLocation StartOfFile = SM.getLocForStartOfFile(FID); + if (addDiagnosticsForContext( + Correction, getIncludeFixerContext( + SM, CI->getPreprocessor().getHeaderSearchInfo(), + MatchedSymbols), + Code, StartOfFile, CI->getASTContext())) + return Correction; + } + return TypoCorrection(); +} + +/// Get the minimal include for a given path. +std::string IncludeFixerSemaSource::minimizeInclude( + StringRef Include, const clang::SourceManager &SourceManager, + clang::HeaderSearch &HeaderSearch) const { + if (!MinimizeIncludePaths) + return Include; + + // Get the FileEntry for the include. + StringRef StrippedInclude = Include.trim("\"<>"); + const FileEntry *Entry = + SourceManager.getFileManager().getFile(StrippedInclude); + + // If the file doesn't exist return the path from the database. + // FIXME: This should never happen. + if (!Entry) + return Include; + + bool IsSystem; + std::string Suggestion = + HeaderSearch.suggestPathToFileForDiagnostics(Entry, &IsSystem); + + return IsSystem ? '<' + Suggestion + '>' : '"' + Suggestion + '"'; +} + +/// Get the include fixer context for the queried symbol. +IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext( + const clang::SourceManager &SourceManager, + clang::HeaderSearch &HeaderSearch, + ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const { + std::vector<find_all_symbols::SymbolInfo> SymbolCandidates; + for (const auto &Symbol : MatchedSymbols) { + std::string FilePath = Symbol.getFilePath().str(); + std::string MinimizedFilePath = minimizeInclude( + ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath + : "\"" + FilePath + "\""), + SourceManager, HeaderSearch); + SymbolCandidates.emplace_back(Symbol.getName(), Symbol.getSymbolKind(), + MinimizedFilePath, Symbol.getContexts()); + } + return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates); +} + +std::vector<find_all_symbols::SymbolInfo> +IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers, + tooling::Range Range) { + assert(!Query.empty() && "Empty query!"); + + // Save all instances of an unidentified symbol. + // + // We use conservative behavior for detecting the same unidentified symbol + // here. The symbols which have the same ScopedQualifier and RawIdentifier + // are considered equal. So that clang-include-fixer avoids false positives, + // and always adds missing qualifiers to correct symbols. + if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) { + if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers && + Query == QuerySymbolInfos.front().RawIdentifier) { + QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range}); + } + return {}; + } + + LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at "); + LLVM_DEBUG(CI->getSourceManager() + .getLocForStartOfFile(CI->getSourceManager().getMainFileID()) + .getLocWithOffset(Range.getOffset()) + .print(llvm::dbgs(), CI->getSourceManager())); + LLVM_DEBUG(llvm::dbgs() << " ..."); + llvm::StringRef FileName = CI->getSourceManager().getFilename( + CI->getSourceManager().getLocForStartOfFile( + CI->getSourceManager().getMainFileID())); + + QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range}); + + // Query the symbol based on C++ name Lookup rules. + // Firstly, lookup the identifier with scoped namespace contexts; + // If that fails, falls back to look up the identifier directly. + // + // For example: + // + // namespace a { + // b::foo f; + // } + // + // 1. lookup a::b::foo. + // 2. lookup b::foo. + std::string QueryString = ScopedQualifiers.str() + Query.str(); + // It's unsafe to do nested search for the identifier with scoped namespace + // context, it might treat the identifier as a nested class of the scoped + // namespace. + std::vector<find_all_symbols::SymbolInfo> MatchedSymbols = + SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName); + if (MatchedSymbols.empty()) + MatchedSymbols = + SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName); + LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size() + << " symbols\n"); + // We store a copy of MatchedSymbols in a place where it's globally reachable. + // This is used by the standalone version of the tool. + this->MatchedSymbols = MatchedSymbols; + return MatchedSymbols; +} + +llvm::Expected<tooling::Replacements> createIncludeFixerReplacements( + StringRef Code, const IncludeFixerContext &Context, + const clang::format::FormatStyle &Style, bool AddQualifiers) { + if (Context.getHeaderInfos().empty()) + return tooling::Replacements(); + StringRef FilePath = Context.getFilePath(); + std::string IncludeName = + "#include " + Context.getHeaderInfos().front().Header + "\n"; + // Create replacements for the new header. + clang::tooling::Replacements Insertions; + auto Err = + Insertions.add(tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName)); + if (Err) + return std::move(Err); + + auto CleanReplaces = cleanupAroundReplacements(Code, Insertions, Style); + if (!CleanReplaces) + return CleanReplaces; + + auto Replaces = std::move(*CleanReplaces); + if (AddQualifiers) { + for (const auto &Info : Context.getQuerySymbolInfos()) { + // Ignore the empty range. + if (Info.Range.getLength() > 0) { + auto R = tooling::Replacement( + {FilePath, Info.Range.getOffset(), Info.Range.getLength(), + Context.getHeaderInfos().front().QualifiedName}); + auto Err = Replaces.add(R); + if (Err) { + llvm::consumeError(std::move(Err)); + R = tooling::Replacement( + R.getFilePath(), Replaces.getShiftedCodePosition(R.getOffset()), + R.getLength(), R.getReplacementText()); + Replaces = Replaces.merge(tooling::Replacements(R)); + } + } + } + } + return formatReplacements(Code, Replaces, Style); +} + +} // namespace include_fixer +} // namespace clang |