refactor: centralize PR state transitions in lib/pr_state.py
Some checks are pending
CI / lint-and-test (push) Waiting to run

Replace 38 hand-crafted UPDATE prs SET status calls across evaluate.py
and merge.py with 7 centralized functions that enforce invariants:
- close_pr: always syncs Forgejo (opt-out for reconciliation)
- approve_pr: raises ValueError on empty domain (prevents NULL bugs)
- mark_merged: always sets merged_at, clears last_error
- mark_conflict: always increments merge_failures, sets merge_cycled
- mark_conflict_permanent: terminal conflict state
- reopen_pr: handles all reopen scenarios (transient, rejection, reeval)
- start_review: atomic claim with bool return

This eliminates the class of bugs that produced 3 incidents:
1. Domain NULL on musings bypass (7 PRs stuck, 20h zero throughput)
2. Forgejo ghost PRs (70 PRs open on Forgejo but closed in DB)
3. Merge_cycled missing on various close paths

Also fixes: 3 close paths in merge.py had DB update before Forgejo call
(reversed order). close_pr does Forgejo first, then DB.

Only remaining raw status transition: _claim_next_pr (approved→merging)
which is an atomic subquery and doesn't have invariant requirements.

20 new tests, 264 total passing, 0 regressions. Net -101 lines in
evaluate.py + merge.py.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
m3taversal 2026-04-16 12:08:50 +01:00
parent d073e22e8d
commit 1f5eb324f3
4 changed files with 590 additions and 158 deletions

View file

@ -31,6 +31,7 @@ from .forgejo import get_agent_token, get_pr_diff, repo_path
from .merge import PIPELINE_OWNED_PREFIXES
from .llm import run_batch_domain_review, run_domain_review, run_leo_review, triage_pr
from .feedback import format_rejection_comment
from .pr_state import approve_pr, close_pr, reopen_pr, start_review
from .validate import load_existing_claims
logger = logging.getLogger("pipeline.evaluate")
@ -375,17 +376,7 @@ async def _terminate_pr(conn, pr_number: int, reason: str):
repo_path(f"issues/{pr_number}/comments"),
{"body": comment_body},
)
await forgejo_api(
"PATCH",
repo_path(f"pulls/{pr_number}"),
{"state": "closed"},
)
# Update PR status
conn.execute(
"UPDATE prs SET status = 'closed', last_error = ? WHERE number = ?",
(reason, pr_number),
)
await close_pr(conn, pr_number, last_error=reason)
# Tag source for re-extraction with feedback
cursor = conn.execute(
@ -506,11 +497,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
return {"pr": pr_number, "terminal": True, "reason": "eval_budget_exhausted"}
# Atomic claim — prevent concurrent workers from evaluating the same PR (Ganymede #11)
cursor = conn.execute(
"UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'",
(pr_number,),
)
if cursor.rowcount == 0:
if not start_review(conn, pr_number):
logger.debug("PR #%d already claimed by another worker, skipping", pr_number)
return {"pr": pr_number, "skipped": True, "reason": "already_claimed"}
@ -533,8 +520,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
diff = await get_pr_diff(pr_number)
if not diff:
# Close PRs with no diff — stale branch, nothing to evaluate
await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"})
conn.execute("UPDATE prs SET status='closed', last_error='closed: no diff against main (stale branch)' WHERE number = ?", (pr_number,))
await close_pr(conn, pr_number, last_error='closed: no diff against main (stale branch)')
return {"pr": pr_number, "skipped": True, "reason": "no_diff_closed"}
# Musings bypass
@ -545,12 +531,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
repo_path(f"issues/{pr_number}/comments"),
{"body": "Auto-approved: musings bypass eval per collective policy."},
)
conn.execute(
"""UPDATE prs SET status = 'approved', leo_verdict = 'skipped',
domain_verdict = 'skipped', domain = COALESCE(domain, 'cross-domain'),
auto_merge = 1 WHERE number = ?""",
(pr_number,),
)
approve_pr(conn, pr_number, domain='cross-domain', auto_merge=1,
leo_verdict='skipped', domain_verdict='skipped')
return {"pr": pr_number, "auto_approved": True, "reason": "musings_only"}
# Reweave bypass — reweave PRs only add frontmatter edges (supports/challenges/
@ -566,12 +548,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
repo_path(f"issues/{pr_number}/comments"),
{"body": "Auto-approved: reweave structural update (frontmatter edges only). Leo reviews manually."},
)
conn.execute(
"""UPDATE prs SET status = 'approved', leo_verdict = 'skipped',
domain_verdict = 'skipped', auto_merge = 1,
domain = COALESCE(domain, 'cross-domain') WHERE number = ?""",
(pr_number,),
)
approve_pr(conn, pr_number, domain='cross-domain', auto_merge=1,
leo_verdict='skipped', domain_verdict='skipped')
db.audit(
conn, "evaluate", "reweave_bypass",
json.dumps({"pr": pr_number, "branch": branch_name}),
@ -676,7 +654,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
if domain_review is None:
# OpenRouter failure (timeout, error) — revert to open for retry.
# NOT a rate limit — don't trigger 15-min backoff, just skip this PR.
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,))
reopen_pr(conn, pr_number)
if pr_cost > 0:
conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number))
return {"pr": pr_number, "skipped": True, "reason": "openrouter_failed"}
@ -700,13 +678,9 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
if domain_verdict == "request_changes":
logger.info("PR #%d: domain rejected, skipping Leo review", pr_number)
domain_issues = _parse_issues(domain_review) if domain_review else []
conn.execute(
"""UPDATE prs SET status = 'open', leo_verdict = 'skipped',
last_error = 'domain review requested changes',
eval_issues = ?
WHERE number = ?""",
(json.dumps(domain_issues), pr_number),
)
reopen_pr(conn, pr_number, leo_verdict='skipped',
last_error='domain review requested changes',
eval_issues=json.dumps(domain_issues))
db.audit(
conn, "evaluate", "domain_rejected", json.dumps({"pr": pr_number, "agent": agent, "issues": domain_issues})
)
@ -744,7 +718,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
if leo_review is None:
# DEEP: Opus rate limited (queue for later). STANDARD: OpenRouter failed (skip, retry next cycle).
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,))
reopen_pr(conn, pr_number)
if domain_verdict != "skipped":
pr_cost += costs.record_usage(
conn, config.EVAL_DOMAIN_MODEL, "eval_domain",
@ -793,10 +767,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
branch_name = branch_row["branch"] if branch_row else ""
is_agent_pr = not branch_name.startswith(PIPELINE_OWNED_PREFIXES)
conn.execute(
"UPDATE prs SET status = 'approved', domain = COALESCE(domain, ?), auto_merge = ? WHERE number = ?",
(domain, 1 if is_agent_pr else 0, pr_number),
)
approve_pr(conn, pr_number, domain=domain, auto_merge=1 if is_agent_pr else 0)
db.audit(
conn,
"evaluate",
@ -821,10 +792,7 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
if leo_verdict == "request_changes" and leo_review is not None:
all_issues.extend(_parse_issues(leo_review))
conn.execute(
"UPDATE prs SET status = 'open', eval_issues = ? WHERE number = ?",
(json.dumps(all_issues), pr_number),
)
reopen_pr(conn, pr_number, eval_issues=json.dumps(all_issues))
# Store feedback for re-extraction path
feedback = {"leo": leo_verdict, "domain": domain_verdict, "tier": tier, "issues": all_issues}
conn.execute(
@ -1050,11 +1018,7 @@ async def _run_batch_domain_eval(
pr_num = pr_row["number"]
# Atomic claim
cursor = conn.execute(
"UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'",
(pr_num,),
)
if cursor.rowcount == 0:
if not start_review(conn, pr_num):
continue
# Increment eval_attempts — skip if merge-cycled (Ganymede+Rhea)
@ -1074,7 +1038,7 @@ async def _run_batch_domain_eval(
diff = await _get_pr_diff(pr_num)
if not diff:
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,))
reopen_pr(conn, pr_num)
continue
# Musings bypass
@ -1084,12 +1048,8 @@ async def _run_batch_domain_eval(
repo_path(f"issues/{pr_num}/comments"),
{"body": "Auto-approved: musings bypass eval per collective policy."},
)
conn.execute(
"UPDATE prs SET status = 'approved', leo_verdict = 'skipped', "
"domain_verdict = 'skipped', domain = COALESCE(domain, 'cross-domain'), "
"auto_merge = 1 WHERE number = ?",
(pr_num,),
)
approve_pr(conn, pr_num, domain='cross-domain', auto_merge=1,
leo_verdict='skipped', domain_verdict='skipped')
succeeded += 1
continue
@ -1130,11 +1090,7 @@ async def _run_batch_domain_eval(
running_bytes += p_bytes
overflow = [p for p in pr_diffs if p not in kept]
for p in overflow:
conn.execute(
"UPDATE prs SET status = 'open', eval_attempts = COALESCE(eval_attempts, 1) - 1 "
"WHERE number = ?",
(p["number"],),
)
reopen_pr(conn, p["number"], dec_eval_attempts=True)
claimed_prs.remove(p["number"])
logger.info(
"PR #%d: diff too large for batch (%d bytes total), deferring to next cycle",
@ -1166,7 +1122,7 @@ async def _run_batch_domain_eval(
# Complete failure — revert all to open
logger.warning("Batch domain review failed — reverting all PRs to open")
for pr_num in claimed_prs:
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,))
reopen_pr(conn, pr_num)
return 0, len(claimed_prs)
# Step 3: Parse + validate fan-out
@ -1198,17 +1154,13 @@ async def _run_batch_domain_eval(
if pr_num in fallback_prs:
# Revert — will be picked up by individual eval next cycle
conn.execute(
"UPDATE prs SET status = 'open', eval_attempts = COALESCE(eval_attempts, 1) - 1 "
"WHERE number = ?",
(pr_num,),
)
reopen_pr(conn, pr_num, dec_eval_attempts=True)
logger.info("PR #%d: batch fallback — will retry individually", pr_num)
continue
if pr_num not in valid_reviews:
# Should not happen, but safety
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,))
reopen_pr(conn, pr_num)
continue
review_text = valid_reviews[pr_num]
@ -1235,11 +1187,9 @@ async def _run_batch_domain_eval(
"SELECT eval_attempts FROM prs WHERE number = ?", (pr_num,)
).fetchone()["eval_attempts"] or 0)
conn.execute(
"UPDATE prs SET status = 'open', leo_verdict = 'skipped', "
"last_error = 'domain review requested changes', eval_issues = ? WHERE number = ?",
(json.dumps(domain_issues), pr_num),
)
reopen_pr(conn, pr_num, leo_verdict='skipped',
last_error='domain review requested changes',
eval_issues=json.dumps(domain_issues))
db.audit(
conn, "evaluate", "domain_rejected",
json.dumps({"pr": pr_num, "agent": agent, "issues": domain_issues, "batch": True}),
@ -1257,12 +1207,12 @@ async def _run_batch_domain_eval(
leo_review, leo_usage = await run_leo_review(review_diff, files, "STANDARD")
if leo_review is None:
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,))
reopen_pr(conn, pr_num)
logger.debug("PR #%d: Leo review failed, will retry next cycle", pr_num)
continue
if leo_review == "RATE_LIMITED":
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_num,))
reopen_pr(conn, pr_num)
logger.info("PR #%d: Leo rate limited, will retry next cycle", pr_num)
continue
@ -1292,7 +1242,7 @@ async def _run_batch_domain_eval(
pr_info = await forgejo_api("GET", repo_path(f"pulls/{pr_num}"))
pr_author = pr_info.get("user", {}).get("login", "") if pr_info else ""
await _post_formal_approvals(pr_num, pr_author)
conn.execute("UPDATE prs SET status = 'approved', domain = COALESCE(domain, ?), auto_merge = 1 WHERE number = ?", (domain or "cross-domain", pr_num,))
approve_pr(conn, pr_num, domain=domain or 'cross-domain', auto_merge=1)
db.audit(
conn, "evaluate", "approved",
json.dumps({"pr": pr_num, "tier": "STANDARD", "domain": domain,
@ -1303,10 +1253,7 @@ async def _run_batch_domain_eval(
all_issues = []
if leo_verdict == "request_changes":
all_issues.extend(_parse_issues(leo_review))
conn.execute(
"UPDATE prs SET status = 'open', eval_issues = ? WHERE number = ?",
(json.dumps(all_issues), pr_num),
)
reopen_pr(conn, pr_num, eval_issues=json.dumps(all_issues))
feedback = {"leo": leo_verdict, "domain": domain_verdict,
"tier": "STANDARD", "issues": all_issues}
conn.execute(
@ -1453,7 +1400,7 @@ async def evaluate_cycle(conn, max_workers=None) -> tuple[int, int]:
logger.exception("Batch eval failed for domain %s", domain)
# Revert all to open
for pr_row in batch_prs:
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_row["number"],))
reopen_pr(conn, pr_row["number"])
failed += len(batch_prs)
# Process individual PRs (DEEP, LIGHT, single-domain, fallback)
@ -1502,7 +1449,7 @@ async def evaluate_cycle(conn, max_workers=None) -> tuple[int, int]:
except Exception:
logger.exception("Failed to evaluate PR #%d", row["number"])
failed += 1
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (row["number"],))
reopen_pr(conn, row["number"])
if succeeded or failed:
logger.info("Evaluate cycle: %d evaluated, %d errors", succeeded, failed)

View file

@ -24,6 +24,7 @@ from .db import classify_branch
from .dedup import dedup_evidence_blocks
from .domains import detect_domain_from_branch
from .forgejo import api as forgejo_api
from .pr_state import close_pr, mark_conflict, mark_conflict_permanent, mark_merged, reopen_pr
# Pipeline-owned branch prefixes — only these get auto-merged.
# Agent branches (theseus/*, rio/*, astra/*, etc.) stay approved but are NOT
@ -1431,10 +1432,7 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]:
logger.error(
"PR #%d merge timed out after %ds — resetting to conflict (Rhea)", pr_num, MERGE_TIMEOUT_SECONDS
)
conn.execute(
"UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?",
(f"merge timed out after {MERGE_TIMEOUT_SECONDS}s", pr_num),
)
mark_conflict(conn, pr_num, last_error=f"merge timed out after {MERGE_TIMEOUT_SECONDS}s")
db.audit(conn, "merge", "timeout", json.dumps({"pr": pr_num, "timeout_seconds": MERGE_TIMEOUT_SECONDS}))
failed += 1
continue
@ -1443,19 +1441,14 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]:
logger.warning("PR #%d merge/cherry-pick failed: %s", pr_num, pick_msg)
# Reweave: close immediately, don't retry (Ship: same rationale as ff-push failure)
if branch.startswith("reweave/"):
conn.execute(
"UPDATE prs SET status = 'closed', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?",
(f"reweave merge failed (closed, not retried): {pick_msg[:400]}", pr_num),
)
await forgejo_api("PATCH", repo_path(f"pulls/{pr_num}"), {"state": "closed"})
await close_pr(conn, pr_num,
last_error=f"reweave merge failed (closed, not retried): {pick_msg[:400]}",
merge_cycled=True, inc_merge_failures=True)
await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"),
{"body": f"Reweave merge failed — closing. Next nightly reweave will create a fresh branch.\n\nError: {pick_msg[:200]}"})
await _delete_remote_branch(branch)
else:
conn.execute(
"UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?",
(pick_msg[:500], pr_num),
)
mark_conflict(conn, pr_num, last_error=pick_msg[:500])
db.audit(conn, "merge", "cherry_pick_failed", json.dumps({"pr": pr_num, "error": pick_msg[:200]}))
failed += 1
continue
@ -1465,10 +1458,7 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]:
# The branch ref still points at old commits (not a descendant of main),
# so pushing branch_sha:main would fail as non-fast-forward.
if pick_msg in ("already merged (all commits empty)", "already up to date"):
conn.execute(
"UPDATE prs SET status = 'merged', merged_at = datetime('now'), last_error = NULL WHERE number = ?",
(pr_num,),
)
mark_merged(conn, pr_num)
db.audit(conn, "merge", "merged", json.dumps({"pr": pr_num, "branch": branch, "note": "content already on main"}))
leo_token = get_agent_token("leo")
await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"),
@ -1523,31 +1513,20 @@ async def _merge_domain_queue(conn, domain: str) -> tuple[int, int]:
# run creates a fresh branch from current main — retry is wasteful.
# (Ship: prevents reweave flood + wasted retry cycles)
if branch.startswith("reweave/"):
conn.execute(
"UPDATE prs SET status = 'closed', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?",
(f"reweave merge failed (closed, not retried): {merge_msg[:400]}", pr_num),
)
await forgejo_api("PATCH", repo_path(f"pulls/{pr_num}"), {"state": "closed"})
await close_pr(conn, pr_num,
last_error=f"reweave merge failed (closed, not retried): {merge_msg[:400]}",
merge_cycled=True, inc_merge_failures=True)
await forgejo_api("POST", repo_path(f"issues/{pr_num}/comments"),
{"body": f"Reweave merge failed — closing. Next nightly reweave will create a fresh branch.\n\nError: {merge_msg[:200]}"})
await _delete_remote_branch(branch)
else:
conn.execute(
"UPDATE prs SET status = 'conflict', merge_cycled = 1, merge_failures = COALESCE(merge_failures, 0) + 1, last_error = ? WHERE number = ?",
(merge_msg[:500], pr_num),
)
mark_conflict(conn, pr_num, last_error=merge_msg[:500])
db.audit(conn, "merge", "merge_failed", json.dumps({"pr": pr_num, "error": merge_msg[:200]}))
failed += 1
continue
# Success — update status and cleanup
conn.execute(
"""UPDATE prs SET status = 'merged',
merged_at = datetime('now'),
last_error = NULL
WHERE number = ?""",
(pr_num,),
)
mark_merged(conn, pr_num)
db.audit(conn, "merge", "merged", json.dumps({"pr": pr_num, "branch": branch}))
logger.info("PR #%d merged successfully", pr_num)
@ -1624,10 +1603,7 @@ async def _reconcile_db_state(conn):
is_merged = pr_info.get("merged", False)
if is_merged and db_status != "merged":
conn.execute(
"UPDATE prs SET status = 'merged', merged_at = datetime('now') WHERE number = ?",
(pr_number,),
)
mark_merged(conn, pr_number)
reconciled += 1
continue
@ -1637,10 +1613,7 @@ async def _reconcile_db_state(conn):
# trigger discover_external_prs → new PR → fail → close → repeat)
if branch:
await _delete_remote_branch(branch)
conn.execute(
"UPDATE prs SET status = 'closed', last_error = 'reconciled: closed on Forgejo' WHERE number = ?",
(pr_number,),
)
await close_pr(conn, pr_number, last_error='reconciled: closed on Forgejo', close_on_forgejo=False)
reconciled += 1
continue
@ -1663,10 +1636,7 @@ async def _reconcile_db_state(conn):
repo_path(f"issues/{pr_number}/comments"),
body={"body": "Auto-closed: branch deleted from remote."},
)
conn.execute(
"UPDATE prs SET status = 'closed', last_error = 'reconciled: branch deleted' WHERE number = ?",
(pr_number,),
)
await close_pr(conn, pr_number, last_error='reconciled: branch deleted', close_on_forgejo=False)
logger.info("Ghost PR #%d: branch %s deleted, closing", pr_number, branch)
reconciled += 1
@ -1763,11 +1733,9 @@ async def _handle_permanent_conflicts(conn) -> int:
except Exception:
pass
await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"})
conn.execute(
"UPDATE prs SET status = 'closed', last_error = 'conflict_permanent: closed + filed in archive' WHERE number = ?",
(pr_number,),
)
await close_pr(conn, pr_number,
last_error='conflict_permanent: closed + filed in archive',
close_on_forgejo=False) # Already closed at line 1718
handled += 1
logger.info("Permanent conflict handled: PR #%d closed, source filed", pr_number)
@ -1835,11 +1803,8 @@ async def _retry_conflict_prs(conn) -> tuple[int, int]:
# (Ship: prevents wasting 3 retry cycles on branches that can never cherry-pick)
if branch.startswith("reweave/"):
logger.info("Reweave PR #%d: skipping retry, closing + deleting branch", pr_number)
conn.execute(
"UPDATE prs SET status = 'closed', last_error = 'reweave: closed (retry skipped, next nightly creates fresh)' WHERE number = ?",
(pr_number,),
)
await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"})
await close_pr(conn, pr_number,
last_error='reweave: closed (retry skipped, next nightly creates fresh)')
await forgejo_api("POST", repo_path(f"issues/{pr_number}/comments"),
{"body": "Reweave conflict — closing instead of retrying. Cherry-pick always fails on reweave branches (they modify existing files). Next nightly reweave will create a fresh branch from current main."})
await _delete_remote_branch(branch)
@ -1858,29 +1823,16 @@ async def _retry_conflict_prs(conn) -> tuple[int, int]:
if ok:
# Rebase succeeded — reset for re-eval (Ganymede: approvals are stale after rebase)
conn.execute(
"""UPDATE prs
SET status = 'open',
leo_verdict = 'pending',
domain_verdict = 'pending',
eval_attempts = 0,
conflict_rebase_attempts = ?
WHERE number = ?""",
(attempts + 1, pr_number),
)
reopen_pr(conn, pr_number, reset_for_reeval=True,
conflict_rebase_attempts=attempts + 1)
logger.info("Conflict resolved: PR #%d rebased successfully, reset for re-eval", pr_number)
resolved += 1
else:
new_attempts = attempts + 1
if new_attempts >= MAX_CONFLICT_REBASE_ATTEMPTS:
conn.execute(
"""UPDATE prs
SET status = 'conflict_permanent',
conflict_rebase_attempts = ?,
last_error = ?
WHERE number = ?""",
(new_attempts, f"rebase failed {MAX_CONFLICT_REBASE_ATTEMPTS}x: {msg[:200]}", pr_number),
)
mark_conflict_permanent(conn, pr_number,
last_error=f"rebase failed {MAX_CONFLICT_REBASE_ATTEMPTS}x: {msg[:200]}",
conflict_rebase_attempts=new_attempts)
logger.warning("Conflict permanent: PR #%d failed %d rebase attempts: %s",
pr_number, new_attempts, msg[:100])
else:

197
lib/pr_state.py Normal file
View file

@ -0,0 +1,197 @@
"""PR state transitions — single source of truth for all status changes.
Every UPDATE prs SET status = ... MUST go through this module.
Invariants enforced:
- close: always syncs Forgejo (opt-out for reconciliation only)
- approve: requires non-empty domain (ValueError)
- merged: always sets merged_at, clears last_error
- conflict: always increments merge_failures, sets merge_cycled
Why this exists: 36 hand-crafted status transitions across evaluate.py
and merge.py produced 3 incidents (domain NULL, Forgejo ghost PRs,
merge_cycled missing). Centralizing eliminates the entire class of
"forgot to update X in this one code path" bugs.
"""
import logging
from .forgejo import api as forgejo_api, repo_path
logger = logging.getLogger("pipeline.pr_state")
async def close_pr(
conn,
pr_number: int,
*,
last_error: str = None,
merge_cycled: bool = False,
inc_merge_failures: bool = False,
close_on_forgejo: bool = True,
):
"""Close a PR in DB and on Forgejo.
Args:
close_on_forgejo: False only when caller already closed on Forgejo
(reconciliation, ghost PR cleanup after manual close).
"""
if close_on_forgejo:
await forgejo_api("PATCH", repo_path(f"pulls/{pr_number}"), {"state": "closed"})
parts = ["status = 'closed'"]
params = []
if last_error is not None:
parts.append("last_error = ?")
params.append(last_error)
if merge_cycled:
parts.append("merge_cycled = 1")
if inc_merge_failures:
parts.append("merge_failures = COALESCE(merge_failures, 0) + 1")
params.append(pr_number)
conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params)
def approve_pr(
conn,
pr_number: int,
*,
domain: str,
auto_merge: int = 0,
leo_verdict: str = None,
domain_verdict: str = None,
):
"""Approve a PR. Raises ValueError if domain is empty/None."""
if not domain:
raise ValueError(f"Cannot approve PR #{pr_number} without domain")
parts = ["status = 'approved'", "domain = COALESCE(domain, ?)"]
params = [domain]
parts.append("auto_merge = ?")
params.append(auto_merge)
if leo_verdict is not None:
parts.append("leo_verdict = ?")
params.append(leo_verdict)
if domain_verdict is not None:
parts.append("domain_verdict = ?")
params.append(domain_verdict)
params.append(pr_number)
conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params)
def mark_merged(conn, pr_number: int):
"""Mark PR as merged. Always sets merged_at, clears last_error."""
conn.execute(
"UPDATE prs SET status = 'merged', merged_at = datetime('now'), "
"last_error = NULL WHERE number = ?",
(pr_number,),
)
def mark_conflict(conn, pr_number: int, *, last_error: str = None):
"""Mark PR as conflict. Always increments merge_failures, sets merge_cycled."""
conn.execute(
"UPDATE prs SET status = 'conflict', merge_cycled = 1, "
"merge_failures = COALESCE(merge_failures, 0) + 1, "
"last_error = ? WHERE number = ?",
(last_error, pr_number),
)
def mark_conflict_permanent(
conn,
pr_number: int,
*,
last_error: str = None,
conflict_rebase_attempts: int = None,
):
"""Mark PR as permanently conflicted (no more retries)."""
parts = ["status = 'conflict_permanent'"]
params = []
if last_error is not None:
parts.append("last_error = ?")
params.append(last_error)
if conflict_rebase_attempts is not None:
parts.append("conflict_rebase_attempts = ?")
params.append(conflict_rebase_attempts)
params.append(pr_number)
conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params)
def reopen_pr(
conn,
pr_number: int,
*,
leo_verdict: str = None,
domain_verdict: str = None,
last_error: str = None,
eval_issues: str = None,
dec_eval_attempts: bool = False,
reset_for_reeval: bool = False,
conflict_rebase_attempts: int = None,
):
"""Set PR back to open.
Covers all reopen scenarios:
- Transient failure (API error): no extra args
- Rejection: leo_verdict + last_error + eval_issues
- Batch overflow: dec_eval_attempts=True
- Conflict resolved: reset_for_reeval=True
"""
parts = ["status = 'open'"]
params = []
if reset_for_reeval:
parts.extend([
"leo_verdict = 'pending'",
"domain_verdict = 'pending'",
"eval_attempts = 0",
])
else:
if leo_verdict is not None:
parts.append("leo_verdict = ?")
params.append(leo_verdict)
if domain_verdict is not None:
parts.append("domain_verdict = ?")
params.append(domain_verdict)
if last_error is not None:
parts.append("last_error = ?")
params.append(last_error)
if eval_issues is not None:
parts.append("eval_issues = ?")
params.append(eval_issues)
if dec_eval_attempts:
parts.append("eval_attempts = COALESCE(eval_attempts, 1) - 1")
if conflict_rebase_attempts is not None:
parts.append("conflict_rebase_attempts = ?")
params.append(conflict_rebase_attempts)
params.append(pr_number)
conn.execute(f"UPDATE prs SET {', '.join(parts)} WHERE number = ?", params)
def start_review(conn, pr_number: int) -> bool:
"""Atomically claim PR for review (status open -> reviewing).
Returns True if claimed, False if already claimed by another worker.
"""
cursor = conn.execute(
"UPDATE prs SET status = 'reviewing' WHERE number = ? AND status = 'open'",
(pr_number,),
)
return cursor.rowcount > 0

336
tests/test_pr_state.py Normal file
View file

@ -0,0 +1,336 @@
"""Tests for lib/pr_state.py — centralized PR state transitions."""
import asyncio
import sqlite3
import sys
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Mock heavy dependencies before importing pr_state
sys.modules.setdefault("aiohttp", MagicMock())
# Add lib parent to path so `lib.pr_state` resolves
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from lib.pr_state import (
approve_pr,
close_pr,
mark_conflict,
mark_conflict_permanent,
mark_merged,
reopen_pr,
start_review,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_db():
"""Create a minimal in-memory DB with the prs table."""
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute("""
CREATE TABLE prs (
number INTEGER PRIMARY KEY,
source_path TEXT,
branch TEXT,
status TEXT NOT NULL DEFAULT 'open',
domain TEXT,
auto_merge INTEGER DEFAULT 0,
leo_verdict TEXT DEFAULT 'pending',
domain_verdict TEXT DEFAULT 'pending',
eval_attempts INTEGER DEFAULT 0,
eval_issues TEXT DEFAULT '[]',
merge_cycled INTEGER DEFAULT 0,
merge_failures INTEGER DEFAULT 0,
conflict_rebase_attempts INTEGER DEFAULT 0,
last_error TEXT,
merged_at TEXT,
last_attempt TEXT,
cost_usd REAL DEFAULT 0,
created_at TEXT DEFAULT (datetime('now'))
)
""")
return conn
def _insert_pr(conn, number=100, status="open", domain=None, **kwargs):
"""Insert a test PR row."""
cols = ["number", "status"]
vals = [number, status]
if domain is not None:
cols.append("domain")
vals.append(domain)
for k, v in kwargs.items():
cols.append(k)
vals.append(v)
placeholders = ", ".join(["?"] * len(vals))
col_names = ", ".join(cols)
conn.execute(f"INSERT INTO prs ({col_names}) VALUES ({placeholders})", vals)
conn.commit()
def _get_pr(conn, number=100):
"""Read a PR row back."""
return conn.execute("SELECT * FROM prs WHERE number = ?", (number,)).fetchone()
# ---------------------------------------------------------------------------
# close_pr
# ---------------------------------------------------------------------------
class TestClosePr:
def test_close_calls_forgejo_and_updates_db(self):
conn = _make_db()
_insert_pr(conn, 42)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, last_error="test close"))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
assert row["last_error"] == "test close"
mock_api.assert_called_once()
def test_close_skips_forgejo_when_opted_out(self):
conn = _make_db()
_insert_pr(conn, 42)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, last_error="reconciled", close_on_forgejo=False))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
mock_api.assert_not_called()
def test_close_increments_merge_failures(self):
conn = _make_db()
_insert_pr(conn, 42, merge_failures=2)
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42, merge_cycled=True, inc_merge_failures=True))
row = _get_pr(conn, 42)
assert row["merge_cycled"] == 1
assert row["merge_failures"] == 3
def test_close_without_last_error(self):
conn = _make_db()
_insert_pr(conn, 42, last_error="old error")
mock_api = AsyncMock(return_value={})
with patch("lib.pr_state.forgejo_api", mock_api), \
patch("lib.pr_state.repo_path", lambda s: f"/repos/test/{s}"):
asyncio.run(close_pr(conn, 42))
row = _get_pr(conn, 42)
assert row["status"] == "closed"
# last_error not overwritten when not provided
assert row["last_error"] == "old error"
# ---------------------------------------------------------------------------
# approve_pr
# ---------------------------------------------------------------------------
class TestApprovePr:
def test_approve_sets_domain_and_auto_merge(self):
conn = _make_db()
_insert_pr(conn, 50)
approve_pr(conn, 50, domain="internet-finance", auto_merge=1)
row = _get_pr(conn, 50)
assert row["status"] == "approved"
assert row["domain"] == "internet-finance"
assert row["auto_merge"] == 1
def test_approve_raises_on_empty_domain(self):
conn = _make_db()
_insert_pr(conn, 50)
with pytest.raises(ValueError, match="without domain"):
approve_pr(conn, 50, domain="")
def test_approve_raises_on_none_domain(self):
conn = _make_db()
_insert_pr(conn, 50)
with pytest.raises(ValueError, match="without domain"):
approve_pr(conn, 50, domain=None)
def test_approve_sets_verdicts(self):
conn = _make_db()
_insert_pr(conn, 50)
approve_pr(conn, 50, domain="cross-domain", auto_merge=1,
leo_verdict="skipped", domain_verdict="skipped")
row = _get_pr(conn, 50)
assert row["leo_verdict"] == "skipped"
assert row["domain_verdict"] == "skipped"
def test_approve_preserves_existing_domain(self):
conn = _make_db()
_insert_pr(conn, 50, domain="living-agents")
approve_pr(conn, 50, domain="cross-domain", auto_merge=0)
row = _get_pr(conn, 50)
# COALESCE(domain, ?) preserves existing domain
assert row["domain"] == "living-agents"
# ---------------------------------------------------------------------------
# mark_merged
# ---------------------------------------------------------------------------
class TestMarkMerged:
def test_sets_merged_at_and_clears_error(self):
conn = _make_db()
_insert_pr(conn, 60, status="approved", last_error="some old error")
mark_merged(conn, 60)
row = _get_pr(conn, 60)
assert row["status"] == "merged"
assert row["merged_at"] is not None
assert row["last_error"] is None
# ---------------------------------------------------------------------------
# mark_conflict
# ---------------------------------------------------------------------------
class TestMarkConflict:
def test_increments_failures_and_sets_cycled(self):
conn = _make_db()
_insert_pr(conn, 70, status="approved", merge_failures=0)
mark_conflict(conn, 70, last_error="cherry-pick failed")
row = _get_pr(conn, 70)
assert row["status"] == "conflict"
assert row["merge_cycled"] == 1
assert row["merge_failures"] == 1
assert row["last_error"] == "cherry-pick failed"
def test_accumulates_failures(self):
conn = _make_db()
_insert_pr(conn, 70, merge_failures=5)
mark_conflict(conn, 70)
row = _get_pr(conn, 70)
assert row["merge_failures"] == 6
# ---------------------------------------------------------------------------
# mark_conflict_permanent
# ---------------------------------------------------------------------------
class TestMarkConflictPermanent:
def test_sets_status_and_attempts(self):
conn = _make_db()
_insert_pr(conn, 80, status="conflict")
mark_conflict_permanent(conn, 80,
last_error="rebase failed 3x",
conflict_rebase_attempts=3)
row = _get_pr(conn, 80)
assert row["status"] == "conflict_permanent"
assert row["last_error"] == "rebase failed 3x"
assert row["conflict_rebase_attempts"] == 3
# ---------------------------------------------------------------------------
# reopen_pr
# ---------------------------------------------------------------------------
class TestReopenPr:
def test_simple_reopen(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing")
reopen_pr(conn, 90)
row = _get_pr(conn, 90)
assert row["status"] == "open"
def test_reopen_with_rejection(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing")
reopen_pr(conn, 90, leo_verdict="skipped",
last_error="domain rejected",
eval_issues='["factual_error"]')
row = _get_pr(conn, 90)
assert row["status"] == "open"
assert row["leo_verdict"] == "skipped"
assert row["last_error"] == "domain rejected"
assert row["eval_issues"] == '["factual_error"]'
def test_reopen_dec_eval_attempts(self):
conn = _make_db()
_insert_pr(conn, 90, status="reviewing", eval_attempts=3)
reopen_pr(conn, 90, dec_eval_attempts=True)
row = _get_pr(conn, 90)
assert row["eval_attempts"] == 2
def test_reopen_reset_for_reeval(self):
conn = _make_db()
_insert_pr(conn, 90, status="conflict",
leo_verdict="approve", domain_verdict="approve",
eval_attempts=2)
reopen_pr(conn, 90, reset_for_reeval=True,
conflict_rebase_attempts=1)
row = _get_pr(conn, 90)
assert row["status"] == "open"
assert row["leo_verdict"] == "pending"
assert row["domain_verdict"] == "pending"
assert row["eval_attempts"] == 0
assert row["conflict_rebase_attempts"] == 1
# ---------------------------------------------------------------------------
# start_review
# ---------------------------------------------------------------------------
class TestStartReview:
def test_claims_open_pr(self):
conn = _make_db()
_insert_pr(conn, 100)
assert start_review(conn, 100) is True
row = _get_pr(conn, 100)
assert row["status"] == "reviewing"
def test_rejects_non_open_pr(self):
conn = _make_db()
_insert_pr(conn, 100, status="reviewing")
assert start_review(conn, 100) is False
def test_double_claim_fails(self):
conn = _make_db()
_insert_pr(conn, 100)
assert start_review(conn, 100) is True
assert start_review(conn, 100) is False