diff --git a/telegram/kb_retrieval.py b/telegram/kb_retrieval.py index 9b0284b..ca4921a 100644 --- a/telegram/kb_retrieval.py +++ b/telegram/kb_retrieval.py @@ -175,8 +175,12 @@ class KBIndex: # Extract wiki-linked claim references from body related_claims = re.findall(r"\[\[([^\]]+)\]\]", body) - # Body excerpt — for decisions, lead with summary for better prompt context - if summary: + # Body excerpt — decisions get full body (typically 1-2K), entities get 500 chars + ft = fm.get("type") + if ft == "decision": + # Include full body for decision records — these are what users ask about + overview = body[:2000] if body else (summary or "") + elif summary: overview = f"{summary} " body_lines = [l for l in body.split("\n") if l.strip() and not l.startswith("#")] remaining = 500 - len(overview) @@ -295,11 +299,12 @@ class KBIndex: def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None, - max_claims: int = 8, max_positions: int = 3) -> KBContext: + max_claims: int = 8, max_entities: int = 5, + max_positions: int = 3) -> KBContext: """Main entry point: retrieve full KB context for a query. Three layers: - 1. Entity resolution — match query tokens to entities + 1. Entity resolution — match query tokens to entities, scored by relevance 2. Claim search — substring + keyword matching on titles and descriptions 3. Agent context — positions and beliefs referencing matched entities/claims """ @@ -314,31 +319,41 @@ def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None, query_tokens = _tokenize(query_lower) # ── Layer 1: Entity Resolution ── - matched_entity_indices = set() + # Score each entity by how many query tokens match its aliases/name + scored_entities: list[tuple[float, int]] = [] # (score, index) + + # Build a set of candidate indices from alias map + substring matching + candidate_indices = set() for token in query_tokens: - # Direct alias match if token in index._entity_alias_map: - matched_entity_indices.update(index._entity_alias_map[token]) - # Strip $ prefix for ticker lookup + candidate_indices.update(index._entity_alias_map[token]) if token.startswith("$"): bare = token[1:] if bare in index._entity_alias_map: - matched_entity_indices.update(index._entity_alias_map[bare]) + candidate_indices.update(index._entity_alias_map[bare]) - # Also try substring match on entity names (e.g. "omnipair" in "OmniPair Protocol") for i, ent in enumerate(index._entities): for token in query_tokens: if len(token) >= 3 and token in ent["name"].lower(): - matched_entity_indices.add(i) + candidate_indices.add(i) - for idx in matched_entity_indices: + # Score candidates by query token overlap + for idx in candidate_indices: + ent = index._entities[idx] + score = _score_entity(query_lower, query_tokens, ent) + if score > 0: + scored_entities.append((score, idx)) + + scored_entities.sort(key=lambda x: x[0], reverse=True) + + for score, idx in scored_entities[:max_entities]: ent = index._entities[idx] ctx.entities.append(EntityMatch( name=ent["name"], path=ent["path"], entity_type=ent["type"], domain=ent["domain"], - overview=_sanitize_for_prompt(ent["overview"]), + overview=_sanitize_for_prompt(ent["overview"], max_len=2000), tags=ent["tags"], related_claims=ent["related_claims"], )) @@ -415,6 +430,36 @@ def retrieve_context(query: str, repo_dir: str, index: KBIndex | None = None, # ─── Scoring ────────────────────────────────────────────────────────── +def _score_entity(query_lower: str, query_tokens: list[str], entity: dict) -> float: + """Score an entity against a query. Higher = more relevant.""" + name_lower = entity["name"].lower() + overview_lower = entity.get("overview", "").lower() + aliases = entity.get("aliases", []) + score = 0.0 + + for token in query_tokens: + if len(token) < 2: + continue + # Name match (highest signal) + if token in name_lower: + score += 3.0 + # Alias match (tags, proposer, parent_entity, tickers) + elif any(token == a or token in a for a in aliases): + score += 1.0 + # Overview match (body content) + elif token in overview_lower: + score += 0.5 + + # Boost multi-word name matches (e.g. "robin hanson" in entity name) + if len(query_tokens) >= 2: + bigrams = [f"{query_tokens[i]} {query_tokens[i+1]}" for i in range(len(query_tokens) - 1)] + for bg in bigrams: + if bg in name_lower: + score += 5.0 + + return score + + def _score_claim(query_lower: str, query_tokens: list[str], claim: dict, entity_claim_titles: set[str]) -> float: """Score a claim against a query. Higher = more relevant.""" @@ -490,14 +535,14 @@ def _tokenize(text: str) -> list[str]: return [t for t in tokens if len(t) >= 2] -def _sanitize_for_prompt(text: str) -> str: +def _sanitize_for_prompt(text: str, max_len: int = 1000) -> str: """Sanitize content before injecting into LLM prompt (Ganymede: security).""" # Strip code blocks text = re.sub(r"```.*?```", "[code block removed]", text, flags=re.DOTALL) # Strip anything that looks like system instructions text = re.sub(r"(system:|assistant:|human:|<\|.*?\|>)", "", text, flags=re.IGNORECASE) # Truncate - return text[:1000] + return text[:max_len] def _extract_relevant_paragraphs(text: str, terms: set[str], max_paragraphs: int = 2) -> list[str]: @@ -522,9 +567,13 @@ def format_context_for_prompt(ctx: KBContext) -> str: if ctx.entities: sections.append("## Matched Entities") - for ent in ctx.entities: + for i, ent in enumerate(ctx.entities): sections.append(f"**{ent.name}** ({ent.entity_type}, {ent.domain})") - sections.append(ent.overview) + # Top entity gets full content (up to 2000 chars), rest get truncated + if i == 0: + sections.append(ent.overview[:2000]) + else: + sections.append(ent.overview[:500]) if ent.related_claims: sections.append("Related claims: " + ", ".join(ent.related_claims[:5])) sections.append("")