teleo-infrastructure/tests/test_search.py
m3taversal e17e6c25db
Some checks failed
CI / lint-and-test (pull_request) Has been cancelled
feat: two-pass retrieval with sort order and graph expansion
lib/search.py — shared search library:
- Pass 1 (default): top 5 from Qdrant, score >= 0.70, no expansion
- Pass 2 (expand=True): next 5 via offset=5, score >= 0.60, plus
  graph expansion from YAML frontmatter edges. Hard cap 10 total.
- Sort order: cosine desc → challenged_by → other graph-expanded
- result_type internal tag for stable sort (direct/challenge/graph)
- Module-level constants for easy threshold tuning post-calibration
- Structural file exclusion (_map.md, _overview.md)
- Within-vector dedup via _dedup_hits()

Caller updates:
- kb_retrieval.py: retrieve_vector_context() calls search(expand=True)
- diagnostics/app.py: search endpoint passes expand query param
- Argus imports from lib/search.py via sys.path (no longer owns search)

Tests: 5 new tests covering pass1-only, pass2 expansion, hard cap,
sort order, challenges-before-other-expansion.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 22:34:45 +00:00

604 lines
25 KiB
Python

"""Tests for lib/search.py — vector search and graph expansion."""
import json
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
from lib.search import (
_parse_frontmatter_edges,
_resolve_claim_path,
graph_expand,
search,
search_qdrant,
WIKI_LINK_RE,
)
# ─── Fixtures ──────────────────────────────────────────────────────────────
@pytest.fixture
def repo(tmp_path):
"""Minimal KB repo structure with claim files."""
domains = tmp_path / "domains"
# ai-alignment domain
ai = domains / "ai-alignment"
ai.mkdir(parents=True)
(ai / "capability-scoping.md").write_text(
"---\nname: AI agent capability scoping\ntype: claim\n"
"supports:\n - capability-matched escalation\n"
"related:\n - multi-agent coordination\n"
"---\n\nBody text with [[wiki-linked claim]].\n"
)
(ai / "capability-matched-escalation.md").write_text(
"---\nname: capability-matched escalation\ntype: claim\n"
"challenges:\n - unconstrained autonomy\n"
"---\n\nEscalation body.\n"
)
(ai / "multi-agent-coordination.md").write_text(
"---\nname: multi-agent coordination\ntype: claim\n---\n\nCoord body.\n"
)
(ai / "wiki-linked-claim.md").write_text(
"---\nname: wiki-linked claim\ntype: claim\n---\n\nWiki body.\n"
)
(ai / "unconstrained-autonomy.md").write_text(
"---\nname: unconstrained autonomy\ntype: claim\n---\n\nAutonomy body.\n"
)
# internet-finance domain
fin = domains / "internet-finance"
fin.mkdir(parents=True)
(fin / "futarchy-governance.md").write_text(
"---\nname: futarchy governance\ntype: claim\n"
"depends_on:\n - prediction market accuracy\n"
"---\n\nFutarchy body.\n"
)
(fin / "prediction-market-accuracy.md").write_text(
"---\nname: prediction market accuracy\ntype: claim\n---\n\nPM body.\n"
)
return tmp_path
@pytest.fixture
def claim_no_frontmatter(tmp_path):
"""Claim file with no frontmatter."""
domains = tmp_path / "domains" / "misc"
domains.mkdir(parents=True)
f = domains / "bare.md"
f.write_text("Just a body, no frontmatter.\n")
return f
# ─── _parse_frontmatter_edges ─────────────────────────────────────────────
class TestParseFrontmatterEdges:
def test_supports_and_related(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
edges = _parse_frontmatter_edges(path)
assert edges["supports"] == ["capability-matched escalation"]
assert edges["related"] == ["multi-agent coordination"]
def test_challenges(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-matched-escalation.md"
edges = _parse_frontmatter_edges(path)
assert edges["challenges"] == ["unconstrained autonomy"]
def test_depends_on(self, repo):
path = repo / "domains" / "internet-finance" / "futarchy-governance.md"
edges = _parse_frontmatter_edges(path)
assert edges["depends_on"] == ["prediction market accuracy"]
def test_wiki_links_extracted_separately(self, repo):
path = repo / "domains" / "ai-alignment" / "capability-scoping.md"
edges = _parse_frontmatter_edges(path)
assert "wiki-linked claim" in edges["wiki_links"]
# Wiki links should NOT appear in explicit related
assert "wiki-linked claim" not in edges["related"]
def test_wiki_links_deduped_from_explicit(self, tmp_path):
"""If a wiki link matches an explicit edge, it's excluded from wiki_links."""
d = tmp_path / "domains" / "test"
d.mkdir(parents=True)
f = d / "overlap.md"
f.write_text(
"---\nname: overlap test\nrelated:\n - shared target\n---\n\n"
"Body with [[shared target]] link.\n"
)
edges = _parse_frontmatter_edges(f)
assert edges["related"] == ["shared target"]
assert edges["wiki_links"] == []
def test_no_frontmatter(self, claim_no_frontmatter):
edges = _parse_frontmatter_edges(claim_no_frontmatter)
assert all(v == [] for v in edges.values())
def test_missing_file(self, tmp_path):
edges = _parse_frontmatter_edges(tmp_path / "nonexistent.md")
assert all(v == [] for v in edges.values())
def test_inline_list_format(self, tmp_path):
"""Handles YAML inline list: depends_on: ["a", "b"]."""
d = tmp_path / "domains" / "test"
d.mkdir(parents=True)
f = d / "inline.md"
f.write_text('---\nname: inline\ndepends_on: ["alpha", "beta"]\n---\n\nBody.\n')
edges = _parse_frontmatter_edges(f)
assert edges["depends_on"] == ["alpha", "beta"]
# ─── _resolve_claim_path ──────────────────────────────────────────────────
class TestResolveClaimPath:
def test_slugified_name(self, repo):
result = _resolve_claim_path("capability-scoping", repo)
assert result is not None
assert result.name == "capability-scoping.md"
def test_name_with_spaces(self, repo):
"""Resolves 'multi-agent coordination' via slug matching."""
result = _resolve_claim_path("multi-agent coordination", repo)
assert result is not None
assert result.name == "multi-agent-coordination.md"
def test_not_found(self, repo):
result = _resolve_claim_path("nonexistent claim", repo)
assert result is None
def test_cross_domain_resolution(self, repo):
"""Can resolve a claim in a different domain subdirectory."""
result = _resolve_claim_path("futarchy-governance", repo)
assert result is not None
assert "internet-finance" in str(result)
# ─── graph_expand ─────────────────────────────────────────────────────────
class TestGraphExpand:
def test_basic_expansion(self, repo):
"""Expanding from capability-scoping should find its edges."""
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
titles = [r["claim_title"] for r in results]
assert "capability-matched escalation" in titles
assert "multi-agent coordination" in titles
def test_wiki_links_included(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
titles = [r["claim_title"] for r in results]
assert "wiki-linked claim" in titles
def test_edge_weights(self, repo):
"""challenges edges get 1.5x weight, wiki_links get 0.5x."""
# Expand from capability-matched-escalation (has challenges edge)
results = graph_expand(
["domains/ai-alignment/capability-matched-escalation.md"],
repo_root=repo,
)
challenges_result = [r for r in results if r["edge_type"] == "challenges"]
assert len(challenges_result) == 1
assert challenges_result[0]["edge_weight"] == 1.5
def test_wiki_link_weight(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
wiki_results = [r for r in results if r["edge_type"] == "wiki_links"]
assert all(r["edge_weight"] == 0.5 for r in wiki_results)
def test_depends_on_weight(self, repo):
results = graph_expand(
["domains/internet-finance/futarchy-governance.md"],
repo_root=repo,
)
dep_results = [r for r in results if r["edge_type"] == "depends_on"]
assert len(dep_results) == 1
assert dep_results[0]["edge_weight"] == 1.25
def test_sorted_by_weight_descending(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
)
weights = [r["edge_weight"] for r in results]
assert weights == sorted(weights, reverse=True)
def test_seed_excluded_from_results(self, repo):
seed = "domains/ai-alignment/capability-scoping.md"
results = graph_expand([seed], repo_root=repo)
paths = [r["claim_path"] for r in results]
assert seed not in paths
def test_seen_set_excludes(self, repo):
"""Paths in seen set are excluded from expansion results."""
already_matched = {"domains/ai-alignment/multi-agent-coordination.md"}
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
seen=already_matched,
)
paths = [r["claim_path"] for r in results]
assert "domains/ai-alignment/multi-agent-coordination.md" not in paths
def test_max_expanded_cap(self, repo):
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
max_expanded=1,
)
assert len(results) <= 1
def test_cap_cuts_lowest_weight(self, repo):
"""With cap=2, wiki_links (0.5x) should be cut before supports (1.0x)."""
results = graph_expand(
["domains/ai-alignment/capability-scoping.md"],
repo_root=repo,
max_expanded=2,
)
edge_types = [r["edge_type"] for r in results]
assert "wiki_links" not in edge_types
def test_nonexistent_seed(self, repo):
results = graph_expand(["domains/nonexistent.md"], repo_root=repo)
assert results == []
# ─── search_qdrant (mocked HTTP) ─────────────────────────────────────────
class TestSearchQdrant:
def _mock_qdrant_response(self, results):
"""Build a mock urllib response returning Qdrant results."""
resp = MagicMock()
resp.read.return_value = json.dumps({"result": results}).encode()
resp.__enter__ = lambda s: s
resp.__exit__ = MagicMock(return_value=False)
return resp
@patch("lib.search.urllib.request.urlopen")
def test_basic_search(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([
{"id": 1, "score": 0.85, "payload": {
"claim_title": "test claim", "claim_path": "domains/test.md",
"domain": "ai", "confidence": "high",
}},
])
results = search_qdrant([0.1] * 1536, limit=5)
assert len(results) == 1
assert results[0]["score"] == 0.85
@patch("lib.search.urllib.request.urlopen")
def test_domain_filter(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536, domain="ai-alignment")
# Verify the request body includes domain filter
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
assert body["filter"]["must"][0]["key"] == "domain"
assert body["filter"]["must"][0]["match"]["value"] == "ai-alignment"
@patch("lib.search.urllib.request.urlopen")
def test_exclude_filter(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536, exclude=["domains/a.md", "domains/b.md"])
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
must_not = body["filter"]["must_not"]
excluded_paths = [f["match"]["value"] for f in must_not]
assert "domains/a.md" in excluded_paths
assert "domains/b.md" in excluded_paths
@patch("lib.search.urllib.request.urlopen")
def test_no_filters_no_filter_key(self, mock_urlopen):
mock_urlopen.return_value = self._mock_qdrant_response([])
search_qdrant([0.1] * 1536)
call_args = mock_urlopen.call_args
req = call_args[0][0]
body = json.loads(req.data)
assert "filter" not in body
@patch("lib.search.urllib.request.urlopen")
def test_http_failure_returns_empty(self, mock_urlopen):
mock_urlopen.side_effect = Exception("connection refused")
results = search_qdrant([0.1] * 1536)
assert results == []
# ─── search() integration (mocked network) ───────────────────────────────
class TestSearch:
@patch("lib.search.embed_query")
@patch("lib.search.search_qdrant")
def test_embedding_failure(self, mock_qdrant, mock_embed):
mock_embed.return_value = None
result = search("test query")
assert result["error"] == "embedding_failed"
assert result["direct_results"] == []
mock_qdrant.assert_not_called()
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_full_pipeline(self, mock_embed, mock_qdrant, mock_expand):
mock_embed.return_value = [0.1] * 1536
# Pass 1 returns one hit, pass 2 returns empty (nothing above lower threshold)
mock_qdrant.side_effect = [
[{"id": 1, "score": 0.82, "payload": {
"claim_title": "direct hit", "claim_path": "domains/hit.md",
"domain": "ai", "confidence": "high", "snippet": "snippet text",
"type": "claim",
}}],
[], # pass 2 returns nothing
]
mock_expand.return_value = [
{"claim_path": "domains/expanded.md", "claim_title": "expanded",
"edge_type": "supports", "edge_weight": 1.0, "from_claim": "domains/hit.md"},
]
# expand=True triggers two-pass: pass 1 + pass 2 + graph expansion
result = search("test query", expand=True)
assert len(result["direct_results"]) == 1
assert result["direct_results"][0]["claim_title"] == "direct hit"
assert len(result["expanded_results"]) == 1
assert result["total"] == 2
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_exclude_passed_through(self, mock_embed, mock_qdrant, mock_expand):
"""Exclude list reaches both Qdrant and graph_expand."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"id": 1, "score": 0.8, "payload": {
"claim_title": "hit", "claim_path": "domains/hit.md",
}}],
[], # pass 2
]
mock_expand.return_value = []
exclude = ["domains/already-matched.md"]
search("query", expand=True, exclude=exclude)
# Qdrant should get exclude (called twice with expand=True: pass 1 + pass 2)
assert mock_qdrant.call_count == 2
# Both calls should include exclude
for call in mock_qdrant.call_args_list:
assert call.kwargs.get("exclude") == exclude \
or call[1].get("exclude") == exclude
# graph_expand should get seen set containing exclude paths
mock_expand.assert_called_once()
call_kwargs = mock_expand.call_args[1] if mock_expand.call_args[1] else {}
seen = call_kwargs.get("seen")
assert seen is not None
assert "domains/already-matched.md" in seen
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_no_expand_when_disabled(self, mock_embed, mock_qdrant):
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = []
result = search("query", expand=False)
assert result["expanded_results"] == []
# ─── WIKI_LINK_RE ─────────────────────────────────────────────────────────
class TestWikiLinkRegex:
def test_basic(self):
assert WIKI_LINK_RE.findall("See [[some claim]]") == ["some claim"]
def test_multiple(self):
text = "Links to [[claim A]] and [[claim B]]"
assert WIKI_LINK_RE.findall(text) == ["claim A", "claim B"]
def test_no_nested(self):
assert WIKI_LINK_RE.findall("[[outer [[inner]]]]") != ["outer [[inner]]"]
# ─── Structural file exclusion ───────────────────────────────────────────
class TestStructuralFileExclusion:
def test_map_excluded_from_expansion(self, tmp_path):
"""_map.md files should be skipped during graph expansion."""
domains = tmp_path / "domains" / "test"
domains.mkdir(parents=True)
(domains / "seed-claim.md").write_text(
"---\nname: seed claim\ntype: claim\n"
"related:\n - domain map\n---\nBody.\n"
)
(domains / "_map.md").write_text(
"---\nname: domain map\ntype: moc\n---\nDomain index.\n"
)
seed = "domains/test/seed-claim.md"
result = graph_expand([seed], repo_root=tmp_path)
paths = [r["claim_path"] for r in result]
assert "domains/test/_map.md" not in paths
def test_overview_excluded_from_expansion(self, tmp_path):
"""_overview.md files should be skipped during graph expansion."""
domains = tmp_path / "domains" / "test"
domains.mkdir(parents=True)
(domains / "seed-claim.md").write_text(
"---\nname: seed claim\ntype: claim\n"
"related:\n - domain overview\n---\nBody.\n"
)
(domains / "_overview.md").write_text(
"---\nname: domain overview\ntype: moc\n---\nOverview.\n"
)
seed = "domains/test/seed-claim.md"
result = graph_expand([seed], repo_root=tmp_path)
paths = [r["claim_path"] for r in result]
assert "domains/test/_overview.md" not in paths
def test_regular_claims_still_expand(self, repo):
"""Non-structural files should still expand normally."""
seed = "domains/ai-alignment/capability-scoping.md"
result = graph_expand([seed], repo_root=repo)
paths = [r["claim_path"] for r in result]
assert len(paths) > 0
assert "domains/ai-alignment/capability-matched-escalation.md" in paths
# ─── Dedup within vector results ─────────────────────────────────────────
class TestSearchDedup:
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_duplicate_paths_deduped(self, mock_embed, mock_qdrant):
"""Duplicate claim_paths in Qdrant results should be collapsed."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.9, "payload": {"claim_title": "Claim A", "claim_path": "domains/x/a.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
{"score": 0.85, "payload": {"claim_title": "Claim A dupe", "claim_path": "domains/x/a.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
{"score": 0.8, "payload": {"claim_title": "Claim B", "claim_path": "domains/x/b.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
]
result = search("test query", expand=False)
paths = [r["claim_path"] for r in result["direct_results"]]
assert paths == ["domains/x/a.md", "domains/x/b.md"]
assert result["direct_results"][0]["claim_title"] == "Claim A"
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_structural_files_excluded_from_direct(self, mock_embed, mock_qdrant):
"""_map.md should be excluded from direct Qdrant results too."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.9, "payload": {"claim_title": "Domain Map", "claim_path": "domains/x/_map.md",
"domain": "x", "confidence": "", "snippet": ""}},
{"score": 0.85, "payload": {"claim_title": "Real Claim", "claim_path": "domains/x/real.md",
"domain": "x", "confidence": "high", "snippet": "..."}},
]
result = search("test query", expand=False)
paths = [r["claim_path"] for r in result["direct_results"]]
assert "domains/x/_map.md" not in paths
assert "domains/x/real.md" in paths
# ─── Two-pass retrieval ──────────────────────────────────────────────────
class TestTwoPassRetrieval:
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_pass1_only_default(self, mock_embed, mock_qdrant, mock_expand):
"""Default search (expand=False) only calls Qdrant once with high threshold."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.85, "payload": {"claim_title": "Hit", "claim_path": "d/a.md"}},
]
result = search("query")
mock_qdrant.assert_called_once()
# Should use PASS1_THRESHOLD (0.70)
call_kwargs = mock_qdrant.call_args
assert call_kwargs.kwargs.get("score_threshold") == 0.70 \
or call_kwargs[1].get("score_threshold") == 0.70
mock_expand.assert_not_called()
assert len(result["direct_results"]) == 1
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_pass2_expands(self, mock_embed, mock_qdrant, mock_expand):
"""expand=True calls Qdrant twice (pass 1 + pass 2) and runs graph expansion."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"score": 0.85, "payload": {"claim_title": "P1", "claim_path": "d/a.md"}}],
[{"score": 0.65, "payload": {"claim_title": "P2", "claim_path": "d/b.md"}}],
]
mock_expand.return_value = []
result = search("query", expand=True)
assert mock_qdrant.call_count == 2
# Pass 2 should use offset=5 and lower threshold
pass2_call = mock_qdrant.call_args_list[1]
assert pass2_call.kwargs.get("offset") == 5 \
or pass2_call[1].get("offset") == 5
assert len(result["direct_results"]) == 2
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_hard_cap_enforced(self, mock_embed, mock_qdrant, mock_expand):
"""Total results never exceed HARD_CAP (10)."""
mock_embed.return_value = [0.1] * 1536
# 5 from pass 1, 5 from pass 2
p1 = [{"score": 0.9 - i * 0.02, "payload": {
"claim_title": f"P1-{i}", "claim_path": f"d/p1-{i}.md"
}} for i in range(5)]
p2 = [{"score": 0.65 - i * 0.01, "payload": {
"claim_title": f"P2-{i}", "claim_path": f"d/p2-{i}.md"
}} for i in range(5)]
mock_qdrant.side_effect = [p1, p2]
# Graph expansion returns 5 more
mock_expand.return_value = [
{"claim_path": f"d/exp-{i}.md", "claim_title": f"Exp-{i}",
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/p1-0.md"}
for i in range(5)
]
result = search("query", expand=True)
assert result["total"] <= 10
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_sort_order_similarity_first(self, mock_embed, mock_qdrant):
"""Direct results are sorted by cosine similarity descending."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.return_value = [
{"score": 0.75, "payload": {"claim_title": "Low", "claim_path": "d/low.md"}},
{"score": 0.95, "payload": {"claim_title": "High", "claim_path": "d/high.md"}},
{"score": 0.85, "payload": {"claim_title": "Mid", "claim_path": "d/mid.md"}},
]
result = search("query")
titles = [r["claim_title"] for r in result["direct_results"]]
assert titles == ["High", "Mid", "Low"]
@patch("lib.search.graph_expand")
@patch("lib.search.search_qdrant")
@patch("lib.search.embed_query")
def test_challenges_before_other_expansion(self, mock_embed, mock_qdrant, mock_expand):
"""challenged_by claims appear before other expanded claims."""
mock_embed.return_value = [0.1] * 1536
mock_qdrant.side_effect = [
[{"score": 0.85, "payload": {"claim_title": "Seed", "claim_path": "d/seed.md"}}],
[],
]
mock_expand.return_value = [
{"claim_path": "d/related.md", "claim_title": "Related",
"edge_type": "related", "edge_weight": 1.0, "from_claim": "d/seed.md"},
{"claim_path": "d/challenge.md", "claim_title": "Challenge",
"edge_type": "challenges", "edge_weight": 1.5, "from_claim": "d/seed.md"},
]
result = search("query", expand=True)
expanded_titles = [r["claim_title"] for r in result["expanded_results"]]
assert expanded_titles.index("Challenge") < expanded_titles.index("Related")