Skip to main content

split_brain_harness/
serve.rs

1/// OpenAI-compatible HTTP proxy server.
2///
3/// Exposes `POST /v1/chat/completions` so any OpenAI-speaking client
4/// (LangChain, Continue.dev, Cursor, custom agents) can route through the
5/// soul-injected telemetry pipeline with zero code changes.
6///
7/// Telemetry is returned two ways:
8///   1. The `content` field carries both the model's answer AND a
9///      `<!-- sbh-telemetry: {...} -->` HTML comment at the end.
10///   2. The `x-sbh-telemetry` response header carries the same JSON, URL-encoded.
11///
12/// Hardening:
13///   - `SBH_SERVE_KEY`      — require Bearer token on all requests
14///   - `SBH_SERVE_RATE`     — max requests/min per IP (default 60)
15///   - `SBH_SERVE_MAX_BODY` — max body bytes (default 1 MiB)
16///
17/// Multi-turn session tracking:
18///   Pass `x-sbh-session: <id>` on requests to link turns into a session.
19///   The response echoes the session ID. If the manipulation_risk signal shows
20///   an upward trend across turns (slow-boil escalation), the response sets
21///   `x-sbh-session-alert: escalation_detected`. Sessions expire after 30
22///   minutes of inactivity (lazy eviction on each request).
23///
24/// Start with: `sbh serve [--listen <addr>]`   default: 127.0.0.1:8088
25use std::collections::{HashMap, VecDeque};
26use std::net::{IpAddr, SocketAddr};
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::{Arc, Mutex};
29use std::time::{Duration, Instant};
30
31use axum::{
32    extract::{ConnectInfo, DefaultBodyLimit, State},
33    http::{HeaderMap, HeaderValue, StatusCode},
34    response::IntoResponse,
35    routing::{get, post},
36    Json, Router,
37};
38use serde::{Deserialize, Serialize};
39
40use crate::{analyze, session_log, types::Config};
41use anyhow::Context as _;
42
43// ---------------------------------------------------------------------------
44// Request / response types (OpenAI wire format subset)
45// ---------------------------------------------------------------------------
46
47#[derive(Debug, Deserialize)]
48pub struct ChatRequest {
49    pub model: Option<String>,
50    pub messages: Vec<ChatMessage>,
51    #[serde(default)]
52    pub stream: bool,
53    // All other fields are accepted and ignored
54    #[serde(flatten)]
55    pub _extra: serde_json::Value,
56}
57
58#[derive(Debug, Deserialize, Serialize, Clone)]
59pub struct ChatMessage {
60    pub role: String,
61    pub content: String,
62}
63
64#[derive(Debug, Serialize)]
65pub struct ChatResponse {
66    pub id: String,
67    pub object: String,
68    pub created: u64,
69    pub model: String,
70    pub choices: Vec<ChatChoice>,
71    pub usage: Usage,
72}
73
74#[derive(Debug, Serialize)]
75pub struct ChatChoice {
76    pub index: u32,
77    pub message: ChatMessage,
78    pub finish_reason: String,
79}
80
81#[derive(Debug, Serialize)]
82pub struct Usage {
83    pub prompt_tokens: u32,
84    pub completion_tokens: u32,
85    pub total_tokens: u32,
86}
87
88#[derive(Debug, Serialize)]
89struct ErrorBody {
90    error: ErrorDetail,
91}
92
93#[derive(Debug, Serialize)]
94struct ErrorDetail {
95    message: String,
96    #[serde(rename = "type")]
97    kind: String,
98}
99
100// ---------------------------------------------------------------------------
101// Session tracking — multi-turn manipulation detection
102// ---------------------------------------------------------------------------
103
104const SESSION_MAX_TURNS: usize = 10;
105const SESSION_TTL: Duration = Duration::from_secs(30 * 60);
106/// Maximum number of concurrent sessions held in memory. New sessions beyond
107/// this cap are refused rather than allowing unbounded HashMap growth.
108const SESSION_MAX_COUNT: usize = 10_000;
109/// Background sweep interval for evicting expired sessions.
110/// The per-request path no longer calls retain() — O(1) instead of O(N).
111const SESSION_SWEEP_INTERVAL: Duration = Duration::from_secs(5 * 60);
112
113// ---------------------------------------------------------------------------
114// Rate limiter — 16-shard sliding window, no extra deps
115// ---------------------------------------------------------------------------
116
117const RATE_LIMITER_SHARDS: usize = 16;
118/// Hard cap on total tracked IPs across all shards. Beyond this, new IPs
119/// are passed through untracked rather than allocating unbounded memory.
120const MAX_TRACKED_IPS: usize = 50_000;
121const MAX_IPS_PER_SHARD: usize = MAX_TRACKED_IPS / RATE_LIMITER_SHARDS;
122
123struct ShardedRateLimiter {
124    shards: Box<[Mutex<HashMap<IpAddr, VecDeque<Instant>>>; RATE_LIMITER_SHARDS]>,
125}
126
127impl ShardedRateLimiter {
128    fn new() -> Self {
129        Self {
130            shards: Box::new(std::array::from_fn(|_| Mutex::new(HashMap::new()))),
131        }
132    }
133
134    fn shard_idx(ip: IpAddr) -> usize {
135        use std::hash::{Hash, Hasher};
136        let mut h = std::collections::hash_map::DefaultHasher::new();
137        ip.hash(&mut h);
138        (h.finish() as usize) % RATE_LIMITER_SHARDS
139    }
140
141    fn check(&self, ip: IpAddr, max_per_minute: u32) -> bool {
142        let idx = Self::shard_idx(ip);
143        let now = Instant::now();
144        let window = Duration::from_secs(60);
145        let mut shard = self.shards[idx].lock().unwrap_or_else(|e| e.into_inner());
146        let is_new = !shard.contains_key(&ip);
147        if is_new && shard.len() >= MAX_IPS_PER_SHARD {
148            // Shard full — try to evict one expired entry first.
149            // If none are expired, pass request through untracked: a sustained
150            // attack filling all shards still hits per-session caps.
151            let expired = shard
152                .iter()
153                .find(|(_, q)| q.back().map_or(true, |&t| now.duration_since(t) > window))
154                .map(|(k, _)| *k);
155            match expired {
156                Some(evict) => {
157                    shard.remove(&evict);
158                }
159                None => return true,
160            }
161        }
162        let queue = shard.entry(ip).or_default();
163        while let Some(&front) = queue.front() {
164            if now.duration_since(front) > window {
165                queue.pop_front();
166            } else {
167                break;
168            }
169        }
170        if queue.len() >= max_per_minute as usize {
171            return false;
172        }
173        queue.push_back(now);
174        true
175    }
176}
177
178/// One analyzed turn in a session, recording the risk signals.
179#[derive(Debug, Clone)]
180struct SessionTurn {
181    manipulation_risk: String,
182}
183
184/// Ring buffer of the most recent turns for one session.
185#[derive(Debug)]
186struct SessionHistory {
187    turns: VecDeque<SessionTurn>,
188    last_seen: Instant,
189}
190
191impl SessionHistory {
192    fn new() -> Self {
193        Self {
194            turns: VecDeque::new(),
195            last_seen: Instant::now(),
196        }
197    }
198
199    fn push(&mut self, risk: &str) {
200        let now = Instant::now();
201        self.last_seen = now;
202        if self.turns.len() >= SESSION_MAX_TURNS {
203            self.turns.pop_front();
204        }
205        self.turns.push_back(SessionTurn {
206            manipulation_risk: risk.to_string(),
207        });
208    }
209
210    /// Returns true when the current session shows an upward escalation in
211    /// manipulation_risk compared to the historical mean. Requires ≥3 turns.
212    ///
213    /// Algorithm: map risk to 0/1/2, compute mean of all-but-last turns.
214    /// Escalation fires when the latest turn scores above the historical mean
215    /// by more than 0.5 AND is not "low".
216    fn is_escalating(&self) -> bool {
217        if self.turns.len() < 3 {
218            return false;
219        }
220        let scores: Vec<f64> = self
221            .turns
222            .iter()
223            .map(|t| risk_score(&t.manipulation_risk))
224            .collect();
225        let n = scores.len();
226        let historical_mean: f64 = scores[..n - 1].iter().sum::<f64>() / (n - 1) as f64;
227        let current = scores[n - 1];
228        current > (historical_mean + 0.5) && current >= 1.0
229    }
230
231    fn turn_count(&self) -> usize {
232        self.turns.len()
233    }
234
235    /// Returns (trajectory, historical_mean) — the same values used by
236    /// `is_escalating`, exposed so the caller can write a session log entry.
237    fn risk_summary(&self) -> (Vec<String>, f64) {
238        let trajectory: Vec<String> = self
239            .turns
240            .iter()
241            .map(|t| t.manipulation_risk.clone())
242            .collect();
243        let n = trajectory.len();
244        if n < 2 {
245            return (trajectory, 0.0);
246        }
247        let scores: Vec<f64> = self
248            .turns
249            .iter()
250            .map(|t| risk_score(&t.manipulation_risk))
251            .collect();
252        let historical_mean = scores[..n - 1].iter().sum::<f64>() / (n - 1) as f64;
253        (trajectory, historical_mean)
254    }
255}
256
257fn risk_score(risk: &str) -> f64 {
258    match risk {
259        "high" => 2.0,
260        "medium" => 1.0,
261        _ => 0.0,
262    }
263}
264
265// ---------------------------------------------------------------------------
266// Witness status cache — polled every 30 s by a background task
267// ---------------------------------------------------------------------------
268
269const WITNESS_ACTIVE: u8 = 0;
270const WITNESS_INACTIVE: u8 = 1;
271const WITNESS_UNCONFIGURED: u8 = 2;
272
273fn witness_status_str(v: u8) -> &'static str {
274    match v {
275        WITNESS_ACTIVE => "active",
276        WITNESS_INACTIVE => "inactive",
277        _ => "not-configured",
278    }
279}
280
281/// Spawn a background task that polls `witness status` once at startup and
282/// every 30 seconds thereafter. The result is stored in `cache` (an AtomicU8)
283/// so that the hot request path never blocks on a subprocess.
284///
285/// Only spawned when `audit_path` is Some — otherwise status is fixed to
286/// WITNESS_UNCONFIGURED.
287fn spawn_witness_poller(cache: Arc<std::sync::atomic::AtomicU8>) {
288    tokio::spawn(async move {
289        loop {
290            let result = tokio::process::Command::new("witness")
291                .arg("status")
292                .stdout(std::process::Stdio::null())
293                .stderr(std::process::Stdio::null())
294                .status()
295                .await;
296            let val = match result {
297                Ok(s) if s.success() => WITNESS_ACTIVE,
298                _ => WITNESS_INACTIVE,
299            };
300            cache.store(val, std::sync::atomic::Ordering::Relaxed);
301            tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
302        }
303    });
304}
305
306// ---------------------------------------------------------------------------
307// Metrics — lock-free counters, Prometheus text exposition
308// ---------------------------------------------------------------------------
309
310#[derive(Default)]
311pub struct Metrics {
312    pub requests_total: AtomicU64,
313    pub requests_ok_total: AtomicU64,
314    pub requests_error_total: AtomicU64,
315    pub auth_failures_total: AtomicU64,
316    pub rate_limit_total: AtomicU64,
317    pub escalations_total: AtomicU64,
318}
319
320impl Metrics {
321    fn inc(counter: &AtomicU64) {
322        counter.fetch_add(1, Ordering::Relaxed);
323    }
324
325    pub fn render(&self, active_sessions: usize, uptime_secs: u64) -> String {
326        let mut out = String::with_capacity(512);
327        let pairs: &[(&str, &str, &str, u64)] = &[
328            (
329                "sbh_requests_total",
330                "counter",
331                "Total POST /v1/chat/completions requests",
332                self.requests_total.load(Ordering::Relaxed),
333            ),
334            (
335                "sbh_requests_ok_total",
336                "counter",
337                "Requests that returned 200 OK",
338                self.requests_ok_total.load(Ordering::Relaxed),
339            ),
340            (
341                "sbh_requests_error_total",
342                "counter",
343                "Requests that returned 4xx or 5xx",
344                self.requests_error_total.load(Ordering::Relaxed),
345            ),
346            (
347                "sbh_auth_failures_total",
348                "counter",
349                "Requests rejected for missing/invalid auth key",
350                self.auth_failures_total.load(Ordering::Relaxed),
351            ),
352            (
353                "sbh_rate_limit_total",
354                "counter",
355                "Requests rejected by per-IP rate limiter",
356                self.rate_limit_total.load(Ordering::Relaxed),
357            ),
358            (
359                "sbh_escalations_total",
360                "counter",
361                "Slow-boil session escalation events detected",
362                self.escalations_total.load(Ordering::Relaxed),
363            ),
364            (
365                "sbh_active_sessions",
366                "gauge",
367                "Sessions currently held in memory",
368                active_sessions as u64,
369            ),
370            (
371                "sbh_uptime_seconds",
372                "gauge",
373                "Seconds since sbh serve started",
374                uptime_secs,
375            ),
376        ];
377        for (name, kind, help, value) in pairs {
378            out.push_str(&format!("# HELP {name} {help}\n"));
379            out.push_str(&format!("# TYPE {name} {kind}\n"));
380            out.push_str(&format!("{name} {value}\n"));
381        }
382        out
383    }
384}
385
386// ---------------------------------------------------------------------------
387// Server state
388// ---------------------------------------------------------------------------
389
390#[derive(Clone)]
391pub struct ServeState {
392    config: Arc<Config>,
393    /// Per-IP sliding window — sharded to avoid global lock contention.
394    rate_limiter: Arc<ShardedRateLimiter>,
395    /// Per-session turn history for multi-turn escalation detection.
396    sessions: Arc<Mutex<HashMap<String, SessionHistory>>>,
397    /// Path to append-only session escalation log. Written on every escalation event.
398    session_log_path: Option<String>,
399    /// Prometheus-style counters, shared across handler clones.
400    metrics: Arc<Metrics>,
401    /// Timestamp of server start, used to compute uptime.
402    start_time: Arc<Instant>,
403    /// Cached witness status, refreshed every 30s by a background task.
404    /// "active" | "inactive" | "not-configured"
405    witness_status: Arc<std::sync::atomic::AtomicU8>,
406}
407
408// ---------------------------------------------------------------------------
409// Route handler
410// ---------------------------------------------------------------------------
411
412async fn chat_completions(
413    State(state): State<ServeState>,
414    ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
415    headers: HeaderMap,
416    Json(req): Json<ChatRequest>,
417) -> impl IntoResponse {
418    let config = &*state.config;
419    Metrics::inc(&state.metrics.requests_total);
420
421    // --- serve-level auth (checked before anything else) ---
422    if let Some(sk) = &config.serve_key {
423        let provided = headers
424            .get("authorization")
425            .and_then(|v| v.to_str().ok())
426            .map(|s| s.trim_start_matches("Bearer ").trim().to_string())
427            .unwrap_or_default();
428        if &provided != sk {
429            Metrics::inc(&state.metrics.auth_failures_total);
430            Metrics::inc(&state.metrics.requests_error_total);
431            let body = ErrorBody {
432                error: ErrorDetail {
433                    message: "Unauthorized: invalid or missing SBH serve key.".into(),
434                    kind: "authentication_error".into(),
435                },
436            };
437            return (
438                StatusCode::UNAUTHORIZED,
439                HeaderMap::new(),
440                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
441            )
442                .into_response();
443        }
444    }
445
446    // --- per-IP rate limit ---
447    let ip = remote_addr.ip();
448    if !state.rate_limiter.check(ip, config.serve_rate_limit) {
449        Metrics::inc(&state.metrics.rate_limit_total);
450        Metrics::inc(&state.metrics.requests_error_total);
451        let body = ErrorBody {
452            error: ErrorDetail {
453                message: format!(
454                    "Rate limit exceeded: max {} requests/min per IP.",
455                    config.serve_rate_limit
456                ),
457                kind: "rate_limit_error".into(),
458            },
459        };
460        return (
461            StatusCode::TOO_MANY_REQUESTS,
462            HeaderMap::new(),
463            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
464        )
465            .into_response();
466    }
467
468    // --- streaming not supported ---
469    if req.stream {
470        let body = ErrorBody {
471            error: ErrorDetail {
472                message: "sbh serve does not support streaming. Set stream=false.".into(),
473                kind: "unsupported_parameter".into(),
474            },
475        };
476        return (
477            StatusCode::BAD_REQUEST,
478            HeaderMap::new(),
479            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
480        )
481            .into_response();
482    }
483
484    // --- extract last user message ---
485    let user_input = req
486        .messages
487        .iter()
488        .rev()
489        .find(|m| m.role == "user")
490        .map(|m| m.content.as_str())
491        .unwrap_or("");
492
493    if user_input.is_empty() {
494        let body = ErrorBody {
495            error: ErrorDetail {
496                message: "No user message found in messages array.".into(),
497                kind: "invalid_request_error".into(),
498            },
499        };
500        return (
501            StatusCode::BAD_REQUEST,
502            HeaderMap::new(),
503            Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
504        )
505            .into_response();
506    }
507
508    // --- optionally forward Authorization as upstream API key
509    //     (only when serve_key is NOT set — when serve_key is set, auth is
510    //      used for access control and must not leak to the upstream) ---
511    let mut cfg = (*state.config).clone();
512    if config.serve_key.is_none() {
513        if let Some(auth) = headers.get("authorization") {
514            if let Ok(val) = auth.to_str() {
515                let key = val.trim_start_matches("Bearer ").trim().to_string();
516                if !key.is_empty() {
517                    cfg.api_key = Some(key);
518                }
519            }
520        }
521    }
522
523    // --- session ID: validate client-supplied or mint a cryptographically random one ---
524    let session_id = headers
525        .get("x-sbh-session")
526        .and_then(|v| v.to_str().ok())
527        // Only accept IDs that are safe for HTTP headers and won't enable enumeration
528        .filter(|s| {
529            !s.is_empty()
530                && s.len() <= 64
531                && s.chars()
532                    .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
533        })
534        .map(|s| s.to_string())
535        .unwrap_or_else(mint_session_id);
536
537    // --- run the full harness pipeline ---
538    let result = match analyze(user_input, &cfg).await {
539        Ok(r) => r,
540        Err(e) => {
541            Metrics::inc(&state.metrics.requests_error_total);
542            let msg = e.to_string();
543            let (status, kind) = if msg.contains("input")
544                || msg.contains("null byte")
545                || msg.contains("too long")
546                || msg.contains("control char")
547            {
548                (StatusCode::BAD_REQUEST, "invalid_request_error")
549            } else {
550                (StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
551            };
552            let body = ErrorBody {
553                error: ErrorDetail {
554                    message: msg,
555                    kind: kind.into(),
556                },
557            };
558            return (
559                status,
560                HeaderMap::new(),
561                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
562            )
563                .into_response();
564        }
565    };
566
567    // --- session tracking: push turn, check for escalation, evict stale ---
568    let (session_turn_count, session_escalating, session_log_info) = {
569        let mut sessions = state.sessions.lock().unwrap_or_else(|e| e.into_inner());
570        let now = Instant::now();
571        // Lazy TTL: evict only the accessed session if it has expired.
572        // Full map cleanup runs in the background sweeper — no O(N) walk per request.
573        if let Some(h) = sessions.get(&session_id) {
574            if now.duration_since(h.last_seen) >= SESSION_TTL {
575                sessions.remove(&session_id);
576            }
577        }
578        // Refuse new sessions beyond the cap to prevent memory DoS.
579        let is_new = !sessions.contains_key(&session_id);
580        if is_new && sessions.len() >= SESSION_MAX_COUNT {
581            drop(sessions);
582            Metrics::inc(&state.metrics.requests_error_total);
583            let body = ErrorBody {
584                error: ErrorDetail {
585                    message: "session capacity reached — retry later".into(),
586                    kind: "capacity_error".into(),
587                },
588            };
589            return (
590                StatusCode::SERVICE_UNAVAILABLE,
591                HeaderMap::new(),
592                Json(serde_json::to_value(body).unwrap_or_else(|_| serde_json::json!({}))),
593            )
594                .into_response();
595        }
596        let hist = sessions
597            .entry(session_id.clone())
598            .or_insert_with(SessionHistory::new);
599        hist.push(&result.telemetry.intent_matrix.manipulation_risk);
600        let escalating = hist.is_escalating();
601        let summary = if escalating {
602            Some(hist.risk_summary())
603        } else {
604            None
605        };
606        (hist.turn_count(), escalating, summary)
607    };
608
609    // --- write session log entry on escalation ---
610    if session_escalating {
611        Metrics::inc(&state.metrics.escalations_total);
612        if let (Some(ref log_path), Some((trajectory, historical_mean))) =
613            (&state.session_log_path, session_log_info)
614        {
615            let entry = session_log::SessionLogEntry::new(
616                session_id.clone(),
617                session_turn_count,
618                trajectory,
619                historical_mean,
620                &ip,
621                user_input,
622            );
623            if let Err(e) = session_log::append(log_path, &entry) {
624                eprintln!("sbh serve: session log write error: {e}");
625            }
626        }
627    }
628
629    // --- build response ---
630    let telemetry_json = serde_json::to_string(&result).unwrap_or_else(|_| "{}".into());
631    let content = format!(
632        "{}\n\n<!-- sbh-telemetry: {} -->",
633        summarize_result(&result),
634        telemetry_json,
635    );
636
637    let model_name = req.model.as_deref().unwrap_or(&cfg.model_name).to_string();
638
639    let response_body = ChatResponse {
640        id: format!("sbh-{}", monotonic_id()),
641        object: "chat.completion".into(),
642        created: unix_now(),
643        model: model_name,
644        choices: vec![ChatChoice {
645            index: 0,
646            message: ChatMessage {
647                role: "assistant".into(),
648                content,
649            },
650            finish_reason: "stop".into(),
651        }],
652        usage: Usage {
653            prompt_tokens: (user_input.len() / 4) as u32,
654            completion_tokens: (telemetry_json.len() / 4) as u32,
655            total_tokens: ((user_input.len() + telemetry_json.len()) / 4) as u32,
656        },
657    };
658
659    let mut resp_headers = HeaderMap::new();
660    if let Ok(encoded) = url_encode(&telemetry_json) {
661        if let Ok(val) = HeaderValue::from_str(&encoded) {
662            resp_headers.insert("x-sbh-telemetry", val);
663        }
664    }
665    resp_headers.insert(
666        "x-sbh-version",
667        HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
668    );
669    // Witness status is refreshed every 30s by a background task — zero blocking here.
670    let witness_status = witness_status_str(
671        state
672            .witness_status
673            .load(std::sync::atomic::Ordering::Relaxed),
674    );
675    if let Ok(val) = HeaderValue::from_str(witness_status) {
676        resp_headers.insert("x-sbh-witness", val);
677    }
678    // Session headers
679    if let Ok(val) = HeaderValue::from_str(&session_id) {
680        resp_headers.insert("x-sbh-session", val);
681    }
682    if let Ok(val) = HeaderValue::from_str(&session_turn_count.to_string()) {
683        resp_headers.insert("x-sbh-session-turns", val);
684    }
685    if session_escalating {
686        resp_headers.insert(
687            "x-sbh-session-alert",
688            HeaderValue::from_static("escalation_detected"),
689        );
690    }
691
692    Metrics::inc(&state.metrics.requests_ok_total);
693    (
694        StatusCode::OK,
695        resp_headers,
696        Json(serde_json::to_value(response_body).unwrap_or_else(|_| serde_json::json!({"error":{"message":"serialization error","type":"internal_error"}}))),
697    )
698        .into_response()
699}
700
701// ---------------------------------------------------------------------------
702// Metrics endpoint — Prometheus text exposition format
703// ---------------------------------------------------------------------------
704
705async fn metrics_handler(State(state): State<ServeState>, headers: HeaderMap) -> impl IntoResponse {
706    // /metrics is protected by the same bearer key as the main endpoint.
707    // Without this, an unauthenticated observer can read request rates,
708    // escalation counts, and active session count.
709    if let Some(sk) = &state.config.serve_key {
710        let provided = headers
711            .get("authorization")
712            .and_then(|v| v.to_str().ok())
713            .map(|s| s.trim_start_matches("Bearer ").trim().to_string())
714            .unwrap_or_default();
715        if &provided != sk {
716            return (
717                StatusCode::UNAUTHORIZED,
718                [("content-type", "text/plain; charset=utf-8")],
719                "Unauthorized".to_string(),
720            );
721        }
722    }
723
724    let active_sessions = state
725        .sessions
726        .lock()
727        .unwrap_or_else(|e| e.into_inner())
728        .len();
729    let uptime_secs = state.start_time.elapsed().as_secs();
730    let body = state.metrics.render(active_sessions, uptime_secs);
731    (
732        StatusCode::OK,
733        [("content-type", "text/plain; version=0.0.4; charset=utf-8")],
734        body,
735    )
736}
737
738// ---------------------------------------------------------------------------
739// Health check
740// ---------------------------------------------------------------------------
741
742async fn health() -> impl IntoResponse {
743    Json(serde_json::json!({
744        "status": "ok",
745        "version": env!("CARGO_PKG_VERSION"),
746        "service": "split-brain-harness"
747    }))
748}
749
750// ---------------------------------------------------------------------------
751// Public entry point
752// ---------------------------------------------------------------------------
753
754pub async fn run_server(
755    listen: &str,
756    config: Config,
757    tls_cert: Option<&str>,
758    tls_key: Option<&str>,
759) -> anyhow::Result<()> {
760    let rate_limit = config.serve_rate_limit;
761    let max_body = config.serve_max_body_bytes;
762    let auth_enabled = config.serve_key.is_some();
763    let session_log_path = config.session_log_path.clone();
764    let context_path = config.context_path.clone();
765
766    let witness_cache = Arc::new(std::sync::atomic::AtomicU8::new(WITNESS_UNCONFIGURED));
767    if config.audit_path.is_some() {
768        spawn_witness_poller(Arc::clone(&witness_cache));
769    }
770
771    let sessions: Arc<Mutex<HashMap<String, SessionHistory>>> =
772        Arc::new(Mutex::new(HashMap::new()));
773
774    // Background task: sweep expired sessions every SESSION_SWEEP_INTERVAL.
775    // The hot path no longer calls retain() — this is the only full-map walk.
776    {
777        let sessions_sweep = Arc::clone(&sessions);
778        tokio::spawn(async move {
779            loop {
780                tokio::time::sleep(SESSION_SWEEP_INTERVAL).await;
781                let mut map = sessions_sweep.lock().unwrap_or_else(|e| e.into_inner());
782                let now = Instant::now();
783                map.retain(|_, h| now.duration_since(h.last_seen) < SESSION_TTL);
784            }
785        });
786    }
787
788    let state = ServeState {
789        config: Arc::new(config),
790        rate_limiter: Arc::new(ShardedRateLimiter::new()),
791        sessions,
792        session_log_path: session_log_path.clone(),
793        metrics: Arc::new(Metrics::default()),
794        start_time: Arc::new(Instant::now()),
795        witness_status: witness_cache,
796    };
797
798    let app = Router::new()
799        .route("/v1/chat/completions", post(chat_completions))
800        .route("/health", get(health))
801        .route("/metrics", get(metrics_handler))
802        .layer(DefaultBodyLimit::max(max_body))
803        .with_state(state);
804
805    let print_banner = |scheme: &str, addr: SocketAddr| {
806        eprintln!("sbh serve: listening on {scheme}://{addr}");
807        eprintln!("  POST /v1/chat/completions  — OpenAI-compatible harness proxy");
808        eprintln!("  GET  /health               — liveness check");
809        eprintln!("  GET  /metrics              — Prometheus counters");
810        eprintln!(
811            "  auth: {}  rate: {}/min/IP  max-body: {} bytes",
812            if auth_enabled { "enabled" } else { "disabled" },
813            rate_limit,
814            max_body,
815        );
816        match &session_log_path {
817            Some(p) => eprintln!("  session log: {p}"),
818            None => eprintln!("  session log: disabled (set SBH_SESSION_LOG or --session-log)"),
819        };
820        {
821            use crate::rag::ContextCorpus;
822            let embedded_count = ContextCorpus::embedded().len();
823            match context_path.as_deref() {
824                None => eprintln!("  context: {embedded_count} embedded docs (set SBH_CONTEXT_PATH to add operator docs)"),
825                Some(p) => match ContextCorpus::load(p) {
826                    Ok(extra) => eprintln!("  context: {} embedded + {} operator docs from {p}", embedded_count, extra.len()),
827                    Err(e) => eprintln!("  context: {p} load error — {e}"),
828                },
829            }
830        }
831    };
832
833    match (tls_cert, tls_key) {
834        (Some(cert), Some(key)) => {
835            use axum_server::tls_rustls::RustlsConfig;
836            let tls_config = RustlsConfig::from_pem_file(cert, key)
837                .await
838                .with_context(|| format!("TLS: failed to load cert={cert} key={key}"))?;
839            let addr: SocketAddr = listen
840                .parse()
841                .with_context(|| format!("invalid listen address: {listen}"))?;
842            print_banner("https", addr);
843            axum_server::bind_rustls(addr, tls_config)
844                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
845                .await?;
846        }
847        (Some(_), None) => anyhow::bail!("--tls-cert requires --tls-key"),
848        (None, Some(_)) => anyhow::bail!("--tls-key requires --tls-cert"),
849        (None, None) => {
850            let listener = tokio::net::TcpListener::bind(listen).await?;
851            let addr = listener.local_addr()?;
852            print_banner("http", addr);
853            axum::serve(
854                listener,
855                app.into_make_service_with_connect_info::<SocketAddr>(),
856            )
857            .await?;
858        }
859    }
860    Ok(())
861}
862
863// ---------------------------------------------------------------------------
864// Helpers
865// ---------------------------------------------------------------------------
866
867fn summarize_result(result: &crate::types::HarnessResult) -> String {
868    let t = &result.telemetry;
869    let v = &result.verification;
870    format!(
871        "[SBH Analysis]\nEmotion: {} (intensity {:.2})\nManipulation risk: {}\nCoherence: {:.2}\nVerification: {} (confidence {:.2}){}",
872        t.affective_telemetry.primary_emotion,
873        t.affective_telemetry.emotional_intensity,
874        t.intent_matrix.manipulation_risk,
875        t.cognitive_state.coherence_rating,
876        if v.passed { "passed" } else { "flagged" },
877        v.confidence,
878        if v.stop_and_ask {
879            "\n⚠ stop_and_ask=true — low confidence, review before acting"
880        } else {
881            ""
882        },
883    )
884}
885
886fn unix_now() -> u64 {
887    std::time::SystemTime::now()
888        .duration_since(std::time::UNIX_EPOCH)
889        .map(|d| d.as_secs())
890        .unwrap_or(0)
891}
892
893fn monotonic_id() -> u64 {
894    use std::sync::atomic::{AtomicU64, Ordering};
895    static CTR: AtomicU64 = AtomicU64::new(1);
896    CTR.fetch_add(1, Ordering::Relaxed)
897}
898
899/// Generate a cryptographically random session ID using OS entropy.
900/// Falls back to monotonic counter + timestamp mix if /dev/urandom is unavailable.
901fn mint_session_id() -> String {
902    // Read 16 bytes from /dev/urandom — available on all Linux targets.
903    let mut buf = [0u8; 16];
904    let ok = std::fs::File::open("/dev/urandom")
905        .and_then(|mut f| {
906            use std::io::Read;
907            f.read_exact(&mut buf)
908        })
909        .is_ok();
910    if ok {
911        format!(
912            "sbh-{:08x}{:08x}{:08x}{:08x}",
913            u32::from_le_bytes(buf[0..4].try_into().unwrap()),
914            u32::from_le_bytes(buf[4..8].try_into().unwrap()),
915            u32::from_le_bytes(buf[8..12].try_into().unwrap()),
916            u32::from_le_bytes(buf[12..16].try_into().unwrap()),
917        )
918    } else {
919        format!("sbh-s-{}-{}", monotonic_id(), unix_now())
920    }
921}
922
923/// Percent-encode a string for use in HTTP header values.
924///
925/// Encodes each UTF-8 byte that is not an unreserved ASCII character.
926/// This is correct: we encode bytes, not Unicode codepoints, so multibyte
927/// chars like `é` (UTF-8: 0xC3 0xA9) become `%C3%A9`, not `%E9`.
928fn url_encode(s: &str) -> Result<String, ()> {
929    let mut out = String::with_capacity(s.len());
930    for byte in s.as_bytes() {
931        match byte {
932            // Unreserved ASCII — pass through as-is
933            b'A'..=b'Z'
934            | b'a'..=b'z'
935            | b'0'..=b'9'
936            | b'-'
937            | b'_'
938            | b'.'
939            | b'~'
940            | b':'
941            | b'/'
942            | b','
943            | b'['
944            | b']'
945            | b'{'
946            | b'}' => out.push(*byte as char),
947            // Everything else (including %, space, quotes, newlines, high bytes)
948            b => out.push_str(&format!("%{b:02X}")),
949        }
950    }
951    Ok(out)
952}
953
954// ---------------------------------------------------------------------------
955// Tests
956// ---------------------------------------------------------------------------
957
958#[cfg(test)]
959mod tests {
960    use super::*;
961
962    // --- metrics ---
963
964    #[test]
965    fn metrics_render_contains_all_metric_names() {
966        let m = Metrics::default();
967        let out = m.render(0, 0);
968        for name in &[
969            "sbh_requests_total",
970            "sbh_requests_ok_total",
971            "sbh_requests_error_total",
972            "sbh_auth_failures_total",
973            "sbh_rate_limit_total",
974            "sbh_escalations_total",
975            "sbh_active_sessions",
976            "sbh_uptime_seconds",
977        ] {
978            assert!(out.contains(name), "missing metric: {name}");
979        }
980    }
981
982    #[test]
983    fn metrics_render_prometheus_format() {
984        let m = Metrics::default();
985        let out = m.render(3, 42);
986        assert!(out.contains("# HELP sbh_requests_total"));
987        assert!(out.contains("# TYPE sbh_requests_total counter"));
988        assert!(out.contains("sbh_requests_total 0\n"));
989        assert!(out.contains("sbh_active_sessions 3\n"));
990        assert!(out.contains("sbh_uptime_seconds 42\n"));
991    }
992
993    #[test]
994    fn metrics_counters_increment_correctly() {
995        let m = Metrics::default();
996        Metrics::inc(&m.requests_total);
997        Metrics::inc(&m.requests_total);
998        Metrics::inc(&m.escalations_total);
999        let out = m.render(0, 0);
1000        assert!(out.contains("sbh_requests_total 2\n"));
1001        assert!(out.contains("sbh_escalations_total 1\n"));
1002        assert!(out.contains("sbh_requests_ok_total 0\n"));
1003    }
1004
1005    #[test]
1006    fn metrics_render_has_help_and_type_for_every_metric() {
1007        let m = Metrics::default();
1008        let out = m.render(0, 0);
1009        let help_count = out.lines().filter(|l| l.starts_with("# HELP")).count();
1010        let type_count = out.lines().filter(|l| l.starts_with("# TYPE")).count();
1011        assert_eq!(help_count, 8, "expected 8 # HELP lines");
1012        assert_eq!(type_count, 8, "expected 8 # TYPE lines");
1013    }
1014
1015    // --- url_encode ---
1016
1017    #[test]
1018    fn url_encode_spaces_and_quotes() {
1019        let s = r#"{"key": "val ue"}"#;
1020        let encoded = url_encode(s).unwrap();
1021        assert!(!encoded.contains(' '));
1022        assert!(!encoded.contains('"'));
1023        assert!(encoded.contains("%20"));
1024        assert!(encoded.contains("%22"));
1025    }
1026
1027    #[test]
1028    fn url_encode_clean_string_unchanged() {
1029        let s = "hello-world_123";
1030        assert_eq!(url_encode(s).unwrap(), s);
1031    }
1032
1033    #[test]
1034    fn unix_now_is_nonzero() {
1035        assert!(unix_now() > 0);
1036    }
1037
1038    #[test]
1039    fn monotonic_id_increases() {
1040        let a = monotonic_id();
1041        let b = monotonic_id();
1042        assert!(b > a);
1043    }
1044
1045    #[test]
1046    fn session_no_escalation_below_three_turns() {
1047        let mut h = SessionHistory::new();
1048        h.push("high");
1049        h.push("high");
1050        assert!(!h.is_escalating(), "need ≥3 turns before firing");
1051    }
1052
1053    #[test]
1054    fn session_escalation_detected_on_slow_boil() {
1055        let mut h = SessionHistory::new();
1056        h.push("low");
1057        h.push("low");
1058        h.push("high");
1059        assert!(h.is_escalating(), "low→low→high is slow-boil escalation");
1060    }
1061
1062    #[test]
1063    fn session_no_escalation_when_already_high() {
1064        let mut h = SessionHistory::new();
1065        h.push("high");
1066        h.push("high");
1067        h.push("high");
1068        // All turns already high — no upward delta
1069        assert!(!h.is_escalating());
1070    }
1071
1072    #[test]
1073    fn session_no_escalation_medium_to_medium() {
1074        let mut h = SessionHistory::new();
1075        h.push("low");
1076        h.push("medium");
1077        h.push("medium");
1078        // medium is 1.0; historical mean 0.5 → delta 0.5, but not > 0.5
1079        assert!(!h.is_escalating());
1080    }
1081
1082    #[test]
1083    fn session_escalation_low_to_high_five_turns() {
1084        let mut h = SessionHistory::new();
1085        for _ in 0..4 {
1086            h.push("low");
1087        }
1088        h.push("high");
1089        assert!(h.is_escalating());
1090    }
1091
1092    #[test]
1093    fn session_ring_caps_at_max_turns() {
1094        let mut h = SessionHistory::new();
1095        for _ in 0..SESSION_MAX_TURNS + 5 {
1096            h.push("low");
1097        }
1098        assert_eq!(h.turn_count(), SESSION_MAX_TURNS);
1099    }
1100
1101    #[test]
1102    fn risk_score_mapping() {
1103        assert_eq!(risk_score("low"), 0.0);
1104        assert_eq!(risk_score("medium"), 1.0);
1105        assert_eq!(risk_score("high"), 2.0);
1106        assert_eq!(risk_score("unknown"), 0.0);
1107    }
1108
1109    #[test]
1110    fn rate_limit_allows_up_to_max() {
1111        let limiter = ShardedRateLimiter::new();
1112        let ip: IpAddr = "127.0.0.1".parse().unwrap();
1113        for _ in 0..5 {
1114            assert!(limiter.check(ip, 5));
1115        }
1116        assert!(!limiter.check(ip, 5));
1117    }
1118
1119    #[test]
1120    fn rate_limit_different_ips_are_independent() {
1121        let limiter = ShardedRateLimiter::new();
1122        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1123        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1124        for _ in 0..3 {
1125            assert!(limiter.check(ip1, 3));
1126        }
1127        assert!(!limiter.check(ip1, 3));
1128        assert!(limiter.check(ip2, 3));
1129    }
1130
1131    #[test]
1132    fn summarize_result_contains_key_fields() {
1133        use crate::types::*;
1134        let result = HarnessResult {
1135            telemetry: TelemetryResult {
1136                affective_telemetry: AfferentTelemetry {
1137                    primary_emotion: "neutral".into(),
1138                    emotional_intensity: 0.1,
1139                    structural_tone: vec!["analytical".into()],
1140                },
1141                intent_matrix: IntentMatrix {
1142                    stated_objective: "test query".into(),
1143                    subtextual_motive: "none".into(),
1144                    manipulation_risk: "low".into(),
1145                },
1146                cognitive_state: CognitiveState {
1147                    urgency_vector: 0.0,
1148                    coherence_rating: 0.9,
1149                },
1150            },
1151            verification: VerificationReport {
1152                passed: true,
1153                consistency_flags: vec![],
1154                unsupported_claims: vec![],
1155                assumptions: vec![],
1156                unresolved: vec![],
1157                confidence: 0.9,
1158                disagreement: Default::default(),
1159                stop_and_ask: false,
1160            },
1161            trace: vec![],
1162            capability_request: None,
1163            obfuscation: None,
1164        };
1165        let s = summarize_result(&result);
1166        assert!(s.contains("neutral"));
1167        assert!(s.contains("low"));
1168        assert!(s.contains("passed"));
1169    }
1170}