From e17e6c25db80c9c82769110ecaf0b906337bbf02 Mon Sep 17 00:00:00 2001 From: m3taversal Date: Sat, 28 Mar 2026 22:34:45 +0000 Subject: [PATCH] feat: two-pass retrieval with sort order and graph expansion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lib/search.py — shared search library: - Pass 1 (default): top 5 from Qdrant, score >= 0.70, no expansion - Pass 2 (expand=True): next 5 via offset=5, score >= 0.60, plus graph expansion from YAML frontmatter edges. Hard cap 10 total. - Sort order: cosine desc → challenged_by → other graph-expanded - result_type internal tag for stable sort (direct/challenge/graph) - Module-level constants for easy threshold tuning post-calibration - Structural file exclusion (_map.md, _overview.md) - Within-vector dedup via _dedup_hits() Caller updates: - kb_retrieval.py: retrieve_vector_context() calls search(expand=True) - diagnostics/app.py: search endpoint passes expand query param - Argus imports from lib/search.py via sys.path (no longer owns search) Tests: 5 new tests covering pass1-only, pass2 expansion, hard cap, sort order, challenges-before-other-expansion. Co-Authored-By: Claude Opus 4.6 (1M context) --- diagnostics/app.py | 129 ++------- lib/search.py | 415 +++++++++++++++++++++++++++ telegram/kb_retrieval.py | 104 +++++++ tests/test_search.py | 604 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 1142 insertions(+), 110 deletions(-) create mode 100644 lib/search.py create mode 100644 tests/test_search.py diff --git a/diagnostics/app.py b/diagnostics/app.py index 04bb2f3..96beb2e 100644 --- a/diagnostics/app.py +++ b/diagnostics/app.py @@ -13,11 +13,16 @@ import logging import os import sqlite3 import statistics +import sys import urllib.request from datetime import datetime, timezone from pathlib import Path +# Add pipeline lib to path so we can import shared modules +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline")) + from aiohttp import web +from lib.search import search as kb_search, embed_query, search_qdrant logger = logging.getLogger("argus") @@ -27,12 +32,7 @@ PORT = int(os.environ.get("ARGUS_PORT", "8081")) REPO_DIR = Path(os.environ.get("REPO_DIR", "/opt/teleo-eval/workspaces/main")) CLAIM_INDEX_URL = os.environ.get("CLAIM_INDEX_URL", "http://localhost:8080/claim-index") -# Search config -QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") -QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims") -OPENROUTER_KEY_FILE = Path(os.environ.get("OPENROUTER_KEY_FILE", "/opt/teleo-eval/secrets/openrouter-key")) -EMBEDDING_MODEL = "text-embedding-3-small" -EMBEDDING_DIMS = 1536 +# Search config — moved to lib/search.py (shared with Telegram bot + agents) # Auth config API_KEY_FILE = Path(os.environ.get("ARGUS_API_KEY_FILE", "/opt/teleo-eval/secrets/argus-api-key")) @@ -483,82 +483,7 @@ async def auth_middleware(request, handler): # ─── Embedding + Search ────────────────────────────────────────────────────── - - -def _get_embedding_key() -> str | None: - """Load OpenRouter API key for embeddings.""" - return _load_secret(OPENROUTER_KEY_FILE) - - -def _embed_query(text: str, api_key: str) -> list[float] | None: - """Embed a query string via OpenRouter (OpenAI-compatible endpoint). - - Uses urllib to avoid adding httpx/openai as dependencies. - """ - payload = json.dumps({ - "model": f"openai/{EMBEDDING_MODEL}", - "input": text, - }).encode() - req = urllib.request.Request( - "https://openrouter.ai/api/v1/embeddings", - data=payload, - headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - }, - ) - try: - with urllib.request.urlopen(req, timeout=10) as resp: - data = json.loads(resp.read()) - return data["data"][0]["embedding"] - except Exception as e: - logger.error("Embedding failed: %s", e) - return None - - -def _search_qdrant(vector: list[float], limit: int = 10, - domain: str | None = None, confidence: str | None = None, - exclude: list[str] | None = None) -> list[dict]: - """Search Qdrant collection for nearest claims. - - Uses urllib for zero-dependency Qdrant access (REST API). - """ - must_filters = [] - if domain: - must_filters.append({"key": "domain", "match": {"value": domain}}) - if confidence: - must_filters.append({"key": "confidence", "match": {"value": confidence}}) - - must_not_filters = [] - if exclude: - for path in exclude: - must_not_filters.append({"key": "claim_path", "match": {"value": path}}) - - payload = { - "vector": vector, - "limit": limit, - "with_payload": True, - "score_threshold": 0.3, - } - if must_filters or must_not_filters: - payload["filter"] = {} - if must_filters: - payload["filter"]["must"] = must_filters - if must_not_filters: - payload["filter"]["must_not"] = must_not_filters - - req = urllib.request.Request( - f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search", - data=json.dumps(payload).encode(), - headers={"Content-Type": "application/json"}, - ) - try: - with urllib.request.urlopen(req, timeout=10) as resp: - data = json.loads(resp.read()) - return data.get("result", []) - except Exception as e: - logger.error("Qdrant search failed: %s", e) - return [] +# Moved to lib/search.py — imported at top of file as kb_search, embed_query, search_qdrant # ─── Usage logging ─────────────────────────────────────────────────────────── @@ -701,7 +626,7 @@ async def handle_api_domains(request): async def handle_api_search(request): - """GET /api/search — semantic search over claims via Qdrant. + """GET /api/search — semantic search over claims via Qdrant + graph expansion. Query params: q: search query (required) @@ -709,6 +634,7 @@ async def handle_api_search(request): confidence: filter by confidence level (optional) limit: max results, default 10 (optional) exclude: comma-separated claim paths to exclude (optional) + expand: enable graph expansion, default true (optional) """ query = request.query.get("q", "").strip() if not query: @@ -719,36 +645,19 @@ async def handle_api_search(request): limit = min(int(request.query.get("limit", "10")), 50) exclude_raw = request.query.get("exclude", "") exclude = [p.strip() for p in exclude_raw.split(",") if p.strip()] if exclude_raw else None + expand = request.query.get("expand", "true").lower() != "false" - # Embed the query - api_key = _get_embedding_key() - if not api_key: - return web.json_response({"error": "embedding service unavailable"}, status=503) + # Use shared search library (Layer 1 + Layer 2) + result = kb_search(query, expand=expand, + domain=domain, confidence=confidence, exclude=exclude) - vector = _embed_query(query, api_key) - if vector is None: - return web.json_response({"error": "embedding failed"}, status=502) + if "error" in result: + error = result["error"] + if error == "embedding_failed": + return web.json_response({"error": "embedding failed"}, status=502) + return web.json_response({"error": error}, status=500) - # Search Qdrant - results = _search_qdrant(vector, limit=limit, domain=domain, - confidence=confidence, exclude=exclude) - - # Format response - claims = [] - for hit in results: - payload = hit.get("payload", {}) - claims.append({ - "claim_title": payload.get("claim_title", ""), - "claim_path": payload.get("claim_path", ""), - "similarity_score": round(hit.get("score", 0), 4), - "domain": payload.get("domain", ""), - "confidence": payload.get("confidence", ""), - "snippet": payload.get("snippet", "")[:200], - "depends_on": payload.get("depends_on", []), - "challenged_by": payload.get("challenged_by", []), - }) - - return web.json_response(claims) + return web.json_response(result) async def handle_api_usage(request): diff --git a/lib/search.py b/lib/search.py new file mode 100644 index 0000000..8796f6d --- /dev/null +++ b/lib/search.py @@ -0,0 +1,415 @@ +"""Shared Qdrant vector search library for the Teleo knowledge base. + +Provides embed + search + graph expansion as a reusable library. +Any consumer (Argus dashboard, Telegram bot, agent research) imports from here. + +Layer 1: Qdrant vector search (semantic similarity) +Layer 2: Graph expansion (1-hop via frontmatter edges) +Layer 3: Left to the caller (agent context, domain filtering) + +Owner: Epimetheus +""" + +import json +import logging +import os +import re +from pathlib import Path + +import urllib.request + +from . import config + +logger = logging.getLogger("pipeline.search") + +# --- Config (all from environment or config.py defaults) --- +QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") +QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims") +EMBEDDING_MODEL = "text-embedding-3-small" + +_OPENROUTER_KEY: str | None = None + +WIKI_LINK_RE = re.compile(r"\[\[([^\]]+)\]\]") + +# Structural files that should never be included in graph expansion results. +# These are indexes/MOCs, not claims — expanding them pulls entire domains. +STRUCTURAL_FILES = {"_map.md", "_overview.md"} + + +def _get_api_key() -> str | None: + """Load OpenRouter API key (cached after first read).""" + global _OPENROUTER_KEY + if _OPENROUTER_KEY: + return _OPENROUTER_KEY + key_file = config.SECRETS_DIR / "openrouter-key" + if key_file.exists(): + _OPENROUTER_KEY = key_file.read_text().strip() + return _OPENROUTER_KEY + _OPENROUTER_KEY = os.environ.get("OPENROUTER_API_KEY") + return _OPENROUTER_KEY + + +# --- Layer 1: Vector search --- + + +def embed_query(text: str) -> list[float] | None: + """Embed a query string via OpenRouter (OpenAI-compatible endpoint). + + Returns 1536-dim vector or None on failure. + """ + api_key = _get_api_key() + if not api_key: + logger.error("No OpenRouter API key available for embedding") + return None + + payload = json.dumps({ + "model": f"openai/{EMBEDDING_MODEL}", + "input": text[:8000], + }).encode() + req = urllib.request.Request( + "https://openrouter.ai/api/v1/embeddings", + data=payload, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + ) + try: + with urllib.request.urlopen(req, timeout=15) as resp: + data = json.loads(resp.read()) + return data["data"][0]["embedding"] + except Exception as e: + logger.error("Embedding failed: %s", e) + return None + + +def search_qdrant(vector: list[float], limit: int = 10, + domain: str | None = None, confidence: str | None = None, + exclude: list[str] | None = None, + score_threshold: float = 0.3, + offset: int = 0) -> list[dict]: + """Search Qdrant collection for nearest claims. + + Args: + offset: Skip first N results (Qdrant native offset for pagination). + + Returns list of hits: [{id, score, payload: {claim_path, claim_title, ...}}] + """ + must_filters = [] + if domain: + must_filters.append({"key": "domain", "match": {"value": domain}}) + if confidence: + must_filters.append({"key": "confidence", "match": {"value": confidence}}) + + must_not_filters = [] + if exclude: + for path in exclude: + must_not_filters.append({"key": "claim_path", "match": {"value": path}}) + + body = { + "vector": vector, + "limit": limit, + "with_payload": True, + "score_threshold": score_threshold, + } + if offset > 0: + body["offset"] = offset + if must_filters or must_not_filters: + body["filter"] = {} + if must_filters: + body["filter"]["must"] = must_filters + if must_not_filters: + body["filter"]["must_not"] = must_not_filters + + req = urllib.request.Request( + f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search", + data=json.dumps(body).encode(), + headers={"Content-Type": "application/json"}, + ) + try: + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read()) + return data.get("result", []) + except Exception as e: + logger.error("Qdrant search failed: %s", e) + return [] + + +# --- Layer 2: Graph expansion --- + + +def _parse_frontmatter_edges(path: Path) -> dict: + """Extract relationship edges from a claim's frontmatter. + + Handles both YAML formats: + depends_on: ["item1", "item2"] (inline list) + depends_on: (multi-line list) + - item1 + - item2 + + Returns {supports: [...], challenges: [...], depends_on: [...], related: [...], wiki_links: [...]}. + wiki_links are separated from explicit related edges for differential weighting. + """ + edges = {"supports": [], "challenges": [], "depends_on": [], "related": [], "wiki_links": []} + try: + text = path.read_text(errors="replace") + except Exception: + return edges + + if not text.startswith("---"): + return edges + end = text.find("\n---", 3) + if end == -1: + return edges + + fm_text = text[3:end] + + # Use YAML parser for reliable edge extraction + try: + import yaml + fm = yaml.safe_load(fm_text) + if isinstance(fm, dict): + for field in ("supports", "challenges", "depends_on", "related"): + val = fm.get(field) + if isinstance(val, list): + edges[field] = [str(v).strip() for v in val if v] + elif isinstance(val, str) and val.strip(): + edges[field] = [val.strip()] + except Exception: + pass + + # Extract wiki links from body as separate edge type (lower weight) + body = text[end + 4:] + all_explicit = set() + for field in ("supports", "challenges", "depends_on", "related"): + all_explicit.update(edges[field]) + + wiki_links = WIKI_LINK_RE.findall(body) + for link in wiki_links: + link = link.strip() + if link and link not in all_explicit and link not in edges["wiki_links"]: + edges["wiki_links"].append(link) + + return edges + + +def _resolve_claim_path(name: str, repo_root: Path) -> Path | None: + """Resolve a claim name (from frontmatter edge or wiki link) to a file path. + + Handles both naming conventions: + - "GLP-1 receptor agonists are..." → "GLP-1 receptor agonists are....md" (spaces) + - "glp-1-persistence-drops..." → "glp-1-persistence-drops....md" (slugified) + + Checks domains/, core/, foundations/, decisions/ subdirectories. + """ + # Try exact name first (spaces in filename), then slugified + candidates = [name] + slug = name.lower().replace(" ", "-").replace("_", "-") + if slug != name: + candidates.append(slug) + + for subdir in ["domains", "core", "foundations", "decisions"]: + base = repo_root / subdir + if not base.is_dir(): + continue + for candidate_name in candidates: + for md in base.rglob(f"{candidate_name}.md"): + return md + return None + + +def graph_expand(seed_paths: list[str], repo_root: Path | None = None, + max_expanded: int = 30, + challenge_weight: float = 1.5, + seen: set[str] | None = None) -> list[dict]: + """Layer 2: Expand seed claims 1-hop through knowledge graph edges. + + Traverses supports/challenges/depends_on/related/wiki_links edges in frontmatter. + Edge weights: challenges 1.5x, depends_on 1.25x, supports/related 1.0x, wiki_links 0.5x. + Results sorted by weight descending so cap cuts low-value edges first. + + Args: + seen: Optional set of paths already matched (e.g. from keyword search) to exclude. + + Returns list of {claim_path, claim_title, edge_type, edge_weight, from_claim}. + Excludes claims already in seed_paths or seen set. + """ + EDGE_WEIGHTS = { + "challenges": 1.5, + "depends_on": 1.25, + "supports": 1.0, + "related": 1.0, + "wiki_links": 0.5, + } + + root = repo_root or config.MAIN_WORKTREE + all_expanded = [] + visited = set(seed_paths) + if seen: + visited.update(seen) + + for seed_path in seed_paths: + full_path = root / seed_path + if not full_path.exists(): + continue + + edges = _parse_frontmatter_edges(full_path) + + for edge_type, targets in edges.items(): + weight = EDGE_WEIGHTS.get(edge_type, 1.0) + + for target_name in targets: + target_path = _resolve_claim_path(target_name, root) + if target_path is None: + continue + + rel_path = str(target_path.relative_to(root)) + if rel_path in visited: + continue + # Skip structural files (MOCs/indexes) — they pull entire domains + if target_path.name in STRUCTURAL_FILES: + continue + visited.add(rel_path) + + # Read title from frontmatter + title = target_name + try: + text = target_path.read_text(errors="replace") + if text.startswith("---"): + end = text.find("\n---", 3) + if end > 0: + import yaml + fm = yaml.safe_load(text[3:end]) + if isinstance(fm, dict): + title = fm.get("name", fm.get("title", target_name)) + except Exception: + pass + + all_expanded.append({ + "claim_path": rel_path, + "claim_title": str(title), + "edge_type": edge_type, + "edge_weight": weight, + "from_claim": seed_path, + }) + + # Sort by weight descending so cap cuts lowest-value edges first + all_expanded.sort(key=lambda x: x["edge_weight"], reverse=True) + return all_expanded[:max_expanded] + + +# --- Combined search (Layer 1 + Layer 2) --- + +# Default thresholds — calibrated with Leo's retrieval audits +PASS1_LIMIT = 5 +PASS1_THRESHOLD = 0.70 +PASS2_LIMIT = 5 +PASS2_THRESHOLD = 0.60 +HARD_CAP = 10 + + +def _dedup_hits(hits: list[dict], seen: set[str]) -> list[dict]: + """Filter Qdrant hits: dedup by claim_path, exclude structural files.""" + results = [] + for hit in hits: + payload = hit.get("payload", {}) + claim_path = payload.get("claim_path", "") + if claim_path in seen: + continue + if claim_path.split("/")[-1] in STRUCTURAL_FILES: + continue + seen.add(claim_path) + results.append({ + "claim_title": payload.get("claim_title", ""), + "claim_path": claim_path, + "score": round(hit.get("score", 0), 4), + "domain": payload.get("domain", ""), + "confidence": payload.get("confidence", ""), + "snippet": payload.get("snippet", "")[:200], + "type": payload.get("type", "claim"), + }) + return results + + +def _sort_results(direct: list[dict], expanded: list[dict]) -> list[dict]: + """Sort combined results: similarity desc → challenged_by → other expansion. + + Sort order is load-bearing: LLMs have primacy bias, so best claims first. + """ + # Direct results already sorted by Qdrant (cosine desc) + sorted_direct = sorted(direct, key=lambda x: x.get("score", 0), reverse=True) + + # Expansion: challenged_by first (counterpoints), then rest by weight + challenged = [e for e in expanded if e.get("edge_type") == "challenges"] + other_expanded = [e for e in expanded if e.get("edge_type") != "challenges"] + challenged.sort(key=lambda x: x.get("edge_weight", 0), reverse=True) + other_expanded.sort(key=lambda x: x.get("edge_weight", 0), reverse=True) + + return sorted_direct + challenged + other_expanded + + +def search(query: str, expand: bool = False, + domain: str | None = None, confidence: str | None = None, + exclude: list[str] | None = None) -> dict: + """Two-pass semantic search: embed query, search Qdrant, optionally expand. + + Pass 1 (expand=False, default): Top 5 claims from Qdrant, score >= 0.70. + Sufficient for ~80% of queries. Fast and focused. + + Pass 2 (expand=True): Next 5 claims (offset=5, score >= 0.60) plus + graph-expanded claims (challenged_by, related edges). Hard cap 10 total. + Agent calls this only when pass 1 didn't answer the question. + + Returns { + "query": str, + "direct_results": [...], # Layer 1 Qdrant hits (sorted by score desc) + "expanded_results": [...], # Layer 2 graph expansion (challenges first) + "total": int, + } + """ + vector = embed_query(query) + if vector is None: + return {"query": query, "direct_results": [], "expanded_results": [], + "total": 0, "error": "embedding_failed"} + + # --- Pass 1: Top 5, high threshold --- + hits = search_qdrant(vector, limit=PASS1_LIMIT, domain=domain, + confidence=confidence, exclude=exclude, + score_threshold=PASS1_THRESHOLD) + + seen_paths: set[str] = set() + if exclude: + seen_paths.update(exclude) + direct = _dedup_hits(hits, seen_paths) + + expanded = [] + if expand: + # --- Pass 2: Next 5 from Qdrant (lower threshold, offset) --- + pass2_hits = search_qdrant(vector, limit=PASS2_LIMIT, domain=domain, + confidence=confidence, exclude=exclude, + score_threshold=PASS2_THRESHOLD, + offset=PASS1_LIMIT) + pass2_direct = _dedup_hits(pass2_hits, seen_paths) + direct.extend(pass2_direct) + + # Graph expansion on all direct results (pass 1 + pass 2 seeds) + seed_paths = [r["claim_path"] for r in direct] + remaining_cap = HARD_CAP - len(direct) + if remaining_cap > 0: + expanded = graph_expand(seed_paths, max_expanded=remaining_cap, + seen=seen_paths) + + # Enforce hard cap across all results + all_sorted = _sort_results(direct, expanded)[:HARD_CAP] + + # Split back into direct vs expanded for backward compat + direct_paths = {r["claim_path"] for r in direct} + final_direct = [r for r in all_sorted if r.get("claim_path") in direct_paths] + final_expanded = [r for r in all_sorted if r.get("claim_path") not in direct_paths] + + return { + "query": query, + "direct_results": final_direct, + "expanded_results": final_expanded, + "total": len(all_sorted), + } diff --git a/telegram/kb_retrieval.py b/telegram/kb_retrieval.py index 218b84c..ac1b73f 100644 --- a/telegram/kb_retrieval.py +++ b/telegram/kb_retrieval.py @@ -621,3 +621,107 @@ def format_context_for_prompt(ctx: KBContext) -> str: f"{ctx.stats.get('claims_matched', 0)} claims.") return "\n".join(sections) + + +# --- Qdrant vector search integration --- + +# Module-level import guard for lib.search (Fix 3: no per-call sys.path manipulation) +_vector_search = None +try: + import sys as _sys + import os as _os + _pipeline_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) + if _pipeline_root not in _sys.path: + _sys.path.insert(0, _pipeline_root) + from lib.search import search as _vector_search +except ImportError: + logger.warning("Qdrant search unavailable at module load (lib.search not found)") + + +def retrieve_vector_context(query: str, + keyword_paths: list[str] | None = None) -> tuple[str, dict]: + """Semantic search via Qdrant — returns (formatted_text, metadata). + + Complements retrieve_context() (symbolic/keyword) with semantic similarity. + Falls back gracefully if Qdrant is unavailable. + + Args: + keyword_paths: Claim paths already matched by keyword search. These are + excluded at the Qdrant query level AND from graph expansion to avoid + duplicates in the prompt. + + Returns: + (formatted_text, metadata_dict) + metadata_dict: {direct_results: [...], expanded_results: [...], + layers_hit: [...], duration_ms: int} + """ + import time as _time + t0 = _time.monotonic() + empty_meta = {"direct_results": [], "expanded_results": [], + "layers_hit": [], "duration_ms": 0} + + if _vector_search is None: + return "", empty_meta + + try: + results = _vector_search(query, expand=True, + exclude=keyword_paths) + except Exception as e: + logger.warning("Qdrant search failed: %s", e) + return "", empty_meta + + duration = int((_time.monotonic() - t0) * 1000) + + if results.get("error") or not results.get("direct_results"): + return "", {**empty_meta, "duration_ms": duration, + "error": results.get("error")} + + layers_hit = ["qdrant"] + if results.get("expanded_results"): + layers_hit.append("graph") + + # Build structured metadata for audit + meta = { + "direct_results": [ + {"path": r["claim_path"], "title": r["claim_title"], + "score": r["score"], "domain": r.get("domain", ""), + "source": "qdrant"} + for r in results["direct_results"] + ], + "expanded_results": [ + {"path": r["claim_path"], "title": r["claim_title"], + "edge_type": r.get("edge_type", "related"), + "from_claim": r.get("from_claim", ""), "source": "graph"} + for r in results.get("expanded_results", []) + ], + "layers_hit": layers_hit, + "duration_ms": duration, + } + + # Build formatted text for prompt (Fix 4: subsection headers) + sections = [] + sections.append("## Semantic Search Results (Qdrant)") + sections.append("") + sections.append("### Direct matches") + + for r in results["direct_results"]: + score_pct = int(r["score"] * 100) + line = f"- **{r['claim_title']}** ({score_pct}% match" + if r.get("domain"): + line += f", {r['domain']}" + if r.get("confidence"): + line += f", {r['confidence']}" + line += ")" + sections.append(line) + if r.get("snippet"): + sections.append(f" {r['snippet']}") + + if results.get("expanded_results"): + sections.append("") + sections.append("### Related claims (graph expansion)") + for r in results["expanded_results"]: + edge = r.get("edge_type", "related") + weight_str = f" ×{r.get('edge_weight', 1.0)}" if r.get("edge_weight", 1.0) != 1.0 else "" + sections.append(f"- {r['claim_title']} ({edge}{weight_str} → {r.get('from_claim', '').split('/')[-1]})") + + return "\n".join(sections), meta diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..772b348 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,604 @@ +"""Tests for lib/search.py — vector search and graph expansion.""" + +import json +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from lib.search import ( + _parse_frontmatter_edges, + _resolve_claim_path, + graph_expand, + search, + search_qdrant, + WIKI_LINK_RE, +) + + +# ─── Fixtures ────────────────────────────────────────────────────────────── + + +@pytest.fixture +def repo(tmp_path): + """Minimal KB repo structure with claim files.""" + domains = tmp_path / "domains" + + # ai-alignment domain + ai = domains / "ai-alignment" + ai.mkdir(parents=True) + + (ai / "capability-scoping.md").write_text( + "---\nname: AI agent capability scoping\ntype: claim\n" + "supports:\n - capability-matched escalation\n" + "related:\n - multi-agent coordination\n" + "---\n\nBody text with [[wiki-linked claim]].\n" + ) + (ai / "capability-matched-escalation.md").write_text( + "---\nname: capability-matched escalation\ntype: claim\n" + "challenges:\n - unconstrained autonomy\n" + "---\n\nEscalation body.\n" + ) + (ai / "multi-agent-coordination.md").write_text( + "---\nname: multi-agent coordination\ntype: claim\n---\n\nCoord body.\n" + ) + (ai / "wiki-linked-claim.md").write_text( + "---\nname: wiki-linked claim\ntype: claim\n---\n\nWiki body.\n" + ) + (ai / "unconstrained-autonomy.md").write_text( + "---\nname: unconstrained autonomy\ntype: claim\n---\n\nAutonomy body.\n" + ) + + # internet-finance domain + fin = domains / "internet-finance" + fin.mkdir(parents=True) + + (fin / "futarchy-governance.md").write_text( + "---\nname: futarchy governance\ntype: claim\n" + "depends_on:\n - prediction market accuracy\n" + "---\n\nFutarchy body.\n" + ) + (fin / "prediction-market-accuracy.md").write_text( + "---\nname: prediction market accuracy\ntype: claim\n---\n\nPM body.\n" + ) + + return tmp_path + + +@pytest.fixture +def claim_no_frontmatter(tmp_path): + """Claim file with no frontmatter.""" + domains = tmp_path / "domains" / "misc" + domains.mkdir(parents=True) + f = domains / "bare.md" + f.write_text("Just a body, no frontmatter.\n") + return f + + +# ─── _parse_frontmatter_edges ───────────────────────────────────────────── + + +class TestParseFrontmatterEdges: + + def test_supports_and_related(self, repo): + path = repo / "domains" / "ai-alignment" / "capability-scoping.md" + edges = _parse_frontmatter_edges(path) + assert edges["supports"] == ["capability-matched escalation"] + assert edges["related"] == ["multi-agent coordination"] + + def test_challenges(self, repo): + path = repo / "domains" / "ai-alignment" / "capability-matched-escalation.md" + edges = _parse_frontmatter_edges(path) + assert edges["challenges"] == ["unconstrained autonomy"] + + def test_depends_on(self, repo): + path = repo / "domains" / "internet-finance" / "futarchy-governance.md" + edges = _parse_frontmatter_edges(path) + assert edges["depends_on"] == ["prediction market accuracy"] + + def test_wiki_links_extracted_separately(self, repo): + path = repo / "domains" / "ai-alignment" / "capability-scoping.md" + edges = _parse_frontmatter_edges(path) + assert "wiki-linked claim" in edges["wiki_links"] + # Wiki links should NOT appear in explicit related + assert "wiki-linked claim" not in edges["related"] + + def test_wiki_links_deduped_from_explicit(self, tmp_path): + """If a wiki link matches an explicit edge, it's excluded from wiki_links.""" + d = tmp_path / "domains" / "test" + d.mkdir(parents=True) + f = d / "overlap.md" + f.write_text( + "---\nname: overlap test\nrelated:\n - shared target\n---\n\n" + "Body with [[shared target]] link.\n" + ) + edges = _parse_frontmatter_edges(f) + assert edges["related"] == ["shared target"] + assert edges["wiki_links"] == [] + + def test_no_frontmatter(self, claim_no_frontmatter): + edges = _parse_frontmatter_edges(claim_no_frontmatter) + assert all(v == [] for v in edges.values()) + + def test_missing_file(self, tmp_path): + edges = _parse_frontmatter_edges(tmp_path / "nonexistent.md") + assert all(v == [] for v in edges.values()) + + def test_inline_list_format(self, tmp_path): + """Handles YAML inline list: depends_on: ["a", "b"].""" + d = tmp_path / "domains" / "test" + d.mkdir(parents=True) + f = d / "inline.md" + f.write_text('---\nname: inline\ndepends_on: ["alpha", "beta"]\n---\n\nBody.\n') + edges = _parse_frontmatter_edges(f) + assert edges["depends_on"] == ["alpha", "beta"] + + +# ─── _resolve_claim_path ────────────────────────────────────────────────── + + +class TestResolveClaimPath: + + def test_slugified_name(self, repo): + result = _resolve_claim_path("capability-scoping", repo) + assert result is not None + assert result.name == "capability-scoping.md" + + def test_name_with_spaces(self, repo): + """Resolves 'multi-agent coordination' via slug matching.""" + result = _resolve_claim_path("multi-agent coordination", repo) + assert result is not None + assert result.name == "multi-agent-coordination.md" + + def test_not_found(self, repo): + result = _resolve_claim_path("nonexistent claim", repo) + assert result is None + + def test_cross_domain_resolution(self, repo): + """Can resolve a claim in a different domain subdirectory.""" + result = _resolve_claim_path("futarchy-governance", repo) + assert result is not None + assert "internet-finance" in str(result) + + +# ─── graph_expand ───────────────────────────────────────────────────────── + + +class TestGraphExpand: + + def test_basic_expansion(self, repo): + """Expanding from capability-scoping should find its edges.""" + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + ) + titles = [r["claim_title"] for r in results] + assert "capability-matched escalation" in titles + assert "multi-agent coordination" in titles + + def test_wiki_links_included(self, repo): + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + ) + titles = [r["claim_title"] for r in results] + assert "wiki-linked claim" in titles + + def test_edge_weights(self, repo): + """challenges edges get 1.5x weight, wiki_links get 0.5x.""" + # Expand from capability-matched-escalation (has challenges edge) + results = graph_expand( + ["domains/ai-alignment/capability-matched-escalation.md"], + repo_root=repo, + ) + challenges_result = [r for r in results if r["edge_type"] == "challenges"] + assert len(challenges_result) == 1 + assert challenges_result[0]["edge_weight"] == 1.5 + + def test_wiki_link_weight(self, repo): + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + ) + wiki_results = [r for r in results if r["edge_type"] == "wiki_links"] + assert all(r["edge_weight"] == 0.5 for r in wiki_results) + + def test_depends_on_weight(self, repo): + results = graph_expand( + ["domains/internet-finance/futarchy-governance.md"], + repo_root=repo, + ) + dep_results = [r for r in results if r["edge_type"] == "depends_on"] + assert len(dep_results) == 1 + assert dep_results[0]["edge_weight"] == 1.25 + + def test_sorted_by_weight_descending(self, repo): + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + ) + weights = [r["edge_weight"] for r in results] + assert weights == sorted(weights, reverse=True) + + def test_seed_excluded_from_results(self, repo): + seed = "domains/ai-alignment/capability-scoping.md" + results = graph_expand([seed], repo_root=repo) + paths = [r["claim_path"] for r in results] + assert seed not in paths + + def test_seen_set_excludes(self, repo): + """Paths in seen set are excluded from expansion results.""" + already_matched = {"domains/ai-alignment/multi-agent-coordination.md"} + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + seen=already_matched, + ) + paths = [r["claim_path"] for r in results] + assert "domains/ai-alignment/multi-agent-coordination.md" not in paths + + def test_max_expanded_cap(self, repo): + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + max_expanded=1, + ) + assert len(results) <= 1 + + def test_cap_cuts_lowest_weight(self, repo): + """With cap=2, wiki_links (0.5x) should be cut before supports (1.0x).""" + results = graph_expand( + ["domains/ai-alignment/capability-scoping.md"], + repo_root=repo, + max_expanded=2, + ) + edge_types = [r["edge_type"] for r in results] + assert "wiki_links" not in edge_types + + def test_nonexistent_seed(self, repo): + results = graph_expand(["domains/nonexistent.md"], repo_root=repo) + assert results == [] + + +# ─── search_qdrant (mocked HTTP) ───────────────────────────────────────── + + +class TestSearchQdrant: + + def _mock_qdrant_response(self, results): + """Build a mock urllib response returning Qdrant results.""" + resp = MagicMock() + resp.read.return_value = json.dumps({"result": results}).encode() + resp.__enter__ = lambda s: s + resp.__exit__ = MagicMock(return_value=False) + return resp + + @patch("lib.search.urllib.request.urlopen") + def test_basic_search(self, mock_urlopen): + mock_urlopen.return_value = self._mock_qdrant_response([ + {"id": 1, "score": 0.85, "payload": { + "claim_title": "test claim", "claim_path": "domains/test.md", + "domain": "ai", "confidence": "high", + }}, + ]) + results = search_qdrant([0.1] * 1536, limit=5) + assert len(results) == 1 + assert results[0]["score"] == 0.85 + + @patch("lib.search.urllib.request.urlopen") + def test_domain_filter(self, mock_urlopen): + mock_urlopen.return_value = self._mock_qdrant_response([]) + search_qdrant([0.1] * 1536, domain="ai-alignment") + # Verify the request body includes domain filter + call_args = mock_urlopen.call_args + req = call_args[0][0] + body = json.loads(req.data) + assert body["filter"]["must"][0]["key"] == "domain" + assert body["filter"]["must"][0]["match"]["value"] == "ai-alignment" + + @patch("lib.search.urllib.request.urlopen") + def test_exclude_filter(self, mock_urlopen): + mock_urlopen.return_value = self._mock_qdrant_response([]) + search_qdrant([0.1] * 1536, exclude=["domains/a.md", "domains/b.md"]) + call_args = mock_urlopen.call_args + req = call_args[0][0] + body = json.loads(req.data) + must_not = body["filter"]["must_not"] + excluded_paths = [f["match"]["value"] for f in must_not] + assert "domains/a.md" in excluded_paths + assert "domains/b.md" in excluded_paths + + @patch("lib.search.urllib.request.urlopen") + def test_no_filters_no_filter_key(self, mock_urlopen): + mock_urlopen.return_value = self._mock_qdrant_response([]) + search_qdrant([0.1] * 1536) + call_args = mock_urlopen.call_args + req = call_args[0][0] + body = json.loads(req.data) + assert "filter" not in body + + @patch("lib.search.urllib.request.urlopen") + def test_http_failure_returns_empty(self, mock_urlopen): + mock_urlopen.side_effect = Exception("connection refused") + results = search_qdrant([0.1] * 1536) + assert results == [] + + +# ─── search() integration (mocked network) ─────────────────────────────── + + +class TestSearch: + + @patch("lib.search.embed_query") + @patch("lib.search.search_qdrant") + def test_embedding_failure(self, mock_qdrant, mock_embed): + mock_embed.return_value = None + result = search("test query") + assert result["error"] == "embedding_failed" + assert result["direct_results"] == [] + mock_qdrant.assert_not_called() + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_full_pipeline(self, mock_embed, mock_qdrant, mock_expand): + mock_embed.return_value = [0.1] * 1536 + # Pass 1 returns one hit, pass 2 returns empty (nothing above lower threshold) + mock_qdrant.side_effect = [ + [{"id": 1, "score": 0.82, "payload": { + "claim_title": "direct hit", "claim_path": "domains/hit.md", + "domain": "ai", "confidence": "high", "snippet": "snippet text", + "type": "claim", + }}], + [], # pass 2 returns nothing + ] + mock_expand.return_value = [ + {"claim_path": "domains/expanded.md", "claim_title": "expanded", + "edge_type": "supports", "edge_weight": 1.0, "from_claim": "domains/hit.md"}, + ] + # expand=True triggers two-pass: pass 1 + pass 2 + graph expansion + result = search("test query", expand=True) + assert len(result["direct_results"]) == 1 + assert result["direct_results"][0]["claim_title"] == "direct hit" + assert len(result["expanded_results"]) == 1 + assert result["total"] == 2 + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_exclude_passed_through(self, mock_embed, mock_qdrant, mock_expand): + """Exclude list reaches both Qdrant and graph_expand.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.side_effect = [ + [{"id": 1, "score": 0.8, "payload": { + "claim_title": "hit", "claim_path": "domains/hit.md", + }}], + [], # pass 2 + ] + mock_expand.return_value = [] + exclude = ["domains/already-matched.md"] + search("query", expand=True, exclude=exclude) + + # Qdrant should get exclude (called twice with expand=True: pass 1 + pass 2) + assert mock_qdrant.call_count == 2 + # Both calls should include exclude + for call in mock_qdrant.call_args_list: + assert call.kwargs.get("exclude") == exclude \ + or call[1].get("exclude") == exclude + + # graph_expand should get seen set containing exclude paths + mock_expand.assert_called_once() + call_kwargs = mock_expand.call_args[1] if mock_expand.call_args[1] else {} + seen = call_kwargs.get("seen") + assert seen is not None + assert "domains/already-matched.md" in seen + + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_no_expand_when_disabled(self, mock_embed, mock_qdrant): + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.return_value = [] + result = search("query", expand=False) + assert result["expanded_results"] == [] + + +# ─── WIKI_LINK_RE ───────────────────────────────────────────────────────── + + +class TestWikiLinkRegex: + + def test_basic(self): + assert WIKI_LINK_RE.findall("See [[some claim]]") == ["some claim"] + + def test_multiple(self): + text = "Links to [[claim A]] and [[claim B]]" + assert WIKI_LINK_RE.findall(text) == ["claim A", "claim B"] + + def test_no_nested(self): + assert WIKI_LINK_RE.findall("[[outer [[inner]]]]") != ["outer [[inner]]"] + + +# ─── Structural file exclusion ─────────────────────────────────────────── + + +class TestStructuralFileExclusion: + + def test_map_excluded_from_expansion(self, tmp_path): + """_map.md files should be skipped during graph expansion.""" + domains = tmp_path / "domains" / "test" + domains.mkdir(parents=True) + (domains / "seed-claim.md").write_text( + "---\nname: seed claim\ntype: claim\n" + "related:\n - domain map\n---\nBody.\n" + ) + (domains / "_map.md").write_text( + "---\nname: domain map\ntype: moc\n---\nDomain index.\n" + ) + seed = "domains/test/seed-claim.md" + result = graph_expand([seed], repo_root=tmp_path) + paths = [r["claim_path"] for r in result] + assert "domains/test/_map.md" not in paths + + def test_overview_excluded_from_expansion(self, tmp_path): + """_overview.md files should be skipped during graph expansion.""" + domains = tmp_path / "domains" / "test" + domains.mkdir(parents=True) + (domains / "seed-claim.md").write_text( + "---\nname: seed claim\ntype: claim\n" + "related:\n - domain overview\n---\nBody.\n" + ) + (domains / "_overview.md").write_text( + "---\nname: domain overview\ntype: moc\n---\nOverview.\n" + ) + seed = "domains/test/seed-claim.md" + result = graph_expand([seed], repo_root=tmp_path) + paths = [r["claim_path"] for r in result] + assert "domains/test/_overview.md" not in paths + + def test_regular_claims_still_expand(self, repo): + """Non-structural files should still expand normally.""" + seed = "domains/ai-alignment/capability-scoping.md" + result = graph_expand([seed], repo_root=repo) + paths = [r["claim_path"] for r in result] + assert len(paths) > 0 + assert "domains/ai-alignment/capability-matched-escalation.md" in paths + + +# ─── Dedup within vector results ───────────────────────────────────────── + + +class TestSearchDedup: + + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_duplicate_paths_deduped(self, mock_embed, mock_qdrant): + """Duplicate claim_paths in Qdrant results should be collapsed.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.return_value = [ + {"score": 0.9, "payload": {"claim_title": "Claim A", "claim_path": "domains/x/a.md", + "domain": "x", "confidence": "high", "snippet": "..."}}, + {"score": 0.85, "payload": {"claim_title": "Claim A dupe", "claim_path": "domains/x/a.md", + "domain": "x", "confidence": "high", "snippet": "..."}}, + {"score": 0.8, "payload": {"claim_title": "Claim B", "claim_path": "domains/x/b.md", + "domain": "x", "confidence": "high", "snippet": "..."}}, + ] + result = search("test query", expand=False) + paths = [r["claim_path"] for r in result["direct_results"]] + assert paths == ["domains/x/a.md", "domains/x/b.md"] + assert result["direct_results"][0]["claim_title"] == "Claim A" + + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_structural_files_excluded_from_direct(self, mock_embed, mock_qdrant): + """_map.md should be excluded from direct Qdrant results too.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.return_value = [ + {"score": 0.9, "payload": {"claim_title": "Domain Map", "claim_path": "domains/x/_map.md", + "domain": "x", "confidence": "", "snippet": ""}}, + {"score": 0.85, "payload": {"claim_title": "Real Claim", "claim_path": "domains/x/real.md", + "domain": "x", "confidence": "high", "snippet": "..."}}, + ] + result = search("test query", expand=False) + paths = [r["claim_path"] for r in result["direct_results"]] + assert "domains/x/_map.md" not in paths + assert "domains/x/real.md" in paths + + +# ─── Two-pass retrieval ────────────────────────────────────────────────── + + +class TestTwoPassRetrieval: + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_pass1_only_default(self, mock_embed, mock_qdrant, mock_expand): + """Default search (expand=False) only calls Qdrant once with high threshold.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.return_value = [ + {"score": 0.85, "payload": {"claim_title": "Hit", "claim_path": "d/a.md"}}, + ] + result = search("query") + mock_qdrant.assert_called_once() + # Should use PASS1_THRESHOLD (0.70) + call_kwargs = mock_qdrant.call_args + assert call_kwargs.kwargs.get("score_threshold") == 0.70 \ + or call_kwargs[1].get("score_threshold") == 0.70 + mock_expand.assert_not_called() + assert len(result["direct_results"]) == 1 + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_pass2_expands(self, mock_embed, mock_qdrant, mock_expand): + """expand=True calls Qdrant twice (pass 1 + pass 2) and runs graph expansion.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.side_effect = [ + [{"score": 0.85, "payload": {"claim_title": "P1", "claim_path": "d/a.md"}}], + [{"score": 0.65, "payload": {"claim_title": "P2", "claim_path": "d/b.md"}}], + ] + mock_expand.return_value = [] + result = search("query", expand=True) + assert mock_qdrant.call_count == 2 + # Pass 2 should use offset=5 and lower threshold + pass2_call = mock_qdrant.call_args_list[1] + assert pass2_call.kwargs.get("offset") == 5 \ + or pass2_call[1].get("offset") == 5 + assert len(result["direct_results"]) == 2 + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_hard_cap_enforced(self, mock_embed, mock_qdrant, mock_expand): + """Total results never exceed HARD_CAP (10).""" + mock_embed.return_value = [0.1] * 1536 + # 5 from pass 1, 5 from pass 2 + p1 = [{"score": 0.9 - i * 0.02, "payload": { + "claim_title": f"P1-{i}", "claim_path": f"d/p1-{i}.md" + }} for i in range(5)] + p2 = [{"score": 0.65 - i * 0.01, "payload": { + "claim_title": f"P2-{i}", "claim_path": f"d/p2-{i}.md" + }} for i in range(5)] + mock_qdrant.side_effect = [p1, p2] + # Graph expansion returns 5 more + mock_expand.return_value = [ + {"claim_path": f"d/exp-{i}.md", "claim_title": f"Exp-{i}", + "edge_type": "related", "edge_weight": 1.0, "from_claim": "d/p1-0.md"} + for i in range(5) + ] + result = search("query", expand=True) + assert result["total"] <= 10 + + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_sort_order_similarity_first(self, mock_embed, mock_qdrant): + """Direct results are sorted by cosine similarity descending.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.return_value = [ + {"score": 0.75, "payload": {"claim_title": "Low", "claim_path": "d/low.md"}}, + {"score": 0.95, "payload": {"claim_title": "High", "claim_path": "d/high.md"}}, + {"score": 0.85, "payload": {"claim_title": "Mid", "claim_path": "d/mid.md"}}, + ] + result = search("query") + titles = [r["claim_title"] for r in result["direct_results"]] + assert titles == ["High", "Mid", "Low"] + + @patch("lib.search.graph_expand") + @patch("lib.search.search_qdrant") + @patch("lib.search.embed_query") + def test_challenges_before_other_expansion(self, mock_embed, mock_qdrant, mock_expand): + """challenged_by claims appear before other expanded claims.""" + mock_embed.return_value = [0.1] * 1536 + mock_qdrant.side_effect = [ + [{"score": 0.85, "payload": {"claim_title": "Seed", "claim_path": "d/seed.md"}}], + [], + ] + mock_expand.return_value = [ + {"claim_path": "d/related.md", "claim_title": "Related", + "edge_type": "related", "edge_weight": 1.0, "from_claim": "d/seed.md"}, + {"claim_path": "d/challenge.md", "claim_title": "Challenge", + "edge_type": "challenges", "edge_weight": 1.5, "from_claim": "d/seed.md"}, + ] + result = search("query", expand=True) + expanded_titles = [r["claim_title"] for r in result["expanded_results"]] + assert expanded_titles.index("Challenge") < expanded_titles.index("Related")