Mercurial > hg > CbC > CbC_llvm
diff clang-tools-extra/clangd/Quality.cpp @ 221:79ff65ed7e25
LLVM12 Original
author | Shinji KONO <kono@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 15 Jun 2021 19:15:29 +0900 |
parents | 0572611fdcc8 |
children | 5f17cb93ff66 |
line wrap: on
line diff
--- a/clang-tools-extra/clangd/Quality.cpp Tue Jun 15 19:13:43 2021 +0900 +++ b/clang-tools-extra/clangd/Quality.cpp Tue Jun 15 19:15:29 2021 +0900 @@ -8,6 +8,7 @@ #include "Quality.h" #include "AST.h" +#include "CompletionModel.h" #include "FileDistance.h" #include "SourceCode.h" #include "URI.h" @@ -200,7 +201,7 @@ ReservedName = ReservedName || isReserved(IndexResult.Name); } -float SymbolQualitySignals::evaluate() const { +float SymbolQualitySignals::evaluateHeuristics() const { float Score = 1; // This avoids a sharp gradient for tail symbols, and also neatly avoids the @@ -252,10 +253,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const SymbolQualitySignals &S) { - OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluate()); + OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluateHeuristics()); OS << llvm::formatv("\tReferences: {0}\n", S.References); OS << llvm::formatv("\tDeprecated: {0}\n", S.Deprecated); OS << llvm::formatv("\tReserved name: {0}\n", S.ReservedName); + OS << llvm::formatv("\tImplementation detail: {0}\n", S.ImplementationDetail); OS << llvm::formatv("\tCategory: {0}\n", static_cast<int>(S.Category)); return OS; } @@ -292,6 +294,38 @@ if (!(IndexResult.Flags & Symbol::VisibleOutsideFile)) { Scope = AccessibleScope::FileScope; } + if (MainFileSignals) { + MainFileRefs = + std::max(MainFileRefs, + MainFileSignals->ReferencedSymbols.lookup(IndexResult.ID)); + ScopeRefsInFile = + std::max(ScopeRefsInFile, + MainFileSignals->RelatedNamespaces.lookup(IndexResult.Scope)); + } +} + +void SymbolRelevanceSignals::computeASTSignals( + const CodeCompletionResult &SemaResult) { + if (!MainFileSignals) + return; + if ((SemaResult.Kind != CodeCompletionResult::RK_Declaration) && + (SemaResult.Kind != CodeCompletionResult::RK_Pattern)) + return; + if (const NamedDecl *ND = SemaResult.getDeclaration()) { + auto ID = getSymbolID(ND); + if (!ID) + return; + MainFileRefs = + std::max(MainFileRefs, MainFileSignals->ReferencedSymbols.lookup(ID)); + if (const auto *NSD = dyn_cast<NamespaceDecl>(ND->getDeclContext())) { + if (NSD->isAnonymousNamespace()) + return; + std::string Scope = printNamespaceScope(*NSD); + if (!Scope.empty()) + ScopeRefsInFile = std::max( + ScopeRefsInFile, MainFileSignals->RelatedNamespaces.lookup(Scope)); + } + } } void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) { @@ -313,6 +347,7 @@ InBaseClass |= SemaCCResult.InBaseClass; } + computeASTSignals(SemaCCResult); // Declarations are scoped, others (like macros) are assumed global. if (SemaCCResult.Declaration) Scope = std::min(Scope, computeScope(SemaCCResult.Declaration)); @@ -320,35 +355,51 @@ NeedsFixIts = !SemaCCResult.FixIts.empty(); } -static std::pair<float, unsigned> uriProximity(llvm::StringRef SymbolURI, - URIDistance *D) { - if (!D || SymbolURI.empty()) - return {0.f, 0u}; - unsigned Distance = D->distance(SymbolURI); +static float fileProximityScore(unsigned FileDistance) { + // Range: [0, 1] + // FileDistance = [0, 1, 2, 3, 4, .., FileDistance::Unreachable] + // Score = [1, 0.82, 0.67, 0.55, 0.45, .., 0] + if (FileDistance == FileDistance::Unreachable) + return 0; // Assume approximately default options are used for sensible scoring. - return {std::exp(Distance * -0.4f / FileDistanceOptions().UpCost), Distance}; + return std::exp(FileDistance * -0.4f / FileDistanceOptions().UpCost); } -static float scopeBoost(ScopeDistance &Distance, - llvm::Optional<llvm::StringRef> SymbolScope) { - if (!SymbolScope) - return 1; - auto D = Distance.distance(*SymbolScope); - if (D == FileDistance::Unreachable) +static float scopeProximityScore(unsigned ScopeDistance) { + // Range: [0.6, 2]. + // ScopeDistance = [0, 1, 2, 3, 4, 5, 6, 7, .., FileDistance::Unreachable] + // Score = [2.0, 1.55, 1.2, 0.93, 0.72, 0.65, 0.65, 0.65, .., 0.6] + if (ScopeDistance == FileDistance::Unreachable) return 0.6f; - return std::max(0.65, 2.0 * std::pow(0.6, D / 2.0)); + return std::max(0.65, 2.0 * std::pow(0.6, ScopeDistance / 2.0)); } static llvm::Optional<llvm::StringRef> wordMatching(llvm::StringRef Name, const llvm::StringSet<> *ContextWords) { if (ContextWords) - for (const auto& Word : ContextWords->keys()) + for (const auto &Word : ContextWords->keys()) if (Name.contains_lower(Word)) return Word; return llvm::None; } -float SymbolRelevanceSignals::evaluate() const { +SymbolRelevanceSignals::DerivedSignals +SymbolRelevanceSignals::calculateDerivedSignals() const { + DerivedSignals Derived; + Derived.NameMatchesContext = wordMatching(Name, ContextWords).hasValue(); + Derived.FileProximityDistance = !FileProximityMatch || SymbolURI.empty() + ? FileDistance::Unreachable + : FileProximityMatch->distance(SymbolURI); + if (ScopeProximityMatch) { + // For global symbol, the distance is 0. + Derived.ScopeProximityDistance = + SymbolScope ? ScopeProximityMatch->distance(*SymbolScope) : 0; + } + return Derived; +} + +float SymbolRelevanceSignals::evaluateHeuristics() const { + DerivedSignals Derived = calculateDerivedSignals(); float Score = 1; if (Forbidden) @@ -358,7 +409,7 @@ // File proximity scores are [0,1] and we translate them into a multiplier in // the range from 1 to 3. - Score *= 1 + 2 * std::max(uriProximity(SymbolURI, FileProximityMatch).first, + Score *= 1 + 2 * std::max(fileProximityScore(Derived.FileProximityDistance), SemaFileProximityScore); if (ScopeProximityMatch) @@ -366,10 +417,11 @@ // can be tricky (e.g. class/function scope). Set to the max boost as we // don't load top-level symbols from the preamble and sema results are // always in the accessible scope. - Score *= - SemaSaysInScope ? 2.0 : scopeBoost(*ScopeProximityMatch, SymbolScope); + Score *= SemaSaysInScope + ? 2.0 + : scopeProximityScore(Derived.ScopeProximityDistance); - if (wordMatching(Name, ContextWords)) + if (Derived.NameMatchesContext) Score *= 1.5; // Symbols like local variables may only be referenced within their scope. @@ -422,12 +474,27 @@ if (NeedsFixIts) Score *= 0.5f; + // Use a sigmoid style boosting function similar to `References`, which flats + // out nicely for large values. This avoids a sharp gradient for heavily + // referenced symbols. Use smaller gradient for ScopeRefsInFile since ideally + // MainFileRefs <= ScopeRefsInFile. + if (MainFileRefs >= 2) { + // E.g.: (2, 1.12), (9, 2.0), (48, 3.0). + float S = std::pow(MainFileRefs, -0.11); + Score *= 11.0 * (1 - S) / (1 + S) + 0.7; + } + if (ScopeRefsInFile >= 2) { + // E.g.: (2, 1.04), (14, 2.0), (109, 3.0), (400, 3.6). + float S = std::pow(ScopeRefsInFile, -0.10); + Score *= 10.0 * (1 - S) / (1 + S) + 0.7; + } + return Score; } llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const SymbolRelevanceSignals &S) { - OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluate()); + OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluateHeuristics()); OS << llvm::formatv("\tName: {0}\n", S.Name); OS << llvm::formatv("\tName match: {0}\n", S.NameMatch); if (S.ContextWords) @@ -437,6 +504,7 @@ OS << llvm::formatv("\tForbidden: {0}\n", S.Forbidden); OS << llvm::formatv("\tNeedsFixIts: {0}\n", S.NeedsFixIts); OS << llvm::formatv("\tIsInstanceMember: {0}\n", S.IsInstanceMember); + OS << llvm::formatv("\tInBaseClass: {0}\n", S.InBaseClass); OS << llvm::formatv("\tContext: {0}\n", getCompletionKindString(S.Context)); OS << llvm::formatv("\tQuery type: {0}\n", static_cast<int>(S.Query)); OS << llvm::formatv("\tScope: {0}\n", static_cast<int>(S.Scope)); @@ -445,17 +513,18 @@ OS << llvm::formatv("\tSymbol scope: {0}\n", S.SymbolScope ? *S.SymbolScope : "<None>"); + SymbolRelevanceSignals::DerivedSignals Derived = S.calculateDerivedSignals(); if (S.FileProximityMatch) { - auto Score = uriProximity(S.SymbolURI, S.FileProximityMatch); - OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", - Score.first, Score.second); + unsigned Score = fileProximityScore(Derived.FileProximityDistance); + OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", Score, + Derived.FileProximityDistance); } OS << llvm::formatv("\tSema file proximity: {0}\n", S.SemaFileProximityScore); OS << llvm::formatv("\tSema says in scope: {0}\n", S.SemaSaysInScope); if (S.ScopeProximityMatch) OS << llvm::formatv("\tIndex scope boost: {0}\n", - scopeBoost(*S.ScopeProximityMatch, S.SymbolScope)); + scopeProximityScore(Derived.ScopeProximityDistance)); OS << llvm::formatv( "\tType matched preferred: {0} (Context type: {1}, Symbol type: {2}\n", @@ -468,6 +537,65 @@ return SymbolQuality * SymbolRelevance; } +DecisionForestScores +evaluateDecisionForest(const SymbolQualitySignals &Quality, + const SymbolRelevanceSignals &Relevance, float Base) { + Example E; + E.setIsDeprecated(Quality.Deprecated); + E.setIsReservedName(Quality.ReservedName); + E.setIsImplementationDetail(Quality.ImplementationDetail); + E.setNumReferences(Quality.References); + E.setSymbolCategory(Quality.Category); + + SymbolRelevanceSignals::DerivedSignals Derived = + Relevance.calculateDerivedSignals(); + int NumMatch = 0; + if (Relevance.ContextWords) { + for (const auto &Word : Relevance.ContextWords->keys()) { + if (Relevance.Name.contains_lower(Word)) { + ++NumMatch; + } + } + } + E.setIsNameInContext(NumMatch > 0); + E.setNumNameInContext(NumMatch); + E.setFractionNameInContext( + Relevance.ContextWords && !Relevance.ContextWords->empty() + ? NumMatch * 1.0 / Relevance.ContextWords->size() + : 0); + E.setIsInBaseClass(Relevance.InBaseClass); + E.setFileProximityDistanceCost(Derived.FileProximityDistance); + E.setSemaFileProximityScore(Relevance.SemaFileProximityScore); + E.setSymbolScopeDistanceCost(Derived.ScopeProximityDistance); + E.setSemaSaysInScope(Relevance.SemaSaysInScope); + E.setScope(Relevance.Scope); + E.setContextKind(Relevance.Context); + E.setIsInstanceMember(Relevance.IsInstanceMember); + E.setHadContextType(Relevance.HadContextType); + E.setHadSymbolType(Relevance.HadSymbolType); + E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred); + + DecisionForestScores Scores; + // Exponentiating DecisionForest prediction makes the score of each tree a + // multiplciative boost (like NameMatch). This allows us to weigh the + // prediciton score and NameMatch appropriately. + Scores.ExcludingName = pow(Base, Evaluate(E)); + // Following cases are not part of the generated training dataset: + // - Symbols with `NeedsFixIts`. + // - Forbidden symbols. + // - Keywords: Dataset contains only macros and decls. + if (Relevance.NeedsFixIts) + Scores.ExcludingName *= 0.5; + if (Relevance.Forbidden) + Scores.ExcludingName *= 0; + if (Quality.Category == SymbolQualitySignals::Keyword) + Scores.ExcludingName *= 4; + + // NameMatch should be a multiplier on total score to support rescoring. + Scores.Total = Relevance.NameMatch * Scores.ExcludingName; + return Scores; +} + // Produces an integer that sorts in the same order as F. // That is: a < b <==> encodeFloat(a) < encodeFloat(b). static uint32_t encodeFloat(float F) {