diff clang-tools-extra/clangd/Quality.cpp @ 221:79ff65ed7e25

LLVM12 Original
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Tue, 15 Jun 2021 19:15:29 +0900
parents 0572611fdcc8
children 5f17cb93ff66
line wrap: on
line diff
--- a/clang-tools-extra/clangd/Quality.cpp	Tue Jun 15 19:13:43 2021 +0900
+++ b/clang-tools-extra/clangd/Quality.cpp	Tue Jun 15 19:15:29 2021 +0900
@@ -8,6 +8,7 @@
 
 #include "Quality.h"
 #include "AST.h"
+#include "CompletionModel.h"
 #include "FileDistance.h"
 #include "SourceCode.h"
 #include "URI.h"
@@ -200,7 +201,7 @@
   ReservedName = ReservedName || isReserved(IndexResult.Name);
 }
 
-float SymbolQualitySignals::evaluate() const {
+float SymbolQualitySignals::evaluateHeuristics() const {
   float Score = 1;
 
   // This avoids a sharp gradient for tail symbols, and also neatly avoids the
@@ -252,10 +253,11 @@
 
 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
                               const SymbolQualitySignals &S) {
-  OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluate());
+  OS << llvm::formatv("=== Symbol quality: {0}\n", S.evaluateHeuristics());
   OS << llvm::formatv("\tReferences: {0}\n", S.References);
   OS << llvm::formatv("\tDeprecated: {0}\n", S.Deprecated);
   OS << llvm::formatv("\tReserved name: {0}\n", S.ReservedName);
+  OS << llvm::formatv("\tImplementation detail: {0}\n", S.ImplementationDetail);
   OS << llvm::formatv("\tCategory: {0}\n", static_cast<int>(S.Category));
   return OS;
 }
@@ -292,6 +294,38 @@
   if (!(IndexResult.Flags & Symbol::VisibleOutsideFile)) {
     Scope = AccessibleScope::FileScope;
   }
+  if (MainFileSignals) {
+    MainFileRefs =
+        std::max(MainFileRefs,
+                 MainFileSignals->ReferencedSymbols.lookup(IndexResult.ID));
+    ScopeRefsInFile =
+        std::max(ScopeRefsInFile,
+                 MainFileSignals->RelatedNamespaces.lookup(IndexResult.Scope));
+  }
+}
+
+void SymbolRelevanceSignals::computeASTSignals(
+    const CodeCompletionResult &SemaResult) {
+  if (!MainFileSignals)
+    return;
+  if ((SemaResult.Kind != CodeCompletionResult::RK_Declaration) &&
+      (SemaResult.Kind != CodeCompletionResult::RK_Pattern))
+    return;
+  if (const NamedDecl *ND = SemaResult.getDeclaration()) {
+    auto ID = getSymbolID(ND);
+    if (!ID)
+      return;
+    MainFileRefs =
+        std::max(MainFileRefs, MainFileSignals->ReferencedSymbols.lookup(ID));
+    if (const auto *NSD = dyn_cast<NamespaceDecl>(ND->getDeclContext())) {
+      if (NSD->isAnonymousNamespace())
+        return;
+      std::string Scope = printNamespaceScope(*NSD);
+      if (!Scope.empty())
+        ScopeRefsInFile = std::max(
+            ScopeRefsInFile, MainFileSignals->RelatedNamespaces.lookup(Scope));
+    }
+  }
 }
 
 void SymbolRelevanceSignals::merge(const CodeCompletionResult &SemaCCResult) {
@@ -313,6 +347,7 @@
     InBaseClass |= SemaCCResult.InBaseClass;
   }
 
+  computeASTSignals(SemaCCResult);
   // Declarations are scoped, others (like macros) are assumed global.
   if (SemaCCResult.Declaration)
     Scope = std::min(Scope, computeScope(SemaCCResult.Declaration));
@@ -320,35 +355,51 @@
   NeedsFixIts = !SemaCCResult.FixIts.empty();
 }
 
-static std::pair<float, unsigned> uriProximity(llvm::StringRef SymbolURI,
-                                               URIDistance *D) {
-  if (!D || SymbolURI.empty())
-    return {0.f, 0u};
-  unsigned Distance = D->distance(SymbolURI);
+static float fileProximityScore(unsigned FileDistance) {
+  // Range: [0, 1]
+  // FileDistance = [0, 1, 2, 3, 4, .., FileDistance::Unreachable]
+  // Score = [1, 0.82, 0.67, 0.55, 0.45, .., 0]
+  if (FileDistance == FileDistance::Unreachable)
+    return 0;
   // Assume approximately default options are used for sensible scoring.
-  return {std::exp(Distance * -0.4f / FileDistanceOptions().UpCost), Distance};
+  return std::exp(FileDistance * -0.4f / FileDistanceOptions().UpCost);
 }
 
-static float scopeBoost(ScopeDistance &Distance,
-                        llvm::Optional<llvm::StringRef> SymbolScope) {
-  if (!SymbolScope)
-    return 1;
-  auto D = Distance.distance(*SymbolScope);
-  if (D == FileDistance::Unreachable)
+static float scopeProximityScore(unsigned ScopeDistance) {
+  // Range: [0.6, 2].
+  // ScopeDistance = [0, 1, 2, 3, 4, 5, 6, 7, .., FileDistance::Unreachable]
+  // Score = [2.0, 1.55, 1.2, 0.93, 0.72, 0.65, 0.65, 0.65, .., 0.6]
+  if (ScopeDistance == FileDistance::Unreachable)
     return 0.6f;
-  return std::max(0.65, 2.0 * std::pow(0.6, D / 2.0));
+  return std::max(0.65, 2.0 * std::pow(0.6, ScopeDistance / 2.0));
 }
 
 static llvm::Optional<llvm::StringRef>
 wordMatching(llvm::StringRef Name, const llvm::StringSet<> *ContextWords) {
   if (ContextWords)
-    for (const auto& Word : ContextWords->keys())
+    for (const auto &Word : ContextWords->keys())
       if (Name.contains_lower(Word))
         return Word;
   return llvm::None;
 }
 
-float SymbolRelevanceSignals::evaluate() const {
+SymbolRelevanceSignals::DerivedSignals
+SymbolRelevanceSignals::calculateDerivedSignals() const {
+  DerivedSignals Derived;
+  Derived.NameMatchesContext = wordMatching(Name, ContextWords).hasValue();
+  Derived.FileProximityDistance = !FileProximityMatch || SymbolURI.empty()
+                                      ? FileDistance::Unreachable
+                                      : FileProximityMatch->distance(SymbolURI);
+  if (ScopeProximityMatch) {
+    // For global symbol, the distance is 0.
+    Derived.ScopeProximityDistance =
+        SymbolScope ? ScopeProximityMatch->distance(*SymbolScope) : 0;
+  }
+  return Derived;
+}
+
+float SymbolRelevanceSignals::evaluateHeuristics() const {
+  DerivedSignals Derived = calculateDerivedSignals();
   float Score = 1;
 
   if (Forbidden)
@@ -358,7 +409,7 @@
 
   // File proximity scores are [0,1] and we translate them into a multiplier in
   // the range from 1 to 3.
-  Score *= 1 + 2 * std::max(uriProximity(SymbolURI, FileProximityMatch).first,
+  Score *= 1 + 2 * std::max(fileProximityScore(Derived.FileProximityDistance),
                             SemaFileProximityScore);
 
   if (ScopeProximityMatch)
@@ -366,10 +417,11 @@
     // can be tricky (e.g. class/function scope). Set to the max boost as we
     // don't load top-level symbols from the preamble and sema results are
     // always in the accessible scope.
-    Score *=
-        SemaSaysInScope ? 2.0 : scopeBoost(*ScopeProximityMatch, SymbolScope);
+    Score *= SemaSaysInScope
+                 ? 2.0
+                 : scopeProximityScore(Derived.ScopeProximityDistance);
 
-  if (wordMatching(Name, ContextWords))
+  if (Derived.NameMatchesContext)
     Score *= 1.5;
 
   // Symbols like local variables may only be referenced within their scope.
@@ -422,12 +474,27 @@
   if (NeedsFixIts)
     Score *= 0.5f;
 
+  // Use a sigmoid style boosting function similar to `References`, which flats
+  // out nicely for large values. This avoids a sharp gradient for heavily
+  // referenced symbols. Use smaller gradient for ScopeRefsInFile since ideally
+  // MainFileRefs <= ScopeRefsInFile.
+  if (MainFileRefs >= 2) {
+    // E.g.: (2, 1.12), (9, 2.0), (48, 3.0).
+    float S = std::pow(MainFileRefs, -0.11);
+    Score *= 11.0 * (1 - S) / (1 + S) + 0.7;
+  }
+  if (ScopeRefsInFile >= 2) {
+    // E.g.: (2, 1.04), (14, 2.0), (109, 3.0), (400, 3.6).
+    float S = std::pow(ScopeRefsInFile, -0.10);
+    Score *= 10.0 * (1 - S) / (1 + S) + 0.7;
+  }
+
   return Score;
 }
 
 llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
                               const SymbolRelevanceSignals &S) {
-  OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluate());
+  OS << llvm::formatv("=== Symbol relevance: {0}\n", S.evaluateHeuristics());
   OS << llvm::formatv("\tName: {0}\n", S.Name);
   OS << llvm::formatv("\tName match: {0}\n", S.NameMatch);
   if (S.ContextWords)
@@ -437,6 +504,7 @@
   OS << llvm::formatv("\tForbidden: {0}\n", S.Forbidden);
   OS << llvm::formatv("\tNeedsFixIts: {0}\n", S.NeedsFixIts);
   OS << llvm::formatv("\tIsInstanceMember: {0}\n", S.IsInstanceMember);
+  OS << llvm::formatv("\tInBaseClass: {0}\n", S.InBaseClass);
   OS << llvm::formatv("\tContext: {0}\n", getCompletionKindString(S.Context));
   OS << llvm::formatv("\tQuery type: {0}\n", static_cast<int>(S.Query));
   OS << llvm::formatv("\tScope: {0}\n", static_cast<int>(S.Scope));
@@ -445,17 +513,18 @@
   OS << llvm::formatv("\tSymbol scope: {0}\n",
                       S.SymbolScope ? *S.SymbolScope : "<None>");
 
+  SymbolRelevanceSignals::DerivedSignals Derived = S.calculateDerivedSignals();
   if (S.FileProximityMatch) {
-    auto Score = uriProximity(S.SymbolURI, S.FileProximityMatch);
-    OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n",
-                        Score.first, Score.second);
+    unsigned Score = fileProximityScore(Derived.FileProximityDistance);
+    OS << llvm::formatv("\tIndex URI proximity: {0} (distance={1})\n", Score,
+                        Derived.FileProximityDistance);
   }
   OS << llvm::formatv("\tSema file proximity: {0}\n", S.SemaFileProximityScore);
 
   OS << llvm::formatv("\tSema says in scope: {0}\n", S.SemaSaysInScope);
   if (S.ScopeProximityMatch)
     OS << llvm::formatv("\tIndex scope boost: {0}\n",
-                        scopeBoost(*S.ScopeProximityMatch, S.SymbolScope));
+                        scopeProximityScore(Derived.ScopeProximityDistance));
 
   OS << llvm::formatv(
       "\tType matched preferred: {0} (Context type: {1}, Symbol type: {2}\n",
@@ -468,6 +537,65 @@
   return SymbolQuality * SymbolRelevance;
 }
 
+DecisionForestScores
+evaluateDecisionForest(const SymbolQualitySignals &Quality,
+                       const SymbolRelevanceSignals &Relevance, float Base) {
+  Example E;
+  E.setIsDeprecated(Quality.Deprecated);
+  E.setIsReservedName(Quality.ReservedName);
+  E.setIsImplementationDetail(Quality.ImplementationDetail);
+  E.setNumReferences(Quality.References);
+  E.setSymbolCategory(Quality.Category);
+
+  SymbolRelevanceSignals::DerivedSignals Derived =
+      Relevance.calculateDerivedSignals();
+  int NumMatch = 0;
+  if (Relevance.ContextWords) {
+    for (const auto &Word : Relevance.ContextWords->keys()) {
+      if (Relevance.Name.contains_lower(Word)) {
+        ++NumMatch;
+      }
+    }
+  }
+  E.setIsNameInContext(NumMatch > 0);
+  E.setNumNameInContext(NumMatch);
+  E.setFractionNameInContext(
+      Relevance.ContextWords && !Relevance.ContextWords->empty()
+          ? NumMatch * 1.0 / Relevance.ContextWords->size()
+          : 0);
+  E.setIsInBaseClass(Relevance.InBaseClass);
+  E.setFileProximityDistanceCost(Derived.FileProximityDistance);
+  E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
+  E.setSymbolScopeDistanceCost(Derived.ScopeProximityDistance);
+  E.setSemaSaysInScope(Relevance.SemaSaysInScope);
+  E.setScope(Relevance.Scope);
+  E.setContextKind(Relevance.Context);
+  E.setIsInstanceMember(Relevance.IsInstanceMember);
+  E.setHadContextType(Relevance.HadContextType);
+  E.setHadSymbolType(Relevance.HadSymbolType);
+  E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);
+
+  DecisionForestScores Scores;
+  // Exponentiating DecisionForest prediction makes the score of each tree a
+  // multiplciative boost (like NameMatch). This allows us to weigh the
+  // prediciton score and NameMatch appropriately.
+  Scores.ExcludingName = pow(Base, Evaluate(E));
+  // Following cases are not part of the generated training dataset:
+  //  - Symbols with `NeedsFixIts`.
+  //  - Forbidden symbols.
+  //  - Keywords: Dataset contains only macros and decls.
+  if (Relevance.NeedsFixIts)
+    Scores.ExcludingName *= 0.5;
+  if (Relevance.Forbidden)
+    Scores.ExcludingName *= 0;
+  if (Quality.Category == SymbolQualitySignals::Keyword)
+    Scores.ExcludingName *= 4;
+
+  // NameMatch should be a multiplier on total score to support rescoring.
+  Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
+  return Scores;
+}
+
 // Produces an integer that sorts in the same order as F.
 // That is: a < b <==> encodeFloat(a) < encodeFloat(b).
 static uint32_t encodeFloat(float F) {