"""Tests for lib/pr_state.py — centralized PR state transitions.""" import asyncio import sqlite3 import sys import os from unittest.mock import AsyncMock, MagicMock, patch import pytest # Mock heavy dependencies before importing pr_state sys.modules.setdefault("aiohttp", MagicMock()) # Add lib parent to path so `lib.pr_state` resolves sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from lib.pr_state import ( approve_pr, close_pr, mark_conflict, mark_conflict_permanent, mark_merged, reopen_pr, reset_for_reeval, start_fixing, start_review, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _make_db(): """Create a minimal in-memory DB with the prs table.""" conn = sqlite3.connect(":memory:") conn.row_factory = sqlite3.Row conn.execute(""" CREATE TABLE prs ( number INTEGER PRIMARY KEY, source_path TEXT, branch TEXT, status TEXT NOT NULL DEFAULT 'open', domain TEXT, auto_merge INTEGER DEFAULT 0, leo_verdict TEXT DEFAULT 'pending', domain_verdict TEXT DEFAULT 'pending', eval_attempts INTEGER DEFAULT 0, eval_issues TEXT DEFAULT '[]', merge_cycled INTEGER DEFAULT 0, merge_failures INTEGER DEFAULT 0, conflict_rebase_attempts INTEGER DEFAULT 0, last_error TEXT, merged_at TEXT, last_attempt TEXT, tier0_pass INTEGER, fix_attempts INTEGER DEFAULT 0, cost_usd REAL DEFAULT 0, created_at TEXT DEFAULT (datetime('now')) ) """) return conn def _insert_pr(conn, number=100, status="open", domain=None, **kwargs): """Insert a test PR row.""" cols = ["number", "status"] vals = [number, status] if domain is not None: cols.append("domain") vals.append(domain) for k, v in kwargs.items(): cols.append(k) vals.append(v) placeholders = ", ".join(["?"] * len(vals)) col_names = ", ".join(cols) conn.execute(f"INSERT INTO prs ({col_names}) VALUES ({placeholders})", vals) conn.commit() def _get_pr(conn, number=100): """Read a PR row back.""" return conn.execute("SELECT * FROM prs WHERE number = ?", (number,)).fetchone() # --------------------------------------------------------------------------- # close_pr # --------------------------------------------------------------------------- class TestClosePr: def test_close_calls_forgejo_and_updates_db(self): conn = _make_db() _insert_pr(conn, 42) mock_api = AsyncMock(return_value={}) with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): asyncio.run(close_pr(conn, 42, last_error="test close")) row = _get_pr(conn, 42) assert row["status"] == "closed" assert row["last_error"] == "test close" mock_api.assert_called_once() def test_close_skips_forgejo_when_opted_out(self): conn = _make_db() _insert_pr(conn, 42) mock_api = AsyncMock(return_value={}) with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): asyncio.run(close_pr(conn, 42, last_error="reconciled", close_on_forgejo=False)) row = _get_pr(conn, 42) assert row["status"] == "closed" mock_api.assert_not_called() def test_close_increments_merge_failures(self): conn = _make_db() _insert_pr(conn, 42, merge_failures=2) mock_api = AsyncMock(return_value={}) with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): asyncio.run(close_pr(conn, 42, merge_cycled=True, inc_merge_failures=True)) row = _get_pr(conn, 42) assert row["merge_cycled"] == 1 assert row["merge_failures"] == 3 def test_close_without_last_error(self): conn = _make_db() _insert_pr(conn, 42, last_error="old error") mock_api = AsyncMock(return_value={}) with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): result = asyncio.run(close_pr(conn, 42)) assert result is True row = _get_pr(conn, 42) assert row["status"] == "closed" # last_error not overwritten when not provided assert row["last_error"] == "old error" def test_close_returns_false_on_forgejo_failure(self): """Critical: if Forgejo API fails, DB must NOT be updated (ghost PR prevention).""" conn = _make_db() _insert_pr(conn, 42) mock_api = AsyncMock(return_value=None) # forgejo_api returns None on failure with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): result = asyncio.run(close_pr(conn, 42, last_error="should not persist")) assert result is False row = _get_pr(conn, 42) assert row["status"] == "open" # DB not touched assert row["last_error"] is None # last_error not set def test_close_skips_forgejo_check_when_opted_out(self): """close_on_forgejo=False always succeeds (caller already closed Forgejo).""" conn = _make_db() _insert_pr(conn, 42) mock_api = AsyncMock(return_value=None) # would fail if called with patch("lib.pr_state.forgejo_api", mock_api), \ patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): result = asyncio.run(close_pr(conn, 42, close_on_forgejo=False)) assert result is True row = _get_pr(conn, 42) assert row["status"] == "closed" mock_api.assert_not_called() # --------------------------------------------------------------------------- # approve_pr # --------------------------------------------------------------------------- class TestApprovePr: def test_approve_sets_domain_and_auto_merge(self): conn = _make_db() _insert_pr(conn, 50) approve_pr(conn, 50, domain="internet-finance", auto_merge=1) row = _get_pr(conn, 50) assert row["status"] == "approved" assert row["domain"] == "internet-finance" assert row["auto_merge"] == 1 def test_approve_raises_on_empty_domain(self): conn = _make_db() _insert_pr(conn, 50) with pytest.raises(ValueError, match="without domain"): approve_pr(conn, 50, domain="") def test_approve_raises_on_none_domain(self): conn = _make_db() _insert_pr(conn, 50) with pytest.raises(ValueError, match="without domain"): approve_pr(conn, 50, domain=None) def test_approve_sets_verdicts(self): conn = _make_db() _insert_pr(conn, 50) approve_pr(conn, 50, domain="cross-domain", auto_merge=1, leo_verdict="skipped", domain_verdict="skipped") row = _get_pr(conn, 50) assert row["leo_verdict"] == "skipped" assert row["domain_verdict"] == "skipped" def test_approve_preserves_existing_domain(self): conn = _make_db() _insert_pr(conn, 50, domain="living-agents") approve_pr(conn, 50, domain="cross-domain", auto_merge=0) row = _get_pr(conn, 50) # COALESCE(domain, ?) preserves existing domain assert row["domain"] == "living-agents" # --------------------------------------------------------------------------- # mark_merged # --------------------------------------------------------------------------- class TestMarkMerged: def test_sets_merged_at_and_clears_error(self): conn = _make_db() _insert_pr(conn, 60, status="approved", last_error="some old error") mark_merged(conn, 60) row = _get_pr(conn, 60) assert row["status"] == "merged" assert row["merged_at"] is not None assert row["last_error"] is None # --------------------------------------------------------------------------- # mark_conflict # --------------------------------------------------------------------------- class TestMarkConflict: def test_increments_failures_and_sets_cycled(self): conn = _make_db() _insert_pr(conn, 70, status="approved", merge_failures=0) mark_conflict(conn, 70, last_error="cherry-pick failed") row = _get_pr(conn, 70) assert row["status"] == "conflict" assert row["merge_cycled"] == 1 assert row["merge_failures"] == 1 assert row["last_error"] == "cherry-pick failed" def test_accumulates_failures(self): conn = _make_db() _insert_pr(conn, 70, merge_failures=5) mark_conflict(conn, 70) row = _get_pr(conn, 70) assert row["merge_failures"] == 6 # --------------------------------------------------------------------------- # mark_conflict_permanent # --------------------------------------------------------------------------- class TestMarkConflictPermanent: def test_sets_status_and_attempts(self): conn = _make_db() _insert_pr(conn, 80, status="conflict") mark_conflict_permanent(conn, 80, last_error="rebase failed 3x", conflict_rebase_attempts=3) row = _get_pr(conn, 80) assert row["status"] == "conflict_permanent" assert row["last_error"] == "rebase failed 3x" assert row["conflict_rebase_attempts"] == 3 # --------------------------------------------------------------------------- # reopen_pr # --------------------------------------------------------------------------- class TestReopenPr: def test_simple_reopen(self): conn = _make_db() _insert_pr(conn, 90, status="reviewing") reopen_pr(conn, 90) row = _get_pr(conn, 90) assert row["status"] == "open" def test_reopen_with_rejection(self): conn = _make_db() _insert_pr(conn, 90, status="reviewing") reopen_pr(conn, 90, leo_verdict="skipped", last_error="domain rejected", eval_issues='["factual_error"]') row = _get_pr(conn, 90) assert row["status"] == "open" assert row["leo_verdict"] == "skipped" assert row["last_error"] == "domain rejected" assert row["eval_issues"] == '["factual_error"]' def test_reopen_dec_eval_attempts(self): conn = _make_db() _insert_pr(conn, 90, status="reviewing", eval_attempts=3) reopen_pr(conn, 90, dec_eval_attempts=True) row = _get_pr(conn, 90) assert row["eval_attempts"] == 2 def test_reopen_reset_for_reeval(self): conn = _make_db() _insert_pr(conn, 90, status="conflict", leo_verdict="approve", domain_verdict="approve", eval_attempts=2) reopen_pr(conn, 90, reset_for_reeval=True, conflict_rebase_attempts=1) row = _get_pr(conn, 90) assert row["status"] == "open" assert row["leo_verdict"] == "pending" assert row["domain_verdict"] == "pending" assert row["eval_attempts"] == 0 assert row["conflict_rebase_attempts"] == 1 # --------------------------------------------------------------------------- # start_review # --------------------------------------------------------------------------- class TestStartReview: def test_claims_open_pr(self): conn = _make_db() _insert_pr(conn, 100) assert start_review(conn, 100) is True row = _get_pr(conn, 100) assert row["status"] == "reviewing" def test_rejects_non_open_pr(self): conn = _make_db() _insert_pr(conn, 100, status="reviewing") assert start_review(conn, 100) is False def test_double_claim_fails(self): conn = _make_db() _insert_pr(conn, 100) assert start_review(conn, 100) is True assert start_review(conn, 100) is False # --------------------------------------------------------------------------- # start_fixing # --------------------------------------------------------------------------- class TestStartFixing: def test_claims_open_pr(self): conn = _make_db() _insert_pr(conn, 200) assert start_fixing(conn, 200) is True row = _get_pr(conn, 200) assert row["status"] == "fixing" assert row["fix_attempts"] == 1 def test_increments_fix_attempts(self): conn = _make_db() _insert_pr(conn, 200, fix_attempts=3) assert start_fixing(conn, 200) is True row = _get_pr(conn, 200) assert row["fix_attempts"] == 4 def test_sets_last_attempt(self): conn = _make_db() _insert_pr(conn, 200) start_fixing(conn, 200) row = _get_pr(conn, 200) assert row["last_attempt"] is not None def test_rejects_non_open_pr(self): conn = _make_db() _insert_pr(conn, 200, status="reviewing") assert start_fixing(conn, 200) is False def test_double_claim_fails(self): conn = _make_db() _insert_pr(conn, 200) assert start_fixing(conn, 200) is True assert start_fixing(conn, 200) is False # --------------------------------------------------------------------------- # reset_for_reeval # --------------------------------------------------------------------------- class TestResetForReeval: def test_resets_all_eval_state(self): conn = _make_db() _insert_pr(conn, 300, status="fixing", eval_attempts=3, leo_verdict="request_changes", domain_verdict="approve") conn.execute( "UPDATE prs SET eval_issues = ?, tier0_pass = 1, last_error = 'some error' WHERE number = 300", ('["broken_wiki_links"]',), ) reset_for_reeval(conn, 300) row = _get_pr(conn, 300) assert row["status"] == "open" assert row["eval_attempts"] == 0 assert row["eval_issues"] == "[]" assert row["tier0_pass"] is None assert row["domain_verdict"] == "pending" assert row["leo_verdict"] == "pending" assert row["last_error"] is None def test_preserves_non_eval_fields(self): conn = _make_db() _insert_pr(conn, 300, status="fixing", domain="internet-finance", fix_attempts=2) reset_for_reeval(conn, 300) row = _get_pr(conn, 300) assert row["domain"] == "internet-finance" assert row["fix_attempts"] == 2