feat: two-pass retrieval with sort order and graph expansion #5
4 changed files with 1142 additions and 110 deletions
|
|
@ -13,11 +13,16 @@ import logging
|
|||
import os
|
||||
import sqlite3
|
||||
import statistics
|
||||
import sys
|
||||
import urllib.request
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add pipeline lib to path so we can import shared modules
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "pipeline"))
|
||||
|
||||
from aiohttp import web
|
||||
from lib.search import search as kb_search, embed_query, search_qdrant
|
||||
|
||||
logger = logging.getLogger("argus")
|
||||
|
||||
|
|
@ -27,12 +32,7 @@ PORT = int(os.environ.get("ARGUS_PORT", "8081"))
|
|||
REPO_DIR = Path(os.environ.get("REPO_DIR", "/opt/teleo-eval/workspaces/main"))
|
||||
CLAIM_INDEX_URL = os.environ.get("CLAIM_INDEX_URL", "http://localhost:8080/claim-index")
|
||||
|
||||
# Search config
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims")
|
||||
OPENROUTER_KEY_FILE = Path(os.environ.get("OPENROUTER_KEY_FILE", "/opt/teleo-eval/secrets/openrouter-key"))
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
EMBEDDING_DIMS = 1536
|
||||
# Search config — moved to lib/search.py (shared with Telegram bot + agents)
|
||||
|
||||
# Auth config
|
||||
API_KEY_FILE = Path(os.environ.get("ARGUS_API_KEY_FILE", "/opt/teleo-eval/secrets/argus-api-key"))
|
||||
|
|
@ -483,82 +483,7 @@ async def auth_middleware(request, handler):
|
|||
|
||||
|
||||
# ─── Embedding + Search ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_embedding_key() -> str | None:
|
||||
"""Load OpenRouter API key for embeddings."""
|
||||
return _load_secret(OPENROUTER_KEY_FILE)
|
||||
|
||||
|
||||
def _embed_query(text: str, api_key: str) -> list[float] | None:
|
||||
"""Embed a query string via OpenRouter (OpenAI-compatible endpoint).
|
||||
|
||||
Uses urllib to avoid adding httpx/openai as dependencies.
|
||||
"""
|
||||
payload = json.dumps({
|
||||
"model": f"openai/{EMBEDDING_MODEL}",
|
||||
"input": text,
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
"https://openrouter.ai/api/v1/embeddings",
|
||||
data=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error("Embedding failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _search_qdrant(vector: list[float], limit: int = 10,
|
||||
domain: str | None = None, confidence: str | None = None,
|
||||
exclude: list[str] | None = None) -> list[dict]:
|
||||
"""Search Qdrant collection for nearest claims.
|
||||
|
||||
Uses urllib for zero-dependency Qdrant access (REST API).
|
||||
"""
|
||||
must_filters = []
|
||||
if domain:
|
||||
must_filters.append({"key": "domain", "match": {"value": domain}})
|
||||
if confidence:
|
||||
must_filters.append({"key": "confidence", "match": {"value": confidence}})
|
||||
|
||||
must_not_filters = []
|
||||
if exclude:
|
||||
for path in exclude:
|
||||
must_not_filters.append({"key": "claim_path", "match": {"value": path}})
|
||||
|
||||
payload = {
|
||||
"vector": vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"score_threshold": 0.3,
|
||||
}
|
||||
if must_filters or must_not_filters:
|
||||
payload["filter"] = {}
|
||||
if must_filters:
|
||||
payload["filter"]["must"] = must_filters
|
||||
if must_not_filters:
|
||||
payload["filter"]["must_not"] = must_not_filters
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
|
||||
data=json.dumps(payload).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("result", [])
|
||||
except Exception as e:
|
||||
logger.error("Qdrant search failed: %s", e)
|
||||
return []
|
||||
# Moved to lib/search.py — imported at top of file as kb_search, embed_query, search_qdrant
|
||||
|
||||
|
||||
# ─── Usage logging ───────────────────────────────────────────────────────────
|
||||
|
|
@ -701,7 +626,7 @@ async def handle_api_domains(request):
|
|||
|
||||
|
||||
async def handle_api_search(request):
|
||||
"""GET /api/search — semantic search over claims via Qdrant.
|
||||
"""GET /api/search — semantic search over claims via Qdrant + graph expansion.
|
||||
|
||||
Query params:
|
||||
q: search query (required)
|
||||
|
|
@ -709,6 +634,7 @@ async def handle_api_search(request):
|
|||
confidence: filter by confidence level (optional)
|
||||
limit: max results, default 10 (optional)
|
||||
exclude: comma-separated claim paths to exclude (optional)
|
||||
expand: enable graph expansion, default true (optional)
|
||||
"""
|
||||
query = request.query.get("q", "").strip()
|
||||
if not query:
|
||||
|
|
@ -719,36 +645,19 @@ async def handle_api_search(request):
|
|||
limit = min(int(request.query.get("limit", "10")), 50)
|
||||
exclude_raw = request.query.get("exclude", "")
|
||||
exclude = [p.strip() for p in exclude_raw.split(",") if p.strip()] if exclude_raw else None
|
||||
expand = request.query.get("expand", "true").lower() != "false"
|
||||
|
||||
# Embed the query
|
||||
api_key = _get_embedding_key()
|
||||
if not api_key:
|
||||
return web.json_response({"error": "embedding service unavailable"}, status=503)
|
||||
# Use shared search library (Layer 1 + Layer 2)
|
||||
result = kb_search(query, expand=expand,
|
||||
domain=domain, confidence=confidence, exclude=exclude)
|
||||
|
||||
vector = _embed_query(query, api_key)
|
||||
if vector is None:
|
||||
return web.json_response({"error": "embedding failed"}, status=502)
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
if error == "embedding_failed":
|
||||
return web.json_response({"error": "embedding failed"}, status=502)
|
||||
return web.json_response({"error": error}, status=500)
|
||||
|
||||
# Search Qdrant
|
||||
results = _search_qdrant(vector, limit=limit, domain=domain,
|
||||
confidence=confidence, exclude=exclude)
|
||||
|
||||
# Format response
|
||||
claims = []
|
||||
for hit in results:
|
||||
payload = hit.get("payload", {})
|
||||
claims.append({
|
||||
"claim_title": payload.get("claim_title", ""),
|
||||
"claim_path": payload.get("claim_path", ""),
|
||||
"similarity_score": round(hit.get("score", 0), 4),
|
||||
"domain": payload.get("domain", ""),
|
||||
"confidence": payload.get("confidence", ""),
|
||||
"snippet": payload.get("snippet", "")[:200],
|
||||
"depends_on": payload.get("depends_on", []),
|
||||
"challenged_by": payload.get("challenged_by", []),
|
||||
})
|
||||
|
||||
return web.json_response(claims)
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
async def handle_api_usage(request):
|
||||
|
|
|
|||
415
lib/search.py
Normal file
415
lib/search.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
"""Shared Qdrant vector search library for the Teleo knowledge base.
|
||||
|
||||
Provides embed + search + graph expansion as a reusable library.
|
||||
Any consumer (Argus dashboard, Telegram bot, agent research) imports from here.
|
||||
|
||||
Layer 1: Qdrant vector search (semantic similarity)
|
||||
Layer 2: Graph expansion (1-hop via frontmatter edges)
|
||||
Layer 3: Left to the caller (agent context, domain filtering)
|
||||
|
||||
Owner: Epimetheus
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import urllib.request
|
||||
|
||||
from . import config
|
||||
|
||||
logger = logging.getLogger("pipeline.search")
|
||||
|
||||
# --- Config (all from environment or config.py defaults) ---
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "teleo-claims")
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
_OPENROUTER_KEY: str | None = None
|
||||
|
||||
WIKI_LINK_RE = re.compile(r"\[\[([^\]]+)\]\]")
|
||||
|
||||
# Structural files that should never be included in graph expansion results.
|
||||
# These are indexes/MOCs, not claims — expanding them pulls entire domains.
|
||||
STRUCTURAL_FILES = {"_map.md", "_overview.md"}
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
"""Load OpenRouter API key (cached after first read)."""
|
||||
global _OPENROUTER_KEY
|
||||
if _OPENROUTER_KEY:
|
||||
return _OPENROUTER_KEY
|
||||
key_file = config.SECRETS_DIR / "openrouter-key"
|
||||
if key_file.exists():
|
||||
_OPENROUTER_KEY = key_file.read_text().strip()
|
||||
return _OPENROUTER_KEY
|
||||
_OPENROUTER_KEY = os.environ.get("OPENROUTER_API_KEY")
|
||||
return _OPENROUTER_KEY
|
||||
|
||||
|
||||
# --- Layer 1: Vector search ---
|
||||
|
||||
|
||||
def embed_query(text: str) -> list[float] | None:
|
||||
"""Embed a query string via OpenRouter (OpenAI-compatible endpoint).
|
||||
|
||||
Returns 1536-dim vector or None on failure.
|
||||
"""
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
logger.error("No OpenRouter API key available for embedding")
|
||||
return None
|
||||
|
||||
payload = json.dumps({
|
||||
"model": f"openai/{EMBEDDING_MODEL}",
|
||||
"input": text[:8000],
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
"https://openrouter.ai/api/v1/embeddings",
|
||||
data=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error("Embedding failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def search_qdrant(vector: list[float], limit: int = 10,
|
||||
domain: str | None = None, confidence: str | None = None,
|
||||
exclude: list[str] | None = None,
|
||||
score_threshold: float = 0.3,
|
||||
offset: int = 0) -> list[dict]:
|
||||
"""Search Qdrant collection for nearest claims.
|
||||
|
||||
Args:
|
||||
offset: Skip first N results (Qdrant native offset for pagination).
|
||||
|
||||
Returns list of hits: [{id, score, payload: {claim_path, claim_title, ...}}]
|
||||
"""
|
||||
must_filters = []
|
||||
if domain:
|
||||
must_filters.append({"key": "domain", "match": {"value": domain}})
|
||||
if confidence:
|
||||
must_filters.append({"key": "confidence", "match": {"value": confidence}})
|
||||
|
||||
must_not_filters = []
|
||||
if exclude:
|
||||
for path in exclude:
|
||||
must_not_filters.append({"key": "claim_path", "match": {"value": path}})
|
||||
|
||||
body = {
|
||||
"vector": vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"score_threshold": score_threshold,
|
||||
}
|
||||
if offset > 0:
|
||||
body["offset"] = offset
|
||||
if must_filters or must_not_filters:
|
||||
body["filter"] = {}
|
||||
if must_filters:
|
||||
body["filter"]["must"] = must_filters
|
||||
if must_not_filters:
|
||||
body["filter"]["must_not"] = must_not_filters
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{QDRANT_URL}/collections/{QDRANT_COLLECTION}/points/search",
|
||||
data=json.dumps(body).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("result", [])
|
||||
except Exception as e:
|
||||
logger.error("Qdrant search failed: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
# --- Layer 2: Graph expansion ---
|
||||
|
||||
|
||||
def _parse_frontmatter_edges(path: Path) -> dict:
|
||||
"""Extract relationship edges from a claim's frontmatter.
|
||||
|
||||
Handles both YAML formats:
|
||||
depends_on: ["item1", "item2"] (inline list)
|
||||
depends_on: (multi-line list)
|
||||
- item1
|
||||
- item2
|
||||
|
||||
Returns {supports: [...], challenges: [...], depends_on: [...], related: [...], wiki_links: [...]}.
|
||||
wiki_links are separated from explicit related edges for differential weighting.
|
||||
"""
|
||||
edges = {"supports": [], "challenges": [], "depends_on": [], "related": [], "wiki_links": []}
|
||||
try:
|
||||
text = path.read_text(errors="replace")
|
||||
except Exception:
|
||||
return edges
|
||||
|
||||
if not text.startswith("---"):
|
||||
return edges
|
||||
end = text.find("\n---", 3)
|
||||
if end == -1:
|
||||
return edges
|
||||
|
||||
fm_text = text[3:end]
|
||||
|
||||
# Use YAML parser for reliable edge extraction
|
||||
try:
|
||||
import yaml
|
||||
fm = yaml.safe_load(fm_text)
|
||||
if isinstance(fm, dict):
|
||||
for field in ("supports", "challenges", "depends_on", "related"):
|
||||
val = fm.get(field)
|
||||
if isinstance(val, list):
|
||||
edges[field] = [str(v).strip() for v in val if v]
|
||||
elif isinstance(val, str) and val.strip():
|
||||
edges[field] = [val.strip()]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract wiki links from body as separate edge type (lower weight)
|
||||
body = text[end + 4:]
|
||||
all_explicit = set()
|
||||
for field in ("supports", "challenges", "depends_on", "related"):
|
||||
all_explicit.update(edges[field])
|
||||
|
||||
wiki_links = WIKI_LINK_RE.findall(body)
|
||||
for link in wiki_links:
|
||||
link = link.strip()
|
||||
if link and link not in all_explicit and link not in edges["wiki_links"]:
|
||||
edges["wiki_links"].append(link)
|
||||
|
||||
return edges
|
||||
|
||||
|
||||
def _resolve_claim_path(name: str, repo_root: Path) -> Path | None:
|
||||
"""Resolve a claim name (from frontmatter edge or wiki link) to a file path.
|
||||
|
||||
Handles both naming conventions:
|
||||
- "GLP-1 receptor agonists are..." → "GLP-1 receptor agonists are....md" (spaces)
|
||||
- "glp-1-persistence-drops..." → "glp-1-persistence-drops....md" (slugified)
|
||||
|
||||
Checks domains/, core/, foundations/, decisions/ subdirectories.
|
||||
"""
|
||||
# Try exact name first (spaces in filename), then slugified
|
||||
candidates = [name]
|
||||
slug = name.lower().replace(" ", "-").replace("_", "-")
|
||||
if slug != name:
|
||||
candidates.append(slug)
|
||||
|
||||
for subdir in ["domains", "core", "foundations", "decisions"]:
|
||||
base = repo_root / subdir
|
||||
if not base.is_dir():
|
||||
continue
|
||||
for candidate_name in candidates:
|
||||
for md in base.rglob(f"{candidate_name}.md"):
|
||||
return md
|
||||
return None
|
||||
|
||||
|
||||
def graph_expand(seed_paths: list[str], repo_root: Path | None = None,
|
||||
max_expanded: int = 30,
|
||||
challenge_weight: float = 1.5,
|
||||
seen: set[str] | None = None) -> list[dict]:
|
||||
"""Layer 2: Expand seed claims 1-hop through knowledge graph edges.
|
||||
|
||||
Traverses supports/challenges/depends_on/related/wiki_links edges in frontmatter.
|
||||
Edge weights: challenges 1.5x, depends_on 1.25x, supports/related 1.0x, wiki_links 0.5x.
|
||||
Results sorted by weight descending so cap cuts low-value edges first.
|
||||
|
||||
Args:
|
||||
seen: Optional set of paths already matched (e.g. from keyword search) to exclude.
|
||||
|
||||
Returns list of {claim_path, claim_title, edge_type, edge_weight, from_claim}.
|
||||
Excludes claims already in seed_paths or seen set.
|
||||
"""
|
||||
EDGE_WEIGHTS = {
|
||||
"challenges": 1.5,
|
||||
"depends_on": 1.25,
|
||||
"supports": 1.0,
|
||||
"related": 1.0,
|
||||
"wiki_links": 0.5,
|
||||
}
|
||||
|
||||
root = repo_root or config.MAIN_WORKTREE
|
||||
all_expanded = []
|
||||
visited = set(seed_paths)
|
||||
if seen:
|
||||
visited.update(seen)
|
||||
|
||||
for seed_path in seed_paths:
|
||||
full_path = root / seed_path
|
||||
if not full_path.exists():
|
||||
continue
|
||||
|
||||
edges = _parse_frontmatter_edges(full_path)
|
||||
|
||||
for edge_type, targets in edges.items():
|
||||
weight = EDGE_WEIGHTS.get(edge_type, 1.0)
|
||||
|
||||
for target_name in targets:
|
||||
target_path = _resolve_claim_path(target_name, root)
|
||||
if target_path is None:
|
||||
continue
|
||||
|
||||
rel_path = str(target_path.relative_to(root))
|
||||
if rel_path in visited:
|
||||
continue
|
||||
# Skip structural files (MOCs/indexes) — they pull entire domains
|
||||
if target_path.name in STRUCTURAL_FILES:
|
||||
continue
|
||||
visited.add(rel_path)
|
||||
|
||||
# Read title from frontmatter
|
||||
title = target_name
|
||||
try:
|
||||
text = target_path.read_text(errors="replace")
|
||||
if text.startswith("---"):
|
||||
end = text.find("\n---", 3)
|
||||
if end > 0:
|
||||
import yaml
|
||||
fm = yaml.safe_load(text[3:end])
|
||||
if isinstance(fm, dict):
|
||||
title = fm.get("name", fm.get("title", target_name))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
all_expanded.append({
|
||||
"claim_path": rel_path,
|
||||
"claim_title": str(title),
|
||||
"edge_type": edge_type,
|
||||
"edge_weight": weight,
|
||||
"from_claim": seed_path,
|
||||
})
|
||||
|
||||
# Sort by weight descending so cap cuts lowest-value edges first
|
||||
all_expanded.sort(key=lambda x: x["edge_weight"], reverse=True)
|
||||
return all_expanded[:max_expanded]
|
||||
|
||||
|
||||
# --- Combined search (Layer 1 + Layer 2) ---
|
||||
|
||||
# Default thresholds — calibrated with Leo's retrieval audits
|
||||
PASS1_LIMIT = 5
|
||||
PASS1_THRESHOLD = 0.70
|
||||
PASS2_LIMIT = 5
|
||||
PASS2_THRESHOLD = 0.60
|
||||
HARD_CAP = 10
|
||||
|
||||
|
||||
def _dedup_hits(hits: list[dict], seen: set[str]) -> list[dict]:
|
||||
"""Filter Qdrant hits: dedup by claim_path, exclude structural files."""
|
||||
results = []
|
||||
for hit in hits:
|
||||
payload = hit.get("payload", {})
|
||||
claim_path = payload.get("claim_path", "")
|
||||
if claim_path in seen:
|
||||
continue
|
||||
if claim_path.split("/")[-1] in STRUCTURAL_FILES:
|
||||
continue
|
||||
seen.add(claim_path)
|
||||
results.append({
|
||||
"claim_title": payload.get("claim_title", ""),
|
||||
"claim_path": claim_path,
|
||||
"score": round(hit.get("score", 0), 4),
|
||||
"domain": payload.get("domain", ""),
|
||||
"confidence": payload.get("confidence", ""),
|
||||
"snippet": payload.get("snippet", "")[:200],
|
||||
"type": payload.get("type", "claim"),
|
||||
})
|
||||
return results
|
||||
|
||||
|
||||
def _sort_results(direct: list[dict], expanded: list[dict]) -> list[dict]:
|
||||
"""Sort combined results: similarity desc → challenged_by → other expansion.
|
||||
|
||||
Sort order is load-bearing: LLMs have primacy bias, so best claims first.
|
||||
"""
|
||||
# Direct results already sorted by Qdrant (cosine desc)
|
||||
sorted_direct = sorted(direct, key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
# Expansion: challenged_by first (counterpoints), then rest by weight
|
||||
challenged = [e for e in expanded if e.get("edge_type") == "challenges"]
|
||||
other_expanded = [e for e in expanded if e.get("edge_type") != "challenges"]
|
||||
challenged.sort(key=lambda x: x.get("edge_weight", 0), reverse=True)
|
||||
other_expanded.sort(key=lambda x: x.get("edge_weight", 0), reverse=True)
|
||||
|
||||
return sorted_direct + challenged + other_expanded
|
||||
|
||||
|
||||
def search(query: str, expand: bool = False,
|
||||
domain: str | None = None, confidence: str | None = None,
|
||||
exclude: list[str] | None = None) -> dict:
|
||||
"""Two-pass semantic search: embed query, search Qdrant, optionally expand.
|
||||
|
||||
Pass 1 (expand=False, default): Top 5 claims from Qdrant, score >= 0.70.
|
||||
Sufficient for ~80% of queries. Fast and focused.
|
||||
|
||||
Pass 2 (expand=True): Next 5 claims (offset=5, score >= 0.60) plus
|
||||
graph-expanded claims (challenged_by, related edges). Hard cap 10 total.
|
||||
Agent calls this only when pass 1 didn't answer the question.
|
||||
|
||||
Returns {
|
||||
"query": str,
|
||||
"direct_results": [...], # Layer 1 Qdrant hits (sorted by score desc)
|
||||
"expanded_results": [...], # Layer 2 graph expansion (challenges first)
|
||||
"total": int,
|
||||
}
|
||||
"""
|
||||
vector = embed_query(query)
|
||||
if vector is None:
|
||||
return {"query": query, "direct_results": [], "expanded_results": [],
|
||||
"total": 0, "error": "embedding_failed"}
|
||||
|
||||
# --- Pass 1: Top 5, high threshold ---
|
||||
hits = search_qdrant(vector, limit=PASS1_LIMIT, domain=domain,
|
||||
confidence=confidence, exclude=exclude,
|
||||
score_threshold=PASS1_THRESHOLD)
|
||||
|
||||
seen_paths: set[str] = set()
|
||||
if exclude:
|
||||
seen_paths.update(exclude)
|
||||
direct = _dedup_hits(hits, seen_paths)
|
||||
|
||||
expanded = []
|
||||
if expand:
|
||||
# --- Pass 2: Next 5 from Qdrant (lower threshold, offset) ---
|
||||
pass2_hits = search_qdrant(vector, limit=PASS2_LIMIT, domain=domain,
|
||||
confidence=confidence, exclude=exclude,
|
||||
score_threshold=PASS2_THRESHOLD,
|
||||
offset=PASS1_LIMIT)
|
||||
pass2_direct = _dedup_hits(pass2_hits, seen_paths)
|
||||
direct.extend(pass2_direct)
|
||||
|
||||
# Graph expansion on all direct results (pass 1 + pass 2 seeds)
|
||||
seed_paths = [r["claim_path"] for r in direct]
|
||||
remaining_cap = HARD_CAP - len(direct)
|
||||
if remaining_cap > 0:
|
||||
expanded = graph_expand(seed_paths, max_expanded=remaining_cap,
|
||||
seen=seen_paths)
|
||||
|
||||
# Enforce hard cap across all results
|
||||
all_sorted = _sort_results(direct, expanded)[:HARD_CAP]
|
||||
|
||||
# Split back into direct vs expanded for backward compat
|
||||
direct_paths = {r["claim_path"] for r in direct}
|
||||
final_direct = [r for r in all_sorted if r.get("claim_path") in direct_paths]
|
||||
final_expanded = [r for r in all_sorted if r.get("claim_path") not in direct_paths]
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"direct_results": final_direct,
|
||||
"expanded_results": final_expanded,
|
||||
"total": len(all_sorted),
|
||||
}
|
||||
|
|
@ -621,3 +621,107 @@ def format_context_for_prompt(ctx: KBContext) -> str:
|
|||
f"{ctx.stats.get('claims_matched', 0)} claims.")
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
# --- Qdrant vector search integration ---
|
||||
|
||||
# Module-level import guard for lib.search (Fix 3: no per-call sys.path manipulation)
|
||||
_vector_search = None
|
||||
try:
|
||||
import sys as _sys
|
||||
import os as _os
|
||||
_pipeline_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))
|
||||
if _pipeline_root not in _sys.path:
|
||||
_sys.path.insert(0, _pipeline_root)
|
||||
from lib.search import search as _vector_search
|
||||
except ImportError:
|
||||
logger.warning("Qdrant search unavailable at module load (lib.search not found)")
|
||||
|
||||
|
||||
def retrieve_vector_context(query: str,
|
||||
keyword_paths: list[str] | None = None) -> tuple[str, dict]:
|
||||
"""Semantic search via Qdrant — returns (formatted_text, metadata).
|
||||
|
||||
Complements retrieve_context() (symbolic/keyword) with semantic similarity.
|
||||
Falls back gracefully if Qdrant is unavailable.
|
||||
|
||||
Args:
|
||||
keyword_paths: Claim paths already matched by keyword search. These are
|
||||
excluded at the Qdrant query level AND from graph expansion to avoid
|
||||
duplicates in the prompt.
|
||||
|
||||
Returns:
|
||||
(formatted_text, metadata_dict)
|
||||
metadata_dict: {direct_results: [...], expanded_results: [...],
|
||||
layers_hit: [...], duration_ms: int}
|
||||
"""
|
||||
import time as _time
|
||||
t0 = _time.monotonic()
|
||||
empty_meta = {"direct_results": [], "expanded_results": [],
|
||||
"layers_hit": [], "duration_ms": 0}
|
||||
|
||||
if _vector_search is None:
|
||||
return "", empty_meta
|
||||
|
||||
try:
|
||||
results = _vector_search(query, expand=True,
|
||||
exclude=keyword_paths)
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant search failed: %s", e)
|
||||
return "", empty_meta
|
||||
|
||||
duration = int((_time.monotonic() - t0) * 1000)
|
||||
|
||||
if results.get("error") or not results.get("direct_results"):
|
||||
return "", {**empty_meta, "duration_ms": duration,
|
||||
"error": results.get("error")}
|
||||
|
||||
layers_hit = ["qdrant"]
|
||||
if results.get("expanded_results"):
|
||||
layers_hit.append("graph")
|
||||
|
||||
# Build structured metadata for audit
|
||||
meta = {
|
||||
"direct_results": [
|
||||
{"path": r["claim_path"], "title": r["claim_title"],
|
||||
"score": r["score"], "domain": r.get("domain", ""),
|
||||
"source": "qdrant"}
|
||||
for r in results["direct_results"]
|
||||
],
|
||||
"expanded_results": [
|
||||
{"path": r["claim_path"], "title": r["claim_title"],
|
||||
"edge_type": r.get("edge_type", "related"),
|
||||
"from_claim": r.get("from_claim", ""), "source": "graph"}
|
||||
for r in results.get("expanded_results", [])
|
||||
],
|
||||
"layers_hit": layers_hit,
|
||||
"duration_ms": duration,
|
||||
}
|
||||
|
||||
# Build formatted text for prompt (Fix 4: subsection headers)
|
||||
sections = []
|
||||
sections.append("## Semantic Search Results (Qdrant)")
|
||||
sections.append("")
|
||||
sections.append("### Direct matches")
|
||||
|
||||
for r in results["direct_results"]:
|
||||
score_pct = int(r["score"] * 100)
|
||||
line = f"- **{r['claim_title']}** ({score_pct}% match"
|
||||
if r.get("domain"):
|
||||
line += f", {r['domain']}"
|
||||
if r.get("confidence"):
|
||||
line += f", {r['confidence']}"
|
||||
line += ")"
|
||||
sections.append(line)
|
||||
if r.get("snippet"):
|
||||
sections.append(f" {r['snippet']}")
|
||||
|
||||
if results.get("expanded_results"):
|
||||
sections.append("")
|
||||
sections.append("### Related claims (graph expansion)")
|
||||
for r in results["expanded_results"]:
|
||||
edge = r.get("edge_type", "related")
|
||||
weight_str = f" ×{r.get('edge_weight', 1.0)}" if r.get("edge_weight", 1.0) != 1.0 else ""
|
||||
sections.append(f"- {r['claim_title']} ({edge}{weight_str} → {r.get('from_claim', '').split('/')[-1]})")
|
||||
|
||||
return "\n".join(sections), meta
|
||||
|
|
|
|||
604
tests/test_search.py
Normal file
604
tests/test_search.py
Normal file
|
|
@ -0,0 +1,604 @@
|
|||
"""Tests for lib/search.py — vector search and graph expansion."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from lib.search import (
|
||||
_parse_frontmatter_edges,
|
||||
_resolve_claim_path,
|
||||
graph_expand,
|
||||
search,
|
||||
search_qdrant,
|
||||
WIKI_LINK_RE,
|
||||
)
|
||||
|
||||
|
||||
# ─── Fixtures ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo(tmp_path):
|
||||
"""Minimal KB repo structure with claim files."""
|
||||
domains = tmp_path / "domains"
|
||||
|
||||
# ai-alignment domain
|
||||
ai = domains / "ai-alignment"
|
||||
ai.mkdir(parents=True)
|
||||
|
||||
(ai / "capability-scoping.md").write_text(
|
||||
"---\nname: AI agent capability scoping\ntype: claim\n"
|
||||
"supports:\n - capability-matched escalation\n"
|
||||
"related:\n - multi-agent coordination\n"
|
||||
"---\n\nBody text with [[wiki-linked claim]].\n"
|
||||
)
|
||||
(ai / "capability-matched-escalation.md").write_text(
|
||||
"---\nname: capability-matched escalation\ntype: claim\n"
|
||||
"challenges:\n - unconstrained autonomy\n"
|
||||
"---\n\nEscalation body.\n"
|
||||
)
|
||||
(ai / "multi-agent-coordination.md").write_text(
|
||||
"---\nname: multi-agent coordination\ntype: claim\n---\n\nCoord body.\n"
|
||||
)
|
||||
(ai / "wiki-linked-claim.md").write_text(
|
||||
"---\nname: wiki-linked claim\ntype: claim\n---\n\nWiki body.\n"
|
||||
)
|
||||
(ai / "unconstrained-autonomy.md").write_text(
|
||||
"---\nname: unconstrained autonomy\ntype: claim\n---\n\nAutonomy body.\n"
|
||||
)
|
||||
|
||||
# internet-finance domain
|
||||
fin = domains / "internet-finance"
|
||||
fin.mkdir(parents=True)
|
||||
|
||||
(fin / "futarchy-governance.md").write_text(
|
||||
"---\nname: futarchy governance\ntype: claim\n"
|
||||
"depends_on:\n - prediction market accuracy\n"
|
||||
"---\n\nFutarchy body.\n"
|
||||
)
|
||||
(fin / "prediction-market-accuracy.md").write_text(
|
||||
"---\nname: prediction market accuracy\ntype: claim\n---\n\nPM body.\n"
|
||||
)
|
||||
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def claim_no_frontmatter(tmp_path):
|
||||
"""Claim file with no frontmatter."""
|
||||
domains = tmp_path / "domains" / "misc"
|
||||
domains.mkdir(parents=True)
|
||||
f = domains / "bare.md"
|
||||
f.write_text("Just a body, no frontmatter.\n")
|
||||
return f
|
||||
|
||||
|
||||
# ─── _parse_frontmatter_edges ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestParseFrontmatterEdges:
|
||||
|
||||
def test_supports_and_related(self, repo):
|
||||
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
|
||||
edges = _parse_frontmatter_edges(path)
|
||||
assert edges["supports"] == ["capability-matched escalation"]
|
||||
assert edges["related"] == ["multi-agent coordination"]
|
||||
|
||||
def test_challenges(self, repo):
|
||||
path = repo / "domains" / "ai-alignment" / "capability-matched-escalation.md"
|
||||
edges = _parse_frontmatter_edges(path)
|
||||
assert edges["challenges"] == ["unconstrained autonomy"]
|
||||
|
||||
def test_depends_on(self, repo):
|
||||
path = repo / "domains" / "internet-finance" / "futarchy-governance.md"
|
||||
edges = _parse_frontmatter_edges(path)
|
||||
assert edges["depends_on"] == ["prediction market accuracy"]
|
||||
|
||||
def test_wiki_links_extracted_separately(self, repo):
|
||||
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
|
||||
edges = _parse_frontmatter_edges(path)
|
||||
assert "wiki-linked claim" in edges["wiki_links"]
|
||||
# Wiki links should NOT appear in explicit related
|
||||
assert "wiki-linked claim" not in edges["related"]
|
||||
|
||||
def test_wiki_links_deduped_from_explicit(self, tmp_path):
|
||||
"""If a wiki link matches an explicit edge, it's excluded from wiki_links."""
|
||||
d = tmp_path / "domains" / "test"
|
||||
d.mkdir(parents=True)
|
||||
f = d / "overlap.md"
|
||||
f.write_text(
|
||||
"---\nname: overlap test\nrelated:\n - shared target\n---\n\n"
|
||||
"Body with [[shared target]] link.\n"
|
||||
)
|
||||
edges = _parse_frontmatter_edges(f)
|
||||
assert edges["related"] == ["shared target"]
|
||||
assert edges["wiki_links"] == []
|
||||
|
||||
def test_no_frontmatter(self, claim_no_frontmatter):
|
||||
edges = _parse_frontmatter_edges(claim_no_frontmatter)
|
||||
assert all(v == [] for v in edges.values())
|
||||
|
||||
def test_missing_file(self, tmp_path):
|
||||
edges = _parse_frontmatter_edges(tmp_path / "nonexistent.md")
|
||||
assert all(v == [] for v in edges.values())
|
||||
|
||||
def test_inline_list_format(self, tmp_path):
|
||||
"""Handles YAML inline list: depends_on: ["a", "b"]."""
|
||||
d = tmp_path / "domains" / "test"
|
||||
d.mkdir(parents=True)
|
||||
f = d / "inline.md"
|
||||
f.write_text('---\nname: inline\ndepends_on: ["alpha", "beta"]\n---\n\nBody.\n')
|
||||
edges = _parse_frontmatter_edges(f)
|
||||
assert edges["depends_on"] == ["alpha", "beta"]
|
||||
|
||||
|
||||
# ─── _resolve_claim_path ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveClaimPath:
|
||||
|
||||
def test_slugified_name(self, repo):
|
||||
result = _resolve_claim_path("capability-scoping", repo)
|
||||
assert result is not None
|
||||
assert result.name == "capability-scoping.md"
|
||||
|
||||
def test_name_with_spaces(self, repo):
|
||||
"""Resolves 'multi-agent coordination' via slug matching."""
|
||||
result = _resolve_claim_path("multi-agent coordination", repo)
|
||||
assert result is not None
|
||||
assert result.name == "multi-agent-coordination.md"
|
||||
|
||||
def test_not_found(self, repo):
|
||||
result = _resolve_claim_path("nonexistent claim", repo)
|
||||
assert result is None
|
||||
|
||||
def test_cross_domain_resolution(self, repo):
|
||||
"""Can resolve a claim in a different domain subdirectory."""
|
||||
result = _resolve_claim_path("futarchy-governance", repo)
|
||||
assert result is not None
|
||||
assert "internet-finance" in str(result)
|
||||
|
||||
|
||||
# ─── graph_expand ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGraphExpand:
|
||||
|
||||
def test_basic_expansion(self, repo):
|
||||
"""Expanding from capability-scoping should find its edges."""
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
titles = [r["claim_title"] for r in results]
|
||||
assert "capability-matched escalation" in titles
|
||||
assert "multi-agent coordination" in titles
|
||||
|
||||
def test_wiki_links_included(self, repo):
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
titles = [r["claim_title"] for r in results]
|
||||
assert "wiki-linked claim" in titles
|
||||
|
||||
def test_edge_weights(self, repo):
|
||||
"""challenges edges get 1.5x weight, wiki_links get 0.5x."""
|
||||
# Expand from capability-matched-escalation (has challenges edge)
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-matched-escalation.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
challenges_result = [r for r in results if r["edge_type"] == "challenges"]
|
||||
assert len(challenges_result) == 1
|
||||
assert challenges_result[0]["edge_weight"] == 1.5
|
||||
|
||||
def test_wiki_link_weight(self, repo):
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
wiki_results = [r for r in results if r["edge_type"] == "wiki_links"]
|
||||
assert all(r["edge_weight"] == 0.5 for r in wiki_results)
|
||||
|
||||
def test_depends_on_weight(self, repo):
|
||||
results = graph_expand(
|
||||
["domains/internet-finance/futarchy-governance.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
dep_results = [r for r in results if r["edge_type"] == "depends_on"]
|
||||
assert len(dep_results) == 1
|
||||
assert dep_results[0]["edge_weight"] == 1.25
|
||||
|
||||
def test_sorted_by_weight_descending(self, repo):
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
)
|
||||
weights = [r["edge_weight"] for r in results]
|
||||
assert weights == sorted(weights, reverse=True)
|
||||
|
||||
def test_seed_excluded_from_results(self, repo):
|
||||
seed = "domains/ai-alignment/capability-scoping.md"
|
||||
results = graph_expand([seed], repo_root=repo)
|
||||
paths = [r["claim_path"] for r in results]
|
||||
assert seed not in paths
|
||||
|
||||
def test_seen_set_excludes(self, repo):
|
||||
"""Paths in seen set are excluded from expansion results."""
|
||||
already_matched = {"domains/ai-alignment/multi-agent-coordination.md"}
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
seen=already_matched,
|
||||
)
|
||||
paths = [r["claim_path"] for r in results]
|
||||
assert "domains/ai-alignment/multi-agent-coordination.md" not in paths
|
||||
|
||||
def test_max_expanded_cap(self, repo):
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
max_expanded=1,
|
||||
)
|
||||
assert len(results) <= 1
|
||||
|
||||
def test_cap_cuts_lowest_weight(self, repo):
|
||||
"""With cap=2, wiki_links (0.5x) should be cut before supports (1.0x)."""
|
||||
results = graph_expand(
|
||||
["domains/ai-alignment/capability-scoping.md"],
|
||||
repo_root=repo,
|
||||
max_expanded=2,
|
||||
)
|
||||
edge_types = [r["edge_type"] for r in results]
|
||||
assert "wiki_links" not in edge_types
|
||||
|
||||
def test_nonexistent_seed(self, repo):
|
||||
results = graph_expand(["domains/nonexistent.md"], repo_root=repo)
|
||||
assert results == []
|
||||
|
||||
|
||||
# ─── search_qdrant (mocked HTTP) ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchQdrant:
|
||||
|
||||
def _mock_qdrant_response(self, results):
|
||||
"""Build a mock urllib response returning Qdrant results."""
|
||||
resp = MagicMock()
|
||||
resp.read.return_value = json.dumps({"result": results}).encode()
|
||||
resp.__enter__ = lambda s: s
|
||||
resp.__exit__ = MagicMock(return_value=False)
|
||||
return resp
|
||||
|
||||
@patch("lib.search.urllib.request.urlopen")
|
||||
def test_basic_search(self, mock_urlopen):
|
||||
mock_urlopen.return_value = self._mock_qdrant_response([
|
||||
{"id": 1, "score": 0.85, "payload": {
|
||||
"claim_title": "test claim", "claim_path": "domains/test.md",
|
||||
"domain": "ai", "confidence": "high",
|
||||
}},
|
||||
])
|
||||
results = search_qdrant([0.1] * 1536, limit=5)
|
||||
assert len(results) == 1
|
||||
assert results[0]["score"] == 0.85
|
||||
|
||||
@patch("lib.search.urllib.request.urlopen")
|
||||
def test_domain_filter(self, mock_urlopen):
|
||||
mock_urlopen.return_value = self._mock_qdrant_response([])
|
||||
search_qdrant([0.1] * 1536, domain="ai-alignment")
|
||||
# Verify the request body includes domain filter
|
||||
call_args = mock_urlopen.call_args
|
||||
req = call_args[0][0]
|
||||
body = json.loads(req.data)
|
||||
assert body["filter"]["must"][0]["key"] == "domain"
|
||||
assert body["filter"]["must"][0]["match"]["value"] == "ai-alignment"
|
||||
|
||||
@patch("lib.search.urllib.request.urlopen")
|
||||
def test_exclude_filter(self, mock_urlopen):
|
||||
mock_urlopen.return_value = self._mock_qdrant_response([])
|
||||
search_qdrant([0.1] * 1536, exclude=["domains/a.md", "domains/b.md"])
|
||||
call_args = mock_urlopen.call_args
|
||||
req = call_args[0][0]
|
||||
body = json.loads(req.data)
|
||||
must_not = body["filter"]["must_not"]
|
||||
excluded_paths = [f["match"]["value"] for f in must_not]
|
||||
assert "domains/a.md" in excluded_paths
|
||||
assert "domains/b.md" in excluded_paths
|
||||
|
||||
@patch("lib.search.urllib.request.urlopen")
|
||||
def test_no_filters_no_filter_key(self, mock_urlopen):
|
||||
mock_urlopen.return_value = self._mock_qdrant_response([])
|
||||
search_qdrant([0.1] * 1536)
|
||||
call_args = mock_urlopen.call_args
|
||||
req = call_args[0][0]
|
||||
body = json.loads(req.data)
|
||||
assert "filter" not in body
|
||||
|
||||
@patch("lib.search.urllib.request.urlopen")
|
||||
def test_http_failure_returns_empty(self, mock_urlopen):
|
||||
mock_urlopen.side_effect = Exception("connection refused")
|
||||
results = search_qdrant([0.1] * 1536)
|
||||
assert results == []
|
||||
|
||||
|
||||
# ─── search() integration (mocked network) ───────────────────────────────
|
||||
|
||||
|
||||
class TestSearch:
|
||||
|
||||
@patch("lib.search.embed_query")
|
||||
@patch("lib.search.search_qdrant")
|
||||
def test_embedding_failure(self, mock_qdrant, mock_embed):
|
||||
mock_embed.return_value = None
|
||||
result = search("test query")
|
||||
assert result["error"] == "embedding_failed"
|
||||
assert result["direct_results"] == []
|
||||
mock_qdrant.assert_not_called()
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_full_pipeline(self, mock_embed, mock_qdrant, mock_expand):
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
# Pass 1 returns one hit, pass 2 returns empty (nothing above lower threshold)
|
||||
mock_qdrant.side_effect = [
|
||||
[{"id": 1, "score": 0.82, "payload": {
|
||||
"claim_title": "direct hit", "claim_path": "domains/hit.md",
|
||||
"domain": "ai", "confidence": "high", "snippet": "snippet text",
|
||||
"type": "claim",
|
||||
}}],
|
||||
[], # pass 2 returns nothing
|
||||
]
|
||||
mock_expand.return_value = [
|
||||
{"claim_path": "domains/expanded.md", "claim_title": "expanded",
|
||||
"edge_type": "supports", "edge_weight": 1.0, "from_claim": "domains/hit.md"},
|
||||
]
|
||||
# expand=True triggers two-pass: pass 1 + pass 2 + graph expansion
|
||||
result = search("test query", expand=True)
|
||||
assert len(result["direct_results"]) == 1
|
||||
assert result["direct_results"][0]["claim_title"] == "direct hit"
|
||||
assert len(result["expanded_results"]) == 1
|
||||
assert result["total"] == 2
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_exclude_passed_through(self, mock_embed, mock_qdrant, mock_expand):
|
||||
"""Exclude list reaches both Qdrant and graph_expand."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.side_effect = [
|
||||
[{"id": 1, "score": 0.8, "payload": {
|
||||
"claim_title": "hit", "claim_path": "domains/hit.md",
|
||||
}}],
|
||||
[], # pass 2
|
||||
]
|
||||
mock_expand.return_value = []
|
||||
exclude = ["domains/already-matched.md"]
|
||||
search("query", expand=True, exclude=exclude)
|
||||
|
||||
# Qdrant should get exclude (called twice with expand=True: pass 1 + pass 2)
|
||||
assert mock_qdrant.call_count == 2
|
||||
# Both calls should include exclude
|
||||
for call in mock_qdrant.call_args_list:
|
||||
assert call.kwargs.get("exclude") == exclude \
|
||||
or call[1].get("exclude") == exclude
|
||||
|
||||
# graph_expand should get seen set containing exclude paths
|
||||
mock_expand.assert_called_once()
|
||||
call_kwargs = mock_expand.call_args[1] if mock_expand.call_args[1] else {}
|
||||
seen = call_kwargs.get("seen")
|
||||
assert seen is not None
|
||||
assert "domains/already-matched.md" in seen
|
||||
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_no_expand_when_disabled(self, mock_embed, mock_qdrant):
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.return_value = []
|
||||
result = search("query", expand=False)
|
||||
assert result["expanded_results"] == []
|
||||
|
||||
|
||||
# ─── WIKI_LINK_RE ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWikiLinkRegex:
|
||||
|
||||
def test_basic(self):
|
||||
assert WIKI_LINK_RE.findall("See [[some claim]]") == ["some claim"]
|
||||
|
||||
def test_multiple(self):
|
||||
text = "Links to [[claim A]] and [[claim B]]"
|
||||
assert WIKI_LINK_RE.findall(text) == ["claim A", "claim B"]
|
||||
|
||||
def test_no_nested(self):
|
||||
assert WIKI_LINK_RE.findall("[[outer [[inner]]]]") != ["outer [[inner]]"]
|
||||
|
||||
|
||||
# ─── Structural file exclusion ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStructuralFileExclusion:
|
||||
|
||||
def test_map_excluded_from_expansion(self, tmp_path):
|
||||
"""_map.md files should be skipped during graph expansion."""
|
||||
domains = tmp_path / "domains" / "test"
|
||||
domains.mkdir(parents=True)
|
||||
(domains / "seed-claim.md").write_text(
|
||||
"---\nname: seed claim\ntype: claim\n"
|
||||
"related:\n - domain map\n---\nBody.\n"
|
||||
)
|
||||
(domains / "_map.md").write_text(
|
||||
"---\nname: domain map\ntype: moc\n---\nDomain index.\n"
|
||||
)
|
||||
seed = "domains/test/seed-claim.md"
|
||||
result = graph_expand([seed], repo_root=tmp_path)
|
||||
paths = [r["claim_path"] for r in result]
|
||||
assert "domains/test/_map.md" not in paths
|
||||
|
||||
def test_overview_excluded_from_expansion(self, tmp_path):
|
||||
"""_overview.md files should be skipped during graph expansion."""
|
||||
domains = tmp_path / "domains" / "test"
|
||||
domains.mkdir(parents=True)
|
||||
(domains / "seed-claim.md").write_text(
|
||||
"---\nname: seed claim\ntype: claim\n"
|
||||
"related:\n - domain overview\n---\nBody.\n"
|
||||
)
|
||||
(domains / "_overview.md").write_text(
|
||||
"---\nname: domain overview\ntype: moc\n---\nOverview.\n"
|
||||
)
|
||||
seed = "domains/test/seed-claim.md"
|
||||
result = graph_expand([seed], repo_root=tmp_path)
|
||||
paths = [r["claim_path"] for r in result]
|
||||
assert "domains/test/_overview.md" not in paths
|
||||
|
||||
def test_regular_claims_still_expand(self, repo):
|
||||
"""Non-structural files should still expand normally."""
|
||||
seed = "domains/ai-alignment/capability-scoping.md"
|
||||
result = graph_expand([seed], repo_root=repo)
|
||||
paths = [r["claim_path"] for r in result]
|
||||
assert len(paths) > 0
|
||||
assert "domains/ai-alignment/capability-matched-escalation.md" in paths
|
||||
|
||||
|
||||
# ─── Dedup within vector results ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchDedup:
|
||||
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_duplicate_paths_deduped(self, mock_embed, mock_qdrant):
|
||||
"""Duplicate claim_paths in Qdrant results should be collapsed."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.return_value = [
|
||||
{"score": 0.9, "payload": {"claim_title": "Claim A", "claim_path": "domains/x/a.md",
|
||||
"domain": "x", "confidence": "high", "snippet": "..."}},
|
||||
{"score": 0.85, "payload": {"claim_title": "Claim A dupe", "claim_path": "domains/x/a.md",
|
||||
"domain": "x", "confidence": "high", "snippet": "..."}},
|
||||
{"score": 0.8, "payload": {"claim_title": "Claim B", "claim_path": "domains/x/b.md",
|
||||
"domain": "x", "confidence": "high", "snippet": "..."}},
|
||||
]
|
||||
result = search("test query", expand=False)
|
||||
paths = [r["claim_path"] for r in result["direct_results"]]
|
||||
assert paths == ["domains/x/a.md", "domains/x/b.md"]
|
||||
assert result["direct_results"][0]["claim_title"] == "Claim A"
|
||||
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_structural_files_excluded_from_direct(self, mock_embed, mock_qdrant):
|
||||
"""_map.md should be excluded from direct Qdrant results too."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.return_value = [
|
||||
{"score": 0.9, "payload": {"claim_title": "Domain Map", "claim_path": "domains/x/_map.md",
|
||||
"domain": "x", "confidence": "", "snippet": ""}},
|
||||
{"score": 0.85, "payload": {"claim_title": "Real Claim", "claim_path": "domains/x/real.md",
|
||||
"domain": "x", "confidence": "high", "snippet": "..."}},
|
||||
]
|
||||
result = search("test query", expand=False)
|
||||
paths = [r["claim_path"] for r in result["direct_results"]]
|
||||
assert "domains/x/_map.md" not in paths
|
||||
assert "domains/x/real.md" in paths
|
||||
|
||||
|
||||
# ─── Two-pass retrieval ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTwoPassRetrieval:
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_pass1_only_default(self, mock_embed, mock_qdrant, mock_expand):
|
||||
"""Default search (expand=False) only calls Qdrant once with high threshold."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.return_value = [
|
||||
{"score": 0.85, "payload": {"claim_title": "Hit", "claim_path": "d/a.md"}},
|
||||
]
|
||||
result = search("query")
|
||||
mock_qdrant.assert_called_once()
|
||||
# Should use PASS1_THRESHOLD (0.70)
|
||||
call_kwargs = mock_qdrant.call_args
|
||||
assert call_kwargs.kwargs.get("score_threshold") == 0.70 \
|
||||
or call_kwargs[1].get("score_threshold") == 0.70
|
||||
mock_expand.assert_not_called()
|
||||
assert len(result["direct_results"]) == 1
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_pass2_expands(self, mock_embed, mock_qdrant, mock_expand):
|
||||
"""expand=True calls Qdrant twice (pass 1 + pass 2) and runs graph expansion."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.side_effect = [
|
||||
[{"score": 0.85, "payload": {"claim_title": "P1", "claim_path": "d/a.md"}}],
|
||||
[{"score": 0.65, "payload": {"claim_title": "P2", "claim_path": "d/b.md"}}],
|
||||
]
|
||||
mock_expand.return_value = []
|
||||
result = search("query", expand=True)
|
||||
assert mock_qdrant.call_count == 2
|
||||
# Pass 2 should use offset=5 and lower threshold
|
||||
pass2_call = mock_qdrant.call_args_list[1]
|
||||
assert pass2_call.kwargs.get("offset") == 5 \
|
||||
or pass2_call[1].get("offset") == 5
|
||||
assert len(result["direct_results"]) == 2
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_hard_cap_enforced(self, mock_embed, mock_qdrant, mock_expand):
|
||||
"""Total results never exceed HARD_CAP (10)."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
# 5 from pass 1, 5 from pass 2
|
||||
p1 = [{"score": 0.9 - i * 0.02, "payload": {
|
||||
"claim_title": f"P1-{i}", "claim_path": f"d/p1-{i}.md"
|
||||
}} for i in range(5)]
|
||||
p2 = [{"score": 0.65 - i * 0.01, "payload": {
|
||||
"claim_title": f"P2-{i}", "claim_path": f"d/p2-{i}.md"
|
||||
}} for i in range(5)]
|
||||
mock_qdrant.side_effect = [p1, p2]
|
||||
# Graph expansion returns 5 more
|
||||
mock_expand.return_value = [
|
||||
{"claim_path": f"d/exp-{i}.md", "claim_title": f"Exp-{i}",
|
||||
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/p1-0.md"}
|
||||
for i in range(5)
|
||||
]
|
||||
result = search("query", expand=True)
|
||||
assert result["total"] <= 10
|
||||
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_sort_order_similarity_first(self, mock_embed, mock_qdrant):
|
||||
"""Direct results are sorted by cosine similarity descending."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.return_value = [
|
||||
{"score": 0.75, "payload": {"claim_title": "Low", "claim_path": "d/low.md"}},
|
||||
{"score": 0.95, "payload": {"claim_title": "High", "claim_path": "d/high.md"}},
|
||||
{"score": 0.85, "payload": {"claim_title": "Mid", "claim_path": "d/mid.md"}},
|
||||
]
|
||||
result = search("query")
|
||||
titles = [r["claim_title"] for r in result["direct_results"]]
|
||||
assert titles == ["High", "Mid", "Low"]
|
||||
|
||||
@patch("lib.search.graph_expand")
|
||||
@patch("lib.search.search_qdrant")
|
||||
@patch("lib.search.embed_query")
|
||||
def test_challenges_before_other_expansion(self, mock_embed, mock_qdrant, mock_expand):
|
||||
"""challenged_by claims appear before other expanded claims."""
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
mock_qdrant.side_effect = [
|
||||
[{"score": 0.85, "payload": {"claim_title": "Seed", "claim_path": "d/seed.md"}}],
|
||||
[],
|
||||
]
|
||||
mock_expand.return_value = [
|
||||
{"claim_path": "d/related.md", "claim_title": "Related",
|
||||
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/seed.md"},
|
||||
{"claim_path": "d/challenge.md", "claim_title": "Challenge",
|
||||
"edge_type": "challenges", "edge_weight": 1.5, "from_claim": "d/seed.md"},
|
||||
]
|
||||
result = search("query", expand=True)
|
||||
expanded_titles = [r["claim_title"] for r in result["expanded_results"]]
|
||||
assert expanded_titles.index("Challenge") < expanded_titles.index("Related")
|
||||
Loading…
Reference in a new issue