fix: score + rank entities, limit to top 5, full body for decisions
Before: "Robin Hanson MetaDAO proposal" returned 34 entities (39K chars) with the target record buried at position 13. No relevance scoring. After: entities scored by query token overlap (name 3x, alias 1x, bigram 5x), limited to top 5 results. Decision records get full body (2K chars) instead of 500-char truncation. Top result gets 2K in prompt, rest get 500. Pentagon-Agent: Epimetheus <3D35839A-7722-4740-B93D-51157F7D5E70>
This commit is contained in:
parent
3ed0f20fa1
commit
089b4609d5
1 changed files with 66 additions and 17 deletions
|
|
@ -175,8 +175,12 @@ class KBIndex:
|
||||||
# Extract wiki-linked claim references from body
|
# Extract wiki-linked claim references from body
|
||||||
related_claims = re.findall(r"\[\[([^\]]+)\]\]", body)
|
related_claims = re.findall(r"\[\[([^\]]+)\]\]", body)
|
||||||
|
|
||||||
# Body excerpt — for decisions, lead with summary for better prompt context
|
# Body excerpt — decisions get full body (typically 1-2K), entities get 500 chars
|
||||||
if summary:
|
ft = fm.get("type")
|
||||||
|
if ft == "decision":
|
||||||
|
# Include full body for decision records — these are what users ask about
|
||||||
|
overview = body[:2000] if body else (summary or "")
|
||||||
|
elif summary:
|
||||||
overview = f"{summary} "
|
overview = f"{summary} "
|
||||||
body_lines = [l for l in body.split("\n") if l.strip() and not l.startswith("#")]
|
body_lines = [l for l in body.split("\n") if l.strip() and not l.startswith("#")]
|
||||||
remaining = 500 - len(overview)
|
remaining = 500 - len(overview)
|
||||||
|
|
@ -295,11 +299,12 @@ class KBIndex:
|
||||||
|
|
||||||
|
|
||||||
def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None,
|
def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None,
|
||||||
max_claims: int = 8, max_positions: int = 3) -> KBContext:
|
max_claims: int = 8, max_entities: int = 5,
|
||||||
|
max_positions: int = 3) -> KBContext:
|
||||||
"""Main entry point: retrieve full KB context for a query.
|
"""Main entry point: retrieve full KB context for a query.
|
||||||
|
|
||||||
Three layers:
|
Three layers:
|
||||||
1. Entity resolution — match query tokens to entities
|
1. Entity resolution — match query tokens to entities, scored by relevance
|
||||||
2. Claim search — substring + keyword matching on titles and descriptions
|
2. Claim search — substring + keyword matching on titles and descriptions
|
||||||
3. Agent context — positions and beliefs referencing matched entities/claims
|
3. Agent context — positions and beliefs referencing matched entities/claims
|
||||||
"""
|
"""
|
||||||
|
|
@ -314,31 +319,41 @@ def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None,
|
||||||
query_tokens = _tokenize(query_lower)
|
query_tokens = _tokenize(query_lower)
|
||||||
|
|
||||||
# ── Layer 1: Entity Resolution ──
|
# ── Layer 1: Entity Resolution ──
|
||||||
matched_entity_indices = set()
|
# Score each entity by how many query tokens match its aliases/name
|
||||||
|
scored_entities: list[tuple[float, int]] = [] # (score, index)
|
||||||
|
|
||||||
|
# Build a set of candidate indices from alias map + substring matching
|
||||||
|
candidate_indices = set()
|
||||||
for token in query_tokens:
|
for token in query_tokens:
|
||||||
# Direct alias match
|
|
||||||
if token in index._entity_alias_map:
|
if token in index._entity_alias_map:
|
||||||
matched_entity_indices.update(index._entity_alias_map[token])
|
candidate_indices.update(index._entity_alias_map[token])
|
||||||
# Strip $ prefix for ticker lookup
|
|
||||||
if token.startswith("$"):
|
if token.startswith("$"):
|
||||||
bare = token[1:]
|
bare = token[1:]
|
||||||
if bare in index._entity_alias_map:
|
if bare in index._entity_alias_map:
|
||||||
matched_entity_indices.update(index._entity_alias_map[bare])
|
candidate_indices.update(index._entity_alias_map[bare])
|
||||||
|
|
||||||
# Also try substring match on entity names (e.g. "omnipair" in "OmniPair Protocol")
|
|
||||||
for i, ent in enumerate(index._entities):
|
for i, ent in enumerate(index._entities):
|
||||||
for token in query_tokens:
|
for token in query_tokens:
|
||||||
if len(token) >= 3 and token in ent["name"].lower():
|
if len(token) >= 3 and token in ent["name"].lower():
|
||||||
matched_entity_indices.add(i)
|
candidate_indices.add(i)
|
||||||
|
|
||||||
for idx in matched_entity_indices:
|
# Score candidates by query token overlap
|
||||||
|
for idx in candidate_indices:
|
||||||
|
ent = index._entities[idx]
|
||||||
|
score = _score_entity(query_lower, query_tokens, ent)
|
||||||
|
if score > 0:
|
||||||
|
scored_entities.append((score, idx))
|
||||||
|
|
||||||
|
scored_entities.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
for score, idx in scored_entities[:max_entities]:
|
||||||
ent = index._entities[idx]
|
ent = index._entities[idx]
|
||||||
ctx.entities.append(EntityMatch(
|
ctx.entities.append(EntityMatch(
|
||||||
name=ent["name"],
|
name=ent["name"],
|
||||||
path=ent["path"],
|
path=ent["path"],
|
||||||
entity_type=ent["type"],
|
entity_type=ent["type"],
|
||||||
domain=ent["domain"],
|
domain=ent["domain"],
|
||||||
overview=_sanitize_for_prompt(ent["overview"]),
|
overview=_sanitize_for_prompt(ent["overview"], max_len=2000),
|
||||||
tags=ent["tags"],
|
tags=ent["tags"],
|
||||||
related_claims=ent["related_claims"],
|
related_claims=ent["related_claims"],
|
||||||
))
|
))
|
||||||
|
|
@ -415,6 +430,36 @@ def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None,
|
||||||
# ─── Scoring ──────────────────────────────────────────────────────────
|
# ─── Scoring ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _score_entity(query_lower: str, query_tokens: list[str], entity: dict) -> float:
|
||||||
|
"""Score an entity against a query. Higher = more relevant."""
|
||||||
|
name_lower = entity["name"].lower()
|
||||||
|
overview_lower = entity.get("overview", "").lower()
|
||||||
|
aliases = entity.get("aliases", [])
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
for token in query_tokens:
|
||||||
|
if len(token) < 2:
|
||||||
|
continue
|
||||||
|
# Name match (highest signal)
|
||||||
|
if token in name_lower:
|
||||||
|
score += 3.0
|
||||||
|
# Alias match (tags, proposer, parent_entity, tickers)
|
||||||
|
elif any(token == a or token in a for a in aliases):
|
||||||
|
score += 1.0
|
||||||
|
# Overview match (body content)
|
||||||
|
elif token in overview_lower:
|
||||||
|
score += 0.5
|
||||||
|
|
||||||
|
# Boost multi-word name matches (e.g. "robin hanson" in entity name)
|
||||||
|
if len(query_tokens) >= 2:
|
||||||
|
bigrams = [f"{query_tokens[i]} {query_tokens[i+1]}" for i in range(len(query_tokens) - 1)]
|
||||||
|
for bg in bigrams:
|
||||||
|
if bg in name_lower:
|
||||||
|
score += 5.0
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
def _score_claim(query_lower: str, query_tokens: list[str], claim: dict,
|
def _score_claim(query_lower: str, query_tokens: list[str], claim: dict,
|
||||||
entity_claim_titles: set[str]) -> float:
|
entity_claim_titles: set[str]) -> float:
|
||||||
"""Score a claim against a query. Higher = more relevant."""
|
"""Score a claim against a query. Higher = more relevant."""
|
||||||
|
|
@ -490,14 +535,14 @@ def _tokenize(text: str) -> list[str]:
|
||||||
return [t for t in tokens if len(t) >= 2]
|
return [t for t in tokens if len(t) >= 2]
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_for_prompt(text: str) -> str:
|
def _sanitize_for_prompt(text: str, max_len: int = 1000) -> str:
|
||||||
"""Sanitize content before injecting into LLM prompt (Ganymede: security)."""
|
"""Sanitize content before injecting into LLM prompt (Ganymede: security)."""
|
||||||
# Strip code blocks
|
# Strip code blocks
|
||||||
text = re.sub(r"```.*?```", "[code block removed]", text, flags=re.DOTALL)
|
text = re.sub(r"```.*?```", "[code block removed]", text, flags=re.DOTALL)
|
||||||
# Strip anything that looks like system instructions
|
# Strip anything that looks like system instructions
|
||||||
text = re.sub(r"(system:|assistant:|human:|<\|.*?\|>)", "", text, flags=re.IGNORECASE)
|
text = re.sub(r"(system:|assistant:|human:|<\|.*?\|>)", "", text, flags=re.IGNORECASE)
|
||||||
# Truncate
|
# Truncate
|
||||||
return text[:1000]
|
return text[:max_len]
|
||||||
|
|
||||||
|
|
||||||
def _extract_relevant_paragraphs(text: str, terms: set[str], max_paragraphs: int = 2) -> list[str]:
|
def _extract_relevant_paragraphs(text: str, terms: set[str], max_paragraphs: int = 2) -> list[str]:
|
||||||
|
|
@ -522,9 +567,13 @@ def format_context_for_prompt(ctx: KBContext) -> str:
|
||||||
|
|
||||||
if ctx.entities:
|
if ctx.entities:
|
||||||
sections.append("## Matched Entities")
|
sections.append("## Matched Entities")
|
||||||
for ent in ctx.entities:
|
for i, ent in enumerate(ctx.entities):
|
||||||
sections.append(f"**{ent.name}** ({ent.entity_type}, {ent.domain})")
|
sections.append(f"**{ent.name}** ({ent.entity_type}, {ent.domain})")
|
||||||
sections.append(ent.overview)
|
# Top entity gets full content (up to 2000 chars), rest get truncated
|
||||||
|
if i == 0:
|
||||||
|
sections.append(ent.overview[:2000])
|
||||||
|
else:
|
||||||
|
sections.append(ent.overview[:500])
|
||||||
if ent.related_claims:
|
if ent.related_claims:
|
||||||
sections.append("Related claims: " + ", ".join(ent.related_claims[:5]))
|
sections.append("Related claims: " + ", ".join(ent.related_claims[:5]))
|
||||||
sections.append("")
|
sections.append("")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue