Merge branch 'epimetheus/eval-cost-tracking'
This commit is contained in:
commit
6361c7e9e8
2 changed files with 74 additions and 14 deletions
|
|
@ -493,6 +493,9 @@ async def _dispose_rejected_pr(conn, pr_number: int, eval_attempts: int, all_iss
|
|||
|
||||
async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
||||
"""Evaluate a single PR. Returns result dict."""
|
||||
from . import costs
|
||||
pr_cost = 0.0
|
||||
|
||||
# Check eval attempt budget before claiming
|
||||
row = conn.execute("SELECT eval_attempts FROM prs WHERE number = ?", (pr_number,)).fetchone()
|
||||
eval_attempts = (row["eval_attempts"] or 0) if row else 0
|
||||
|
|
@ -608,10 +611,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
json.dumps({"pr": pr_number, "tier": tier}),
|
||||
)
|
||||
else:
|
||||
tier, triage_usage = await triage_pr(diff)
|
||||
# Record triage cost
|
||||
from . import costs
|
||||
costs.record_usage(
|
||||
tier, triage_usage, _triage_reason = await triage_pr(diff)
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.TRIAGE_MODEL, "eval_triage",
|
||||
input_tokens=triage_usage.get("prompt_tokens", 0),
|
||||
output_tokens=triage_usage.get("completion_tokens", 0),
|
||||
|
|
@ -674,6 +675,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
# OpenRouter failure (timeout, error) — revert to open for retry.
|
||||
# NOT a rate limit — don't trigger 15-min backoff, just skip this PR.
|
||||
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,))
|
||||
if pr_cost > 0:
|
||||
conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number))
|
||||
return {"pr": pr_number, "skipped": True, "reason": "openrouter_failed"}
|
||||
|
||||
domain_verdict = _parse_verdict(domain_review, agent)
|
||||
|
|
@ -714,6 +717,15 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
# Disposition: check if this PR should be terminated or kept open
|
||||
await _dispose_rejected_pr(conn, pr_number, eval_attempts, domain_issues)
|
||||
|
||||
if domain_verdict != "skipped":
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.EVAL_DOMAIN_MODEL, "eval_domain",
|
||||
input_tokens=domain_usage.get("prompt_tokens", 0),
|
||||
output_tokens=domain_usage.get("completion_tokens", 0),
|
||||
backend="openrouter",
|
||||
)
|
||||
if pr_cost > 0:
|
||||
conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number))
|
||||
return {
|
||||
"pr": pr_number,
|
||||
"domain_verdict": domain_verdict,
|
||||
|
|
@ -731,6 +743,15 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
if leo_review is None:
|
||||
# DEEP: Opus rate limited (queue for later). STANDARD: OpenRouter failed (skip, retry next cycle).
|
||||
conn.execute("UPDATE prs SET status = 'open' WHERE number = ?", (pr_number,))
|
||||
if domain_verdict != "skipped":
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.EVAL_DOMAIN_MODEL, "eval_domain",
|
||||
input_tokens=domain_usage.get("prompt_tokens", 0),
|
||||
output_tokens=domain_usage.get("completion_tokens", 0),
|
||||
backend="openrouter",
|
||||
)
|
||||
if pr_cost > 0:
|
||||
conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number))
|
||||
reason = "opus_rate_limited" if tier == "DEEP" else "openrouter_failed"
|
||||
return {"pr": pr_number, "skipped": True, "reason": reason}
|
||||
|
||||
|
|
@ -834,10 +855,8 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
await _dispose_rejected_pr(conn, pr_number, eval_attempts, all_issues)
|
||||
|
||||
# Record cost (only for reviews that actually ran)
|
||||
from . import costs
|
||||
|
||||
if domain_verdict != "skipped":
|
||||
costs.record_usage(
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.EVAL_DOMAIN_MODEL, "eval_domain",
|
||||
input_tokens=domain_usage.get("prompt_tokens", 0),
|
||||
output_tokens=domain_usage.get("completion_tokens", 0),
|
||||
|
|
@ -845,15 +864,23 @@ async def evaluate_pr(conn, pr_number: int, tier: str = None) -> dict:
|
|||
)
|
||||
if leo_verdict not in ("skipped",):
|
||||
if tier == "DEEP":
|
||||
costs.record_usage(conn, config.EVAL_LEO_MODEL, "eval_leo", backend="max")
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.EVAL_LEO_MODEL, "eval_leo",
|
||||
input_tokens=leo_usage.get("prompt_tokens", 0),
|
||||
output_tokens=leo_usage.get("completion_tokens", 0),
|
||||
backend="max",
|
||||
)
|
||||
else:
|
||||
costs.record_usage(
|
||||
pr_cost += costs.record_usage(
|
||||
conn, config.EVAL_LEO_STANDARD_MODEL, "eval_leo",
|
||||
input_tokens=leo_usage.get("prompt_tokens", 0),
|
||||
output_tokens=leo_usage.get("completion_tokens", 0),
|
||||
backend="openrouter",
|
||||
)
|
||||
|
||||
if pr_cost > 0:
|
||||
conn.execute("UPDATE prs SET cost_usd = cost_usd + ? WHERE number = ?", (pr_cost, pr_number))
|
||||
|
||||
return {
|
||||
"pr": pr_number,
|
||||
"tier": tier,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from .domains import agent_for_domain
|
|||
from .extraction_prompt import build_extraction_prompt
|
||||
from .forgejo import api as forgejo_api
|
||||
from .llm import openrouter_call
|
||||
from .connect import connect_new_claims
|
||||
from .post_extract import load_existing_claims_from_repo, validate_and_fix_claims
|
||||
from .worktree_lock import async_main_worktree_lock
|
||||
|
||||
|
|
@ -225,7 +226,29 @@ def _build_claim_content(claim: dict, agent: str) -> str:
|
|||
body = claim.get("body", "")
|
||||
scope = claim.get("scope", "")
|
||||
sourcer = claim.get("sourcer", "")
|
||||
related = claim.get("related_claims", [])
|
||||
related_claims = claim.get("related_claims", [])
|
||||
connections = claim.get("connections", [])
|
||||
|
||||
edge_fields = {"supports": [], "challenges": [], "related": []}
|
||||
for conn in connections:
|
||||
target = conn.get("target", "")
|
||||
rel = conn.get("relationship", "related")
|
||||
if target and rel in edge_fields:
|
||||
target = target.replace(".md", "")
|
||||
if target not in edge_fields[rel]:
|
||||
edge_fields[rel].append(target)
|
||||
for r in related_claims[:5]:
|
||||
r_clean = r.replace(".md", "")
|
||||
if r_clean not in edge_fields["related"]:
|
||||
edge_fields["related"].append(r_clean)
|
||||
|
||||
edge_lines = []
|
||||
for edge_type in ("supports", "challenges", "related"):
|
||||
targets = edge_fields[edge_type]
|
||||
if targets:
|
||||
edge_lines.append(f"{edge_type}:")
|
||||
for t in targets:
|
||||
edge_lines.append(f" - {t}")
|
||||
|
||||
lines = [
|
||||
"---",
|
||||
|
|
@ -242,10 +265,7 @@ def _build_claim_content(claim: dict, agent: str) -> str:
|
|||
lines.append(f"scope: {scope}")
|
||||
if sourcer:
|
||||
lines.append(f'sourcer: "{sourcer}"')
|
||||
if related:
|
||||
lines.append("related_claims:")
|
||||
for r in related:
|
||||
lines.append(f' - "[[{r}]]"')
|
||||
lines.extend(edge_lines)
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
lines.append(f"# {title}")
|
||||
|
|
@ -456,6 +476,19 @@ async def _extract_one_source(
|
|||
await _archive_source(source_path, domain, "null-result")
|
||||
return 0, 0
|
||||
|
||||
# Post-write: connect new claims to existing KB via vector search (non-fatal)
|
||||
claim_paths = [str(worktree / f) for f in files_written if f.startswith("domains/")]
|
||||
if claim_paths:
|
||||
try:
|
||||
connect_stats = connect_new_claims(claim_paths)
|
||||
if connect_stats["connected"] > 0:
|
||||
logger.info(
|
||||
"Extract-connect: %d/%d claims → %d edges",
|
||||
connect_stats["connected"], len(claim_paths), connect_stats["edges_added"],
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Extract-connect failed (non-fatal)", exc_info=True)
|
||||
|
||||
# Stage and commit
|
||||
for f in files_written:
|
||||
await _git("add", f, cwd=str(EXTRACT_WORKTREE))
|
||||
|
|
|
|||
Loading…
Reference in a new issue