"""Circuit breaker state machine — per-stage, backed by SQLite.""" import logging from datetime import datetime, timezone from . import config logger = logging.getLogger("pipeline.breaker") # States CLOSED = "closed" OPEN = "open" HALFOPEN = "halfopen" class CircuitBreaker: """Per-stage circuit breaker. CLOSED: normal operation OPEN: stage paused (threshold consecutive failures reached) HALFOPEN: cooldown expired, try 1 worker to probe recovery """ def __init__(self, name: str, conn): self.name = name self.conn = conn self._ensure_row() def _ensure_row(self): self.conn.execute( "INSERT OR IGNORE INTO circuit_breakers (name) VALUES (?)", (self.name,), ) def _get_state(self) -> dict: row = self.conn.execute( "SELECT state, failures, successes, tripped_at, last_success_at FROM circuit_breakers WHERE name = ?", (self.name,), ).fetchone() return ( dict(row) if row else {"state": CLOSED, "failures": 0, "successes": 0, "tripped_at": None, "last_success_at": None} ) def _set_state( self, state: str, failures: int = None, successes: int = None, tripped_at: str = None, last_success_at: str = None, ): updates = ["state = ?", "last_update = datetime('now')"] params = [state] if failures is not None: updates.append("failures = ?") params.append(failures) if successes is not None: updates.append("successes = ?") params.append(successes) if tripped_at is not None: updates.append("tripped_at = ?") params.append(tripped_at) if last_success_at is not None: updates.append("last_success_at = ?") params.append(last_success_at) params.append(self.name) self.conn.execute( f"UPDATE circuit_breakers SET {', '.join(updates)} WHERE name = ?", params, ) def allow_request(self) -> bool: """Check if requests are allowed. Returns True if CLOSED or HALFOPEN.""" s = self._get_state() if s["state"] == CLOSED: return True if s["state"] == OPEN: # Check cooldown if s["tripped_at"]: tripped = datetime.fromisoformat(s["tripped_at"]) if tripped.tzinfo is None: tripped = tripped.replace(tzinfo=timezone.utc) elapsed = (datetime.now(timezone.utc) - tripped).total_seconds() if elapsed >= config.BREAKER_COOLDOWN: logger.info("Breaker %s: cooldown expired, entering HALFOPEN", self.name) self._set_state(HALFOPEN, successes=0) return True return False # HALFOPEN — allow one probe return True def max_workers(self) -> int: """Return max workers allowed in current state.""" s = self._get_state() if s["state"] == HALFOPEN: return 1 # probe with single worker return None # no restriction from breaker def record_success(self): """Record a successful cycle. Updates last_success_at for stall detection (Vida).""" s = self._get_state() now = datetime.now(timezone.utc).isoformat() if s["state"] == HALFOPEN: logger.info("Breaker %s: HALFOPEN probe succeeded, closing", self.name) self._set_state(CLOSED, failures=0, successes=0, last_success_at=now) elif s["state"] == CLOSED: if s["failures"] > 0: self._set_state(CLOSED, failures=0, last_success_at=now) else: self._set_state(CLOSED, last_success_at=now) def record_failure(self): """Record a failed cycle.""" s = self._get_state() if s["state"] == HALFOPEN: logger.warning("Breaker %s: HALFOPEN probe failed, reopening", self.name) self._set_state( OPEN, failures=s["failures"] + 1, tripped_at=datetime.now(timezone.utc).isoformat(), ) elif s["state"] == CLOSED: new_failures = s["failures"] + 1 if new_failures >= config.BREAKER_THRESHOLD: logger.warning( "Breaker %s: threshold reached (%d failures), opening", self.name, new_failures, ) self._set_state( OPEN, failures=new_failures, tripped_at=datetime.now(timezone.utc).isoformat(), ) else: self._set_state(CLOSED, failures=new_failures) elif s["state"] == OPEN: self._set_state(OPEN, failures=s["failures"] + 1) def reset(self): """Force reset to CLOSED.""" logger.info("Breaker %s: force reset to CLOSED", self.name) self._set_state(CLOSED, failures=0, successes=0)