"""Tests for lib/search.py — vector search and graph expansion.""" import json from pathlib import Path from unittest.mock import patch, MagicMock import pytest from lib.search import ( _parse_frontmatter_edges, _resolve_claim_path, graph_expand, search, search_qdrant, WIKI_LINK_RE, ) # ─── Fixtures ────────────────────────────────────────────────────────────── @pytest.fixture def repo(tmp_path): """Minimal KB repo structure with claim files.""" domains = tmp_path / "domains" # ai-alignment domain ai = domains / "ai-alignment" ai.mkdir(parents=True) (ai / "capability-scoping.md").write_text( "---\nname: AI agent capability scoping\ntype: claim\n" "supports:\n - capability-matched escalation\n" "related:\n - multi-agent coordination\n" "---\n\nBody text with [[wiki-linked claim]].\n" ) (ai / "capability-matched-escalation.md").write_text( "---\nname: capability-matched escalation\ntype: claim\n" "challenges:\n - unconstrained autonomy\n" "---\n\nEscalation body.\n" ) (ai / "multi-agent-coordination.md").write_text( "---\nname: multi-agent coordination\ntype: claim\n---\n\nCoord body.\n" ) (ai / "wiki-linked-claim.md").write_text( "---\nname: wiki-linked claim\ntype: claim\n---\n\nWiki body.\n" ) (ai / "unconstrained-autonomy.md").write_text( "---\nname: unconstrained autonomy\ntype: claim\n---\n\nAutonomy body.\n" ) # internet-finance domain fin = domains / "internet-finance" fin.mkdir(parents=True) (fin / "futarchy-governance.md").write_text( "---\nname: futarchy governance\ntype: claim\n" "depends_on:\n - prediction market accuracy\n" "---\n\nFutarchy body.\n" ) (fin / "prediction-market-accuracy.md").write_text( "---\nname: prediction market accuracy\ntype: claim\n---\n\nPM body.\n" ) return tmp_path @pytest.fixture def claim_no_frontmatter(tmp_path): """Claim file with no frontmatter.""" domains = tmp_path / "domains" / "misc" domains.mkdir(parents=True) f = domains / "bare.md" f.write_text("Just a body, no frontmatter.\n") return f # ─── _parse_frontmatter_edges ───────────────────────────────────────────── class TestParseFrontmatterEdges: def test_supports_and_related(self, repo): path = repo / "domains" / "ai-alignment" / "capability-scoping.md" edges = _parse_frontmatter_edges(path) assert edges["supports"] == ["capability-matched escalation"] assert edges["related"] == ["multi-agent coordination"] def test_challenges(self, repo): path = repo / "domains" / "ai-alignment" / "capability-matched-escalation.md" edges = _parse_frontmatter_edges(path) assert edges["challenges"] == ["unconstrained autonomy"] def test_depends_on(self, repo): path = repo / "domains" / "internet-finance" / "futarchy-governance.md" edges = _parse_frontmatter_edges(path) assert edges["depends_on"] == ["prediction market accuracy"] def test_wiki_links_extracted_separately(self, repo): path = repo / "domains" / "ai-alignment" / "capability-scoping.md" edges = _parse_frontmatter_edges(path) assert "wiki-linked claim" in edges["wiki_links"] # Wiki links should NOT appear in explicit related assert "wiki-linked claim" not in edges["related"] def test_wiki_links_deduped_from_explicit(self, tmp_path): """If a wiki link matches an explicit edge, it's excluded from wiki_links.""" d = tmp_path / "domains" / "test" d.mkdir(parents=True) f = d / "overlap.md" f.write_text( "---\nname: overlap test\nrelated:\n - shared target\n---\n\n" "Body with [[shared target]] link.\n" ) edges = _parse_frontmatter_edges(f) assert edges["related"] == ["shared target"] assert edges["wiki_links"] == [] def test_no_frontmatter(self, claim_no_frontmatter): edges = _parse_frontmatter_edges(claim_no_frontmatter) assert all(v == [] for v in edges.values()) def test_missing_file(self, tmp_path): edges = _parse_frontmatter_edges(tmp_path / "nonexistent.md") assert all(v == [] for v in edges.values()) def test_inline_list_format(self, tmp_path): """Handles YAML inline list: depends_on: ["a", "b"].""" d = tmp_path / "domains" / "test" d.mkdir(parents=True) f = d / "inline.md" f.write_text('---\nname: inline\ndepends_on: ["alpha", "beta"]\n---\n\nBody.\n') edges = _parse_frontmatter_edges(f) assert edges["depends_on"] == ["alpha", "beta"] # ─── _resolve_claim_path ────────────────────────────────────────────────── class TestResolveClaimPath: def test_slugified_name(self, repo): result = _resolve_claim_path("capability-scoping", repo) assert result is not None assert result.name == "capability-scoping.md" def test_name_with_spaces(self, repo): """Resolves 'multi-agent coordination' via slug matching.""" result = _resolve_claim_path("multi-agent coordination", repo) assert result is not None assert result.name == "multi-agent-coordination.md" def test_not_found(self, repo): result = _resolve_claim_path("nonexistent claim", repo) assert result is None def test_cross_domain_resolution(self, repo): """Can resolve a claim in a different domain subdirectory.""" result = _resolve_claim_path("futarchy-governance", repo) assert result is not None assert "internet-finance" in str(result) # ─── graph_expand ───────────────────────────────────────────────────────── class TestGraphExpand: def test_basic_expansion(self, repo): """Expanding from capability-scoping should find its edges.""" results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, ) titles = [r["claim_title"] for r in results] assert "capability-matched escalation" in titles assert "multi-agent coordination" in titles def test_wiki_links_included(self, repo): results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, ) titles = [r["claim_title"] for r in results] assert "wiki-linked claim" in titles def test_edge_weights(self, repo): """challenges edges get 1.5x weight, wiki_links get 0.5x.""" # Expand from capability-matched-escalation (has challenges edge) results = graph_expand( ["domains/ai-alignment/capability-matched-escalation.md"], repo_root=repo, ) challenges_result = [r for r in results if r["edge_type"] == "challenges"] assert len(challenges_result) == 1 assert challenges_result[0]["edge_weight"] == 1.5 def test_wiki_link_weight(self, repo): results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, ) wiki_results = [r for r in results if r["edge_type"] == "wiki_links"] assert all(r["edge_weight"] == 0.5 for r in wiki_results) def test_depends_on_weight(self, repo): results = graph_expand( ["domains/internet-finance/futarchy-governance.md"], repo_root=repo, ) dep_results = [r for r in results if r["edge_type"] == "depends_on"] assert len(dep_results) == 1 assert dep_results[0]["edge_weight"] == 1.25 def test_sorted_by_weight_descending(self, repo): results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, ) weights = [r["edge_weight"] for r in results] assert weights == sorted(weights, reverse=True) def test_seed_excluded_from_results(self, repo): seed = "domains/ai-alignment/capability-scoping.md" results = graph_expand([seed], repo_root=repo) paths = [r["claim_path"] for r in results] assert seed not in paths def test_seen_set_excludes(self, repo): """Paths in seen set are excluded from expansion results.""" already_matched = {"domains/ai-alignment/multi-agent-coordination.md"} results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, seen=already_matched, ) paths = [r["claim_path"] for r in results] assert "domains/ai-alignment/multi-agent-coordination.md" not in paths def test_max_expanded_cap(self, repo): results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, max_expanded=1, ) assert len(results) <= 1 def test_cap_cuts_lowest_weight(self, repo): """With cap=2, wiki_links (0.5x) should be cut before supports (1.0x).""" results = graph_expand( ["domains/ai-alignment/capability-scoping.md"], repo_root=repo, max_expanded=2, ) edge_types = [r["edge_type"] for r in results] assert "wiki_links" not in edge_types def test_nonexistent_seed(self, repo): results = graph_expand(["domains/nonexistent.md"], repo_root=repo) assert results == [] # ─── search_qdrant (mocked HTTP) ───────────────────────────────────────── class TestSearchQdrant: def _mock_qdrant_response(self, results): """Build a mock urllib response returning Qdrant results.""" resp = MagicMock() resp.read.return_value = json.dumps({"result": results}).encode() resp.__enter__ = lambda s: s resp.__exit__ = MagicMock(return_value=False) return resp @patch("lib.search.urllib.request.urlopen") def test_basic_search(self, mock_urlopen): mock_urlopen.return_value = self._mock_qdrant_response([ {"id": 1, "score": 0.85, "payload": { "claim_title": "test claim", "claim_path": "domains/test.md", "domain": "ai", "confidence": "high", }}, ]) results = search_qdrant([0.1] * 1536, limit=5) assert len(results) == 1 assert results[0]["score"] == 0.85 @patch("lib.search.urllib.request.urlopen") def test_domain_filter(self, mock_urlopen): mock_urlopen.return_value = self._mock_qdrant_response([]) search_qdrant([0.1] * 1536, domain="ai-alignment") # Verify the request body includes domain filter call_args = mock_urlopen.call_args req = call_args[0][0] body = json.loads(req.data) assert body["filter"]["must"][0]["key"] == "domain" assert body["filter"]["must"][0]["match"]["value"] == "ai-alignment" @patch("lib.search.urllib.request.urlopen") def test_exclude_filter(self, mock_urlopen): mock_urlopen.return_value = self._mock_qdrant_response([]) search_qdrant([0.1] * 1536, exclude=["domains/a.md", "domains/b.md"]) call_args = mock_urlopen.call_args req = call_args[0][0] body = json.loads(req.data) must_not = body["filter"]["must_not"] excluded_paths = [f["match"]["value"] for f in must_not] assert "domains/a.md" in excluded_paths assert "domains/b.md" in excluded_paths @patch("lib.search.urllib.request.urlopen") def test_no_filters_no_filter_key(self, mock_urlopen): mock_urlopen.return_value = self._mock_qdrant_response([]) search_qdrant([0.1] * 1536) call_args = mock_urlopen.call_args req = call_args[0][0] body = json.loads(req.data) assert "filter" not in body @patch("lib.search.urllib.request.urlopen") def test_http_failure_returns_empty(self, mock_urlopen): mock_urlopen.side_effect = Exception("connection refused") results = search_qdrant([0.1] * 1536) assert results == [] # ─── search() integration (mocked network) ─────────────────────────────── class TestSearch: @patch("lib.search.embed_query") @patch("lib.search.search_qdrant") def test_embedding_failure(self, mock_qdrant, mock_embed): mock_embed.return_value = None result = search("test query") assert result["error"] == "embedding_failed" assert result["direct_results"] == [] mock_qdrant.assert_not_called() @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_full_pipeline(self, mock_embed, mock_qdrant, mock_expand): mock_embed.return_value = [0.1] * 1536 # Pass 1 returns one hit, pass 2 returns empty (nothing above lower threshold) mock_qdrant.side_effect = [ [{"id": 1, "score": 0.82, "payload": { "claim_title": "direct hit", "claim_path": "domains/hit.md", "domain": "ai", "confidence": "high", "snippet": "snippet text", "type": "claim", }}], [], # pass 2 returns nothing ] mock_expand.return_value = [ {"claim_path": "domains/expanded.md", "claim_title": "expanded", "edge_type": "supports", "edge_weight": 1.0, "from_claim": "domains/hit.md"}, ] # expand=True triggers two-pass: pass 1 + pass 2 + graph expansion result = search("test query", expand=True) assert len(result["direct_results"]) == 1 assert result["direct_results"][0]["claim_title"] == "direct hit" assert len(result["expanded_results"]) == 1 assert result["total"] == 2 @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_exclude_passed_through(self, mock_embed, mock_qdrant, mock_expand): """Exclude list reaches both Qdrant and graph_expand.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.side_effect = [ [{"id": 1, "score": 0.8, "payload": { "claim_title": "hit", "claim_path": "domains/hit.md", }}], [], # pass 2 ] mock_expand.return_value = [] exclude = ["domains/already-matched.md"] search("query", expand=True, exclude=exclude) # Qdrant should get exclude (called twice with expand=True: pass 1 + pass 2) assert mock_qdrant.call_count == 2 # Both calls should include exclude for call in mock_qdrant.call_args_list: assert call.kwargs.get("exclude") == exclude \ or call[1].get("exclude") == exclude # graph_expand should get seen set containing exclude paths mock_expand.assert_called_once() call_kwargs = mock_expand.call_args[1] if mock_expand.call_args[1] else {} seen = call_kwargs.get("seen") assert seen is not None assert "domains/already-matched.md" in seen @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_no_expand_when_disabled(self, mock_embed, mock_qdrant): mock_embed.return_value = [0.1] * 1536 mock_qdrant.return_value = [] result = search("query", expand=False) assert result["expanded_results"] == [] # ─── WIKI_LINK_RE ───────────────────────────────────────────────────────── class TestWikiLinkRegex: def test_basic(self): assert WIKI_LINK_RE.findall("See [[some claim]]") == ["some claim"] def test_multiple(self): text = "Links to [[claim A]] and [[claim B]]" assert WIKI_LINK_RE.findall(text) == ["claim A", "claim B"] def test_no_nested(self): assert WIKI_LINK_RE.findall("[[outer [[inner]]]]") != ["outer [[inner]]"] # ─── Structural file exclusion ─────────────────────────────────────────── class TestStructuralFileExclusion: def test_map_excluded_from_expansion(self, tmp_path): """_map.md files should be skipped during graph expansion.""" domains = tmp_path / "domains" / "test" domains.mkdir(parents=True) (domains / "seed-claim.md").write_text( "---\nname: seed claim\ntype: claim\n" "related:\n - domain map\n---\nBody.\n" ) (domains / "_map.md").write_text( "---\nname: domain map\ntype: moc\n---\nDomain index.\n" ) seed = "domains/test/seed-claim.md" result = graph_expand([seed], repo_root=tmp_path) paths = [r["claim_path"] for r in result] assert "domains/test/_map.md" not in paths def test_overview_excluded_from_expansion(self, tmp_path): """_overview.md files should be skipped during graph expansion.""" domains = tmp_path / "domains" / "test" domains.mkdir(parents=True) (domains / "seed-claim.md").write_text( "---\nname: seed claim\ntype: claim\n" "related:\n - domain overview\n---\nBody.\n" ) (domains / "_overview.md").write_text( "---\nname: domain overview\ntype: moc\n---\nOverview.\n" ) seed = "domains/test/seed-claim.md" result = graph_expand([seed], repo_root=tmp_path) paths = [r["claim_path"] for r in result] assert "domains/test/_overview.md" not in paths def test_regular_claims_still_expand(self, repo): """Non-structural files should still expand normally.""" seed = "domains/ai-alignment/capability-scoping.md" result = graph_expand([seed], repo_root=repo) paths = [r["claim_path"] for r in result] assert len(paths) > 0 assert "domains/ai-alignment/capability-matched-escalation.md" in paths # ─── Dedup within vector results ───────────────────────────────────────── class TestSearchDedup: @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_duplicate_paths_deduped(self, mock_embed, mock_qdrant): """Duplicate claim_paths in Qdrant results should be collapsed.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.return_value = [ {"score": 0.9, "payload": {"claim_title": "Claim A", "claim_path": "domains/x/a.md", "domain": "x", "confidence": "high", "snippet": "..."}}, {"score": 0.85, "payload": {"claim_title": "Claim A dupe", "claim_path": "domains/x/a.md", "domain": "x", "confidence": "high", "snippet": "..."}}, {"score": 0.8, "payload": {"claim_title": "Claim B", "claim_path": "domains/x/b.md", "domain": "x", "confidence": "high", "snippet": "..."}}, ] result = search("test query", expand=False) paths = [r["claim_path"] for r in result["direct_results"]] assert paths == ["domains/x/a.md", "domains/x/b.md"] assert result["direct_results"][0]["claim_title"] == "Claim A" @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_structural_files_excluded_from_direct(self, mock_embed, mock_qdrant): """_map.md should be excluded from direct Qdrant results too.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.return_value = [ {"score": 0.9, "payload": {"claim_title": "Domain Map", "claim_path": "domains/x/_map.md", "domain": "x", "confidence": "", "snippet": ""}}, {"score": 0.85, "payload": {"claim_title": "Real Claim", "claim_path": "domains/x/real.md", "domain": "x", "confidence": "high", "snippet": "..."}}, ] result = search("test query", expand=False) paths = [r["claim_path"] for r in result["direct_results"]] assert "domains/x/_map.md" not in paths assert "domains/x/real.md" in paths # ─── Two-pass retrieval ────────────────────────────────────────────────── class TestTwoPassRetrieval: @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_pass1_only_default(self, mock_embed, mock_qdrant, mock_expand): """Default search (expand=False) only calls Qdrant once with high threshold.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.return_value = [ {"score": 0.85, "payload": {"claim_title": "Hit", "claim_path": "d/a.md"}}, ] result = search("query") mock_qdrant.assert_called_once() # Should use PASS1_THRESHOLD (0.70) call_kwargs = mock_qdrant.call_args assert call_kwargs.kwargs.get("score_threshold") == 0.70 \ or call_kwargs[1].get("score_threshold") == 0.70 mock_expand.assert_not_called() assert len(result["direct_results"]) == 1 @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_pass2_expands(self, mock_embed, mock_qdrant, mock_expand): """expand=True calls Qdrant twice (pass 1 + pass 2) and runs graph expansion.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.side_effect = [ [{"score": 0.85, "payload": {"claim_title": "P1", "claim_path": "d/a.md"}}], [{"score": 0.65, "payload": {"claim_title": "P2", "claim_path": "d/b.md"}}], ] mock_expand.return_value = [] result = search("query", expand=True) assert mock_qdrant.call_count == 2 # Pass 2 should use offset=5 and lower threshold pass2_call = mock_qdrant.call_args_list[1] assert pass2_call.kwargs.get("offset") == 5 \ or pass2_call[1].get("offset") == 5 assert len(result["direct_results"]) == 2 @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_hard_cap_enforced(self, mock_embed, mock_qdrant, mock_expand): """Total results never exceed HARD_CAP (10).""" mock_embed.return_value = [0.1] * 1536 # 5 from pass 1, 5 from pass 2 p1 = [{"score": 0.9 - i * 0.02, "payload": { "claim_title": f"P1-{i}", "claim_path": f"d/p1-{i}.md" }} for i in range(5)] p2 = [{"score": 0.65 - i * 0.01, "payload": { "claim_title": f"P2-{i}", "claim_path": f"d/p2-{i}.md" }} for i in range(5)] mock_qdrant.side_effect = [p1, p2] # Graph expansion returns 5 more mock_expand.return_value = [ {"claim_path": f"d/exp-{i}.md", "claim_title": f"Exp-{i}", "edge_type": "related", "edge_weight": 1.0, "from_claim": "d/p1-0.md"} for i in range(5) ] result = search("query", expand=True) assert result["total"] <= 10 @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_sort_order_similarity_first(self, mock_embed, mock_qdrant): """Direct results are sorted by cosine similarity descending.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.return_value = [ {"score": 0.75, "payload": {"claim_title": "Low", "claim_path": "d/low.md"}}, {"score": 0.95, "payload": {"claim_title": "High", "claim_path": "d/high.md"}}, {"score": 0.85, "payload": {"claim_title": "Mid", "claim_path": "d/mid.md"}}, ] result = search("query") titles = [r["claim_title"] for r in result["direct_results"]] assert titles == ["High", "Mid", "Low"] @patch("lib.search.graph_expand") @patch("lib.search.search_qdrant") @patch("lib.search.embed_query") def test_challenges_before_other_expansion(self, mock_embed, mock_qdrant, mock_expand): """challenged_by claims appear before other expanded claims.""" mock_embed.return_value = [0.1] * 1536 mock_qdrant.side_effect = [ [{"score": 0.85, "payload": {"claim_title": "Seed", "claim_path": "d/seed.md"}}], [], ] mock_expand.return_value = [ {"claim_path": "d/related.md", "claim_title": "Related", "edge_type": "related", "edge_weight": 1.0, "from_claim": "d/seed.md"}, {"claim_path": "d/challenge.md", "claim_title": "Challenge", "edge_type": "challenges", "edge_weight": 1.5, "from_claim": "d/seed.md"}, ] result = search("query", expand=True) expanded_titles = [r["claim_title"] for r in result["expanded_results"]] assert expanded_titles.index("Challenge") < expanded_titles.index("Related")