feat: two-pass retrieval with sort order and graph expansion #5

Merged
m3taversal merged 1 commit from epimetheus/two-pass-retrieval into main 2026-03-30 11:32:33 +00:00
4 changed files with 1142 additions and 110 deletions

View file

@ -13,11 +13,16 @@ import logging
import os import os
import sqlite3 import sqlite3
import statistics import statistics
import sys
import urllib.request import urllib.request
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
# Add pipeline lib to path so we can import shared modules
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline"))
from aiohttp import web from aiohttp import web
from lib.search import search as kb_search, embed_query, search_qdrant
logger = logging.getLogger("argus") logger = logging.getLogger("argus")
@ -27,12 +32,7 @@ PORT = int(os.environ.get("ARGUS_PORT", "8081"))
REPO_DIR = Path(os.environ.get("REPO_DIR", "/opt/teleo-eval/workspaces/main")) REPO_DIR = Path(os.environ.get("REPO_DIR", "/opt/teleo-eval/workspaces/main"))
CLAIM_INDEX_URL = os.environ.get("CLAIM_INDEX_URL", "http://localhost:8080/claim-index") CLAIM_INDEX_URL = os.environ.get("CLAIM_INDEX_URL", "http://localhost:8080/claim-index")
# Search config # Search config — moved to lib/search.py (shared with Telegram bot + agents)
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims")
OPENROUTER_KEY_FILE = Path(os.environ.get("OPENROUTER_KEY_FILE", "/opt/teleo-eval/secrets/openrouter-key"))
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_DIMS = 1536
# Auth config # Auth config
API_KEY_FILE = Path(os.environ.get("ARGUS_API_KEY_FILE", "/opt/teleo-eval/secrets/argus-api-key")) API_KEY_FILE = Path(os.environ.get("ARGUS_API_KEY_FILE", "/opt/teleo-eval/secrets/argus-api-key"))
@ -483,82 +483,7 @@ async def auth_middleware(request, handler):
# ─── Embedding + Search ────────────────────────────────────────────────────── # ─── Embedding + Search ──────────────────────────────────────────────────────
# Moved to lib/search.py — imported at top of file as kb_search, embed_query, search_qdrant
def _get_embedding_key() -> str | None:
"""Load OpenRouter API key for embeddings."""
return _load_secret(OPENROUTER_KEY_FILE)
def _embed_query(text: str, api_key: str) -> list[float] | None:
"""Embed a query string via OpenRouter (OpenAI-compatible endpoint).
Uses urllib to avoid adding httpx/openai as dependencies.
"""
payload = json.dumps({
"model": f"openai/{EMBEDDING_MODEL}",
"input": text,
}).encode()
req = urllib.request.Request(
"https://openrouter.ai/api/v1/embeddings",
data=payload,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data["data"][0]["embedding"]
except Exception as e:
logger.error("Embedding failed: %s", e)
return None
def _search_qdrant(vector: list[float], limit: int = 10,
domain: str | None = None, confidence: str | None = None,
exclude: list[str] | None = None) -> list[dict]:
"""Search Qdrant collection for nearest claims.
Uses urllib for zero-dependency Qdrant access (REST API).
"""
must_filters = []
if domain:
must_filters.append({"key": "domain", "match": {"value": domain}})
if confidence:
must_filters.append({"key": "confidence", "match": {"value": confidence}})
must_not_filters = []
if exclude:
for path in exclude:
must_not_filters.append({"key": "claim_path", "match": {"value": path}})
payload = {
"vector": vector,
"limit": limit,
"with_payload": True,
"score_threshold": 0.3,
}
if must_filters or must_not_filters:
payload["filter"] = {}
if must_filters:
payload["filter"]["must"] = must_filters
if must_not_filters:
payload["filter"]["must_not"] = must_not_filters
req = urllib.request.Request(
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("result", [])
except Exception as e:
logger.error("Qdrant search failed: %s", e)
return []
# ─── Usage logging ─────────────────────────────────────────────────────────── # ─── Usage logging ───────────────────────────────────────────────────────────
@ -701,7 +626,7 @@ async def handle_api_domains(request):
async def handle_api_search(request): async def handle_api_search(request):
"""GET /api/search — semantic search over claims via Qdrant. """GET /api/search — semantic search over claims via Qdrant + graph expansion.
Query params: Query params:
q: search query (required) q: search query (required)
@ -709,6 +634,7 @@ async def handle_api_search(request):
confidence: filter by confidence level (optional) confidence: filter by confidence level (optional)
limit: max results, default 10 (optional) limit: max results, default 10 (optional)
exclude: comma-separated claim paths to exclude (optional) exclude: comma-separated claim paths to exclude (optional)
expand: enable graph expansion, default true (optional)
""" """
query = request.query.get("q", "").strip() query = request.query.get("q", "").strip()
if not query: if not query:
@ -719,36 +645,19 @@ async def handle_api_search(request):
limit = min(int(request.query.get("limit", "10")), 50) limit = min(int(request.query.get("limit", "10")), 50)
exclude_raw = request.query.get("exclude", "") exclude_raw = request.query.get("exclude", "")
exclude = [p.strip() for p in exclude_raw.split(",") if p.strip()] if exclude_raw else None exclude = [p.strip() for p in exclude_raw.split(",") if p.strip()] if exclude_raw else None
expand = request.query.get("expand", "true").lower() != "false"
# Embed the query # Use shared search library (Layer 1 + Layer 2)
api_key = _get_embedding_key() result = kb_search(query, expand=expand,
if not api_key: domain=domain, confidence=confidence, exclude=exclude)
return web.json_response({"error": "embedding service unavailable"}, status=503)
vector = _embed_query(query, api_key) if "error" in result:
if vector is None: error = result["error"]
return web.json_response({"error": "embedding failed"}, status=502) if error == "embedding_failed":
return web.json_response({"error": "embedding failed"}, status=502)
return web.json_response({"error": error}, status=500)
# Search Qdrant return web.json_response(result)
results = _search_qdrant(vector, limit=limit, domain=domain,
confidence=confidence, exclude=exclude)
# Format response
claims = []
for hit in results:
payload = hit.get("payload", {})
claims.append({
"claim_title": payload.get("claim_title", ""),
"claim_path": payload.get("claim_path", ""),
"similarity_score": round(hit.get("score", 0), 4),
"domain": payload.get("domain", ""),
"confidence": payload.get("confidence", ""),
"snippet": payload.get("snippet", "")[:200],
"depends_on": payload.get("depends_on", []),
"challenged_by": payload.get("challenged_by", []),
})
return web.json_response(claims)
async def handle_api_usage(request): async def handle_api_usage(request):

415
lib/search.py Normal file
View file

@ -0,0 +1,415 @@
"""Shared Qdrant vector search library for the Teleo knowledge base.
Provides embed + search + graph expansion as a reusable library.
Any consumer (Argus dashboard, Telegram bot, agent research) imports from here.
Layer 1: Qdrant vector search (semantic similarity)
Layer 2: Graph expansion (1-hop via frontmatter edges)
Layer 3: Left to the caller (agent context, domain filtering)
Owner: Epimetheus
"""
import json
import logging
import os
import re
from pathlib import Path
import urllib.request
from . import config
logger = logging.getLogger("pipeline.search")
# --- Config (all from environment or config.py defaults) ---
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims")
EMBEDDING_MODEL = "text-embedding-3-small"
_OPENROUTER_KEY: str | None = None
WIKI_LINK_RE = re.compile(r"\[\[([^\]]+)\]\]")
# Structural files that should never be included in graph expansion results.
# These are indexes/MOCs, not claims — expanding them pulls entire domains.
STRUCTURAL_FILES = {"_map.md", "_overview.md"}
def _get_api_key() -> str | None:
"""Load OpenRouter API key (cached after first read)."""
global _OPENROUTER_KEY
if _OPENROUTER_KEY:
return _OPENROUTER_KEY
key_file = config.SECRETS_DIR / "openrouter-key"
if key_file.exists():
_OPENROUTER_KEY = key_file.read_text().strip()
return _OPENROUTER_KEY
_OPENROUTER_KEY = os.environ.get("OPENROUTER_API_KEY")
return _OPENROUTER_KEY
# --- Layer 1: Vector search ---
def embed_query(text: str) -> list[float] | None:
"""Embed a query string via OpenRouter (OpenAI-compatible endpoint).
Returns 1536-dim vector or None on failure.
"""
api_key = _get_api_key()
if not api_key:
logger.error("No OpenRouter API key available for embedding")
return None
payload = json.dumps({
"model": f"openai/{EMBEDDING_MODEL}",
"input": text[:8000],
}).encode()
req = urllib.request.Request(
"https://openrouter.ai/api/v1/embeddings",
data=payload,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
data = json.loads(resp.read())
return data["data"][0]["embedding"]
except Exception as e:
logger.error("Embedding failed: %s", e)
return None
def search_qdrant(vector: list[float], limit: int = 10,
domain: str | None = None, confidence: str | None = None,
exclude: list[str] | None = None,
score_threshold: float = 0.3,
offset: int = 0) -> list[dict]:
"""Search Qdrant collection for nearest claims.
Args:
offset: Skip first N results (Qdrant native offset for pagination).
Returns list of hits: [{id, score, payload: {claim_path, claim_title, ...}}]
"""
must_filters = []
if domain:
must_filters.append({"key": "domain", "match": {"value": domain}})
if confidence:
must_filters.append({"key": "confidence", "match": {"value": confidence}})
must_not_filters = []
if exclude:
for path in exclude:
must_not_filters.append({"key": "claim_path", "match": {"value": path}})
body = {
"vector": vector,
"limit": limit,
"with_payload": True,
"score_threshold": score_threshold,
}
if offset > 0:
body["offset"] = offset
if must_filters or must_not_filters:
body["filter"] = {}
if must_filters:
body["filter"]["must"] = must_filters
if must_not_filters:
body["filter"]["must_not"] = must_not_filters
req = urllib.request.Request(
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
data=json.dumps(body).encode(),
headers={"Content-Type": "application/json"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("result", [])
except Exception as e:
logger.error("Qdrant search failed: %s", e)
return []
# --- Layer 2: Graph expansion ---
def _parse_frontmatter_edges(path: Path) -> dict:
"""Extract relationship edges from a claim's frontmatter.
Handles both YAML formats:
depends_on: ["item1", "item2"] (inline list)
depends_on: (multi-line list)
- item1
- item2
Returns {supports: [...], challenges: [...], depends_on: [...], related: [...], wiki_links: [...]}.
wiki_links are separated from explicit related edges for differential weighting.
"""
edges = {"supports": [], "challenges": [], "depends_on": [], "related": [], "wiki_links": []}
try:
text = path.read_text(errors="replace")
except Exception:
return edges
if not text.startswith("---"):
return edges
end = text.find("\n---", 3)
if end == -1:
return edges
fm_text = text[3:end]
# Use YAML parser for reliable edge extraction
try:
import yaml
fm = yaml.safe_load(fm_text)
if isinstance(fm, dict):
for field in ("supports", "challenges", "depends_on", "related"):
val = fm.get(field)
if isinstance(val, list):
edges[field] = [str(v).strip() for v in val if v]
elif isinstance(val, str) and val.strip():
edges[field] = [val.strip()]
except Exception:
pass
# Extract wiki links from body as separate edge type (lower weight)
body = text[end + 4:]
all_explicit = set()
for field in ("supports", "challenges", "depends_on", "related"):
all_explicit.update(edges[field])
wiki_links = WIKI_LINK_RE.findall(body)
for link in wiki_links:
link = link.strip()
if link and link not in all_explicit and link not in edges["wiki_links"]:
edges["wiki_links"].append(link)
return edges
def _resolve_claim_path(name: str, repo_root: Path) -> Path | None:
"""Resolve a claim name (from frontmatter edge or wiki link) to a file path.
Handles both naming conventions:
- "GLP-1 receptor agonists are..." "GLP-1 receptor agonists are....md" (spaces)
- "glp-1-persistence-drops..." "glp-1-persistence-drops....md" (slugified)
Checks domains/, core/, foundations/, decisions/ subdirectories.
"""
# Try exact name first (spaces in filename), then slugified
candidates = [name]
slug = name.lower().replace(" ", "-").replace("_", "-")
if slug != name:
candidates.append(slug)
for subdir in ["domains", "core", "foundations", "decisions"]:
base = repo_root / subdir
if not base.is_dir():
continue
for candidate_name in candidates:
for md in base.rglob(f"{candidate_name}.md"):
return md
return None
def graph_expand(seed_paths: list[str], repo_root: Path | None = None,
max_expanded: int = 30,
challenge_weight: float = 1.5,
seen: set[str] | None = None) -> list[dict]:
"""Layer 2: Expand seed claims 1-hop through knowledge graph edges.
Traverses supports/challenges/depends_on/related/wiki_links edges in frontmatter.
Edge weights: challenges 1.5x, depends_on 1.25x, supports/related 1.0x, wiki_links 0.5x.
Results sorted by weight descending so cap cuts low-value edges first.
Args:
seen: Optional set of paths already matched (e.g. from keyword search) to exclude.
Returns list of {claim_path, claim_title, edge_type, edge_weight, from_claim}.
Excludes claims already in seed_paths or seen set.
"""
EDGE_WEIGHTS = {
"challenges": 1.5,
"depends_on": 1.25,
"supports": 1.0,
"related": 1.0,
"wiki_links": 0.5,
}
root = repo_root or config.MAIN_WORKTREE
all_expanded = []
visited = set(seed_paths)
if seen:
visited.update(seen)
for seed_path in seed_paths:
full_path = root / seed_path
if not full_path.exists():
continue
edges = _parse_frontmatter_edges(full_path)
for edge_type, targets in edges.items():
weight = EDGE_WEIGHTS.get(edge_type, 1.0)
for target_name in targets:
target_path = _resolve_claim_path(target_name, root)
if target_path is None:
continue
rel_path = str(target_path.relative_to(root))
if rel_path in visited:
continue
# Skip structural files (MOCs/indexes) — they pull entire domains
if target_path.name in STRUCTURAL_FILES:
continue
visited.add(rel_path)
# Read title from frontmatter
title = target_name
try:
text = target_path.read_text(errors="replace")
if text.startswith("---"):
end = text.find("\n---", 3)
if end > 0:
import yaml
fm = yaml.safe_load(text[3:end])
if isinstance(fm, dict):
title = fm.get("name", fm.get("title", target_name))
except Exception:
pass
all_expanded.append({
"claim_path": rel_path,
"claim_title": str(title),
"edge_type": edge_type,
"edge_weight": weight,
"from_claim": seed_path,
})
# Sort by weight descending so cap cuts lowest-value edges first
all_expanded.sort(key=lambda x: x["edge_weight"], reverse=True)
return all_expanded[:max_expanded]
# --- Combined search (Layer 1 + Layer 2) ---
# Default thresholds — calibrated with Leo's retrieval audits
PASS1_LIMIT = 5
PASS1_THRESHOLD = 0.70
PASS2_LIMIT = 5
PASS2_THRESHOLD = 0.60
HARD_CAP = 10
def _dedup_hits(hits: list[dict], seen: set[str]) -> list[dict]:
"""Filter Qdrant hits: dedup by claim_path, exclude structural files."""
results = []
for hit in hits:
payload = hit.get("payload", {})
claim_path = payload.get("claim_path", "")
if claim_path in seen:
continue
if claim_path.split("/")[-1] in STRUCTURAL_FILES:
continue
seen.add(claim_path)
results.append({
"claim_title": payload.get("claim_title", ""),
"claim_path": claim_path,
"score": round(hit.get("score", 0), 4),
"domain": payload.get("domain", ""),
"confidence": payload.get("confidence", ""),
"snippet": payload.get("snippet", "")[:200],
"type": payload.get("type", "claim"),
})
return results
def _sort_results(direct: list[dict], expanded: list[dict]) -> list[dict]:
"""Sort combined results: similarity desc → challenged_by → other expansion.
Sort order is load-bearing: LLMs have primacy bias, so best claims first.
"""
# Direct results already sorted by Qdrant (cosine desc)
sorted_direct = sorted(direct, key=lambda x: x.get("score", 0), reverse=True)
# Expansion: challenged_by first (counterpoints), then rest by weight
challenged = [e for e in expanded if e.get("edge_type") == "challenges"]
other_expanded = [e for e in expanded if e.get("edge_type") != "challenges"]
challenged.sort(key=lambda x: x.get("edge_weight", 0), reverse=True)
other_expanded.sort(key=lambda x: x.get("edge_weight", 0), reverse=True)
return sorted_direct + challenged + other_expanded
def search(query: str, expand: bool = False,
domain: str | None = None, confidence: str | None = None,
exclude: list[str] | None = None) -> dict:
"""Two-pass semantic search: embed query, search Qdrant, optionally expand.
Pass 1 (expand=False, default): Top 5 claims from Qdrant, score >= 0.70.
Sufficient for ~80% of queries. Fast and focused.
Pass 2 (expand=True): Next 5 claims (offset=5, score >= 0.60) plus
graph-expanded claims (challenged_by, related edges). Hard cap 10 total.
Agent calls this only when pass 1 didn't answer the question.
Returns {
"query": str,
"direct_results": [...], # Layer 1 Qdrant hits (sorted by score desc)
"expanded_results": [...], # Layer 2 graph expansion (challenges first)
"total": int,
}
"""
vector = embed_query(query)
if vector is None:
return {"query": query, "direct_results": [], "expanded_results": [],
"total": 0, "error": "embedding_failed"}
# --- Pass 1: Top 5, high threshold ---
hits = search_qdrant(vector, limit=PASS1_LIMIT, domain=domain,
confidence=confidence, exclude=exclude,
score_threshold=PASS1_THRESHOLD)
seen_paths: set[str] = set()
if exclude:
seen_paths.update(exclude)
direct = _dedup_hits(hits, seen_paths)
expanded = []
if expand:
# --- Pass 2: Next 5 from Qdrant (lower threshold, offset) ---
pass2_hits = search_qdrant(vector, limit=PASS2_LIMIT, domain=domain,
confidence=confidence, exclude=exclude,
score_threshold=PASS2_THRESHOLD,
offset=PASS1_LIMIT)
pass2_direct = _dedup_hits(pass2_hits, seen_paths)
direct.extend(pass2_direct)
# Graph expansion on all direct results (pass 1 + pass 2 seeds)
seed_paths = [r["claim_path"] for r in direct]
remaining_cap = HARD_CAP - len(direct)
if remaining_cap > 0:
expanded = graph_expand(seed_paths, max_expanded=remaining_cap,
seen=seen_paths)
# Enforce hard cap across all results
all_sorted = _sort_results(direct, expanded)[:HARD_CAP]
# Split back into direct vs expanded for backward compat
direct_paths = {r["claim_path"] for r in direct}
final_direct = [r for r in all_sorted if r.get("claim_path") in direct_paths]
final_expanded = [r for r in all_sorted if r.get("claim_path") not in direct_paths]
return {
"query": query,
"direct_results": final_direct,
"expanded_results": final_expanded,
"total": len(all_sorted),
}

View file

@ -621,3 +621,107 @@ def format_context_for_prompt(ctx: KBContext) -> str:
f"{ctx.stats.get('claims_matched', 0)} claims.") f"{ctx.stats.get('claims_matched', 0)} claims.")
return "\n".join(sections) return "\n".join(sections)
# --- Qdrant vector search integration ---
# Module-level import guard for lib.search (Fix 3: no per-call sys.path manipulation)
_vector_search = None
try:
import sys as _sys
import os as _os
_pipeline_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))
if _pipeline_root not in _sys.path:
_sys.path.insert(0, _pipeline_root)
from lib.search import search as _vector_search
except ImportError:
logger.warning("Qdrant search unavailable at module load (lib.search not found)")
def retrieve_vector_context(query: str,
keyword_paths: list[str] | None = None) -> tuple[str, dict]:
"""Semantic search via Qdrant — returns (formatted_text, metadata).
Complements retrieve_context() (symbolic/keyword) with semantic similarity.
Falls back gracefully if Qdrant is unavailable.
Args:
keyword_paths: Claim paths already matched by keyword search. These are
excluded at the Qdrant query level AND from graph expansion to avoid
duplicates in the prompt.
Returns:
(formatted_text, metadata_dict)
metadata_dict: {direct_results: [...], expanded_results: [...],
layers_hit: [...], duration_ms: int}
"""
import time as _time
t0 = _time.monotonic()
empty_meta = {"direct_results": [], "expanded_results": [],
"layers_hit": [], "duration_ms": 0}
if _vector_search is None:
return "", empty_meta
try:
results = _vector_search(query, expand=True,
exclude=keyword_paths)
except Exception as e:
logger.warning("Qdrant search failed: %s", e)
return "", empty_meta
duration = int((_time.monotonic() - t0) * 1000)
if results.get("error") or not results.get("direct_results"):
return "", {**empty_meta, "duration_ms": duration,
"error": results.get("error")}
layers_hit = ["qdrant"]
if results.get("expanded_results"):
layers_hit.append("graph")
# Build structured metadata for audit
meta = {
"direct_results": [
{"path": r["claim_path"], "title": r["claim_title"],
"score": r["score"], "domain": r.get("domain", ""),
"source": "qdrant"}
for r in results["direct_results"]
],
"expanded_results": [
{"path": r["claim_path"], "title": r["claim_title"],
"edge_type": r.get("edge_type", "related"),
"from_claim": r.get("from_claim", ""), "source": "graph"}
for r in results.get("expanded_results", [])
],
"layers_hit": layers_hit,
"duration_ms": duration,
}
# Build formatted text for prompt (Fix 4: subsection headers)
sections = []
sections.append("## Semantic Search Results (Qdrant)")
sections.append("")
sections.append("### Direct matches")
for r in results["direct_results"]:
score_pct = int(r["score"] * 100)
line = f"- **{r['claim_title']}** ({score_pct}% match"
if r.get("domain"):
line += f", {r['domain']}"
if r.get("confidence"):
line += f", {r['confidence']}"
line += ")"
sections.append(line)
if r.get("snippet"):
sections.append(f" {r['snippet']}")
if results.get("expanded_results"):
sections.append("")
sections.append("### Related claims (graph expansion)")
for r in results["expanded_results"]:
edge = r.get("edge_type", "related")
weight_str = f" ×{r.get('edge_weight', 1.0)}" if r.get("edge_weight", 1.0) != 1.0 else ""
sections.append(f"- {r['claim_title']} ({edge}{weight_str}{r.get('from_claim', '').split('/')[-1]})")
return "\n".join(sections), meta

604
tests/test_search.py Normal file
View file

@ -0,0 +1,604 @@
"""Tests for lib/search.py — vector search and graph expansion."""
import json
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from lib.search import (
_parse_frontmatter_edges,
_resolve_claim_path,
graph_expand,
search,
search_qdrant,
WIKI_LINK_RE,
)
# ─── Fixtures ──────────────────────────────────────────────────────────────
@pytest.fixture
def repo(tmp_path):
"""Minimal KB repo structure with claim files."""
domains = tmp_path / "domains"
# ai-alignment domain
ai = domains / "ai-alignment"
ai.mkdir(parents=True)
(ai / "capability-scoping.md").write_text(
"---\nname: AI agent capability scoping\ntype: claim\n"
"supports:\n - capability-matched escalation\n"
"related:\n - multi-agent coordination\n"
"---\n\nBody text with [[wiki-linked claim]].\n"
)
(ai / "capability-matched-escalation.md").write_text(
"---\nname: capability-matched escalation\ntype: claim\n"
"challenges:\n - unconstrained autonomy\n"
"---\n\nEscalation body.\n"
)
(ai / "multi-agent-coordination.md").write_text(
"---\nname: multi-agent coordination\ntype: claim\n---\n\nCoord body.\n"
)
(ai / "wiki-linked-claim.md").write_text(
"---\nname: wiki-linked claim\ntype: claim\n---\n\nWiki body.\n"
)
(ai / "unconstrained-autonomy.md").write_text(
"---\nname: unconstrained autonomy\ntype: claim\n---\n\nAutonomy body.\n"
)
# internet-finance domain
fin = domains / "internet-finance"
fin.mkdir(parents=True)
(fin / "futarchy-governance.md").write_text(
"---\nname: futarchy governance\ntype: claim\n"
"depends_on:\n - prediction market accuracy\n"
"---\n\nFutarchy body.\n"
)
(fin / "prediction-market-accuracy.md").write_text(
"---\nname: prediction market accuracy\ntype: claim\n---\n\nPM body.\n"
)
return tmp_path
@pytest.fixture
def claim_no_frontmatter(tmp_path):
"""Claim file with no frontmatter."""
domains = tmp_path / "domains" / "misc"
domains.mkdir(parents=True)
f = domains / "bare.md"
f.write_text("Just a body, no frontmatter.\n")
return f
# ─── _parse_frontmatter_edges ─────────────────────────────────────────────
class TestParseFrontmatterEdges:
def test_supports_and_related(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
edges = _parse_frontmatter_edges(path)
assert edges["supports"] == ["capability-matched escalation"]
assert edges["related"] == ["multi-agent coordination"]
def test_challenges(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-matched-escalation.md"
edges = _parse_frontmatter_edges(path)
assert edges["challenges"] == ["unconstrained autonomy"]
def test_depends_on(self, repo):
path = repo / "domains" / "internet-finance" / "futarchy-governance.md"
edges = _parse_frontmatter_edges(path)
assert edges["depends_on"] == ["prediction market accuracy"]
def test_wiki_links_extracted_separately(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
edges = _parse_frontmatter_edges(path)
assert "wiki-linked claim" in edges["wiki_links"]
# Wiki links should NOT appear in explicit related
assert "wiki-linked claim" not in edges["related"]
def test_wiki_links_deduped_from_explicit(self, tmp_path):
"""If a wiki link matches an explicit edge, it's excluded from wiki_links."""
d = tmp_path / "domains" / "test"
d.mkdir(parents=True)
f = d / "overlap.md"
f.write_text(
"---\nname: overlap test\nrelated:\n - shared target\n---\n\n"
"Body with [[shared target]] link.\n"
)
edges = _parse_frontmatter_edges(f)
assert edges["related"] == ["shared target"]
assert edges["wiki_links"] == []
def test_no_frontmatter(self, claim_no_frontmatter):
edges = _parse_frontmatter_edges(claim_no_frontmatter)
assert all(v == [] for v in edges.values())
def test_missing_file(self, tmp_path):
edges = _parse_frontmatter_edges(tmp_path / "nonexistent.md")
assert all(v == [] for v in edges.values())
def test_inline_list_format(self, tmp_path):
"""Handles YAML inline list: depends_on: ["a", "b"]."""
d = tmp_path / "domains" / "test"
d.mkdir(parents=True)
f = d / "inline.md"
f.write_text('---\nname: inline\ndepends_on: ["alpha", "beta"]\n---\n\nBody.\n')
edges = _parse_frontmatter_edges(f)
assert edges["depends_on"] == ["alpha", "beta"]
# ─── _resolve_claim_path ──────────────────────────────────────────────────
class TestResolveClaimPath:
def test_slugified_name(self, repo):
result = _resolve_claim_path("capability-scoping", repo)
assert result is not None
assert result.name == "capability-scoping.md"
def test_name_with_spaces(self, repo):
"""Resolves 'multi-agent coordination' via slug matching."""
result = _resolve_claim_path("multi-agent coordination", repo)
assert result is not None
assert result.name == "multi-agent-coordination.md"
def test_not_found(self, repo):
result = _resolve_claim_path("nonexistent claim", repo)
assert result is None
def test_cross_domain_resolution(self, repo):
"""Can resolve a claim in a different domain subdirectory."""
result = _resolve_claim_path("futarchy-governance", repo)
assert result is not None
assert "internet-finance" in str(result)
# ─── graph_expand ─────────────────────────────────────────────────────────
class TestGraphExpand:
def test_basic_expansion(self, repo):
"""Expanding from capability-scoping should find its edges."""
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
titles = [r["claim_title"] for r in results]
assert "capability-matched escalation" in titles
assert "multi-agent coordination" in titles
def test_wiki_links_included(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
titles = [r["claim_title"] for r in results]
assert "wiki-linked claim" in titles
def test_edge_weights(self, repo):
"""challenges edges get 1.5x weight, wiki_links get 0.5x."""
# Expand from capability-matched-escalation (has challenges edge)
results = graph_expand(
["domains/ai-alignment/capability-matched-escalation.md"],
repo_root=repo,
)
challenges_result = [r for r in results if r["edge_type"] == "challenges"]
assert len(challenges_result) == 1
assert challenges_result[0]["edge_weight"] == 1.5
def test_wiki_link_weight(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
wiki_results = [r for r in results if r["edge_type"] == "wiki_links"]
assert all(r["edge_weight"] == 0.5 for r in wiki_results)
def test_depends_on_weight(self, repo):
results = graph_expand(
["domains/internet-finance/futarchy-governance.md"],
repo_root=repo,
)
dep_results = [r for r in results if r["edge_type"] == "depends_on"]
assert len(dep_results) == 1
assert dep_results[0]["edge_weight"] == 1.25
def test_sorted_by_weight_descending(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
weights = [r["edge_weight"] for r in results]
assert weights == sorted(weights, reverse=True)
def test_seed_excluded_from_results(self, repo):
seed = "domains/ai-alignment/capability-scoping.md"
results = graph_expand([seed], repo_root=repo)
paths = [r["claim_path"] for r in results]
assert seed not in paths
def test_seen_set_excludes(self, repo):
"""Paths in seen set are excluded from expansion results."""
already_matched = {"domains/ai-alignment/multi-agent-coordination.md"}
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
seen=already_matched,
)
paths = [r["claim_path"] for r in results]
assert "domains/ai-alignment/multi-agent-coordination.md" not in paths
def test_max_expanded_cap(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
max_expanded=1,
)
assert len(results) <= 1
def test_cap_cuts_lowest_weight(self, repo):
"""With cap=2, wiki_links (0.5x) should be cut before supports (1.0x)."""
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
max_expanded=2,
)
edge_types = [r["edge_type"] for r in results]
assert "wiki_links" not in edge_types
def test_nonexistent_seed(self, repo):
results = graph_expand(["domains/nonexistent.md"], repo_root=repo)
assert results == []
# ─── search_qdrant (mocked HTTP) ─────────────────────────────────────────
class TestSearchQdrant:
def _mock_qdrant_response(self, results):
"""Build a mock urllib response returning Qdrant results."""
resp = MagicMock()
resp.read.return_value = json.dumps({"result": results}).encode()
resp.__enter__ = lambda s: s
resp.__exit__ = MagicMock(return_value=False)
return resp
@patch("lib.search.urllib.request.urlopen")
def test_basic_search(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([
{"id": 1, "score": 0.85, "payload": {
"claim_title": "test claim", "claim_path": "domains/test.md",
"domain": "ai", "confidence": "high",
}},
])
results = search_qdrant([0.1] * 1536, limit=5)
assert len(results) == 1
assert results[0]["score"] == 0.85
@patch("lib.search.urllib.request.urlopen")
def test_domain_filter(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536, domain="ai-alignment")
# Verify the request body includes domain filter
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
assert body["filter"]["must"][0]["key"] == "domain"
assert body["filter"]["must"][0]["match"]["value"] == "ai-alignment"
@patch("lib.search.urllib.request.urlopen")
def test_exclude_filter(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536, exclude=["domains/a.md", "domains/b.md"])
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
must_not = body["filter"]["must_not"]
excluded_paths = [f["match"]["value"] for f in must_not]
assert "domains/a.md" in excluded_paths
assert "domains/b.md" in excluded_paths
@patch("lib.search.urllib.request.urlopen")
def test_no_filters_no_filter_key(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536)
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
assert "filter" not in body
@patch("lib.search.urllib.request.urlopen")
def test_http_failure_returns_empty(self, mock_urlopen):
mock_urlopen.side_effect = Exception("connection refused")
results = search_qdrant([0.1] * 1536)
assert results == []
# ─── search() integration (mocked network) ───────────────────────────────
class TestSearch:
@patch("lib.search.embed_query")
@patch("lib.search.search_qdrant")
def test_embedding_failure(self, mock_qdrant, mock_embed):
mock_embed.return_value = None
result = search("test query")
assert result["error"] == "embedding_failed"
assert result["direct_results"] == []
mock_qdrant.assert_not_called()
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_full_pipeline(self, mock_embed, mock_qdrant, mock_expand):
mock_embed.return_value = [0.1] * 1536
# Pass 1 returns one hit, pass 2 returns empty (nothing above lower threshold)
mock_qdrant.side_effect = [
[{"id": 1, "score": 0.82, "payload": {
"claim_title": "direct hit", "claim_path": "domains/hit.md",
"domain": "ai", "confidence": "high", "snippet": "snippet text",
"type": "claim",
}}],
[], # pass 2 returns nothing
]
mock_expand.return_value = [
{"claim_path": "domains/expanded.md", "claim_title": "expanded",
"edge_type": "supports", "edge_weight": 1.0, "from_claim": "domains/hit.md"},
]
# expand=True triggers two-pass: pass 1 + pass 2 + graph expansion
result = search("test query", expand=True)
assert len(result["direct_results"]) == 1
assert result["direct_results"][0]["claim_title"] == "direct hit"
assert len(result["expanded_results"]) == 1
assert result["total"] == 2
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_exclude_passed_through(self, mock_embed, mock_qdrant, mock_expand):
"""Exclude list reaches both Qdrant and graph_expand."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"id": 1, "score": 0.8, "payload": {
"claim_title": "hit", "claim_path": "domains/hit.md",
}}],
[], # pass 2
]
mock_expand.return_value = []
exclude = ["domains/already-matched.md"]
search("query", expand=True, exclude=exclude)
# Qdrant should get exclude (called twice with expand=True: pass 1 + pass 2)
assert mock_qdrant.call_count == 2
# Both calls should include exclude
for call in mock_qdrant.call_args_list:
assert call.kwargs.get("exclude") == exclude \
or call[1].get("exclude") == exclude
# graph_expand should get seen set containing exclude paths
mock_expand.assert_called_once()
call_kwargs = mock_expand.call_args[1] if mock_expand.call_args[1] else {}
seen = call_kwargs.get("seen")
assert seen is not None
assert "domains/already-matched.md" in seen
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_no_expand_when_disabled(self, mock_embed, mock_qdrant):
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = []
result = search("query", expand=False)
assert result["expanded_results"] == []
# ─── WIKI_LINK_RE ─────────────────────────────────────────────────────────
class TestWikiLinkRegex:
def test_basic(self):
assert WIKI_LINK_RE.findall("See [[some claim]]") == ["some claim"]
def test_multiple(self):
text = "Links to [[claim A]] and [[claim B]]"
assert WIKI_LINK_RE.findall(text) == ["claim A", "claim B"]
def test_no_nested(self):
assert WIKI_LINK_RE.findall("[[outer [[inner]]]]") != ["outer [[inner]]"]
# ─── Structural file exclusion ───────────────────────────────────────────
class TestStructuralFileExclusion:
def test_map_excluded_from_expansion(self, tmp_path):
"""_map.md files should be skipped during graph expansion."""
domains = tmp_path / "domains" / "test"
domains.mkdir(parents=True)
(domains / "seed-claim.md").write_text(
"---\nname: seed claim\ntype: claim\n"
"related:\n - domain map\n---\nBody.\n"
)
(domains / "_map.md").write_text(
"---\nname: domain map\ntype: moc\n---\nDomain index.\n"
)
seed = "domains/test/seed-claim.md"
result = graph_expand([seed], repo_root=tmp_path)
paths = [r["claim_path"] for r in result]
assert "domains/test/_map.md" not in paths
def test_overview_excluded_from_expansion(self, tmp_path):
"""_overview.md files should be skipped during graph expansion."""
domains = tmp_path / "domains" / "test"
domains.mkdir(parents=True)
(domains / "seed-claim.md").write_text(
"---\nname: seed claim\ntype: claim\n"
"related:\n - domain overview\n---\nBody.\n"
)
(domains / "_overview.md").write_text(
"---\nname: domain overview\ntype: moc\n---\nOverview.\n"
)
seed = "domains/test/seed-claim.md"
result = graph_expand([seed], repo_root=tmp_path)
paths = [r["claim_path"] for r in result]
assert "domains/test/_overview.md" not in paths
def test_regular_claims_still_expand(self, repo):
"""Non-structural files should still expand normally."""
seed = "domains/ai-alignment/capability-scoping.md"
result = graph_expand([seed], repo_root=repo)
paths = [r["claim_path"] for r in result]
assert len(paths) > 0
assert "domains/ai-alignment/capability-matched-escalation.md" in paths
# ─── Dedup within vector results ─────────────────────────────────────────
class TestSearchDedup:
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_duplicate_paths_deduped(self, mock_embed, mock_qdrant):
"""Duplicate claim_paths in Qdrant results should be collapsed."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.9, "payload": {"claim_title": "Claim A", "claim_path": "domains/x/a.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
{"score": 0.85, "payload": {"claim_title": "Claim A dupe", "claim_path": "domains/x/a.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
{"score": 0.8, "payload": {"claim_title": "Claim B", "claim_path": "domains/x/b.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
]
result = search("test query", expand=False)
paths = [r["claim_path"] for r in result["direct_results"]]
assert paths == ["domains/x/a.md", "domains/x/b.md"]
assert result["direct_results"][0]["claim_title"] == "Claim A"
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_structural_files_excluded_from_direct(self, mock_embed, mock_qdrant):
"""_map.md should be excluded from direct Qdrant results too."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.9, "payload": {"claim_title": "Domain Map", "claim_path": "domains/x/_map.md",
"domain": "x", "confidence": "", "snippet": ""}},
{"score": 0.85, "payload": {"claim_title": "Real Claim", "claim_path": "domains/x/real.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
]
result = search("test query", expand=False)
paths = [r["claim_path"] for r in result["direct_results"]]
assert "domains/x/_map.md" not in paths
assert "domains/x/real.md" in paths
# ─── Two-pass retrieval ──────────────────────────────────────────────────
class TestTwoPassRetrieval:
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_pass1_only_default(self, mock_embed, mock_qdrant, mock_expand):
"""Default search (expand=False) only calls Qdrant once with high threshold."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.85, "payload": {"claim_title": "Hit", "claim_path": "d/a.md"}},
]
result = search("query")
mock_qdrant.assert_called_once()
# Should use PASS1_THRESHOLD (0.70)
call_kwargs = mock_qdrant.call_args
assert call_kwargs.kwargs.get("score_threshold") == 0.70 \
or call_kwargs[1].get("score_threshold") == 0.70
mock_expand.assert_not_called()
assert len(result["direct_results"]) == 1
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_pass2_expands(self, mock_embed, mock_qdrant, mock_expand):
"""expand=True calls Qdrant twice (pass 1 + pass 2) and runs graph expansion."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"score": 0.85, "payload": {"claim_title": "P1", "claim_path": "d/a.md"}}],
[{"score": 0.65, "payload": {"claim_title": "P2", "claim_path": "d/b.md"}}],
]
mock_expand.return_value = []
result = search("query", expand=True)
assert mock_qdrant.call_count == 2
# Pass 2 should use offset=5 and lower threshold
pass2_call = mock_qdrant.call_args_list[1]
assert pass2_call.kwargs.get("offset") == 5 \
or pass2_call[1].get("offset") == 5
assert len(result["direct_results"]) == 2
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_hard_cap_enforced(self, mock_embed, mock_qdrant, mock_expand):
"""Total results never exceed HARD_CAP (10)."""
mock_embed.return_value = [0.1] * 1536
# 5 from pass 1, 5 from pass 2
p1 = [{"score": 0.9 - i * 0.02, "payload": {
"claim_title": f"P1-{i}", "claim_path": f"d/p1-{i}.md"
}} for i in range(5)]
p2 = [{"score": 0.65 - i * 0.01, "payload": {
"claim_title": f"P2-{i}", "claim_path": f"d/p2-{i}.md"
}} for i in range(5)]
mock_qdrant.side_effect = [p1, p2]
# Graph expansion returns 5 more
mock_expand.return_value = [
{"claim_path": f"d/exp-{i}.md", "claim_title": f"Exp-{i}",
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/p1-0.md"}
for i in range(5)
]
result = search("query", expand=True)
assert result["total"] <= 10
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_sort_order_similarity_first(self, mock_embed, mock_qdrant):
"""Direct results are sorted by cosine similarity descending."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.75, "payload": {"claim_title": "Low", "claim_path": "d/low.md"}},
{"score": 0.95, "payload": {"claim_title": "High", "claim_path": "d/high.md"}},
{"score": 0.85, "payload": {"claim_title": "Mid", "claim_path": "d/mid.md"}},
]
result = search("query")
titles = [r["claim_title"] for r in result["direct_results"]]
assert titles == ["High", "Mid", "Low"]
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_challenges_before_other_expansion(self, mock_embed, mock_qdrant, mock_expand):
"""challenged_by claims appear before other expanded claims."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"score": 0.85, "payload": {"claim_title": "Seed", "claim_path": "d/seed.md"}}],
[],
]
mock_expand.return_value = [
{"claim_path": "d/related.md", "claim_title": "Related",
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/seed.md"},
{"claim_path": "d/challenge.md", "claim_title": "Challenge",
"edge_type": "challenges", "edge_weight": 1.5, "from_claim": "d/seed.md"},
]
result = search("query", expand=True)
expanded_titles = [r["claim_title"] for r in result["expanded_results"]]
assert expanded_titles.index("Challenge") < expanded_titles.index("Related")