From 1f5eb324f3c27fdae669c25a4b7aa0a415a50666 Mon Sep 17 00:00:00 2001 From: m3taversal Date: Thu, 16 Apr 2026 12:08:50 +0100 Subject: [PATCH] refactor: centralize PR state transitions in lib/pr_state.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 38 hand-crafted UPDATE prs SET status calls across evaluate.py and merge.py with 7 centralized functions that enforce invariants: - close_pr: always syncs Forgejo (opt-out for reconciliation) - approve_pr: raises ValueError on empty domain (prevents NULL bugs) - mark_merged: always sets merged_at, clears last_error - mark_conflict: always increments merge_failures, sets merge_cycled - mark_conflict_permanent: terminal conflict state - reopen_pr: handles all reopen scenarios (transient, rejection, reeval) - start_review: atomic claim with bool return This eliminates the class of bugs that produced 3 incidents: 1. Domain NULL on musings bypass (7 PRs stuck, 20h zero throughput) 2. Forgejo ghost PRs (70 PRs open on Forgejo but closed in DB) 3. Merge_cycled missing on various close paths Also fixes: 3 close paths in merge.py had DB update before Forgejo call (reversed order). close_pr does Forgejo first, then DB. Only remaining raw status transition: _claim_next_pr (approved→merging) which is an atomic subquery and doesn't have invariant requirements. 20 new tests, 264 total passing, 0 regressions. Net -101 lines in evaluate.py + merge.py. Co-Authored-By: Claude Opus 4.6 (1M context) --- lib/evaluate.py | 117 ++++---------- lib/merge.py | 98 +++--------- lib/pr_state.py | 197 ++++++++++++++++++++++++ tests/test_pr_state.py | 336 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 590 insertions(+), 158 deletions(-) create mode 100644 lib/pr_state.py create mode 100644 tests/test_pr_state.py diff --git a/lib/evaluate.py b/lib/evaluate.py index 2595137..47812b9 100644 --- a/lib/evaluate.py +++ b/lib/evaluate.py @@ -31,6 +31,7 @@ from .forgejo import get_agent_token, get_pr_diff, repo_path from .merge import PIPELINE_OWNED_PREFIXES from .llm import run_batch_domain_review, run_domain_review, run_leo_review, triage_pr from .feedback import format_rejection_comment +from .pr_state import approve_pr, close_pr, reopen_pr, start_review from .validate import load_existing_claims logger = logging.getLogger("pipeline.evaluate") @@ -375,17 +376,7 @@ async def _terminate_pr(conn, pr_number: int, reason: str): repo_path(f"issues/{pr_number}/comments"), {"body": comment_body}, ) - await forgejo_api( - "PATCH", - repo_path(f"pulls/{pr_number}"), - {"state": "closed"}, - ) - - # Update PR status - conn.execute( - "UPDATE prs SET status = 'closed', last_error = ? WHERE number = ?", - (reason, pr_number), - ) + await close_pr(conn, pr_number, last_error=reason) # Tag source for re-extraction with feedback cursor = conn.execute( @@ -506,11 +497,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: return {"pr": pr_number, "terminal": True, "reason": "eval_budget_exhausted"} # Atomic claim — prevent concurrent workers from evaluating the same PR (Ganymede #11) - cursor = conn.execute( - "UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'", - (pr_number,), - ) - if cursor.rowcount == 0: + if not start_review(conn, pr_number): logger.debug("PR #%d already claimed by another worker, skipping", pr_number) return {"pr": pr_number, "skipped": True, "reason": "already_claimed"} @@ -533,8 +520,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: diff = await get_pr_diff(pr_number) if not diff: # Close PRs with no diff — stale branch, nothing to evaluate - await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"}) - conn.execute("UPDATE prs SET status='closed', last_error='closed: no diff against main (stale branch)' WHERE number = ?", (pr_number,)) + await close_pr(conn, pr_number, last_error='closed: no diff against main (stale branch)') return {"pr": pr_number, "skipped": True, "reason": "no_diff_closed"} # Musings bypass @@ -545,12 +531,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: repo_path(f"issues/{pr_number}/comments"), {"body": "Auto-approved: musings bypass eval per collective policy."}, ) - conn.execute( - """UPDATE prs SET status = 'approved', leo_verdict = 'skipped', - domain_verdict = 'skipped', domain = COALESCE(domain, 'cross-domain'), - auto_merge = 1 WHERE number = ?""", - (pr_number,), - ) + approve_pr(conn, pr_number, domain='cross-domain', auto_merge=1, + leo_verdict='skipped', domain_verdict='skipped') return {"pr": pr_number, "auto_approved": True, "reason": "musings_only"} # Reweave bypass — reweave PRs only add frontmatter edges (supports/challenges/ @@ -566,12 +548,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: repo_path(f"issues/{pr_number}/comments"), {"body": "Auto-approved: reweave structural update (frontmatter edges only). Leo reviews manually."}, ) - conn.execute( - """UPDATE prs SET status = 'approved', leo_verdict = 'skipped', - domain_verdict = 'skipped', auto_merge = 1, - domain = COALESCE(domain, 'cross-domain') WHERE number = ?""", - (pr_number,), - ) + approve_pr(conn, pr_number, domain='cross-domain', auto_merge=1, + leo_verdict='skipped', domain_verdict='skipped') db.audit( conn, "evaluate", "reweave_bypass", json.dumps({"pr": pr_number, "branch": branch_name}), @@ -676,7 +654,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: if domain_review is None: # OpenRouter failure (timeout, error) — revert to open for retry. # NOT a rate limit — don't trigger 15-min backoff, just skip this PR. - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,)) + reopen_pr(conn, pr_number) if pr_cost > 0: conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number)) return {"pr": pr_number, "skipped": True, "reason": "openrouter_failed"} @@ -700,13 +678,9 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: if domain_verdict == "request_changes": logger.info("PR #%d: domain rejected, skipping Leo review", pr_number) domain_issues = _parse_issues(domain_review) if domain_review else [] - conn.execute( - """UPDATE prs SET status = 'open', leo_verdict = 'skipped', - last_error = 'domain review requested changes', - eval_issues = ? - WHERE number = ?""", - (json.dumps(domain_issues), pr_number), - ) + reopen_pr(conn, pr_number, leo_verdict='skipped', + last_error='domain review requested changes', + eval_issues=json.dumps(domain_issues)) db.audit( conn, "evaluate", "domain_rejected", json.dumps({"pr": pr_number, "agent": agent, "issues": domain_issues}) ) @@ -744,7 +718,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: if leo_review is None: # DEEP: Opus rate limited (queue for later). STANDARD: OpenRouter failed (skip, retry next cycle). - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,)) + reopen_pr(conn, pr_number) if domain_verdict != "skipped": pr_cost += costs.record_usage( conn, config.EVAL_DOMAIN_MODEL, "eval_domain", @@ -793,10 +767,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: branch_name = branch_row["branch"] if branch_row else "" is_agent_pr = not branch_name.startswith(PIPELINE_OWNED_PREFIXES) - conn.execute( - "UPDATE prs SET status = 'approved', domain = COALESCE(domain, ?), auto_merge = ? WHERE number = ?", - (domain, 1 if is_agent_pr else 0, pr_number), - ) + approve_pr(conn, pr_number, domain=domain, auto_merge=1 if is_agent_pr else 0) db.audit( conn, "evaluate", @@ -821,10 +792,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict: if leo_verdict == "request_changes" and leo_review is not None: all_issues.extend(_parse_issues(leo_review)) - conn.execute( - "UPDATE prs SET status = 'open', eval_issues = ? WHERE number = ?", - (json.dumps(all_issues), pr_number), - ) + reopen_pr(conn, pr_number, eval_issues=json.dumps(all_issues)) # Store feedback for re-extraction path feedback = {"leo": leo_verdict, "domain": domain_verdict, "tier": tier, "issues": all_issues} conn.execute( @@ -1050,11 +1018,7 @@ async def _run_batch_domain_eval( pr_num = pr_row["number"] # Atomic claim - cursor = conn.execute( - "UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'", - (pr_num,), - ) - if cursor.rowcount == 0: + if not start_review(conn, pr_num): continue # Increment eval_attempts — skip if merge-cycled (Ganymede+Rhea) @@ -1074,7 +1038,7 @@ async def _run_batch_domain_eval( diff = await _get_pr_diff(pr_num) if not diff: - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,)) + reopen_pr(conn, pr_num) continue # Musings bypass @@ -1084,12 +1048,8 @@ async def _run_batch_domain_eval( repo_path(f"issues/{pr_num}/comments"), {"body": "Auto-approved: musings bypass eval per collective policy."}, ) - conn.execute( - "UPDATE prs SET status = 'approved', leo_verdict = 'skipped', " - "domain_verdict = 'skipped', domain = COALESCE(domain, 'cross-domain'), " - "auto_merge = 1 WHERE number = ?", - (pr_num,), - ) + approve_pr(conn, pr_num, domain='cross-domain', auto_merge=1, + leo_verdict='skipped', domain_verdict='skipped') succeeded += 1 continue @@ -1130,11 +1090,7 @@ async def _run_batch_domain_eval( running_bytes += p_bytes overflow = [p for p in pr_diffs if p not in kept] for p in overflow: - conn.execute( - "UPDATE prs SET status = 'open', eval_attempts = COALESCE(eval_attempts, 1) - 1 " - "WHERE number = ?", - (p["number"],), - ) + reopen_pr(conn, p["number"], dec_eval_attempts=True) claimed_prs.remove(p["number"]) logger.info( "PR #%d: diff too large for batch (%d bytes total), deferring to next cycle", @@ -1166,7 +1122,7 @@ async def _run_batch_domain_eval( # Complete failure — revert all to open logger.warning("Batch domain review failed — reverting all PRs to open") for pr_num in claimed_prs: - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,)) + reopen_pr(conn, pr_num) return 0, len(claimed_prs) # Step 3: Parse + validate fan-out @@ -1198,17 +1154,13 @@ async def _run_batch_domain_eval( if pr_num in fallback_prs: # Revert — will be picked up by individual eval next cycle - conn.execute( - "UPDATE prs SET status = 'open', eval_attempts = COALESCE(eval_attempts, 1) - 1 " - "WHERE number = ?", - (pr_num,), - ) + reopen_pr(conn, pr_num, dec_eval_attempts=True) logger.info("PR #%d: batch fallback — will retry individually", pr_num) continue if pr_num not in valid_reviews: # Should not happen, but safety - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,)) + reopen_pr(conn, pr_num) continue review_text = valid_reviews[pr_num] @@ -1235,11 +1187,9 @@ async def _run_batch_domain_eval( "SELECT eval_attempts FROM prs WHERE number = ?", (pr_num,) ).fetchone()["eval_attempts"] or 0) - conn.execute( - "UPDATE prs SET status = 'open', leo_verdict = 'skipped', " - "last_error = 'domain review requested changes', eval_issues = ? WHERE number = ?", - (json.dumps(domain_issues), pr_num), - ) + reopen_pr(conn, pr_num, leo_verdict='skipped', + last_error='domain review requested changes', + eval_issues=json.dumps(domain_issues)) db.audit( conn, "evaluate", "domain_rejected", json.dumps({"pr": pr_num, "agent": agent, "issues": domain_issues, "batch": True}), @@ -1257,12 +1207,12 @@ async def _run_batch_domain_eval( leo_review, leo_usage = await run_leo_review(review_diff, files, "STANDARD") if leo_review is None: - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,)) + reopen_pr(conn, pr_num) logger.debug("PR #%d: Leo review failed, will retry next cycle", pr_num) continue if leo_review == "RATE_LIMITED": - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,)) + reopen_pr(conn, pr_num) logger.info("PR #%d: Leo rate limited, will retry next cycle", pr_num) continue @@ -1292,7 +1242,7 @@ async def _run_batch_domain_eval( pr_info = await forgejo_api("GET", repo_path(f"pulls/{pr_num}")) pr_author = pr_info.get("user", {}).get("login", "") if pr_info else "" await _post_formal_approvals(pr_num, pr_author) - conn.execute("UPDATE prs SET status = 'approved', domain = COALESCE(domain, ?), auto_merge = 1 WHERE number = ?", (domain or "cross-domain", pr_num,)) + approve_pr(conn, pr_num, domain=domain or 'cross-domain', auto_merge=1) db.audit( conn, "evaluate", "approved", json.dumps({"pr": pr_num, "tier": "STANDARD", "domain": domain, @@ -1303,10 +1253,7 @@ async def _run_batch_domain_eval( all_issues = [] if leo_verdict == "request_changes": all_issues.extend(_parse_issues(leo_review)) - conn.execute( - "UPDATE prs SET status = 'open', eval_issues = ? WHERE number = ?", - (json.dumps(all_issues), pr_num), - ) + reopen_pr(conn, pr_num, eval_issues=json.dumps(all_issues)) feedback = {"leo": leo_verdict, "domain": domain_verdict, "tier": "STANDARD", "issues": all_issues} conn.execute( @@ -1453,7 +1400,7 @@ async def evaluate_cycle(conn, max_workers=None) -> tuple[int, int]: logger.exception("Batch eval failed for domain %s", domain) # Revert all to open for pr_row in batch_prs: - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_row["number"],)) + reopen_pr(conn, pr_row["number"]) failed += len(batch_prs) # Process individual PRs (DEEP, LIGHT, single-domain, fallback) @@ -1502,7 +1449,7 @@ async def evaluate_cycle(conn, max_workers=None) -> tuple[int, int]: except Exception: logger.exception("Failed to evaluate PR #%d", row["number"]) failed += 1 - conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (row["number"],)) + reopen_pr(conn, row["number"]) if succeeded or failed: logger.info("Evaluate cycle: %d evaluated, %d errors", succeeded, failed) diff --git a/lib/merge.py b/lib/merge.py index 11afc54..0951d87 100644 --- a/lib/merge.py +++ b/lib/merge.py @@ -24,6 +24,7 @@ from .db import classify_branch from .dedup import dedup_evidence_blocks from .domains import detect_domain_from_branch from .forgejo import api as forgejo_api +from .pr_state import close_pr, mark_conflict, mark_conflict_permanent, mark_merged, reopen_pr # Pipeline-owned branch prefixes — only these get auto-merged. # Agent branches (theseus/*, rio/*, astra/*, etc.) stay approved but are NOT @@ -1431,10 +1432,7 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]: logger.error( "PR #%d merge timed out after %ds — resetting to conflict (Rhea)", pr_num, MERGE_TIMEOUT_SECONDS ) - conn.execute( - "UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?", - (f"merge timed out after {MERGE_TIMEOUT_SECONDS}s", pr_num), - ) + mark_conflict(conn, pr_num, last_error=f"merge timed out after {MERGE_TIMEOUT_SECONDS}s") db.audit(conn, "merge", "timeout", json.dumps({"pr": pr_num, "timeout_seconds": MERGE_TIMEOUT_SECONDS})) failed += 1 continue @@ -1443,19 +1441,14 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]: logger.warning("PR #%d merge/cherry-pick failed: %s", pr_num, pick_msg) # Reweave: close immediately, don't retry (Ship: same rationale as ff-push failure) if branch.startswith("reweave/"): - conn.execute( - "UPDATE prs SET status = 'closed', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?", - (f"reweave merge failed (closed, not retried): {pick_msg[:400]}", pr_num), - ) - await forgejo_api("PATCH", repo_path(f"pulls/{pr_num}"), {"state": "closed"}) + await close_pr(conn, pr_num, + last_error=f"reweave merge failed (closed, not retried): {pick_msg[:400]}", + merge_cycled=True, inc_merge_failures=True) await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"), {"body": f"Reweave merge failed — closing. Next nightly reweave will create a fresh branch.\n\nError: {pick_msg[:200]}"}) await _delete_remote_branch(branch) else: - conn.execute( - "UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?", - (pick_msg[:500], pr_num), - ) + mark_conflict(conn, pr_num, last_error=pick_msg[:500]) db.audit(conn, "merge", "cherry_pick_failed", json.dumps({"pr": pr_num, "error": pick_msg[:200]})) failed += 1 continue @@ -1465,10 +1458,7 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]: # The branch ref still points at old commits (not a descendant of main), # so pushing branch_sha:main would fail as non-fast-forward. if pick_msg in ("already merged (all commits empty)", "already up to date"): - conn.execute( - "UPDATE prs SET status = 'merged', merged_at = datetime('now'), last_error = NULL WHERE number = ?", - (pr_num,), - ) + mark_merged(conn, pr_num) db.audit(conn, "merge", "merged", json.dumps({"pr": pr_num, "branch": branch, "note": "content already on main"})) leo_token = get_agent_token("leo") await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"), @@ -1523,31 +1513,20 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]: # run creates a fresh branch from current main — retry is wasteful. # (Ship: prevents reweave flood + wasted retry cycles) if branch.startswith("reweave/"): - conn.execute( - "UPDATE prs SET status = 'closed', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?", - (f"reweave merge failed (closed, not retried): {merge_msg[:400]}", pr_num), - ) - await forgejo_api("PATCH", repo_path(f"pulls/{pr_num}"), {"state": "closed"}) + await close_pr(conn, pr_num, + last_error=f"reweave merge failed (closed, not retried): {merge_msg[:400]}", + merge_cycled=True, inc_merge_failures=True) await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"), {"body": f"Reweave merge failed — closing. Next nightly reweave will create a fresh branch.\n\nError: {merge_msg[:200]}"}) await _delete_remote_branch(branch) else: - conn.execute( - "UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?", - (merge_msg[:500], pr_num), - ) + mark_conflict(conn, pr_num, last_error=merge_msg[:500]) db.audit(conn, "merge", "merge_failed", json.dumps({"pr": pr_num, "error": merge_msg[:200]})) failed += 1 continue # Success — update status and cleanup - conn.execute( - """UPDATE prs SET status = 'merged', - merged_at = datetime('now'), - last_error = NULL - WHERE number = ?""", - (pr_num,), - ) + mark_merged(conn, pr_num) db.audit(conn, "merge", "merged", json.dumps({"pr": pr_num, "branch": branch})) logger.info("PR #%d merged successfully", pr_num) @@ -1624,10 +1603,7 @@ async def _reconcile_db_state(conn): is_merged = pr_info.get("merged", False) if is_merged and db_status != "merged": - conn.execute( - "UPDATE prs SET status = 'merged', merged_at = datetime('now') WHERE number = ?", - (pr_number,), - ) + mark_merged(conn, pr_number) reconciled += 1 continue @@ -1637,10 +1613,7 @@ async def _reconcile_db_state(conn): # trigger discover_external_prs → new PR → fail → close → repeat) if branch: await _delete_remote_branch(branch) - conn.execute( - "UPDATE prs SET status = 'closed', last_error = 'reconciled: closed on Forgejo' WHERE number = ?", - (pr_number,), - ) + await close_pr(conn, pr_number, last_error='reconciled: closed on Forgejo', close_on_forgejo=False) reconciled += 1 continue @@ -1663,10 +1636,7 @@ async def _reconcile_db_state(conn): repo_path(f"issues/{pr_number}/comments"), body={"body": "Auto-closed: branch deleted from remote."}, ) - conn.execute( - "UPDATE prs SET status = 'closed', last_error = 'reconciled: branch deleted' WHERE number = ?", - (pr_number,), - ) + await close_pr(conn, pr_number, last_error='reconciled: branch deleted', close_on_forgejo=False) logger.info("Ghost PR #%d: branch %s deleted, closing", pr_number, branch) reconciled += 1 @@ -1763,11 +1733,9 @@ async def _handle_permanent_conflicts(conn) -> int: except Exception: pass - await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"}) - conn.execute( - "UPDATE prs SET status = 'closed', last_error = 'conflict_permanent: closed + filed in archive' WHERE number = ?", - (pr_number,), - ) + await close_pr(conn, pr_number, + last_error='conflict_permanent: closed + filed in archive', + close_on_forgejo=False) # Already closed at line 1718 handled += 1 logger.info("Permanent conflict handled: PR #%d closed, source filed", pr_number) @@ -1835,11 +1803,8 @@ async def _retry_conflict_prs(conn) -> tuple[int, int]: # (Ship: prevents wasting 3 retry cycles on branches that can never cherry-pick) if branch.startswith("reweave/"): logger.info("Reweave PR #%d: skipping retry, closing + deleting branch", pr_number) - conn.execute( - "UPDATE prs SET status = 'closed', last_error = 'reweave: closed (retry skipped, next nightly creates fresh)' WHERE number = ?", - (pr_number,), - ) - await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"}) + await close_pr(conn, pr_number, + last_error='reweave: closed (retry skipped, next nightly creates fresh)') await forgejo_api("POST", repo_path(f"issues/{pr_number}/comments"), {"body": "Reweave conflict — closing instead of retrying. Cherry-pick always fails on reweave branches (they modify existing files). Next nightly reweave will create a fresh branch from current main."}) await _delete_remote_branch(branch) @@ -1858,29 +1823,16 @@ async def _retry_conflict_prs(conn) -> tuple[int, int]: if ok: # Rebase succeeded — reset for re-eval (Ganymede: approvals are stale after rebase) - conn.execute( - """UPDATE prs - SET status = 'open', - leo_verdict = 'pending', - domain_verdict = 'pending', - eval_attempts = 0, - conflict_rebase_attempts = ? - WHERE number = ?""", - (attempts + 1, pr_number), - ) + reopen_pr(conn, pr_number, reset_for_reeval=True, + conflict_rebase_attempts=attempts + 1) logger.info("Conflict resolved: PR #%d rebased successfully, reset for re-eval", pr_number) resolved += 1 else: new_attempts = attempts + 1 if new_attempts >= MAX_CONFLICT_REBASE_ATTEMPTS: - conn.execute( - """UPDATE prs - SET status = 'conflict_permanent', - conflict_rebase_attempts = ?, - last_error = ? - WHERE number = ?""", - (new_attempts, f"rebase failed {MAX_CONFLICT_REBASE_ATTEMPTS}x: {msg[:200]}", pr_number), - ) + mark_conflict_permanent(conn, pr_number, + last_error=f"rebase failed {MAX_CONFLICT_REBASE_ATTEMPTS}x: {msg[:200]}", + conflict_rebase_attempts=new_attempts) logger.warning("Conflict permanent: PR #%d failed %d rebase attempts: %s", pr_number, new_attempts, msg[:100]) else: diff --git a/lib/pr_state.py b/lib/pr_state.py new file mode 100644 index 0000000..c5b50fe --- /dev/null +++ b/lib/pr_state.py @@ -0,0 +1,197 @@ +"""PR state transitions — single source of truth for all status changes. + +Every UPDATE prs SET status = ... MUST go through this module. + +Invariants enforced: +- close: always syncs Forgejo (opt-out for reconciliation only) +- approve: requires non-empty domain (ValueError) +- merged: always sets merged_at, clears last_error +- conflict: always increments merge_failures, sets merge_cycled + +Why this exists: 36 hand-crafted status transitions across evaluate.py +and merge.py produced 3 incidents (domain NULL, Forgejo ghost PRs, +merge_cycled missing). Centralizing eliminates the entire class of +"forgot to update X in this one code path" bugs. +""" + +import logging + +from .forgejo import api as forgejo_api, repo_path + +logger = logging.getLogger("pipeline.pr_state") + + +async def close_pr( + conn, + pr_number: int, + *, + last_error: str = None, + merge_cycled: bool = False, + inc_merge_failures: bool = False, + close_on_forgejo: bool = True, +): + """Close a PR in DB and on Forgejo. + + Args: + close_on_forgejo: False only when caller already closed on Forgejo + (reconciliation, ghost PR cleanup after manual close). + """ + if close_on_forgejo: + await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"}) + + parts = ["status = 'closed'"] + params = [] + + if last_error is not None: + parts.append("last_error = ?") + params.append(last_error) + + if merge_cycled: + parts.append("merge_cycled = 1") + + if inc_merge_failures: + parts.append("merge_failures = COALESCE(merge_failures, 0) + 1") + + params.append(pr_number) + conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params) + + +def approve_pr( + conn, + pr_number: int, + *, + domain: str, + auto_merge: int = 0, + leo_verdict: str = None, + domain_verdict: str = None, +): + """Approve a PR. Raises ValueError if domain is empty/None.""" + if not domain: + raise ValueError(f"Cannot approve PR #{pr_number} without domain") + + parts = ["status = 'approved'", "domain = COALESCE(domain, ?)"] + params = [domain] + + parts.append("auto_merge = ?") + params.append(auto_merge) + + if leo_verdict is not None: + parts.append("leo_verdict = ?") + params.append(leo_verdict) + + if domain_verdict is not None: + parts.append("domain_verdict = ?") + params.append(domain_verdict) + + params.append(pr_number) + conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params) + + +def mark_merged(conn, pr_number: int): + """Mark PR as merged. Always sets merged_at, clears last_error.""" + conn.execute( + "UPDATE prs SET status = 'merged', merged_at = datetime('now'), " + "last_error = NULL WHERE number = ?", + (pr_number,), + ) + + +def mark_conflict(conn, pr_number: int, *, last_error: str = None): + """Mark PR as conflict. Always increments merge_failures, sets merge_cycled.""" + conn.execute( + "UPDATE prs SET status = 'conflict', merge_cycled = 1, " + "merge_failures = COALESCE(merge_failures, 0) + 1, " + "last_error = ? WHERE number = ?", + (last_error, pr_number), + ) + + +def mark_conflict_permanent( + conn, + pr_number: int, + *, + last_error: str = None, + conflict_rebase_attempts: int = None, +): + """Mark PR as permanently conflicted (no more retries).""" + parts = ["status = 'conflict_permanent'"] + params = [] + + if last_error is not None: + parts.append("last_error = ?") + params.append(last_error) + + if conflict_rebase_attempts is not None: + parts.append("conflict_rebase_attempts = ?") + params.append(conflict_rebase_attempts) + + params.append(pr_number) + conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params) + + +def reopen_pr( + conn, + pr_number: int, + *, + leo_verdict: str = None, + domain_verdict: str = None, + last_error: str = None, + eval_issues: str = None, + dec_eval_attempts: bool = False, + reset_for_reeval: bool = False, + conflict_rebase_attempts: int = None, +): + """Set PR back to open. + + Covers all reopen scenarios: + - Transient failure (API error): no extra args + - Rejection: leo_verdict + last_error + eval_issues + - Batch overflow: dec_eval_attempts=True + - Conflict resolved: reset_for_reeval=True + """ + parts = ["status = 'open'"] + params = [] + + if reset_for_reeval: + parts.extend([ + "leo_verdict = 'pending'", + "domain_verdict = 'pending'", + "eval_attempts = 0", + ]) + else: + if leo_verdict is not None: + parts.append("leo_verdict = ?") + params.append(leo_verdict) + if domain_verdict is not None: + parts.append("domain_verdict = ?") + params.append(domain_verdict) + + if last_error is not None: + parts.append("last_error = ?") + params.append(last_error) + + if eval_issues is not None: + parts.append("eval_issues = ?") + params.append(eval_issues) + + if dec_eval_attempts: + parts.append("eval_attempts = COALESCE(eval_attempts, 1) - 1") + + if conflict_rebase_attempts is not None: + parts.append("conflict_rebase_attempts = ?") + params.append(conflict_rebase_attempts) + + params.append(pr_number) + conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params) + + +def start_review(conn, pr_number: int) -> bool: + """Atomically claim PR for review (status open -> reviewing). + + Returns True if claimed, False if already claimed by another worker. + """ + cursor = conn.execute( + "UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'", + (pr_number,), + ) + return cursor.rowcount > 0 diff --git a/tests/test_pr_state.py b/tests/test_pr_state.py new file mode 100644 index 0000000..c377020 --- /dev/null +++ b/tests/test_pr_state.py @@ -0,0 +1,336 @@ +"""Tests for lib/pr_state.py — centralized PR state transitions.""" + +import asyncio +import sqlite3 +import sys +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock heavy dependencies before importing pr_state +sys.modules.setdefault("aiohttp", MagicMock()) + +# Add lib parent to path so `lib.pr_state` resolves +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from lib.pr_state import ( + approve_pr, + close_pr, + mark_conflict, + mark_conflict_permanent, + mark_merged, + reopen_pr, + start_review, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_db(): + """Create a minimal in-memory DB with the prs table.""" + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + conn.execute(""" + CREATE TABLE prs ( + number INTEGER PRIMARY KEY, + source_path TEXT, + branch TEXT, + status TEXT NOT NULL DEFAULT 'open', + domain TEXT, + auto_merge INTEGER DEFAULT 0, + leo_verdict TEXT DEFAULT 'pending', + domain_verdict TEXT DEFAULT 'pending', + eval_attempts INTEGER DEFAULT 0, + eval_issues TEXT DEFAULT '[]', + merge_cycled INTEGER DEFAULT 0, + merge_failures INTEGER DEFAULT 0, + conflict_rebase_attempts INTEGER DEFAULT 0, + last_error TEXT, + merged_at TEXT, + last_attempt TEXT, + cost_usd REAL DEFAULT 0, + created_at TEXT DEFAULT (datetime('now')) + ) + """) + return conn + + +def _insert_pr(conn, number=100, status="open", domain=None, **kwargs): + """Insert a test PR row.""" + cols = ["number", "status"] + vals = [number, status] + if domain is not None: + cols.append("domain") + vals.append(domain) + for k, v in kwargs.items(): + cols.append(k) + vals.append(v) + placeholders = ", ".join(["?"] * len(vals)) + col_names = ", ".join(cols) + conn.execute(f"INSERT INTO prs ({col_names}) VALUES ({placeholders})", vals) + conn.commit() + + +def _get_pr(conn, number=100): + """Read a PR row back.""" + return conn.execute("SELECT * FROM prs WHERE number = ?", (number,)).fetchone() + + +# --------------------------------------------------------------------------- +# close_pr +# --------------------------------------------------------------------------- + +class TestClosePr: + def test_close_calls_forgejo_and_updates_db(self): + conn = _make_db() + _insert_pr(conn, 42) + + mock_api = AsyncMock(return_value={}) + with patch("lib.pr_state.forgejo_api", mock_api), \ + patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): + asyncio.run(close_pr(conn, 42, last_error="test close")) + + row = _get_pr(conn, 42) + assert row["status"] == "closed" + assert row["last_error"] == "test close" + mock_api.assert_called_once() + + def test_close_skips_forgejo_when_opted_out(self): + conn = _make_db() + _insert_pr(conn, 42) + + mock_api = AsyncMock(return_value={}) + with patch("lib.pr_state.forgejo_api", mock_api), \ + patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): + asyncio.run(close_pr(conn, 42, last_error="reconciled", close_on_forgejo=False)) + + row = _get_pr(conn, 42) + assert row["status"] == "closed" + mock_api.assert_not_called() + + def test_close_increments_merge_failures(self): + conn = _make_db() + _insert_pr(conn, 42, merge_failures=2) + + mock_api = AsyncMock(return_value={}) + with patch("lib.pr_state.forgejo_api", mock_api), \ + patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): + asyncio.run(close_pr(conn, 42, merge_cycled=True, inc_merge_failures=True)) + + row = _get_pr(conn, 42) + assert row["merge_cycled"] == 1 + assert row["merge_failures"] == 3 + + def test_close_without_last_error(self): + conn = _make_db() + _insert_pr(conn, 42, last_error="old error") + + mock_api = AsyncMock(return_value={}) + with patch("lib.pr_state.forgejo_api", mock_api), \ + patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"): + asyncio.run(close_pr(conn, 42)) + + row = _get_pr(conn, 42) + assert row["status"] == "closed" + # last_error not overwritten when not provided + assert row["last_error"] == "old error" + + +# --------------------------------------------------------------------------- +# approve_pr +# --------------------------------------------------------------------------- + +class TestApprovePr: + def test_approve_sets_domain_and_auto_merge(self): + conn = _make_db() + _insert_pr(conn, 50) + + approve_pr(conn, 50, domain="internet-finance", auto_merge=1) + + row = _get_pr(conn, 50) + assert row["status"] == "approved" + assert row["domain"] == "internet-finance" + assert row["auto_merge"] == 1 + + def test_approve_raises_on_empty_domain(self): + conn = _make_db() + _insert_pr(conn, 50) + + with pytest.raises(ValueError, match="without domain"): + approve_pr(conn, 50, domain="") + + def test_approve_raises_on_none_domain(self): + conn = _make_db() + _insert_pr(conn, 50) + + with pytest.raises(ValueError, match="without domain"): + approve_pr(conn, 50, domain=None) + + def test_approve_sets_verdicts(self): + conn = _make_db() + _insert_pr(conn, 50) + + approve_pr(conn, 50, domain="cross-domain", auto_merge=1, + leo_verdict="skipped", domain_verdict="skipped") + + row = _get_pr(conn, 50) + assert row["leo_verdict"] == "skipped" + assert row["domain_verdict"] == "skipped" + + def test_approve_preserves_existing_domain(self): + conn = _make_db() + _insert_pr(conn, 50, domain="living-agents") + + approve_pr(conn, 50, domain="cross-domain", auto_merge=0) + + row = _get_pr(conn, 50) + # COALESCE(domain, ?) preserves existing domain + assert row["domain"] == "living-agents" + + +# --------------------------------------------------------------------------- +# mark_merged +# --------------------------------------------------------------------------- + +class TestMarkMerged: + def test_sets_merged_at_and_clears_error(self): + conn = _make_db() + _insert_pr(conn, 60, status="approved", last_error="some old error") + + mark_merged(conn, 60) + + row = _get_pr(conn, 60) + assert row["status"] == "merged" + assert row["merged_at"] is not None + assert row["last_error"] is None + + +# --------------------------------------------------------------------------- +# mark_conflict +# --------------------------------------------------------------------------- + +class TestMarkConflict: + def test_increments_failures_and_sets_cycled(self): + conn = _make_db() + _insert_pr(conn, 70, status="approved", merge_failures=0) + + mark_conflict(conn, 70, last_error="cherry-pick failed") + + row = _get_pr(conn, 70) + assert row["status"] == "conflict" + assert row["merge_cycled"] == 1 + assert row["merge_failures"] == 1 + assert row["last_error"] == "cherry-pick failed" + + def test_accumulates_failures(self): + conn = _make_db() + _insert_pr(conn, 70, merge_failures=5) + + mark_conflict(conn, 70) + + row = _get_pr(conn, 70) + assert row["merge_failures"] == 6 + + +# --------------------------------------------------------------------------- +# mark_conflict_permanent +# --------------------------------------------------------------------------- + +class TestMarkConflictPermanent: + def test_sets_status_and_attempts(self): + conn = _make_db() + _insert_pr(conn, 80, status="conflict") + + mark_conflict_permanent(conn, 80, + last_error="rebase failed 3x", + conflict_rebase_attempts=3) + + row = _get_pr(conn, 80) + assert row["status"] == "conflict_permanent" + assert row["last_error"] == "rebase failed 3x" + assert row["conflict_rebase_attempts"] == 3 + + +# --------------------------------------------------------------------------- +# reopen_pr +# --------------------------------------------------------------------------- + +class TestReopenPr: + def test_simple_reopen(self): + conn = _make_db() + _insert_pr(conn, 90, status="reviewing") + + reopen_pr(conn, 90) + + row = _get_pr(conn, 90) + assert row["status"] == "open" + + def test_reopen_with_rejection(self): + conn = _make_db() + _insert_pr(conn, 90, status="reviewing") + + reopen_pr(conn, 90, leo_verdict="skipped", + last_error="domain rejected", + eval_issues='["factual_error"]') + + row = _get_pr(conn, 90) + assert row["status"] == "open" + assert row["leo_verdict"] == "skipped" + assert row["last_error"] == "domain rejected" + assert row["eval_issues"] == '["factual_error"]' + + def test_reopen_dec_eval_attempts(self): + conn = _make_db() + _insert_pr(conn, 90, status="reviewing", eval_attempts=3) + + reopen_pr(conn, 90, dec_eval_attempts=True) + + row = _get_pr(conn, 90) + assert row["eval_attempts"] == 2 + + def test_reopen_reset_for_reeval(self): + conn = _make_db() + _insert_pr(conn, 90, status="conflict", + leo_verdict="approve", domain_verdict="approve", + eval_attempts=2) + + reopen_pr(conn, 90, reset_for_reeval=True, + conflict_rebase_attempts=1) + + row = _get_pr(conn, 90) + assert row["status"] == "open" + assert row["leo_verdict"] == "pending" + assert row["domain_verdict"] == "pending" + assert row["eval_attempts"] == 0 + assert row["conflict_rebase_attempts"] == 1 + + +# --------------------------------------------------------------------------- +# start_review +# --------------------------------------------------------------------------- + +class TestStartReview: + def test_claims_open_pr(self): + conn = _make_db() + _insert_pr(conn, 100) + + assert start_review(conn, 100) is True + + row = _get_pr(conn, 100) + assert row["status"] == "reviewing" + + def test_rejects_non_open_pr(self): + conn = _make_db() + _insert_pr(conn, 100, status="reviewing") + + assert start_review(conn, 100) is False + + def test_double_claim_fails(self): + conn = _make_db() + _insert_pr(conn, 100) + + assert start_review(conn, 100) is True + assert start_review(conn, 100) is False