teleo-infrastructure/tests/test_pr_state.py
m3taversal 1f5eb324f3
Some checks are pending
CI / lint-and-test (push) Waiting to run
refactor: centralize PR state transitions in lib/pr_state.py
Replace 38 hand-crafted UPDATE prs SET status calls across evaluate.py
and merge.py with 7 centralized functions that enforce invariants:
- close_pr: always syncs Forgejo (opt-out for reconciliation)
- approve_pr: raises ValueError on empty domain (prevents NULL bugs)
- mark_merged: always sets merged_at, clears last_error
- mark_conflict: always increments merge_failures, sets merge_cycled
- mark_conflict_permanent: terminal conflict state
- reopen_pr: handles all reopen scenarios (transient, rejection, reeval)
- start_review: atomic claim with bool return

This eliminates the class of bugs that produced 3 incidents:
1. Domain NULL on musings bypass (7 PRs stuck, 20h zero throughput)
2. Forgejo ghost PRs (70 PRs open on Forgejo but closed in DB)
3. Merge_cycled missing on various close paths

Also fixes: 3 close paths in merge.py had DB update before Forgejo call
(reversed order). close_pr does Forgejo first, then DB.

Only remaining raw status transition: _claim_next_pr (approved→merging)
which is an atomic subquery and doesn't have invariant requirements.

20 new tests, 264 total passing, 0 regressions. Net -101 lines in
evaluate.py + merge.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-16 12:08:57 +01:00

336 lines
11 KiB
Python

"""Tests for lib/pr_state.py — centralized PR state transitions."""
import asyncio
import sqlite3
import sys
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Mock heavy dependencies before importing pr_state
sys.modules.setdefault("aiohttp", MagicMock())
# Add lib parent to path so `lib.pr_state` resolves
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from lib.pr_state import (
approve_pr,
close_pr,
mark_conflict,
mark_conflict_permanent,
mark_merged,
reopen_pr,
start_review,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_db():
"""Create a minimal in-memory DB with the prs table."""
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE prs (
number INTEGER PRIMARY KEY,
source_path TEXT,
branch TEXT,
status TEXT NOT NULL DEFAULT 'open',
domain TEXT,
auto_merge INTEGER DEFAULT 0,
leo_verdict TEXT DEFAULT 'pending',
domain_verdict TEXT DEFAULT 'pending',
eval_attempts INTEGER DEFAULT 0,
eval_issues TEXT DEFAULT '[]',
merge_cycled INTEGER DEFAULT 0,
merge_failures INTEGER DEFAULT 0,
conflict_rebase_attempts INTEGER DEFAULT 0,
last_error TEXT,
merged_at TEXT,
last_attempt TEXT,
cost_usd REAL DEFAULT 0,
created_at TEXT DEFAULT (datetime('now'))
)
""")
return conn
def _insert_pr(conn, number=100, status="open", domain=None, **kwargs):
"""Insert a test PR row."""
cols = ["number", "status"]
vals = [number, status]
if domain is not None:
cols.append("domain")
vals.append(domain)
for k, v in kwargs.items():
cols.append(k)
vals.append(v)
placeholders = ", ".join(["?"] * len(vals))
col_names = ", ".join(cols)
conn.execute(f"INSERT INTO prs ({col_names}) VALUES ({placeholders})", vals)
conn.commit()
def _get_pr(conn, number=100):
"""Read a PR row back."""
return conn.execute("SELECT * FROM prs WHERE number = ?", (number,)).fetchone()
# ---------------------------------------------------------------------------
# close_pr
# ---------------------------------------------------------------------------
class TestClosePr:
def test_close_calls_forgejo_and_updates_db(self):
conn = _make_db()
_insert_pr(conn, 42)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, last_error="test close"))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
assert row["last_error"] == "test close"
mock_api.assert_called_once()
def test_close_skips_forgejo_when_opted_out(self):
conn = _make_db()
_insert_pr(conn, 42)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, last_error="reconciled", close_on_forgejo=False))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
mock_api.assert_not_called()
def test_close_increments_merge_failures(self):
conn = _make_db()
_insert_pr(conn, 42, merge_failures=2)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, merge_cycled=True, inc_merge_failures=True))
row = _get_pr(conn, 42)
assert row["merge_cycled"] == 1
assert row["merge_failures"] == 3
def test_close_without_last_error(self):
conn = _make_db()
_insert_pr(conn, 42, last_error="old error")
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
# last_error not overwritten when not provided
assert row["last_error"] == "old error"
# ---------------------------------------------------------------------------
# approve_pr
# ---------------------------------------------------------------------------
class TestApprovePr:
def test_approve_sets_domain_and_auto_merge(self):
conn = _make_db()
_insert_pr(conn, 50)
approve_pr(conn, 50, domain="internet-finance", auto_merge=1)
row = _get_pr(conn, 50)
assert row["status"] == "approved"
assert row["domain"] == "internet-finance"
assert row["auto_merge"] == 1
def test_approve_raises_on_empty_domain(self):
conn = _make_db()
_insert_pr(conn, 50)
with pytest.raises(ValueError, match="without domain"):
approve_pr(conn, 50, domain="")
def test_approve_raises_on_none_domain(self):
conn = _make_db()
_insert_pr(conn, 50)
with pytest.raises(ValueError, match="without domain"):
approve_pr(conn, 50, domain=None)
def test_approve_sets_verdicts(self):
conn = _make_db()
_insert_pr(conn, 50)
approve_pr(conn, 50, domain="cross-domain", auto_merge=1,
leo_verdict="skipped", domain_verdict="skipped")
row = _get_pr(conn, 50)
assert row["leo_verdict"] == "skipped"
assert row["domain_verdict"] == "skipped"
def test_approve_preserves_existing_domain(self):
conn = _make_db()
_insert_pr(conn, 50, domain="living-agents")
approve_pr(conn, 50, domain="cross-domain", auto_merge=0)
row = _get_pr(conn, 50)
# COALESCE(domain, ?) preserves existing domain
assert row["domain"] == "living-agents"
# ---------------------------------------------------------------------------
# mark_merged
# ---------------------------------------------------------------------------
class TestMarkMerged:
def test_sets_merged_at_and_clears_error(self):
conn = _make_db()
_insert_pr(conn, 60, status="approved", last_error="some old error")
mark_merged(conn, 60)
row = _get_pr(conn, 60)
assert row["status"] == "merged"
assert row["merged_at"] is not None
assert row["last_error"] is None
# ---------------------------------------------------------------------------
# mark_conflict
# ---------------------------------------------------------------------------
class TestMarkConflict:
def test_increments_failures_and_sets_cycled(self):
conn = _make_db()
_insert_pr(conn, 70, status="approved", merge_failures=0)
mark_conflict(conn, 70, last_error="cherry-pick failed")
row = _get_pr(conn, 70)
assert row["status"] == "conflict"
assert row["merge_cycled"] == 1
assert row["merge_failures"] == 1
assert row["last_error"] == "cherry-pick failed"
def test_accumulates_failures(self):
conn = _make_db()
_insert_pr(conn, 70, merge_failures=5)
mark_conflict(conn, 70)
row = _get_pr(conn, 70)
assert row["merge_failures"] == 6
# ---------------------------------------------------------------------------
# mark_conflict_permanent
# ---------------------------------------------------------------------------
class TestMarkConflictPermanent:
def test_sets_status_and_attempts(self):
conn = _make_db()
_insert_pr(conn, 80, status="conflict")
mark_conflict_permanent(conn, 80,
last_error="rebase failed 3x",
conflict_rebase_attempts=3)
row = _get_pr(conn, 80)
assert row["status"] == "conflict_permanent"
assert row["last_error"] == "rebase failed 3x"
assert row["conflict_rebase_attempts"] == 3
# ---------------------------------------------------------------------------
# reopen_pr
# ---------------------------------------------------------------------------
class TestReopenPr:
def test_simple_reopen(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing")
reopen_pr(conn, 90)
row = _get_pr(conn, 90)
assert row["status"] == "open"
def test_reopen_with_rejection(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing")
reopen_pr(conn, 90, leo_verdict="skipped",
last_error="domain rejected",
eval_issues='["factual_error"]')
row = _get_pr(conn, 90)
assert row["status"] == "open"
assert row["leo_verdict"] == "skipped"
assert row["last_error"] == "domain rejected"
assert row["eval_issues"] == '["factual_error"]'
def test_reopen_dec_eval_attempts(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing", eval_attempts=3)
reopen_pr(conn, 90, dec_eval_attempts=True)
row = _get_pr(conn, 90)
assert row["eval_attempts"] == 2
def test_reopen_reset_for_reeval(self):
conn = _make_db()
_insert_pr(conn, 90, status="conflict",
leo_verdict="approve", domain_verdict="approve",
eval_attempts=2)
reopen_pr(conn, 90, reset_for_reeval=True,
conflict_rebase_attempts=1)
row = _get_pr(conn, 90)
assert row["status"] == "open"
assert row["leo_verdict"] == "pending"
assert row["domain_verdict"] == "pending"
assert row["eval_attempts"] == 0
assert row["conflict_rebase_attempts"] == 1
# ---------------------------------------------------------------------------
# start_review
# ---------------------------------------------------------------------------
class TestStartReview:
def test_claims_open_pr(self):
conn = _make_db()
_insert_pr(conn, 100)
assert start_review(conn, 100) is True
row = _get_pr(conn, 100)
assert row["status"] == "reviewing"
def test_rejects_non_open_pr(self):
conn = _make_db()
_insert_pr(conn, 100, status="reviewing")
assert start_review(conn, 100) is False
def test_double_claim_fails(self):
conn = _make_db()
_insert_pr(conn, 100)
assert start_review(conn, 100) is True
assert start_review(conn, 100) is False