#!/usr/bin/env python3 """Retrieval orchestration — keyword, vector, RRF merge, query decomposition. All functions are stateless. LLM calls are injected via callback (llm_fn). No Telegram types, no SQLite, no module-level state. Extracted from bot.py (Ganymede decomposition spec). """ import logging import re import time from typing import Any, Callable, Awaitable from lib.config import ( RETRIEVAL_RRF_K as RRF_K, RETRIEVAL_ENTITY_BOOST as ENTITY_BOOST, RETRIEVAL_MAX_RESULTS as MAX_RETRIEVAL_CLAIMS, ) logger = logging.getLogger("tg.retrieval") # Type alias for the LLM callback injected by bot.py LLMFn = Callable[[str, str, int], Awaitable[str | None]] # (model, prompt, max_tokens) → response def rrf_merge_context(kb_ctx: Any, vector_meta: dict, kb_read_dir: str) -> tuple[str, list[dict]]: """Merge keyword and vector retrieval into a single ranked claim list via RRF. Reciprocal Rank Fusion: RRF(d) = Σ 1/(k + rank_i(d)) k=20 tuned for small result sets (5-10 per source). Entity-aware boosting: claims wiki-linked from matched entities get +50% RRF score. Returns (formatted_text, ranked_claims_for_audit). """ # Collect claim titles wiki-linked from matched entities entity_linked_titles: set[str] = set() if kb_ctx and kb_ctx.entities: for ent in kb_ctx.entities: for t in ent.related_claims: entity_linked_titles.add(t.lower()) # --- Build per-claim RRF scores --- claim_map: dict[str, dict] = {} # Keyword claims (already sorted by keyword score desc) for rank, claim in enumerate(kb_ctx.claims): p = claim.path if kb_read_dir and p.startswith(kb_read_dir): p = p[len(kb_read_dir):].lstrip("/") rrf = 1.0 / (RRF_K + rank) claim_map[p] = { "rrf_score": rrf, "title": claim.title, "domain": claim.domain, "confidence": claim.confidence, "description": claim.description, "source": "keyword", "vector_score": None, } # Vector results (already sorted by cosine desc) for rank, vr in enumerate(vector_meta.get("direct_results", [])): p = vr.get("path", "") rrf = 1.0 / (RRF_K + rank) if p in claim_map: claim_map[p]["rrf_score"] += rrf claim_map[p]["source"] = "vector+keyword" claim_map[p]["vector_score"] = vr.get("score") else: claim_map[p] = { "rrf_score": rrf, "title": vr.get("title", ""), "domain": vr.get("domain", ""), "confidence": "", "description": "", "source": "vector", "vector_score": vr.get("score"), } # Apply entity-linked boost if entity_linked_titles: for p, info in claim_map.items(): if info["title"].lower() in entity_linked_titles: info["rrf_score"] *= ENTITY_BOOST info["source"] = info["source"] + "+entity" # Sort by RRF score desc ranked = sorted(claim_map.items(), key=lambda x: x[1]["rrf_score"], reverse=True) # --- Format output --- sections = [] # Entities section (keyword search is still best for entity resolution) if kb_ctx.entities: sections.append("## Matched Entities") for i, ent in enumerate(kb_ctx.entities): sections.append(f"**{ent.name}** ({ent.entity_type}, {ent.domain})") if i < 3: sections.append(ent.overview[:8000]) else: sections.append(ent.overview[:500]) if ent.related_claims: sections.append("Related claims: " + ", ".join(ent.related_claims[:5])) sections.append("") # Merged claims section (RRF-ranked) if ranked: sections.append("## Retrieved Claims") for path, info in ranked[:MAX_RETRIEVAL_CLAIMS]: line = f"- **{info['title']}**" meta_parts = [] if info["confidence"]: meta_parts.append(f"confidence: {info['confidence']}") if info["domain"]: meta_parts.append(info["domain"]) if info["vector_score"] is not None: meta_parts.append(f"{int(info['vector_score'] * 100)}% semantic match") if meta_parts: line += f" ({', '.join(meta_parts)})" sections.append(line) if info["description"]: sections.append(f" {info['description']}") sections.append("") # Positions section if kb_ctx.positions: sections.append("## Agent Positions") for pos in kb_ctx.positions: sections.append(f"**{pos.agent}**: {pos.title}") sections.append(pos.content[:200]) sections.append("") # Beliefs section if kb_ctx.belief_excerpts: sections.append("## Relevant Beliefs") for exc in kb_ctx.belief_excerpts: sections.append(exc) sections.append("") # Build audit-friendly ranked list claims_audit = [] for i, (path, info) in enumerate(ranked[:MAX_RETRIEVAL_CLAIMS]): claims_audit.append({ "path": path, "title": info["title"], "score": round(info["rrf_score"], 4), "rank": i + 1, "source": info["source"], }) if not sections: return "No relevant KB content found for this query.", claims_audit # Stats footer n_vector = sum(1 for _, v in ranked if v["source"] in ("vector", "vector+keyword")) n_keyword = sum(1 for _, v in ranked if v["source"] in ("keyword", "vector+keyword")) n_both = sum(1 for _, v in ranked if v["source"] == "vector+keyword") sections.append(f"---\nKB: {kb_ctx.stats.get('total_claims', '?')} claims, " f"{kb_ctx.stats.get('total_entities', '?')} entities. " f"Retrieved: {len(ranked)} claims (vector: {n_vector}, keyword: {n_keyword}, both: {n_both}).") return "\n".join(sections), claims_audit async def reformulate_query( query: str, history: list[dict], llm_fn: LLMFn, model: str, ) -> str: """Rewrite conversational follow-ups into standalone search queries. If there's no conversation history or the query is already standalone, returns the original query unchanged. """ if not history: return query try: last_exchange = history[-1] recent_context = "" if last_exchange.get("user"): recent_context += f"User: {last_exchange['user'][:300]}\n" if last_exchange.get("bot"): recent_context += f"Bot: {last_exchange['bot'][:300]}\n" reformulate_prompt = ( f"A user is in a conversation. Given the recent exchange and their new message, " f"rewrite the new message as a STANDALONE search query that captures what they're " f"actually asking about. The query should work for semantic search — specific topics, " f"entities, and concepts.\n\n" f"Recent exchange:\n{recent_context}\n" f"New message: {query}\n\n" f"If the message is already a clear standalone question or topic, return it unchanged.\n" f"If it's a follow-up, correction, or reference to the conversation, rewrite it.\n\n" f"Return ONLY the rewritten query, nothing else. Max 30 words." ) reformulated = await llm_fn(model, reformulate_prompt, 80) if reformulated and reformulated.strip() and len(reformulated.strip()) > 3: logger.info("Query reformulated: '%s' → '%s'", query[:60], reformulated.strip()[:60]) return reformulated.strip() except Exception as e: logger.warning("Query reformulation failed: %s", e) return query async def decompose_query( query: str, llm_fn: LLMFn, model: str, ) -> list[str]: """Split multi-part queries into focused sub-queries for vector search. Only decomposes if query is >8 words and contains a conjunction or multiple question marks. Otherwise returns [query] unchanged. """ try: words = query.split() has_conjunction = any(w.lower() in ("and", "but", "also", "plus", "versus", "vs") for w in words) has_question_marks = query.count("?") > 1 if len(words) > 8 and (has_conjunction or has_question_marks): decompose_prompt = ( f"Split this query into 2-3 focused search sub-queries. Each sub-query should " f"target one specific concept or question. Return one sub-query per line, nothing else.\n\n" f"Query: {query}\n\n" f"If the query is already focused on one topic, return it unchanged on a single line." ) decomposed = await llm_fn(model, decompose_prompt, 150) if decomposed: parts = [p.strip().lstrip("0123456789.-) ") for p in decomposed.strip().split("\n") if p.strip()] if 1 < len(parts) <= 4: logger.info("Query decomposed: '%s' → %s", query[:60], parts) return parts except Exception as e: logger.warning("Query decomposition failed: %s", e) return [query] def vector_search_merge( sub_queries: list[str], retrieve_vector_fn: Callable[[str], tuple[str, dict]], ) -> dict: """Run vector search on each sub-query, dedup by path (keep highest score). Returns merged vector_meta dict with keys: direct_results, expanded_results, layers_hit, duration_ms, errors. """ all_direct = [] all_expanded = [] layers = [] total_duration = 0 errors = [] for sq in sub_queries: _, v_meta = retrieve_vector_fn(sq) all_direct.extend(v_meta.get("direct_results", [])) all_expanded.extend(v_meta.get("expanded_results", [])) layers.extend(v_meta.get("layers_hit", [])) total_duration += v_meta.get("duration_ms", 0) if v_meta.get("error"): errors.append(v_meta["error"]) # Dedup by path (keep highest score) seen: dict[str, dict] = {} for vr in all_direct: p = vr.get("path", "") if p not in seen or vr.get("score", 0) > seen[p].get("score", 0): seen[p] = vr result = { "direct_results": list(seen.values()), "expanded_results": all_expanded, "layers_hit": list(set(layers)), "duration_ms": total_duration, } if errors: result["errors"] = errors return result async def orchestrate_retrieval( text: str, search_query: str, kb_read_dir: str, kb_index: Any, llm_fn: LLMFn, triage_model: str, retrieve_context_fn: Callable, retrieve_vector_fn: Callable[[str], tuple[str, dict]], kb_scope: list[str] | None = None, ) -> dict: """Full retrieval pipeline: keyword → decompose → vector → RRF merge. Returns dict with keys: kb_context_text, claims_audit, retrieval_layers, vector_meta, tool_calls, kb_ctx. """ tool_calls = [] # 1. Keyword retrieval (entity resolution needs full context) t_kb = time.monotonic() kb_ctx = retrieve_context_fn(search_query, kb_read_dir, index=kb_index, kb_scope=kb_scope) kb_duration = int((time.monotonic() - t_kb) * 1000) retrieval_layers = ["keyword"] if (kb_ctx and (kb_ctx.entities or kb_ctx.claims)) else [] tool_calls.append({ "tool": "retrieve_context", "input": {"query": search_query[:200], "original_query": text[:200] if search_query != text else None}, "output": {"entities": len(kb_ctx.entities) if kb_ctx else 0, "claims": len(kb_ctx.claims) if kb_ctx else 0}, "duration_ms": kb_duration, }) # 2. Query decomposition t_decompose = time.monotonic() sub_queries = await decompose_query(search_query, llm_fn, triage_model) decompose_duration = int((time.monotonic() - t_decompose) * 1000) if len(sub_queries) > 1: tool_calls.append({ "tool": "query_decompose", "input": {"query": search_query[:200]}, "output": {"sub_queries": sub_queries}, "duration_ms": decompose_duration, }) # 3. Vector search across sub-queries vector_meta = vector_search_merge(sub_queries, retrieve_vector_fn) # 4. RRF merge kb_context_text, claims_audit = rrf_merge_context(kb_ctx, vector_meta, kb_read_dir) retrieval_layers.extend(vector_meta.get("layers_hit", [])) tool_calls.append({ "tool": "retrieve_qdrant_context", "input": {"query": text[:200]}, "output": {"direct_hits": len(vector_meta.get("direct_results", [])), "expanded": len(vector_meta.get("expanded_results", []))}, "duration_ms": vector_meta.get("duration_ms", 0), }) return { "kb_context_text": kb_context_text, "claims_audit": claims_audit, "retrieval_layers": retrieval_layers, "vector_meta": vector_meta, "tool_calls": tool_calls, "kb_ctx": kb_ctx, }