Skip to main content

zeph_llm/router/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Provider router: EMA-based, Thompson Sampling, Cascade, and PILOT Bandit strategies.
5//!
6//! # Security
7//!
8//! Thompson and Bandit state are loaded from user-controlled paths at startup. Files are
9//! validated (finite floats, clamped range) and written with `0o600` permissions
10//! on Unix. Do not store state files in world-writable directories.
11
12pub mod asi;
13pub mod bandit;
14pub mod cascade;
15pub mod reputation;
16pub mod thompson;
17pub mod triage;
18
19use std::collections::HashMap;
20use std::path::Path;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicU64, Ordering};
23
24use parking_lot::Mutex;
25
26use crate::any::AnyProvider;
27use crate::ema::EmaTracker;
28use crate::embed::owned_strs;
29use crate::error::LlmError;
30use crate::provider::{ChatResponse, ChatStream, LlmProvider, Message, StatusTx, ToolDefinition};
31
32use asi::AsiState;
33use bandit::{BanditState, embedding_to_features};
34use cascade::{CascadeState, ClassifierMode, heuristic_score};
35use reputation::ReputationTracker;
36use thompson::ThompsonState;
37use zeph_common::math::cosine_similarity;
38
39/// Simple bounded embedding cache for bandit feature vectors.
40///
41/// Keyed by `u64` hash of query text (using `std::hash`). Eviction is FIFO on insertion
42/// order (not LRU) — acceptable for a routing cache where hot queries repeat often.
43/// The `lru` crate is not in the workspace; a `HashMap` + insertion-order Vec avoids a new dep.
44#[derive(Debug)]
45struct BanditEmbedCache {
46    map: HashMap<u64, Vec<f32>>,
47    order: std::collections::VecDeque<u64>,
48    capacity: usize,
49}
50
51impl BanditEmbedCache {
52    fn new(capacity: usize) -> Self {
53        Self {
54            map: HashMap::with_capacity(capacity),
55            order: std::collections::VecDeque::with_capacity(capacity),
56            capacity,
57        }
58    }
59
60    fn get(&self, key: u64) -> Option<&Vec<f32>> {
61        self.map.get(&key)
62    }
63
64    fn insert(&mut self, key: u64, value: Vec<f32>) {
65        if self.map.contains_key(&key) {
66            return;
67        }
68        if self.map.len() >= self.capacity
69            && let Some(evict) = self.order.pop_front()
70        {
71            self.map.remove(&evict);
72        }
73        self.map.insert(key, value);
74        self.order.push_back(key);
75    }
76}
77
78impl Default for BanditEmbedCache {
79    fn default() -> Self {
80        Self::new(512)
81    }
82}
83
84/// Routing strategy used by [`RouterProvider`].
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
86pub enum RouterStrategy {
87    /// Exponential moving average-based latency-aware ordering.
88    #[default]
89    Ema,
90    /// Thompson Sampling with Beta distributions.
91    Thompson,
92    /// Cascade: try cheapest provider first, escalate on degenerate output.
93    Cascade,
94    /// PILOT: `LinUCB` contextual bandit with online learning and budget-aware selection.
95    Bandit,
96}
97
98/// Configuration for PILOT bandit routing in `RouterProvider`.
99///
100/// See [`bandit`] module for the algorithm details and trade-offs.
101#[derive(Debug, Clone)]
102#[allow(clippy::doc_markdown)] // PILOT, LinUCB, Thompson are proper nouns/acronyms
103pub struct BanditRouterConfig {
104    /// `LinUCB` exploration parameter. Higher = more exploration. Default: 1.0.
105    pub alpha: f32,
106    /// Feature vector dimension (first `dim` components of embedding). Default: 32.
107    pub dim: usize,
108    /// Cost penalty weight in the reward signal: `reward = quality - cost_weight * cost_fraction`.
109    /// Default: 0.1. Increase to penalise expensive providers more aggressively.
110    pub cost_weight: f32,
111    /// Session-level decay factor: values < 1.0 cause re-exploration over time. Default: 1.0.
112    pub decay_factor: f32,
113    /// Minimum total updates before `LinUCB` takes over from Thompson fallback.
114    /// Default: `10 * num_providers` (computed at construction time from provider count).
115    pub warmup_queries: u64,
116    /// Hard timeout for the embedding call (milliseconds). If exceeded, falls back
117    /// to Thompson/uniform selection. Default: 50.
118    pub embedding_timeout_ms: u64,
119    /// Maximum number of cached embeddings (keyed by query string hash). Default: 512.
120    pub cache_size: usize,
121    /// MAR threshold: when `memory_hit_confidence >= this`, bias toward cheap providers.
122    /// Default: 0.9. Set to 1.0 to disable MAR.
123    pub memory_confidence_threshold: f32,
124}
125
126impl Default for BanditRouterConfig {
127    fn default() -> Self {
128        Self {
129            alpha: 1.0,
130            dim: 32,
131            cost_weight: 0.1,
132            decay_factor: 1.0,
133            warmup_queries: 0, // overridden by with_bandit() based on provider count
134            embedding_timeout_ms: 50,
135            cache_size: 512,
136            memory_confidence_threshold: 0.9,
137        }
138    }
139}
140
141/// Runtime ASI configuration passed to [`RouterProvider::with_asi`].
142///
143/// Mirrors `AsiRouterConfig` but lives in `zeph-llm` to avoid
144/// a dependency on `zeph-config`. The bootstrap layer maps config → this struct.
145#[derive(Debug, Clone)]
146pub struct AsiRouterConfig {
147    /// Sliding window size. Default: 5.
148    pub window: usize,
149    /// Coherence score threshold below which the provider is penalized. Default: 0.7.
150    pub coherence_threshold: f32,
151    /// Penalty weight added to Thompson beta on low coherence. Default: 0.3.
152    pub penalty_weight: f32,
153}
154
155impl Default for AsiRouterConfig {
156    fn default() -> Self {
157        Self {
158            window: 5,
159            coherence_threshold: 0.7,
160            penalty_weight: 0.3,
161        }
162    }
163}
164
165/// Configuration for cascade routing in `RouterProvider`.
166#[derive(Debug, Clone)]
167pub struct CascadeRouterConfig {
168    pub quality_threshold: f64,
169    pub max_escalations: u8,
170    pub classifier_mode: ClassifierMode,
171    pub window_size: usize,
172    pub max_cascade_tokens: Option<u32>,
173    /// LLM provider used for judge-mode quality scoring.
174    /// Required when `classifier_mode = Judge`; falls back to heuristic if `None`.
175    pub summary_provider: Option<AnyProvider>,
176    /// Explicit cost ordering of provider names (cheapest first).
177    /// When set, providers are sorted by their position in this list at construction time.
178    /// Providers not listed are appended after listed ones in original chain order.
179    pub cost_tiers: Option<Vec<String>>,
180}
181
182impl Default for CascadeRouterConfig {
183    fn default() -> Self {
184        Self {
185            quality_threshold: 0.5,
186            max_escalations: 2,
187            classifier_mode: ClassifierMode::Heuristic,
188            window_size: 50,
189            max_cascade_tokens: None,
190            summary_provider: None,
191            cost_tiers: None,
192        }
193    }
194}
195
196#[derive(Debug, Clone)]
197pub struct RouterProvider {
198    // Arc<[AnyProvider]> makes self.clone() O(1) for the providers field (atomic refcount
199    // increment) instead of O(N * provider_size). This benefits ALL strategies since every
200    // chat/chat_stream/embed/chat_with_tools call does `let router = self.clone()`.
201    providers: Arc<[AnyProvider]>,
202    status_tx: Option<StatusTx>,
203    ema: Option<EmaTracker>,
204    provider_order: Arc<Mutex<Vec<usize>>>,
205    strategy: RouterStrategy,
206    thompson: Option<Arc<Mutex<ThompsonState>>>,
207    /// Path for persisting Thompson state. `None` disables persistence.
208    thompson_state_path: Option<std::path::PathBuf>,
209    /// Cascade routing state (quality history per provider).
210    cascade_state: Option<Arc<Mutex<CascadeState>>>,
211    /// Cascade routing configuration.
212    cascade_config: Option<CascadeRouterConfig>,
213    /// Bayesian reputation tracker (RAPS). None when disabled.
214    reputation: Option<Arc<Mutex<ReputationTracker>>>,
215    /// Path for persisting reputation state.
216    reputation_state_path: Option<std::path::PathBuf>,
217    /// Reputation weight in [0.0, 1.0] for routing score blend.
218    reputation_weight: f64,
219    /// Name of the sub-provider that served the most recent successful tool call.
220    /// Used by `record_quality_outcome` to attribute quality to the right provider.
221    last_active_provider: Arc<Mutex<Option<String>>>,
222    /// PILOT bandit state.
223    bandit: Option<Arc<Mutex<BanditState>>>,
224    /// Path for persisting bandit state. `None` disables persistence.
225    bandit_state_path: Option<std::path::PathBuf>,
226    /// Bandit routing configuration.
227    bandit_config: Option<BanditRouterConfig>,
228    /// Dedicated embedding provider for bandit feature vectors.
229    /// When `None`, bandit falls back to Thompson/uniform on embed failure.
230    bandit_embedding_provider: Option<AnyProvider>,
231    /// LRU embedding cache: maps query-string hash to feature vector.
232    /// Shared across requests; keyed by `u64` hash of query text.
233    bandit_embed_cache: Arc<Mutex<BanditEmbedCache>>,
234    /// MAR signal: top-1 semantic memory recall score for the current turn.
235    /// Set by the agent before each `chat`/`chat_stream` call; read by `bandit_select_provider`.
236    last_memory_confidence: Arc<Mutex<Option<f32>>>,
237    /// Maps provider name to model identifier for cost estimation.
238    /// Built at construction time from `self.providers`.
239    provider_models: Arc<std::collections::HashMap<String, String>>,
240    /// Agent Stability Index state (session-only coherence tracking).
241    asi: Option<Arc<Mutex<AsiState>>>,
242    /// ASI configuration. `None` when ASI is disabled.
243    asi_config: Option<AsiRouterConfig>,
244    /// Embedding-based quality gate threshold. `None` = disabled.
245    /// After provider selection, `cosine_similarity(query_emb, response_emb)` must be >= this
246    /// value; otherwise the next provider in the ordered list is tried.
247    quality_gate: Option<f32>,
248    /// Monotonically increasing turn counter. Incremented once per top-level `chat()` call.
249    /// Shared across clones so that concurrent sub-calls within the same turn see the same value.
250    turn_counter: Arc<AtomicU64>,
251    /// Turn ID of the last ASI embedding update. Used to debounce `spawn_asi_update` so that
252    /// only one embed call fires per turn even when `chat()` is invoked N times concurrently.
253    asi_last_turn: Arc<AtomicU64>,
254    /// Semaphore limiting concurrent `embed_batch` calls. `None` = unlimited.
255    embed_semaphore: Option<Arc<tokio::sync::Semaphore>>,
256}
257
258impl RouterProvider {
259    #[must_use]
260    pub fn new(providers: Vec<AnyProvider>) -> Self {
261        let n = providers.len();
262        let provider_models: std::collections::HashMap<String, String> = providers
263            .iter()
264            .map(|p| (p.name().to_owned(), p.model_identifier().to_owned()))
265            .collect();
266        Self {
267            providers: Arc::from(providers),
268            status_tx: None,
269            ema: None,
270            provider_order: Arc::new(Mutex::new((0..n).collect())),
271            strategy: RouterStrategy::Ema,
272            thompson: None,
273            thompson_state_path: None,
274            cascade_state: None,
275            cascade_config: None,
276            reputation: None,
277            reputation_state_path: None,
278            reputation_weight: 0.3,
279            last_active_provider: Arc::new(Mutex::new(None)),
280            bandit: None,
281            bandit_state_path: None,
282            bandit_config: None,
283            bandit_embedding_provider: None,
284            bandit_embed_cache: Arc::new(Mutex::new(BanditEmbedCache::default())),
285            last_memory_confidence: Arc::new(Mutex::new(None)),
286            provider_models: Arc::new(provider_models),
287            asi: None,
288            asi_config: None,
289            quality_gate: None,
290            turn_counter: Arc::new(AtomicU64::new(0)),
291            asi_last_turn: Arc::new(AtomicU64::new(u64::MAX)),
292            embed_semaphore: None,
293        }
294    }
295
296    /// Set the maximum number of concurrent `embed_batch` calls.
297    ///
298    /// A value of 0 disables the semaphore (unlimited). Default is no semaphore.
299    #[must_use]
300    pub fn with_embed_concurrency(mut self, limit: usize) -> Self {
301        self.embed_semaphore = if limit > 0 {
302            Some(Arc::new(tokio::sync::Semaphore::new(limit)))
303        } else {
304            None
305        };
306        self
307    }
308
309    /// Set the MAR (Memory-Augmented Routing) signal for the current turn.
310    ///
311    /// Must be called before `chat` / `chat_stream` to influence bandit provider selection.
312    /// Pass `None` to disable MAR for this turn.
313    pub fn set_memory_confidence(&self, confidence: Option<f32>) {
314        *self.last_memory_confidence.lock() = confidence;
315    }
316
317    /// Enable EMA-based adaptive provider ordering.
318    #[must_use]
319    pub fn with_ema(mut self, alpha: f64, reorder_interval: u64) -> Self {
320        self.ema = Some(EmaTracker::new(alpha, reorder_interval));
321        self
322    }
323
324    /// Enable Agent Stability Index (ASI) coherence tracking.
325    ///
326    /// When enabled, each successful response is embedded in a background task and added
327    /// to a per-provider sliding window. The coherence score (cosine similarity of the
328    /// latest embedding vs. window mean) penalizes Thompson/EMA routing priors for
329    /// providers whose responses drift.
330    #[must_use]
331    pub fn with_asi(mut self, config: AsiRouterConfig) -> Self {
332        self.asi = Some(Arc::new(Mutex::new(AsiState::default())));
333        self.asi_config = Some(config);
334        self
335    }
336
337    /// Enable embedding-based quality gate for Thompson/EMA routing.
338    ///
339    /// After provider selection, computes cosine similarity between the query embedding
340    /// and the response embedding. If below `threshold`, tries the next provider in the
341    /// ordered list. On full exhaustion, returns the best response seen (highest similarity).
342    /// Fail-open: embedding errors disable the gate for that request.
343    #[must_use]
344    pub fn with_quality_gate(mut self, threshold: f32) -> Self {
345        self.quality_gate = Some(threshold);
346        self
347    }
348
349    /// Enable Thompson Sampling strategy.
350    ///
351    /// Loads existing state from `state_path` if present; falls back to uniform prior.
352    /// Prunes stale entries for providers not in the current chain.
353    #[must_use]
354    pub fn with_thompson(mut self, state_path: Option<&Path>) -> Self {
355        self.strategy = RouterStrategy::Thompson;
356        let path = state_path.map_or_else(ThompsonState::default_path, Path::to_path_buf);
357        let mut state = ThompsonState::load(&path);
358        // CRIT-3: prune orphan entries from previous configs.
359        let known: std::collections::HashSet<String> =
360            self.providers.iter().map(|p| p.name().to_owned()).collect();
361        state.prune(&known);
362        self.thompson = Some(Arc::new(Mutex::new(state)));
363        self.thompson_state_path = Some(path);
364        self
365    }
366
367    /// Enable PILOT bandit routing strategy (`LinUCB` contextual bandit).
368    ///
369    /// Loads existing state from `state_path` (or the default path). Applies session-level
370    /// decay if `config.decay_factor < 1.0`, and prunes arms for removed providers.
371    ///
372    /// `embedding_provider` is used to obtain feature vectors for each query.
373    /// When `None`, the bandit falls back to Thompson/uniform selection whenever an
374    /// embedding cannot be obtained within `config.embedding_timeout_ms`.
375    ///
376    /// The `warmup_queries` default of `0` in `BanditRouterConfig` is overridden here to
377    /// `10 * num_providers` to ensure sufficient initial exploration.
378    #[must_use]
379    pub fn with_bandit(
380        mut self,
381        mut config: BanditRouterConfig,
382        state_path: Option<&Path>,
383        embedding_provider: Option<AnyProvider>,
384    ) -> Self {
385        self.strategy = RouterStrategy::Bandit;
386        let n = self.providers.len();
387        if config.warmup_queries == 0 {
388            config.warmup_queries = u64::try_from(10 * n.max(1)).unwrap_or(100);
389        }
390        let cache_size = config.cache_size;
391        let path = state_path.map_or_else(BanditState::default_path, Path::to_path_buf);
392        let mut state = BanditState::load(&path);
393        if state.dim == 0 {
394            state = BanditState::new(config.dim);
395        } else if state.dim != config.dim {
396            // Config changed dim — reset state rather than use mismatched arms.
397            tracing::warn!(
398                old_dim = state.dim,
399                new_dim = config.dim,
400                "bandit: dim changed, resetting state"
401            );
402            state = BanditState::new(config.dim);
403        }
404        // Validate config bounds before applying. Clamp to safe ranges with a warning.
405        if config.alpha <= 0.0 {
406            tracing::warn!(alpha = config.alpha, "bandit: alpha <= 0, clamping to 0.01");
407            config.alpha = 0.01;
408        }
409        if config.dim == 0 || config.dim > 256 {
410            tracing::warn!(
411                dim = config.dim,
412                "bandit: dim out of range [1, 256], clamping to 32"
413            );
414            config.dim = 32;
415        }
416        if config.decay_factor <= 0.0 || config.decay_factor > 1.0 {
417            tracing::warn!(
418                decay_factor = config.decay_factor,
419                "bandit: decay_factor out of (0.0, 1.0], clamping to 1.0"
420            );
421            config.decay_factor = 1.0;
422        }
423        if config.decay_factor < 1.0 {
424            state.apply_decay(config.decay_factor);
425        }
426        let known: std::collections::HashSet<String> =
427            self.providers.iter().map(|p| p.name().to_owned()).collect();
428        state.prune(&known);
429        self.bandit = Some(Arc::new(Mutex::new(state)));
430        self.bandit_state_path = Some(path);
431        self.bandit_embed_cache = Arc::new(Mutex::new(BanditEmbedCache::new(cache_size)));
432        self.bandit_embedding_provider = embedding_provider;
433        // Initialize Thompson state for cold-start fallback (total_updates < warmup_queries).
434        // Uses default uniform priors; no persistence path needed since it's a fallback only.
435        self.thompson = Some(Arc::new(Mutex::new(ThompsonState::default())));
436        self.bandit_config = Some(config);
437        self
438    }
439
440    /// Persist current bandit state to disk. No-op if bandit strategy is not active.
441    pub fn save_bandit_state(&self) {
442        let (Some(bandit), Some(path)) = (&self.bandit, &self.bandit_state_path) else {
443            return;
444        };
445        let state = bandit.lock();
446        if let Err(e) = state.save(path) {
447            tracing::warn!(error = %e, "failed to save bandit state");
448        }
449    }
450
451    /// Return bandit diagnostic stats: `(provider_name, pulls, mean_reward)`.
452    ///
453    /// Returns an empty vec if bandit strategy is not active.
454    #[must_use]
455    pub fn bandit_stats(&self) -> Vec<(String, u64, f32)> {
456        let Some(ref bandit) = self.bandit else {
457            return vec![];
458        };
459        let state = bandit.lock();
460        state.stats()
461    }
462
463    /// Enable Bayesian reputation scoring (RAPS).
464    ///
465    /// Loads existing state from `state_path` (or the default path), applies session-level
466    /// decay, and prunes stale provider entries.
467    ///
468    /// No-op for Cascade routing (reputation is not used for cost-tier ordering).
469    #[must_use]
470    pub fn with_reputation(
471        mut self,
472        decay_factor: f64,
473        weight: f64,
474        min_observations: u64,
475        state_path: Option<&Path>,
476    ) -> Self {
477        let path = state_path.map_or_else(ReputationTracker::default_path, Path::to_path_buf);
478        // Load persisted state, apply decay, and prune orphaned providers.
479        let mut tracker = ReputationTracker::load(&path);
480        let known: std::collections::HashSet<String> =
481            self.providers.iter().map(|p| p.name().to_owned()).collect();
482        tracker.apply_decay();
483        tracker.prune(&known);
484        // Overwrite config params (decay/min_obs may differ from the persisted defaults).
485        let tracker = {
486            let stats = tracker.stats();
487            let mut t = ReputationTracker::new(decay_factor, min_observations);
488            for (name, alpha, beta, _, obs) in stats {
489                t.models.insert(
490                    name,
491                    reputation::ReputationEntry {
492                        dist: thompson::BetaDist { alpha, beta },
493                        observations: obs,
494                    },
495                );
496            }
497            t
498        };
499        self.reputation = Some(Arc::new(Mutex::new(tracker)));
500        self.reputation_state_path = Some(path);
501        self.reputation_weight = weight.clamp(0.0, 1.0);
502        self
503    }
504
505    /// Record a quality outcome for the last active sub-provider (tool execution result).
506    ///
507    /// Call only for semantic failures (invalid tool args, parse errors).
508    /// Do NOT call for network errors, rate limits, or transient I/O failures.
509    /// No-op when reputation scoring is disabled, strategy is Cascade, or no tool call
510    /// has been made yet in this session.
511    ///
512    /// The `_provider_name` parameter is ignored — quality is attributed to the sub-provider
513    /// that served the most recent `chat_with_tools` call, tracked via `last_active_provider`.
514    pub fn record_quality_outcome(&self, _provider_name: &str, success: bool) {
515        if matches!(
516            self.strategy,
517            RouterStrategy::Cascade | RouterStrategy::Bandit
518        ) {
519            // Cascade: quality tracked via CascadeState.
520            // Bandit: quality fed via bandit_record_reward() after each response.
521            return;
522        }
523        let Some(ref reputation) = self.reputation else {
524            return;
525        };
526        let active = self.last_active_provider.lock().clone();
527        let Some(provider_name) = active else {
528            return;
529        };
530        let mut tracker = reputation.lock();
531        tracker.record_quality(&provider_name, success);
532    }
533
534    /// Persist current reputation state to disk. No-op if reputation is disabled.
535    pub fn save_reputation_state(&self) {
536        let (Some(reputation), Some(path)) = (&self.reputation, &self.reputation_state_path) else {
537            return;
538        };
539        let state = reputation.lock();
540        if let Err(e) = state.save(path) {
541            tracing::warn!(error = %e, "failed to save reputation state");
542        }
543    }
544
545    /// Return reputation stats for all tracked providers: (name, alpha, beta, mean, observations).
546    #[must_use]
547    pub fn reputation_stats(&self) -> Vec<(String, f64, f64, f64, u64)> {
548        let Some(ref reputation) = self.reputation else {
549            return vec![];
550        };
551        let tracker = reputation.lock();
552        tracker.stats()
553    }
554
555    /// Enable Cascade routing strategy.
556    ///
557    /// Providers are tried in chain order (cheapest first). Each response is evaluated
558    /// by the quality classifier; if it falls below `quality_threshold`, the next
559    /// provider is tried. At most `max_escalations` quality-based escalations occur.
560    ///
561    /// Network/API errors do not count against the escalation budget.
562    /// The best response seen so far is returned if all escalations are exhausted.
563    ///
564    /// When `config.cost_tiers` is set, providers are reordered once at construction
565    /// time (no per-request cost). Providers absent from `cost_tiers` are appended
566    /// after listed ones in original chain order. Unknown names in `cost_tiers` are
567    /// silently ignored.
568    #[must_use]
569    pub fn with_cascade(mut self, config: CascadeRouterConfig) -> Self {
570        self.strategy = RouterStrategy::Cascade;
571
572        if let Some(ref tiers) = config.cost_tiers
573            && !tiers.is_empty()
574        {
575            let tier_pos: std::collections::HashMap<&str, usize> = tiers
576                .iter()
577                .enumerate()
578                .map(|(i, n)| (n.as_str(), i))
579                .collect();
580
581            let before: Vec<_> = self.providers.iter().map(|p| p.name().to_owned()).collect();
582            let mut indexed: Vec<(usize, AnyProvider)> =
583                self.providers.iter().cloned().enumerate().collect();
584            indexed.sort_by_key(|(orig_idx, p)| {
585                tier_pos
586                    .get(p.name())
587                    .copied()
588                    .map_or((1usize, *orig_idx), |t| (0, t))
589            });
590            let after: Vec<_> = indexed.iter().map(|(_, p)| p.name().to_owned()).collect();
591            if before != after {
592                tracing::debug!(
593                    before = ?before,
594                    after = ?after,
595                    "cascade: providers reordered by cost_tiers"
596                );
597            }
598            self.providers = Arc::from(indexed.into_iter().map(|(_, p)| p).collect::<Vec<_>>());
599        }
600
601        let window = config.window_size;
602        self.cascade_state = Some(Arc::new(Mutex::new(CascadeState::new(window))));
603        self.cascade_config = Some(config);
604        self
605    }
606
607    /// Persist current Thompson state to disk.
608    ///
609    /// No-op if Thompson strategy is not active.
610    ///
611    /// # Note
612    ///
613    /// This performs synchronous I/O. Called at agent shutdown from an async context;
614    /// acceptable since it runs after all in-flight requests have completed.
615    // FIXME: if called mid-request, use `tokio::task::spawn_blocking` instead.
616    pub fn save_thompson_state(&self) {
617        let (Some(thompson), Some(path)) = (&self.thompson, &self.thompson_state_path) else {
618            return;
619        };
620        let state = thompson.lock();
621        if let Err(e) = state.save(path) {
622            tracing::warn!(error = %e, "failed to save Thompson router state");
623        }
624    }
625
626    /// Hash a query string to a `u64` cache key.
627    fn query_hash(query: &str) -> u64 {
628        use std::hash::{Hash as _, Hasher as _};
629        let mut h = std::collections::hash_map::DefaultHasher::new();
630        query.hash(&mut h);
631        h.finish()
632    }
633
634    /// Fetch or compute the feature vector for `query` using the bandit embedding provider.
635    ///
636    /// Returns `None` if:
637    /// - No embedding provider is configured.
638    /// - The embedding call exceeds `embedding_timeout_ms`.
639    /// - The embedding is shorter than `dim` or is all-zero.
640    async fn bandit_features(&self, query: &str) -> Option<Vec<f32>> {
641        let cfg = self.bandit_config.as_ref()?;
642        let key = Self::query_hash(query);
643
644        // Check cache first (no async needed).
645        {
646            let cache = self.bandit_embed_cache.lock();
647            if let Some(cached) = cache.get(key) {
648                return Some(cached.clone());
649            }
650        }
651
652        let provider = self.bandit_embedding_provider.as_ref()?;
653        let timeout = std::time::Duration::from_millis(cfg.embedding_timeout_ms);
654        let embed_future = provider.embed(query);
655        let embedding = match tokio::time::timeout(timeout, embed_future).await {
656            Ok(Ok(emb)) => emb,
657            Ok(Err(e)) => {
658                tracing::debug!(error = %e, "bandit: embedding failed, falling back");
659                return None;
660            }
661            Err(_) => {
662                tracing::debug!(
663                    timeout_ms = cfg.embedding_timeout_ms,
664                    "bandit: embedding timed out, falling back"
665                );
666                return None;
667            }
668        };
669
670        let features = embedding_to_features(&embedding, cfg.dim)?;
671
672        // Insert into cache.
673        {
674            let mut cache = self.bandit_embed_cache.lock();
675            cache.insert(key, features.clone());
676        }
677        Some(features)
678    }
679
680    /// Select a provider using `LinUCB` bandit, with Thompson fallback on cold start / missing features.
681    ///
682    /// Falls through to Thompson or first available provider when bandit cannot decide.
683    /// Budget enforcement via global `CostTracker` is handled at the caller level.
684    /// Per-provider budget fractions are intentionally NOT implemented (scope creep, see #2230).
685    async fn bandit_select_provider(&self, query: &str) -> Option<AnyProvider> {
686        let Some(ref bandit_arc) = self.bandit else {
687            return self.providers.first().cloned();
688        };
689        let cfg = self.bandit_config.as_ref()?;
690
691        let names: Vec<String> = self.providers.iter().map(|p| p.name().to_owned()).collect();
692
693        // Try LinUCB selection with feature vector.
694        if let Some(features) = self.bandit_features(query).await {
695            let memory_confidence = self.last_memory_confidence.lock().as_ref().copied();
696            let selected = {
697                let state = bandit_arc.lock();
698                state.select(
699                    &names,
700                    &features,
701                    cfg.alpha,
702                    cfg.warmup_queries,
703                    &|_| true,
704                    cfg.cost_weight,
705                    &self.provider_models,
706                    memory_confidence,
707                    cfg.memory_confidence_threshold,
708                )
709            };
710            if let Some(name) = selected {
711                tracing::debug!(
712                    provider = %name,
713                    strategy = "bandit",
714                    memory_confidence = ?memory_confidence,
715                    "selected provider"
716                );
717                return self.providers.iter().find(|p| p.name() == name).cloned();
718            }
719        }
720
721        // Fallback: Thompson sampling.
722        if let Some(ref thompson) = self.thompson {
723            let mut state = thompson.lock();
724            if let Some(sel) = state.select(&names) {
725                tracing::debug!(
726                    provider = %sel.provider,
727                    strategy = "bandit-fallback-thompson",
728                    "selected provider"
729                );
730                return self
731                    .providers
732                    .iter()
733                    .find(|p| p.name() == sel.provider)
734                    .cloned();
735            }
736        }
737
738        // Last resort: first provider.
739        self.providers.first().cloned()
740    }
741
742    /// Record the bandit reward for a completed request.
743    ///
744    /// `quality_score`: heuristic quality in [0, 1] from `heuristic_score()`.
745    /// `cost_fraction`: `request_cost_cents / max_daily_cents` (0 when budget is unlimited).
746    fn bandit_record_reward(
747        &self,
748        provider_name: &str,
749        features: &[f32],
750        quality_score: f64,
751        cost_fraction: f64,
752    ) {
753        let Some(ref bandit_arc) = self.bandit else {
754            return;
755        };
756        let Some(cfg) = &self.bandit_config else {
757            return;
758        };
759        #[allow(clippy::cast_possible_truncation)]
760        let reward = (quality_score as f32) - cfg.cost_weight * (cost_fraction as f32);
761        let reward = reward.clamp(-1.0, 1.0);
762        let mut state = bandit_arc.lock();
763        state.update(provider_name, features, reward);
764        tracing::debug!(
765            provider = provider_name,
766            reward,
767            quality = quality_score,
768            "bandit: recorded reward"
769        );
770    }
771
772    fn ordered_providers(&self) -> Vec<AnyProvider> {
773        match self.strategy {
774            RouterStrategy::Thompson => self.thompson_ordered_providers(),
775            RouterStrategy::Ema => self.ema_ordered_providers(),
776            // Cascade/Bandit: sync path used only for debug_request_json(); hot paths use
777            // dedicated async selection methods. For Cascade, providers are sorted at
778            // construction time.
779            RouterStrategy::Cascade | RouterStrategy::Bandit => self.providers.to_vec(),
780        }
781    }
782
783    fn ema_ordered_providers(&self) -> Vec<AnyProvider> {
784        let order = self.provider_order.lock();
785        let mut ordered: Vec<AnyProvider> = order
786            .iter()
787            .filter_map(|&i| self.providers.get(i).cloned())
788            .collect();
789
790        // CRIT-2 fix: apply reputation as a multiplicative adjustment to the EMA score,
791        // not an additive term. This avoids unbounded score inflation.
792        //
793        // Adjustment formula: ema_score * (1 + weight * (rep_factor - 0.5) * 2)
794        // where rep_factor in [0,1]: 0.5 = neutral, >0.5 = positive, <0.5 = negative.
795        // CRIT-1 fix: reputation factor is sampled per-provider (each has its own Beta mean).
796        if let Some(ref reputation) = self.reputation
797            && let Some(ref ema) = self.ema
798        {
799            let rep = reputation.lock();
800            let w = self.reputation_weight;
801            let snap = ema.snapshot();
802            let mut scored: Vec<(usize, f64)> = ordered
803                .iter()
804                .enumerate()
805                .map(|(idx, p)| {
806                    let ema_score = snap
807                        .get(p.name())
808                        .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
809                    let score = if let Some(rep_factor) = rep.ema_reputation_factor(p.name()) {
810                        // Multiplicative blend: neutral at rep_factor=0.5, range ±weight.
811                        let adjustment = 1.0 + w * (rep_factor - 0.5) * 2.0;
812                        ema_score * adjustment
813                    } else {
814                        ema_score
815                    };
816                    (idx, score)
817                })
818                .collect();
819            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
820            let reordered: Vec<AnyProvider> = scored
821                .into_iter()
822                .filter_map(|(idx, _)| ordered.get(idx).cloned())
823                .collect();
824            ordered = reordered;
825        }
826
827        // ASI: re-score by down-weighting providers with low coherence.
828        if let (Some(asi_arc), Some(asi_cfg)) = (&self.asi, &self.asi_config) {
829            let asi: parking_lot::MutexGuard<'_, AsiState> = asi_arc.lock();
830            let snap = self.ema.as_ref().map(EmaTracker::snapshot);
831            let mut scored: Vec<(usize, f64)> = ordered
832                .iter()
833                .enumerate()
834                .map(|(idx, p)| {
835                    let coherence = asi.coherence(p.name());
836                    if coherence < asi_cfg.coherence_threshold {
837                        tracing::warn!(
838                            provider = p.name(),
839                            coherence,
840                            threshold = asi_cfg.coherence_threshold,
841                            "asi: coherence below threshold"
842                        );
843                    }
844                    let base_score = snap
845                        .as_ref()
846                        .and_then(|s| s.get(p.name()))
847                        .map_or(0.0, |s| s.success_ema - s.latency_ema_ms / 10_000.0);
848                    // Multiply EMA score by coherence multiplier clamped to [0.5, 1.0].
849                    let multiplier = (coherence / asi_cfg.coherence_threshold).clamp(0.5, 1.0);
850                    #[allow(clippy::cast_possible_truncation)]
851                    let adjusted = base_score * f64::from(multiplier);
852                    (idx, adjusted)
853                })
854                .collect();
855            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
856            let reordered: Vec<AnyProvider> = scored
857                .into_iter()
858                .filter_map(|(idx, _)| ordered.get(idx).cloned())
859                .collect();
860            ordered = reordered;
861        }
862
863        if let Some(first) = ordered.first() {
864            tracing::debug!(
865                provider = %first.name(),
866                strategy = "ema",
867                "selected provider"
868            );
869        }
870        ordered
871    }
872
873    fn thompson_ordered_providers(&self) -> Vec<AnyProvider> {
874        let Some(ref thompson) = self.thompson else {
875            return self.providers.to_vec();
876        };
877        let mut state = thompson.lock();
878        let names: Vec<String> = self.providers.iter().map(|p| p.name().to_owned()).collect();
879
880        // Compute per-provider prior overrides: start from base Beta distribution, apply
881        // reputation shift (CRIT-3), then apply ASI coherence penalty.
882        let has_reputation = self.reputation.is_some();
883        let has_asi = self.asi.is_some() && self.asi_config.is_some();
884
885        let selected = if has_reputation || has_asi {
886            // Build overrides by composing reputation and ASI adjustments.
887            let rep_guard = self.reputation.as_ref().map(|r| r.lock());
888            let asi_guard: Option<parking_lot::MutexGuard<'_, AsiState>> =
889                self.asi.as_ref().map(|a| a.lock());
890            let w = self.reputation_weight;
891
892            let overrides: std::collections::HashMap<String, (f64, f64)> = names
893                .iter()
894                .map(|name| {
895                    let base = state.get_distribution(name);
896                    // Apply reputation prior shift.
897                    let (alpha, mut beta) = if let Some(ref rep) = rep_guard {
898                        rep.shift_thompson_priors(name, base.alpha, base.beta, w)
899                    } else {
900                        (base.alpha, base.beta)
901                    };
902                    // Apply ASI coherence penalty: shift beta by penalty_weight * deficit.
903                    if let (Some(asi), Some(asi_cfg)) = (&asi_guard, &self.asi_config) {
904                        let coherence = asi.coherence(name);
905                        if coherence < asi_cfg.coherence_threshold {
906                            tracing::warn!(
907                                provider = name.as_str(),
908                                coherence,
909                                threshold = asi_cfg.coherence_threshold,
910                                "asi: coherence below threshold"
911                            );
912                            let deficit = asi_cfg.coherence_threshold - coherence;
913                            let penalty = f64::from(asi_cfg.penalty_weight * deficit);
914                            beta += penalty;
915                        }
916                    }
917                    (name.clone(), (alpha, beta))
918                })
919                .collect();
920
921            drop(rep_guard);
922            drop(asi_guard);
923            state.select_with_priors(&names, &overrides)
924        } else {
925            state.select(&names)
926        };
927
928        if let Some(ref sel) = selected {
929            tracing::debug!(
930                provider = %sel.provider,
931                strategy = "thompson",
932                mode = if sel.exploit { "exploit" } else { "explore" },
933                alpha = sel.alpha,
934                beta = sel.beta,
935                "selected provider"
936            );
937        }
938        // Put selected provider first, keep rest in original order.
939        let mut ordered = self.providers.to_vec();
940        if let Some(ref sel) = selected
941            && let Some(pos) = ordered.iter().position(|p| p.name() == sel.provider)
942        {
943            ordered.swap(0, pos);
944        }
945        ordered
946    }
947
948    /// Record availability outcome (network success/failure) for EMA or Thompson.
949    ///
950    /// For cascade routing, quality outcomes are tracked separately in `CascadeState`.
951    /// Only availability outcomes (API up/down) are recorded here to avoid corrupting
952    /// Thompson/EMA distributions with quality-based failures (HIGH-01).
953    fn record_availability(&self, provider_name: &str, success: bool, latency_ms: u64) {
954        match self.strategy {
955            RouterStrategy::Thompson => {
956                if let Some(ref thompson) = self.thompson {
957                    let mut state = thompson.lock();
958                    state.update(provider_name, success);
959                }
960            }
961            RouterStrategy::Ema => {
962                self.ema_record(provider_name, success, latency_ms);
963            }
964            RouterStrategy::Cascade | RouterStrategy::Bandit => {
965                // Cascade does not use Thompson/EMA for ordering; no-op.
966                // Bandit tracks rewards separately via bandit_record_reward().
967            }
968        }
969    }
970
971    fn ema_record(&self, provider_name: &str, success: bool, latency_ms: u64) {
972        let Some(ref ema) = self.ema else {
973            return;
974        };
975        ema.record(provider_name, success, latency_ms);
976        let current_names: Vec<String> =
977            self.providers.iter().map(|p| p.name().to_owned()).collect();
978        if let Some(new_order_names) = ema.maybe_reorder(&current_names) {
979            let name_to_idx: std::collections::HashMap<&str, usize> = self
980                .providers
981                .iter()
982                .enumerate()
983                .map(|(i, p)| (p.name(), i))
984                .collect();
985            let new_order: Vec<usize> = new_order_names
986                .iter()
987                .filter_map(|n| name_to_idx.get(n.as_str()).copied())
988                .collect();
989            let mut order = self.provider_order.lock();
990            *order = new_order;
991        }
992    }
993
994    /// Return a snapshot of Thompson distribution parameters for all tracked providers.
995    ///
996    /// Returns an empty vec if Thompson strategy is not active.
997    #[must_use]
998    pub fn thompson_stats(&self) -> Vec<(String, f64, f64)> {
999        let Some(ref thompson) = self.thompson else {
1000            return vec![];
1001        };
1002        let state = thompson.lock();
1003        state.provider_stats()
1004    }
1005
1006    pub fn set_status_tx(&mut self, tx: StatusTx) {
1007        if let Some(providers) = Arc::get_mut(&mut self.providers) {
1008            for p in providers {
1009                p.set_status_tx(tx.clone());
1010            }
1011        } else {
1012            // Defensive path: should never happen at bootstrap (refcount == 1).
1013            let mut v: Vec<_> = self.providers.iter().cloned().collect();
1014            for p in &mut v {
1015                p.set_status_tx(tx.clone());
1016            }
1017            self.providers = Arc::from(v);
1018        }
1019        self.status_tx = Some(tx);
1020    }
1021
1022    /// Aggregate model lists from all sub-providers, deduplicating by id.
1023    ///
1024    /// Individual sub-provider errors are logged as warnings and skipped.
1025    ///
1026    /// # Errors
1027    ///
1028    /// Always succeeds (errors per-provider are swallowed).
1029    pub async fn list_models_remote(
1030        &self,
1031    ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
1032        let mut seen = std::collections::HashSet::new();
1033        let mut all = Vec::new();
1034        for p in self.providers.iter() {
1035            match p.list_models_remote().await {
1036                Ok(models) => {
1037                    for m in models {
1038                        if seen.insert(m.id.clone()) {
1039                            all.push(m);
1040                        }
1041                    }
1042                }
1043                Err(e) => {
1044                    tracing::warn!(error = %e, "router: list_models_remote sub-provider failed");
1045                }
1046            }
1047        }
1048        Ok(all)
1049    }
1050
1051    /// Evaluate quality with heuristics only.
1052    fn evaluate_heuristic(response: &str, threshold: f64) -> cascade::QualityVerdict {
1053        let mut verdict = heuristic_score(response);
1054        verdict.should_escalate = verdict.score < threshold;
1055        verdict
1056    }
1057
1058    /// Evaluate quality using the configured classifier mode.
1059    ///
1060    /// For `ClassifierMode::Judge`, calls the summary provider and falls back to heuristic
1061    /// on any error. For `ClassifierMode::Heuristic`, evaluates synchronously.
1062    async fn evaluate_quality(
1063        response: &str,
1064        threshold: f64,
1065        mode: ClassifierMode,
1066        summary_provider: Option<&AnyProvider>,
1067    ) -> cascade::QualityVerdict {
1068        if mode == ClassifierMode::Judge {
1069            if let Some(judge) = summary_provider {
1070                match cascade::judge_score(judge, response).await {
1071                    Some(score) => {
1072                        let should_escalate = score < threshold;
1073                        tracing::debug!(
1074                            score,
1075                            threshold,
1076                            should_escalate,
1077                            "cascade: judge scored response"
1078                        );
1079                        return cascade::QualityVerdict {
1080                            score,
1081                            should_escalate,
1082                            reason: format!("judge score: {score:.2}"),
1083                        };
1084                    }
1085                    None => {
1086                        tracing::warn!("cascade: judge call failed, falling back to heuristic");
1087                    }
1088                }
1089            } else {
1090                tracing::warn!(
1091                    "cascade: classifier_mode=judge but no summary_provider configured, \
1092                     using heuristic"
1093                );
1094            }
1095        }
1096        Self::evaluate_heuristic(response, threshold)
1097    }
1098}
1099
1100const EMBED_MAX_RETRIES: u32 = 3;
1101const EMBED_BASE_DELAY_MS: u64 = 500;
1102
1103impl RouterProvider {
1104    /// Spawn a background task to embed `response` and update the ASI window for `provider`.
1105    ///
1106    /// Fire-and-forget: routing is not blocked on the embed call. If the embed fails,
1107    /// the ASI window is not updated (no penalty for embed failure).
1108    ///
1109    /// `turn_id` is used to debounce: at most one ASI update fires per turn even when
1110    /// `chat()` is called N times concurrently (e.g., tool schema fetches). Subsequent
1111    /// calls within the same turn are silently dropped.
1112    fn spawn_asi_update(&self, provider: &str, response: String, turn_id: u64) {
1113        // Debounce: swap in turn_id; if the previous value equals turn_id, another call
1114        // already claimed this turn → drop silently. `swap` is atomic so exactly one
1115        // concurrent caller wins the "first for this turn" race.
1116        let prev = self.asi_last_turn.swap(turn_id, Ordering::AcqRel);
1117        if prev == turn_id {
1118            return;
1119        }
1120
1121        let Some(ref asi_arc) = self.asi else { return };
1122        let Some(ref asi_cfg) = self.asi_config else {
1123            return;
1124        };
1125        let asi = Arc::clone(asi_arc);
1126        let router = self.clone();
1127        let window_size = asi_cfg.window;
1128        let provider_name = provider.to_owned();
1129        tokio::spawn(async move {
1130            match router.embed(&response).await {
1131                Ok(emb) => {
1132                    let mut state = asi.lock();
1133                    state.push_embedding(&provider_name, emb, window_size);
1134                }
1135                Err(e) => {
1136                    tracing::debug!(
1137                        provider = provider_name,
1138                        error = %e,
1139                        "asi: embed failed, skipping coherence update"
1140                    );
1141                }
1142            }
1143        });
1144    }
1145}
1146
1147impl LlmProvider for RouterProvider {
1148    fn context_window(&self) -> Option<usize> {
1149        self.providers.first().and_then(LlmProvider::context_window)
1150    }
1151
1152    fn chat(
1153        &self,
1154        messages: &[Message],
1155    ) -> impl std::future::Future<Output = Result<String, LlmError>> + Send {
1156        let status_tx = self.status_tx.clone();
1157        let messages = messages.to_vec();
1158        let router = self.clone();
1159        // TODO: DRY — `chat` and `chat_stream` share the same fallback loop pattern.
1160        // Refactor into a shared helper once the API stabilizes.
1161        Box::pin(async move {
1162            // Increment turn counter once per top-level chat() call. All concurrent sub-calls
1163            // (tool schema fetches, embed probes) that re-enter chat() will see the same
1164            // turn_id via the shared Arc<AtomicU64>, enabling ASI debounce.
1165            let turn_id = router.turn_counter.fetch_add(1, Ordering::Relaxed);
1166
1167            if router.strategy == RouterStrategy::Cascade {
1168                // Cascade: pass Arc slice directly — providers are sorted at construction,
1169                // so no Vec allocation needed on the hot path.
1170                return router
1171                    .cascade_chat(&router.providers, &messages, status_tx)
1172                    .await;
1173            }
1174            if router.strategy == RouterStrategy::Bandit {
1175                return router.bandit_chat(&messages, status_tx).await;
1176            }
1177            let providers = router.ordered_providers();
1178
1179            // Pre-compute query embedding once for quality gate (fail-open on error).
1180            let query_text = messages
1181                .last()
1182                .map(Message::to_llm_content)
1183                .unwrap_or_default();
1184            let query_embedding = if router.quality_gate.is_some() && !query_text.is_empty() {
1185                router.embed(query_text).await.ok()
1186            } else {
1187                None
1188            };
1189
1190            // Best response seen so far (for quality gate exhaustion fallback, M2).
1191            let mut best_response: Option<(f32, String)> = None;
1192
1193            for p in &providers {
1194                let start = std::time::Instant::now();
1195                match p.chat(&messages).await {
1196                    Ok(r) => {
1197                        router.record_availability(
1198                            p.name(),
1199                            true,
1200                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1201                        );
1202
1203                        // Quality gate: check response-query embedding similarity.
1204                        if let (Some(threshold), Some(qemb)) =
1205                            (router.quality_gate, &query_embedding)
1206                        {
1207                            let resp_emb = router.embed(&r).await.ok();
1208                            let similarity = resp_emb
1209                                .as_ref()
1210                                .map_or(threshold, |e| cosine_similarity(qemb, e)); // fail-open: None → treat as passing
1211                            if similarity < threshold {
1212                                tracing::info!(
1213                                    provider = p.name(),
1214                                    score = similarity,
1215                                    threshold,
1216                                    "thompson_quality_fallback"
1217                                );
1218                                // Track best response seen so far.
1219                                let is_better = best_response
1220                                    .as_ref()
1221                                    .is_none_or(|(best, _)| similarity > *best);
1222                                if is_better {
1223                                    best_response = Some((similarity, r.clone()));
1224                                }
1225                                // Spawn ASI update even on quality failure.
1226                                router.spawn_asi_update(p.name(), r, turn_id);
1227                                continue;
1228                            }
1229                        }
1230
1231                        // Spawn ASI embedding update (fire-and-forget).
1232                        router.spawn_asi_update(p.name(), r.clone(), turn_id);
1233
1234                        return Ok(r);
1235                    }
1236                    Err(e) => {
1237                        router.record_availability(
1238                            p.name(),
1239                            false,
1240                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1241                        );
1242                        if let Some(ref tx) = status_tx {
1243                            let _ = tx.send(format!("router: {} failed, falling back", p.name()));
1244                        }
1245                        tracing::warn!(provider = p.name(), error = %e, "router fallback");
1246                    }
1247                }
1248            }
1249
1250            // All providers exhausted by quality gate: return best response seen (M2).
1251            if let Some((_, response)) = best_response {
1252                return Ok(response);
1253            }
1254
1255            Err(LlmError::NoProviders)
1256        })
1257    }
1258
1259    fn chat_stream(
1260        &self,
1261        messages: &[Message],
1262    ) -> impl std::future::Future<Output = Result<ChatStream, LlmError>> + Send {
1263        let status_tx = self.status_tx.clone();
1264        let messages = messages.to_vec();
1265        let router = self.clone();
1266        Box::pin(async move {
1267            if router.strategy == RouterStrategy::Cascade {
1268                // Cascade: pass Arc slice directly — no Vec allocation on the hot path.
1269                return router
1270                    .cascade_chat_stream(&router.providers, &messages, status_tx)
1271                    .await;
1272            }
1273            if router.strategy == RouterStrategy::Bandit {
1274                // Bandit stream: select provider then stream from it.
1275                // Reward is not recorded for streams (stream completion is async);
1276                // this is a known pre-1.0 limitation — same as Thompson stream mode.
1277                let query = messages
1278                    .last()
1279                    .map(super::provider::Message::to_llm_content)
1280                    .unwrap_or_default();
1281                let p = router
1282                    .bandit_select_provider(query)
1283                    .await
1284                    .ok_or(LlmError::NoProviders)?;
1285                return p.chat_stream(&messages).await;
1286            }
1287            let providers = router.ordered_providers();
1288            for p in &providers {
1289                let start = std::time::Instant::now();
1290                match p.chat_stream(&messages).await {
1291                    Ok(r) => {
1292                        // NOTE: success is recorded at stream-open time, not on stream
1293                        // completion. A provider that opens the stream but then fails
1294                        // mid-delivery still gets alpha += 1. This is a known pre-1.0
1295                        // limitation: fixing it requires wrapping ChatStream to intercept
1296                        // the completion/error signal, which adds latency on the hot path.
1297                        // Tracked in the adaptive-inference epic (CRIT-2).
1298                        router.record_availability(
1299                            p.name(),
1300                            true,
1301                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1302                        );
1303                        return Ok(r);
1304                    }
1305                    Err(e) => {
1306                        router.record_availability(
1307                            p.name(),
1308                            false,
1309                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1310                        );
1311                        if let Some(ref tx) = status_tx {
1312                            let _ = tx.send(format!("router: {} failed, falling back", p.name()));
1313                        }
1314                        tracing::warn!(provider = p.name(), error = %e, "router stream fallback");
1315                    }
1316                }
1317            }
1318            Err(LlmError::NoProviders)
1319        })
1320    }
1321
1322    fn supports_streaming(&self) -> bool {
1323        self.providers.iter().any(LlmProvider::supports_streaming)
1324    }
1325
1326    fn embed(
1327        &self,
1328        text: &str,
1329    ) -> impl std::future::Future<Output = Result<Vec<f32>, LlmError>> + Send {
1330        let providers = self.ordered_providers();
1331        let status_tx = self.status_tx.clone();
1332        let text = text.to_owned();
1333        let router = self.clone();
1334        Box::pin(async move {
1335            for p in &providers {
1336                if !p.supports_embeddings() {
1337                    continue;
1338                }
1339                let mut last_err: Option<LlmError> = None;
1340                for attempt in 0..=EMBED_MAX_RETRIES {
1341                    if attempt > 0 {
1342                        let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1343                        tracing::warn!(
1344                            provider = p.name(),
1345                            attempt,
1346                            delay_ms = delay,
1347                            "embed: rate limited, retrying after backoff"
1348                        );
1349                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1350                    }
1351                    let start = std::time::Instant::now();
1352                    match p.embed(&text).await {
1353                        Ok(r) => {
1354                            router.record_availability(
1355                                p.name(),
1356                                true,
1357                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1358                            );
1359                            return Ok(r);
1360                        }
1361                        Err(e) if e.is_invalid_input() => {
1362                            // The input itself is invalid — retrying on another provider
1363                            // would fail identically. Do not penalize provider reputation.
1364                            tracing::warn!(
1365                                provider = p.name(),
1366                                error = %e,
1367                                "embed: invalid input, not retrying on other providers"
1368                            );
1369                            return Err(e);
1370                        }
1371                        Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1372                            last_err = Some(e);
1373                        }
1374                        Err(e) => {
1375                            router.record_availability(
1376                                p.name(),
1377                                false,
1378                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1379                            );
1380                            if let Some(ref tx) = status_tx {
1381                                let _ = tx.send(format!(
1382                                    "router: {} embed failed, falling back",
1383                                    p.name()
1384                                ));
1385                            }
1386                            tracing::warn!(provider = p.name(), error = %e, "router embed fallback");
1387                            last_err = Some(e);
1388                            break;
1389                        }
1390                    }
1391                }
1392                // All retries exhausted for this provider (rate-limited every time).
1393                if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1394                    router.record_availability(p.name(), false, 0);
1395                    if let Some(ref tx) = status_tx {
1396                        let _ = tx.send(format!(
1397                            "router: {} embed rate limited, falling back",
1398                            p.name()
1399                        ));
1400                    }
1401                    tracing::warn!(
1402                        provider = p.name(),
1403                        "embed: rate limit retries exhausted, falling back"
1404                    );
1405                }
1406            }
1407            Err(LlmError::NoProviders)
1408        })
1409    }
1410
1411    fn embed_batch(
1412        &self,
1413        texts: &[&str],
1414    ) -> impl std::future::Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
1415        let providers = self.ordered_providers();
1416        let status_tx = self.status_tx.clone();
1417        let owned = owned_strs(texts);
1418        let router = self.clone();
1419        let semaphore = self.embed_semaphore.clone();
1420        Box::pin(async move {
1421            // Acquire embed semaphore permit before any HTTP work to cap concurrency.
1422            let _permit = if let Some(ref sem) = semaphore {
1423                Some(sem.acquire().await.map_err(|_| LlmError::NoProviders)?)
1424            } else {
1425                None
1426            };
1427            let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
1428            for p in &providers {
1429                if !p.supports_embeddings() {
1430                    continue;
1431                }
1432                let mut last_err: Option<LlmError> = None;
1433                for attempt in 0..=EMBED_MAX_RETRIES {
1434                    if attempt > 0 {
1435                        let delay = EMBED_BASE_DELAY_MS * (1u64 << (attempt - 1));
1436                        tracing::warn!(
1437                            provider = p.name(),
1438                            attempt,
1439                            delay_ms = delay,
1440                            "embed_batch: rate limited, retrying after backoff"
1441                        );
1442                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1443                    }
1444                    let start = std::time::Instant::now();
1445                    match p.embed_batch(&refs).await {
1446                        Ok(r) => {
1447                            router.record_availability(
1448                                p.name(),
1449                                true,
1450                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1451                            );
1452                            return Ok(r);
1453                        }
1454                        Err(e) if e.is_invalid_input() => {
1455                            tracing::warn!(
1456                                provider = p.name(),
1457                                error = %e,
1458                                "embed_batch: invalid input, not retrying on other providers"
1459                            );
1460                            return Err(e);
1461                        }
1462                        Err(e) if e.is_rate_limited() && attempt < EMBED_MAX_RETRIES => {
1463                            last_err = Some(e);
1464                        }
1465                        Err(e) => {
1466                            router.record_availability(
1467                                p.name(),
1468                                false,
1469                                u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1470                            );
1471                            if let Some(ref tx) = status_tx {
1472                                let _ = tx.send(format!(
1473                                    "router: {} embed_batch failed, falling back",
1474                                    p.name()
1475                                ));
1476                            }
1477                            tracing::warn!(
1478                                provider = p.name(),
1479                                error = %e,
1480                                "router embed_batch fallback"
1481                            );
1482                            last_err = Some(e);
1483                            break;
1484                        }
1485                    }
1486                }
1487                // All retries exhausted for this provider (rate-limited every time).
1488                if matches!(last_err, Some(ref e) if e.is_rate_limited()) {
1489                    router.record_availability(p.name(), false, 0);
1490                    if let Some(ref tx) = status_tx {
1491                        let _ = tx.send(format!(
1492                            "router: {} embed_batch rate limited, falling back",
1493                            p.name()
1494                        ));
1495                    }
1496                    tracing::warn!(
1497                        provider = p.name(),
1498                        "embed_batch: rate limit retries exhausted, falling back"
1499                    );
1500                }
1501            }
1502            Err(LlmError::NoProviders)
1503        })
1504    }
1505
1506    fn supports_embeddings(&self) -> bool {
1507        self.providers.iter().any(LlmProvider::supports_embeddings)
1508    }
1509
1510    #[allow(clippy::unnecessary_literal_bound)]
1511    fn name(&self) -> &str {
1512        "router"
1513    }
1514
1515    fn supports_tool_use(&self) -> bool {
1516        self.providers.iter().any(LlmProvider::supports_tool_use)
1517    }
1518
1519    fn list_models(&self) -> Vec<String> {
1520        self.providers
1521            .iter()
1522            .flat_map(super::provider::LlmProvider::list_models)
1523            .collect()
1524    }
1525
1526    #[allow(refining_impl_trait_reachable)]
1527    fn chat_with_tools(
1528        &self,
1529        messages: &[Message],
1530        tools: &[ToolDefinition],
1531    ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
1532        let messages = messages.to_vec();
1533        let tools = tools.to_vec();
1534        let status_tx = self.status_tx.clone();
1535        let router = self.clone();
1536        Box::pin(async move {
1537            // Bandit routing for tool calls: select a single provider, no quality escalation.
1538            if router.strategy == RouterStrategy::Bandit {
1539                let query = messages
1540                    .last()
1541                    .map(super::provider::Message::to_llm_content)
1542                    .unwrap_or_default();
1543                let p = router
1544                    .bandit_select_provider(query)
1545                    .await
1546                    .ok_or(LlmError::NoProviders)?;
1547                if !p.supports_tool_use() {
1548                    return Err(LlmError::NoProviders);
1549                }
1550                let result = p.chat_with_tools(&messages, &tools).await;
1551                if result.is_ok() {
1552                    *router.last_active_provider.lock() = Some(p.name().to_owned());
1553                }
1554                return result;
1555            }
1556
1557            // Cascade is intentionally skipped for tool calls: evaluating quality of
1558            // a tool-call response (structured JSON with tool name + args) requires
1559            // different heuristics than text quality. Skipping cascade for tool calls
1560            // avoids inappropriate escalation based on text signals (HIGH-04).
1561            let providers = router.ordered_providers();
1562            for p in &providers {
1563                if !p.supports_tool_use() {
1564                    continue;
1565                }
1566                let start = std::time::Instant::now();
1567                match p.chat_with_tools(&messages, &tools).await {
1568                    Ok(r) => {
1569                        router.record_availability(
1570                            p.name(),
1571                            true,
1572                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1573                        );
1574                        // Track which sub-provider served this tool call for reputation attribution.
1575                        *router.last_active_provider.lock() = Some(p.name().to_owned());
1576                        return Ok(r);
1577                    }
1578                    Err(e) => {
1579                        router.record_availability(
1580                            p.name(),
1581                            false,
1582                            u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1583                        );
1584                        if let Some(ref tx) = status_tx {
1585                            let _ = tx.send(format!(
1586                                "router: {} tool call failed, falling back",
1587                                p.name()
1588                            ));
1589                        }
1590                        tracing::warn!(provider = p.name(), error = %e, "router tool fallback");
1591                    }
1592                }
1593            }
1594            Err(LlmError::NoProviders)
1595        })
1596    }
1597
1598    fn debug_request_json(
1599        &self,
1600        messages: &[Message],
1601        tools: &[ToolDefinition],
1602        stream: bool,
1603    ) -> serde_json::Value {
1604        let candidate = if tools.is_empty() {
1605            self.ordered_providers().into_iter().next()
1606        } else {
1607            self.ordered_providers()
1608                .into_iter()
1609                .find(super::provider::LlmProvider::supports_tool_use)
1610        };
1611        candidate.map_or_else(
1612            || crate::provider::default_debug_request_json(messages, tools),
1613            |provider| provider.debug_request_json(messages, tools, stream),
1614        )
1615    }
1616
1617    fn last_cache_usage(&self) -> Option<(u64, u64)> {
1618        None
1619    }
1620}
1621
1622// ── Bandit routing helpers ────────────────────────────────────────────────────
1623
1624impl RouterProvider {
1625    /// Bandit `chat()` implementation: select provider, call, record reward.
1626    async fn bandit_chat(
1627        &self,
1628        messages: &[Message],
1629        status_tx: Option<StatusTx>,
1630    ) -> Result<String, LlmError> {
1631        let query = messages
1632            .last()
1633            .map(super::provider::Message::to_llm_content)
1634            .unwrap_or_default();
1635        let features = self.bandit_features(query.as_ref()).await;
1636
1637        let p = self
1638            .bandit_select_provider(query.as_ref())
1639            .await
1640            .ok_or(LlmError::NoProviders)?;
1641
1642        if let Some(ref tx) = status_tx {
1643            let _ = tx.send(format!("bandit: routing to {}", p.name()));
1644        }
1645
1646        let result = p.chat(messages).await;
1647        match &result {
1648            Ok(response) => {
1649                let verdict = heuristic_score(response);
1650                // Record reward even when embedding failed (use zero vector so the arm's
1651                // update count increments — prevents permanent cold-start on flaky embedders).
1652                let feat_ref: &[f32];
1653                let zero_vec: Vec<f32>;
1654                let dim = self.bandit_config.as_ref().map_or(32, |c| c.dim);
1655                if let Some(ref feat) = features {
1656                    feat_ref = feat;
1657                } else {
1658                    zero_vec = vec![0.0; dim];
1659                    feat_ref = &zero_vec;
1660                    tracing::debug!(
1661                        provider = p.name(),
1662                        "bandit: recording reward with zero features (embed unavailable)"
1663                    );
1664                }
1665                self.bandit_record_reward(p.name(), feat_ref, verdict.score, 0.0);
1666            }
1667            Err(e) => {
1668                tracing::warn!(provider = p.name(), error = %e, "bandit: provider failed");
1669            }
1670        }
1671        result
1672    }
1673}
1674
1675// ── Cascade routing helpers ───────────────────────────────────────────────────
1676
1677/// Outcome of evaluating one provider's response during cascade routing.
1678struct CascadeEvalResult {
1679    verdict: cascade::QualityVerdict,
1680    /// Updated token counter after adding this response's estimated cost.
1681    tokens_used: u32,
1682    /// Whether the token budget is now exhausted.
1683    budget_exhausted: bool,
1684}
1685
1686/// Evaluate a cascade response: score it, record the verdict in shared state, and
1687/// compute whether the token budget is exhausted.
1688async fn cascade_evaluate_response(
1689    provider_name: &str,
1690    response: &str,
1691    cfg: &CascadeRouterConfig,
1692    cascade_state: &Mutex<CascadeState>,
1693    tokens_used_before: u32,
1694    log_prefix: &str,
1695) -> CascadeEvalResult {
1696    let estimated_tokens =
1697        u32::try_from(zeph_common::text::estimate_tokens(response).max(1)).unwrap_or(u32::MAX);
1698    let tokens_used = tokens_used_before.saturating_add(estimated_tokens);
1699
1700    let verdict = RouterProvider::evaluate_quality(
1701        response,
1702        cfg.quality_threshold,
1703        cfg.classifier_mode,
1704        cfg.summary_provider.as_ref(),
1705    )
1706    .await;
1707
1708    {
1709        let mut state = cascade_state.lock();
1710        state.record(provider_name, verdict.score);
1711    }
1712
1713    tracing::debug!(
1714        provider = %provider_name,
1715        score = verdict.score,
1716        threshold = cfg.quality_threshold,
1717        should_escalate = verdict.should_escalate,
1718        reason = %verdict.reason,
1719        "{log_prefix}: quality verdict"
1720    );
1721
1722    let budget_exhausted = cfg
1723        .max_cascade_tokens
1724        .is_some_and(|budget| tokens_used >= budget);
1725
1726    CascadeEvalResult {
1727        verdict,
1728        tokens_used,
1729        budget_exhausted,
1730    }
1731}
1732
1733impl RouterProvider {
1734    /// Cascade chat: try providers in order, escalate on degenerate output.
1735    ///
1736    /// Returns the best-seen response if all providers fail or budget is exhausted.
1737    #[allow(clippy::too_many_lines)] // cascade loop: per-provider error/ok/budget/escalation branches are tightly coupled — extracting would obscure the control flow
1738    async fn cascade_chat(
1739        &self,
1740        providers: &[AnyProvider],
1741        messages: &[Message],
1742        status_tx: Option<StatusTx>,
1743    ) -> Result<String, LlmError> {
1744        let cfg = self
1745            .cascade_config
1746            .as_ref()
1747            .expect("cascade_config must be set");
1748        let cascade_state = self
1749            .cascade_state
1750            .as_ref()
1751            .expect("cascade_state must be set");
1752
1753        let mut escalations_remaining = cfg.max_escalations;
1754        let mut best: Option<(String, f64)> = None; // (response, score)
1755        let mut tokens_used: u32 = 0;
1756
1757        for (idx, p) in providers.iter().enumerate() {
1758            tracing::debug!(
1759                provider = %p.name(),
1760                attempt = idx + 1,
1761                total = providers.len(),
1762                classifier_mode = ?cfg.classifier_mode,
1763                quality_threshold = cfg.quality_threshold,
1764                "cascade: trying provider"
1765            );
1766            let start = std::time::Instant::now();
1767            match p.chat(messages).await {
1768                Err(e) => {
1769                    // Network/API error: record availability failure but don't consume escalation budget.
1770                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1771                    self.record_availability(p.name(), false, latency);
1772                    if let Some(tx) = &status_tx {
1773                        let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
1774                    }
1775                    tracing::warn!(provider = p.name(), error = %e, "cascade: provider error");
1776                }
1777                Ok(response) => {
1778                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1779
1780                    let eval = cascade_evaluate_response(
1781                        p.name(),
1782                        &response,
1783                        cfg,
1784                        cascade_state,
1785                        tokens_used,
1786                        "cascade",
1787                    )
1788                    .await;
1789                    tokens_used = eval.tokens_used;
1790                    let verdict = eval.verdict;
1791                    let budget_exhausted = eval.budget_exhausted;
1792
1793                    // Update best-seen response; skip empty strings to avoid silent failures.
1794                    let is_better = !response.is_empty()
1795                        && best
1796                            .as_ref()
1797                            .is_none_or(|(_, best_score)| verdict.score > *best_score);
1798                    if is_better {
1799                        tracing::debug!(
1800                            provider = %p.name(),
1801                            score = verdict.score,
1802                            "cascade: best_seen updated"
1803                        );
1804                        best = Some((response.clone(), verdict.score));
1805                    }
1806
1807                    let is_last = idx == providers.len() - 1;
1808
1809                    if !verdict.should_escalate
1810                        || is_last
1811                        || escalations_remaining == 0
1812                        || budget_exhausted
1813                    {
1814                        self.record_availability(p.name(), true, latency);
1815                        // When escalation is blocked (budget exhausted or escalation count
1816                        // at zero) and the current response would have triggered escalation,
1817                        // return the best-seen response instead of the current (possibly
1818                        // lower-quality) one.
1819                        if verdict.should_escalate
1820                            && (budget_exhausted || escalations_remaining == 0)
1821                        {
1822                            let best_response = best.take().map_or(response, |(r, _)| r);
1823                            tracing::info!(
1824                                tokens_used,
1825                                budget = cfg.max_cascade_tokens,
1826                                escalations_remaining,
1827                                "cascade: escalation blocked, returning best response"
1828                            );
1829                            return Ok(best_response);
1830                        }
1831                        return Ok(response);
1832                    }
1833
1834                    // Escalate: record availability success (provider worked, just low quality).
1835                    self.record_availability(p.name(), true, latency);
1836                    escalations_remaining -= 1;
1837
1838                    if let Some(tx) = &status_tx {
1839                        let _ = tx.send(format!(
1840                            "cascade: {} quality {:.2} < {:.2}, escalating ({} left)",
1841                            p.name(),
1842                            verdict.score,
1843                            cfg.quality_threshold,
1844                            escalations_remaining
1845                        ));
1846                    }
1847                    tracing::info!(
1848                        provider = %p.name(),
1849                        score = verdict.score,
1850                        threshold = cfg.quality_threshold,
1851                        escalations_remaining,
1852                        "cascade: escalating to next provider"
1853                    );
1854                }
1855            }
1856        }
1857
1858        // All providers tried — return best-seen response, or NoProviders if none worked.
1859        if let Some((_, score)) = &best {
1860            tracing::info!(
1861                score,
1862                "cascade: all providers exhausted, returning best-seen response"
1863            );
1864        } else {
1865            tracing::warn!("cascade: all providers failed, no response available");
1866        }
1867        best.map(|(r, _)| r).ok_or(LlmError::NoProviders)
1868    }
1869
1870    /// Cascade `chat_stream`: buffer cheap response, classify, escalate or replay.
1871    ///
1872    /// # Streaming latency tradeoff
1873    ///
1874    /// The first N-1 providers are fully buffered before classification. If escalation
1875    /// occurs, the user experiences: cheap model's full response time + expensive model's
1876    /// TTFT. This is strictly worse than direct routing to the expensive model for
1877    /// hard queries. Acceptable for v1; see CRIT-01 in critic handoff for alternatives.
1878    #[allow(clippy::too_many_lines)] // sequential cascade semantics: buffer→classify→escalate
1879    async fn cascade_chat_stream(
1880        &self,
1881        providers: &[AnyProvider],
1882        messages: &[Message],
1883        status_tx: Option<StatusTx>,
1884    ) -> Result<ChatStream, LlmError> {
1885        let cfg = self
1886            .cascade_config
1887            .as_ref()
1888            .expect("cascade_config must be set");
1889        let cascade_state = self
1890            .cascade_state
1891            .as_ref()
1892            .expect("cascade_state must be set");
1893
1894        let mut escalations_remaining = cfg.max_escalations;
1895        let mut tokens_used: u32 = 0;
1896        // Tracks the highest-scoring fully-buffered response seen so far.
1897        // Only populated from the early provider loop; the last provider streams
1898        // directly without buffering or scoring, so it never updates best_seen.
1899        let mut best_seen: Option<(String, f64)> = None;
1900
1901        // Try all providers except the last without consuming the escalation budget
1902        // for errors (only quality failures consume it).
1903        let (last, early) = providers.split_last().ok_or(LlmError::NoProviders)?;
1904
1905        for (idx, p) in early.iter().enumerate() {
1906            tracing::debug!(
1907                provider = %p.name(),
1908                attempt = idx + 1,
1909                total = providers.len(),
1910                classifier_mode = ?cfg.classifier_mode,
1911                quality_threshold = cfg.quality_threshold,
1912                "cascade stream: trying provider (buffered)"
1913            );
1914            // Buffer response to classify quality.
1915            let start = std::time::Instant::now();
1916            let stream = match p.chat_stream(messages).await {
1917                Err(e) => {
1918                    let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1919                    self.record_availability(p.name(), false, latency);
1920                    tracing::warn!(provider = p.name(), error = %e, "cascade stream: provider error");
1921                    if let Some(tx) = &status_tx {
1922                        let _ = tx.send(format!("cascade: {} unavailable, trying next", p.name()));
1923                    }
1924                    continue;
1925                }
1926                Ok(s) => s,
1927            };
1928
1929            // Collect the full stream.
1930            let buffered = collect_stream(stream).await;
1931            let latency = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
1932
1933            match buffered {
1934                Err(e) => {
1935                    // Stream failed mid-delivery; treat as availability failure.
1936                    self.record_availability(p.name(), false, latency);
1937                    tracing::warn!(provider = p.name(), error = %e, "cascade stream: stream error");
1938                }
1939                Ok(text) => {
1940                    let eval = cascade_evaluate_response(
1941                        p.name(),
1942                        &text,
1943                        cfg,
1944                        cascade_state,
1945                        tokens_used,
1946                        "cascade stream",
1947                    )
1948                    .await;
1949                    tokens_used = eval.tokens_used;
1950                    let verdict = eval.verdict;
1951                    let budget_exhausted = eval.budget_exhausted;
1952
1953                    // Track the best response seen so far across early providers.
1954                    // Skip empty strings to avoid returning silent failures on all-fail fallback.
1955                    let is_better = !text.is_empty()
1956                        && best_seen
1957                            .as_ref()
1958                            .is_none_or(|(_, best_score)| verdict.score > *best_score);
1959                    if is_better {
1960                        tracing::debug!(
1961                            provider = %p.name(),
1962                            score = verdict.score,
1963                            "cascade stream: best_seen updated"
1964                        );
1965                        best_seen = Some((text.clone(), verdict.score));
1966                    }
1967
1968                    if !verdict.should_escalate || escalations_remaining == 0 || budget_exhausted {
1969                        self.record_availability(p.name(), true, latency);
1970
1971                        // When escalation is blocked (budget exhausted or escalation count
1972                        // at zero) and the current response would have triggered escalation,
1973                        // return the best-seen response instead of the current (possibly
1974                        // lower-quality) one.
1975                        let response_text = if verdict.should_escalate
1976                            && (budget_exhausted || escalations_remaining == 0)
1977                        {
1978                            tracing::info!(
1979                                tokens_used,
1980                                budget = cfg.max_cascade_tokens,
1981                                escalations_remaining,
1982                                "cascade stream: escalation blocked, returning best response"
1983                            );
1984                            best_seen.take().map_or(text, |(r, _)| r)
1985                        } else {
1986                            text
1987                        };
1988
1989                        let stream: ChatStream = Box::pin(tokio_stream::once(Ok(
1990                            crate::provider::StreamChunk::Content(response_text),
1991                        )));
1992                        return Ok(stream);
1993                    }
1994
1995                    // Escalate.
1996                    self.record_availability(p.name(), true, latency);
1997                    escalations_remaining -= 1;
1998
1999                    if let Some(tx) = &status_tx {
2000                        let _ = tx.send(format!(
2001                            "cascade: {} quality {:.2} < {:.2}, escalating",
2002                            p.name(),
2003                            verdict.score,
2004                            cfg.quality_threshold,
2005                        ));
2006                    }
2007                    tracing::info!(
2008                        provider = %p.name(),
2009                        score = verdict.score,
2010                        threshold = cfg.quality_threshold,
2011                        escalations_remaining,
2012                        "cascade stream: escalating to next provider"
2013                    );
2014                }
2015            }
2016        }
2017
2018        // Last provider: stream directly without buffering.
2019        // Note: if the stream itself fails mid-delivery (after Ok(stream) is returned),
2020        // there is no fallback to best_seen — the caller receives a partial response.
2021        // This is a pre-existing limitation; fixing it would require wrapping the stream.
2022        tracing::debug!(
2023            provider = %last.name(),
2024            attempt = providers.len(),
2025            total = providers.len(),
2026            "cascade stream: trying last provider (streaming, no classification)"
2027        );
2028        let start = std::time::Instant::now();
2029        match last.chat_stream(messages).await {
2030            Ok(stream) => {
2031                self.record_availability(
2032                    last.name(),
2033                    true,
2034                    u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2035                );
2036                Ok(stream)
2037            }
2038            Err(e) => {
2039                self.record_availability(
2040                    last.name(),
2041                    false,
2042                    u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
2043                );
2044                // If we have a best-seen response from an early provider, return it
2045                // instead of propagating the last provider's error.
2046                if let Some((best_text, _)) = best_seen {
2047                    tracing::info!(
2048                        "cascade stream: last provider failed, returning best-seen response"
2049                    );
2050                    let stream: ChatStream = Box::pin(tokio_stream::once(Ok(
2051                        crate::provider::StreamChunk::Content(best_text),
2052                    )));
2053                    return Ok(stream);
2054                }
2055                Err(e)
2056            }
2057        }
2058    }
2059}
2060
2061/// Maximum bytes buffered per stream in cascade routing (SEC-CASCADE-03).
2062const CASCADE_STREAM_MAX_BYTES: usize = 1024 * 1024; // 1 MiB
2063
2064/// Collect a `ChatStream` into a String, concatenating only `Content` chunks.
2065///
2066/// Returns `Err` if the accumulated buffer exceeds [`CASCADE_STREAM_MAX_BYTES`].
2067async fn collect_stream(stream: ChatStream) -> Result<String, LlmError> {
2068    use tokio_stream::StreamExt as _;
2069
2070    let mut stream = stream;
2071    let mut buf = String::new();
2072    while let Some(chunk) = stream.next().await {
2073        match chunk? {
2074            crate::provider::StreamChunk::Content(c) => {
2075                if buf.len() + c.len() > CASCADE_STREAM_MAX_BYTES {
2076                    return Err(LlmError::Other(
2077                        "cascade: stream response exceeds 1 MiB buffer limit".into(),
2078                    ));
2079                }
2080                buf.push_str(&c);
2081            }
2082            crate::provider::StreamChunk::Thinking(_)
2083            | crate::provider::StreamChunk::Compaction(_)
2084            | crate::provider::StreamChunk::ToolUse(_) => {}
2085        }
2086    }
2087    Ok(buf)
2088}
2089
2090#[cfg(test)]
2091mod tests {
2092    use super::*;
2093    use crate::provider::Role;
2094
2095    #[test]
2096    fn empty_router_name() {
2097        let r = RouterProvider::new(vec![]);
2098        assert_eq!(r.name(), "router");
2099    }
2100
2101    #[test]
2102    fn empty_router_supports_nothing() {
2103        let r = RouterProvider::new(vec![]);
2104        assert!(!r.supports_streaming());
2105        assert!(!r.supports_embeddings());
2106        assert!(!r.supports_tool_use());
2107    }
2108
2109    #[test]
2110    fn empty_router_context_window_none() {
2111        let r = RouterProvider::new(vec![]);
2112        assert!(r.context_window().is_none());
2113    }
2114
2115    #[tokio::test]
2116    async fn empty_router_chat_returns_no_providers() {
2117        let r = RouterProvider::new(vec![]);
2118        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2119        let err = r.chat(&msgs).await.unwrap_err();
2120        assert!(matches!(err, LlmError::NoProviders));
2121    }
2122
2123    #[tokio::test]
2124    async fn empty_router_chat_stream_returns_no_providers() {
2125        let r = RouterProvider::new(vec![]);
2126        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2127        let result = r.chat_stream(&msgs).await;
2128        assert!(matches!(result, Err(LlmError::NoProviders)));
2129    }
2130
2131    #[tokio::test]
2132    async fn empty_router_embed_returns_no_providers() {
2133        let r = RouterProvider::new(vec![]);
2134        let err = r.embed("test").await.unwrap_err();
2135        assert!(matches!(err, LlmError::NoProviders));
2136    }
2137
2138    #[tokio::test]
2139    async fn empty_router_chat_with_tools_returns_no_providers() {
2140        let r = RouterProvider::new(vec![]);
2141        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2142        let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2143        assert!(matches!(err, LlmError::NoProviders));
2144    }
2145
2146    #[tokio::test]
2147    async fn router_falls_back_on_unreachable() {
2148        use crate::ollama::OllamaProvider;
2149
2150        let p1 = AnyProvider::Ollama(OllamaProvider::new(
2151            "http://127.0.0.1:1",
2152            "m".into(),
2153            "e".into(),
2154        ));
2155        let p2 = AnyProvider::Ollama(OllamaProvider::new(
2156            "http://127.0.0.1:2",
2157            "m".into(),
2158            "e".into(),
2159        ));
2160        let r = RouterProvider::new(vec![p1, p2]);
2161        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2162        let err = r.chat(&msgs).await.unwrap_err();
2163        assert!(matches!(err, LlmError::NoProviders));
2164    }
2165
2166    #[test]
2167    fn router_with_streaming_provider() {
2168        use crate::ollama::OllamaProvider;
2169
2170        let p = AnyProvider::Ollama(OllamaProvider::new(
2171            "http://127.0.0.1:1",
2172            "m".into(),
2173            "e".into(),
2174        ));
2175        let r = RouterProvider::new(vec![p]);
2176        assert!(r.supports_streaming());
2177        assert!(r.supports_embeddings());
2178    }
2179
2180    #[test]
2181    fn clone_preserves_providers() {
2182        use crate::ollama::OllamaProvider;
2183
2184        let p = AnyProvider::Ollama(OllamaProvider::new(
2185            "http://127.0.0.1:1",
2186            "m".into(),
2187            "e".into(),
2188        ));
2189        let r = RouterProvider::new(vec![p]);
2190        let c = r.clone();
2191        assert_eq!(c.providers.len(), 1);
2192        assert_eq!(c.name(), "router");
2193    }
2194
2195    #[test]
2196    fn last_cache_usage_returns_none() {
2197        let r = RouterProvider::new(vec![]);
2198        assert!(r.last_cache_usage().is_none());
2199    }
2200
2201    #[test]
2202    fn thompson_strategy_is_set() {
2203        let r = RouterProvider::new(vec![]).with_thompson(None);
2204        assert_eq!(r.strategy, RouterStrategy::Thompson);
2205        assert!(r.thompson.is_some());
2206    }
2207
2208    #[test]
2209    fn save_thompson_state_noop_without_thompson() {
2210        let r = RouterProvider::new(vec![]);
2211        r.save_thompson_state(); // should not panic
2212    }
2213
2214    #[test]
2215    fn thompson_ordered_providers_empty() {
2216        let r = RouterProvider::new(vec![]).with_thompson(None);
2217        let ordered = r.ordered_providers();
2218        assert!(ordered.is_empty());
2219    }
2220
2221    #[test]
2222    fn concurrent_record_outcome_does_not_deadlock() {
2223        use std::sync::Arc;
2224        let r = Arc::new(RouterProvider::new(vec![]).with_thompson(None));
2225        let handles: Vec<_> = (0..8)
2226            .map(|i| {
2227                let router = Arc::clone(&r);
2228                std::thread::spawn(move || {
2229                    router.record_availability(&format!("p{i}"), i % 2 == 0, 10);
2230                })
2231            })
2232            .collect();
2233        for h in handles {
2234            h.join().expect("thread panicked");
2235        }
2236        // If we reach here, no deadlock occurred.
2237        let stats = r.thompson_stats();
2238        assert_eq!(stats.len(), 8);
2239    }
2240
2241    // ── Cascade tests ──────────────────────────────────────────────────────────
2242
2243    #[test]
2244    fn cascade_strategy_is_set() {
2245        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2246        assert_eq!(r.strategy, RouterStrategy::Cascade);
2247        assert!(r.cascade_state.is_some());
2248        assert!(r.cascade_config.is_some());
2249    }
2250
2251    #[test]
2252    fn cascade_ordered_providers_preserves_chain_order() {
2253        use crate::ollama::OllamaProvider;
2254        let p1 = AnyProvider::Ollama(OllamaProvider::new(
2255            "http://127.0.0.1:1",
2256            "a".into(),
2257            String::new(),
2258        ));
2259        let p2 = AnyProvider::Ollama(OllamaProvider::new(
2260            "http://127.0.0.1:2",
2261            "b".into(),
2262            String::new(),
2263        ));
2264        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2265        let ordered = r.ordered_providers();
2266        assert_eq!(ordered.len(), 2);
2267    }
2268
2269    #[tokio::test]
2270    async fn cascade_empty_router_returns_no_providers() {
2271        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig::default());
2272        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2273        let err = r.chat(&msgs).await.unwrap_err();
2274        assert!(matches!(err, LlmError::NoProviders));
2275    }
2276
2277    #[tokio::test]
2278    async fn cascade_returns_best_seen_when_all_fail_after_good_response() {
2279        use crate::mock::MockProvider;
2280
2281        // Provider 1: returns low-quality response (short "ok", triggers escalation at 0.9 threshold)
2282        let cheap =
2283            AnyProvider::Mock(MockProvider::with_responses(vec!["ok".to_owned()]).with_delay(0));
2284        // Provider 2: fails with availability error
2285        let expensive = AnyProvider::Mock(MockProvider::failing());
2286
2287        let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2288            quality_threshold: 0.9, // high threshold ensures "ok" fails quality check
2289            max_escalations: 2,
2290            ..CascadeRouterConfig::default()
2291        });
2292        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2293        // Should return "ok" from cheap provider (best-seen), not NoProviders.
2294        let result = r.chat(&msgs).await.unwrap();
2295        assert_eq!(result, "ok");
2296    }
2297
2298    #[tokio::test]
2299    async fn cascade_accepts_good_quality_response() {
2300        use crate::mock::MockProvider;
2301
2302        let good_response = "This is a comprehensive, well-structured response that provides \
2303            detailed information about the topic. It covers multiple aspects and explains \
2304            the reasoning clearly with proper sentence structure.";
2305
2306        let cheap = AnyProvider::Mock(
2307            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2308        );
2309        // second provider should never be called
2310        let expensive = AnyProvider::Mock(MockProvider::failing());
2311
2312        let r = RouterProvider::new(vec![cheap, expensive]).with_cascade(CascadeRouterConfig {
2313            quality_threshold: 0.5,
2314            max_escalations: 1,
2315            ..CascadeRouterConfig::default()
2316        });
2317        let msgs = vec![Message::from_legacy(Role::User, "explain something")];
2318        let result = r.chat(&msgs).await.unwrap();
2319        assert_eq!(result, good_response);
2320    }
2321
2322    #[tokio::test]
2323    async fn cascade_max_escalations_budget_exhausted_returns_last_attempted() {
2324        use crate::mock::MockProvider;
2325
2326        // All three providers return degenerate response "x" but budget limits to 1 escalation.
2327        // p1 -> escalation budget 1 -> p2 -> budget=0 -> accept p2's response (not p3).
2328        let p1 =
2329            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2330        let p2 =
2331            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2332        let p3 = AnyProvider::Mock(MockProvider::failing()); // should never be reached
2333
2334        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2335            quality_threshold: 0.9,
2336            max_escalations: 1, // only 1 escalation allowed
2337            ..CascadeRouterConfig::default()
2338        });
2339        let msgs = vec![Message::from_legacy(Role::User, "test")];
2340        let result = r.chat(&msgs).await.unwrap();
2341        assert_eq!(result, "x");
2342    }
2343
2344    #[tokio::test]
2345    async fn cascade_token_budget_stops_escalation() {
2346        use crate::mock::MockProvider;
2347
2348        let p1 =
2349            AnyProvider::Mock(MockProvider::with_responses(vec!["x".to_owned()]).with_delay(0));
2350        let p2 = AnyProvider::Mock(MockProvider::failing()); // should not be reached
2351
2352        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2353            quality_threshold: 0.9, // "x" will fail quality
2354            max_escalations: 5,
2355            max_cascade_tokens: Some(1), // 1 token budget — exhausted after first response (~4 chars / 4 = 0 + 1 min)
2356            ..CascadeRouterConfig::default()
2357        });
2358        let msgs = vec![Message::from_legacy(Role::User, "test")];
2359        let result = r.chat(&msgs).await.unwrap();
2360        assert_eq!(result, "x"); // returned despite low quality due to token budget
2361    }
2362
2363    #[tokio::test]
2364    async fn cascade_budget_returns_best_seen_not_current() {
2365        use crate::mock::MockProvider;
2366
2367        // p1 returns a decent response, p2 returns a worse one but exhausts the budget.
2368        // With budget_exhausted, we should get the best-seen (p1) not the current (p2).
2369        let good_response = "This is a reasonable response with enough content to score well.";
2370        let bad_response = "x"; // degenerate, score << good_response
2371
2372        let p1 = AnyProvider::Mock(
2373            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2374        );
2375        let p2 = AnyProvider::Mock(
2376            MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2377        );
2378
2379        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2380            quality_threshold: 0.95, // both fail quality check but good > bad
2381            max_escalations: 5,
2382            max_cascade_tokens: Some(1), // budget exhausted after p1 (1 token min)
2383            ..CascadeRouterConfig::default()
2384        });
2385        let msgs = vec![Message::from_legacy(Role::User, "test")];
2386        // p1 exhausts the budget; should return p1's response (better), not p2's (worse).
2387        // Note: p2 is reached since budget check happens AFTER p1's response is processed
2388        // and p1 fails quality. Budget exhausted at p2 → return best-seen (p1).
2389        let result = r.chat(&msgs).await.unwrap();
2390        // The result must not be the degenerate "x" response.
2391        assert_ne!(result, bad_response, "should return best-seen, not current");
2392    }
2393
2394    #[tokio::test]
2395    async fn cascade_escalations_exhausted_returns_best_seen_not_current() {
2396        use crate::mock::MockProvider;
2397
2398        // p1: decent response, fails quality at 0.95 → escalates (escalations_remaining: 1 → 0)
2399        // p2: degenerate "x", fails quality → escalations_remaining == 0 → blocked → best_seen wins
2400        let good_response = "This is a reasonable response with enough content to score well.";
2401        let bad_response = "x";
2402
2403        let p1 = AnyProvider::Mock(
2404            MockProvider::with_responses(vec![good_response.to_owned()]).with_delay(0),
2405        );
2406        let p2 = AnyProvider::Mock(
2407            MockProvider::with_responses(vec![bad_response.to_owned()]).with_delay(0),
2408        );
2409        let p3 = AnyProvider::Mock(MockProvider::failing()); // should not be reached
2410
2411        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2412            quality_threshold: 0.95, // both fail quality; p1 score > p2 score
2413            max_escalations: 1,      // p1 escalates (budget: 1→0), p2 is blocked
2414            ..CascadeRouterConfig::default()
2415        });
2416        let msgs = vec![Message::from_legacy(Role::User, "test")];
2417        let result = r.chat(&msgs).await.unwrap();
2418        assert_eq!(
2419            result, good_response,
2420            "should return best-seen (p1), not the degenerate current response (p2)"
2421        );
2422        assert_ne!(
2423            result, bad_response,
2424            "must not return degenerate p2 response"
2425        );
2426    }
2427
2428    #[tokio::test]
2429    async fn cascade_stream_escalations_exhausted_returns_best_seen_not_current() {
2430        use crate::mock::MockProvider;
2431
2432        // Same scenario as above but for cascade_chat_stream.
2433        // p1: decent response, fails quality at 0.95 → escalates (escalations_remaining: 1 → 0)
2434        // p2: degenerate "x", fails quality → escalations_remaining == 0 → return best_seen
2435        let good_response = "This is a reasonable response with enough content to score well.";
2436        let bad_response = "x";
2437
2438        let p1 = AnyProvider::Mock(
2439            MockProvider::with_responses(vec![good_response.to_owned()])
2440                .with_delay(0)
2441                .with_streaming(),
2442        );
2443        let p2 = AnyProvider::Mock(
2444            MockProvider::with_responses(vec![bad_response.to_owned()])
2445                .with_delay(0)
2446                .with_streaming(),
2447        );
2448        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, should not be reached
2449
2450        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2451            quality_threshold: 0.95, // both fail quality; p1 score > p2 score
2452            max_escalations: 1,      // p1 escalates (budget: 1→0), p2 is blocked
2453            ..CascadeRouterConfig::default()
2454        });
2455        let msgs = vec![Message::from_legacy(Role::User, "test")];
2456        let stream = r.chat_stream(&msgs).await.unwrap();
2457        let collected = collect_stream(stream).await.unwrap();
2458        assert_eq!(
2459            collected, good_response,
2460            "should return best-seen (p1), not the degenerate current response (p2)"
2461        );
2462        assert_ne!(
2463            collected, bad_response,
2464            "must not return degenerate p2 response"
2465        );
2466    }
2467
2468    #[tokio::test]
2469    async fn cascade_all_providers_fail_returns_no_providers() {
2470        use crate::mock::MockProvider;
2471
2472        let p1 = AnyProvider::Mock(MockProvider::failing());
2473        let p2 = AnyProvider::Mock(MockProvider::failing());
2474
2475        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2476        let msgs = vec![Message::from_legacy(Role::User, "test")];
2477        let err = r.chat(&msgs).await.unwrap_err();
2478        assert!(matches!(err, LlmError::NoProviders));
2479    }
2480
2481    #[tokio::test]
2482    async fn cascade_stream_good_quality_no_escalation() {
2483        use crate::mock::MockProvider;
2484
2485        let good = "This is a well-formed response with sufficient length and coherent structure.";
2486        let p1 = AnyProvider::Mock(
2487            MockProvider::with_responses(vec![good.to_owned()])
2488                .with_delay(0)
2489                .with_streaming(),
2490        );
2491        let p2 = AnyProvider::Mock(MockProvider::failing());
2492
2493        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2494            quality_threshold: 0.5,
2495            max_escalations: 1,
2496            ..CascadeRouterConfig::default()
2497        });
2498        let msgs = vec![Message::from_legacy(Role::User, "q")];
2499        let stream = r.chat_stream(&msgs).await.unwrap();
2500        let collected = collect_stream(stream).await.unwrap();
2501        assert_eq!(collected, good);
2502    }
2503
2504    #[tokio::test]
2505    async fn cascade_stream_escalates_to_last_provider() {
2506        use crate::mock::MockProvider;
2507
2508        let bad = "x"; // low quality, should escalate
2509        let good = "This is the expensive model's comprehensive response.";
2510        let p1 = AnyProvider::Mock(
2511            MockProvider::with_responses(vec![bad.to_owned()])
2512                .with_delay(0)
2513                .with_streaming(),
2514        );
2515        let p2 = AnyProvider::Mock(
2516            MockProvider::with_responses(vec![good.to_owned()])
2517                .with_delay(0)
2518                .with_streaming(),
2519        );
2520
2521        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2522            quality_threshold: 0.9, // "x" fails quality
2523            max_escalations: 1,
2524            ..CascadeRouterConfig::default()
2525        });
2526        let msgs = vec![Message::from_legacy(Role::User, "q")];
2527        let stream = r.chat_stream(&msgs).await.unwrap();
2528        let collected = collect_stream(stream).await.unwrap();
2529        assert_eq!(collected, good);
2530    }
2531
2532    #[tokio::test]
2533    async fn cascade_stream_budget_returns_best_seen() {
2534        use crate::mock::MockProvider;
2535
2536        // Three providers: early=[p1, p2], last=p3.
2537        // p1 returns a decent response (fails quality threshold at 0.95, triggers escalation).
2538        // Budget is set to 1 token, so it is exhausted immediately after p1 processes.
2539        // best_seen = p1's response; budget_exhausted + should_escalate → return best_seen.
2540        let good_response = "This is a reasonable response with enough content to score well.";
2541        let bad_response = "x"; // degenerate, score << good_response
2542
2543        let p1 = AnyProvider::Mock(
2544            MockProvider::with_responses(vec![good_response.to_owned()])
2545                .with_delay(0)
2546                .with_streaming(),
2547        );
2548        let p2 = AnyProvider::Mock(
2549            MockProvider::with_responses(vec![bad_response.to_owned()])
2550                .with_delay(0)
2551                .with_streaming(),
2552        );
2553        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
2554
2555        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2556            quality_threshold: 0.95, // p1 fails quality check → triggers escalation path
2557            max_escalations: 5,
2558            max_cascade_tokens: Some(1), // budget exhausted after p1 (1 token min)
2559            ..CascadeRouterConfig::default()
2560        });
2561        let msgs = vec![Message::from_legacy(Role::User, "test")];
2562        let stream = r.chat_stream(&msgs).await.unwrap();
2563        let collected = collect_stream(stream).await.unwrap();
2564        // Must return best-seen (p1's good response).
2565        assert_eq!(
2566            collected, good_response,
2567            "should return best-seen p1 response when budget exhausted"
2568        );
2569    }
2570
2571    #[tokio::test]
2572    async fn cascade_stream_budget_returns_best_seen_not_current() {
2573        use crate::mock::MockProvider;
2574
2575        // Four providers: early=[p1, p2, p3], last=p4.
2576        // p1 returns a good response, fails quality at 0.95 (score ~0.6), escalates; budget not yet exhausted.
2577        // p2 returns a degenerate response "x", fails quality, exhausts the budget.
2578        // At budget exhaustion: best_seen = p1 (higher score), current = p2's "x".
2579        // Must return best_seen (p1), not current (p2).
2580        let good_response = "This is a reasonable response with enough content to score well.";
2581        let bad_response = "x"; // 1 char → estimated_tokens = max(1/4, 1) = 1
2582
2583        let p1 = AnyProvider::Mock(
2584            MockProvider::with_responses(vec![good_response.to_owned()])
2585                .with_delay(0)
2586                .with_streaming(),
2587        );
2588        let p2 = AnyProvider::Mock(
2589            MockProvider::with_responses(vec![bad_response.to_owned()])
2590                .with_delay(0)
2591                .with_streaming(),
2592        );
2593        let p3 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
2594        let p4 = AnyProvider::Mock(MockProvider::failing()); // last provider, not reached
2595
2596        // Budget = 20: p1 uses ~16 tokens (65 chars / 4), p2 uses 1 → total 17 ≥ 20? No.
2597        // Use budget = 17 so p2 exhausts it.
2598        let r = RouterProvider::new(vec![p1, p2, p3, p4]).with_cascade(CascadeRouterConfig {
2599            quality_threshold: 0.95, // both fail; p1 score > p2 score
2600            max_escalations: 5,
2601            max_cascade_tokens: Some(17), // p1 uses 16, p2 uses 1 → total 17 ≥ 17 after p2
2602            ..CascadeRouterConfig::default()
2603        });
2604        let msgs = vec![Message::from_legacy(Role::User, "test")];
2605        let stream = r.chat_stream(&msgs).await.unwrap();
2606        let collected = collect_stream(stream).await.unwrap();
2607        // Must return p1 (best_seen), not p2 (current at time of budget exhaustion).
2608        assert_eq!(
2609            collected, good_response,
2610            "should return best-seen (p1), not current degenerate (p2)"
2611        );
2612        assert_ne!(
2613            collected, bad_response,
2614            "must not return the degenerate p2 response"
2615        );
2616    }
2617
2618    #[tokio::test]
2619    async fn cascade_stream_last_fails_returns_best_seen() {
2620        use crate::mock::MockProvider;
2621
2622        // Two providers: early=[p1], last=p2.
2623        // p1 returns a low-quality response that triggers escalation.
2624        // p2 (last) fails with an error.
2625        // Should return p1's response (best-seen) instead of propagating the error.
2626        let low_quality = "ok"; // short, triggers escalation at 0.9 threshold
2627        let p1 = AnyProvider::Mock(
2628            MockProvider::with_responses(vec![low_quality.to_owned()])
2629                .with_delay(0)
2630                .with_streaming(),
2631        );
2632        let p2 = AnyProvider::Mock(MockProvider::failing()); // last provider fails
2633
2634        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2635            quality_threshold: 0.9, // "ok" fails quality, triggers escalation
2636            max_escalations: 2,
2637            ..CascadeRouterConfig::default()
2638        });
2639        let msgs = vec![Message::from_legacy(Role::User, "hello")];
2640        let stream = r.chat_stream(&msgs).await.unwrap();
2641        let collected = collect_stream(stream).await.unwrap();
2642        assert_eq!(collected, low_quality);
2643    }
2644
2645    #[tokio::test]
2646    async fn cascade_stream_all_fail_returns_error() {
2647        use crate::mock::MockProvider;
2648
2649        // Two providers, both fail. No best_seen accumulated.
2650        // p1 is early (errors → continue), p2 is last (errors → propagated).
2651        // The last provider's error must be propagated, not swallowed.
2652        let p1 = AnyProvider::Mock(MockProvider::failing());
2653        let p2 = AnyProvider::Mock(MockProvider::failing());
2654
2655        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2656        let msgs = vec![Message::from_legacy(Role::User, "test")];
2657        let result = r.chat_stream(&msgs).await;
2658        assert!(
2659            result.is_err(),
2660            "expected error when all providers fail with no best_seen"
2661        );
2662    }
2663
2664    #[test]
2665    fn cascade_config_default_values() {
2666        let cfg = CascadeRouterConfig::default();
2667        assert!((cfg.quality_threshold - 0.5).abs() < f64::EPSILON);
2668        assert_eq!(cfg.max_escalations, 2);
2669        assert_eq!(cfg.window_size, 50);
2670        assert!(cfg.max_cascade_tokens.is_none());
2671        assert_eq!(cfg.classifier_mode, cascade::ClassifierMode::Heuristic);
2672    }
2673
2674    #[test]
2675    fn evaluate_heuristic_empty_should_escalate_above_threshold() {
2676        let verdict = RouterProvider::evaluate_heuristic("", 0.05);
2677        // score = 0.0, threshold = 0.05 → should_escalate = true
2678        assert!(verdict.should_escalate);
2679    }
2680
2681    #[test]
2682    fn evaluate_heuristic_good_response_does_not_escalate() {
2683        let text = "The answer to your question is straightforward. Consider the options and pick the best one.";
2684        let verdict = RouterProvider::evaluate_heuristic(text, 0.5);
2685        assert!(!verdict.should_escalate, "score={}", verdict.score);
2686    }
2687
2688    /// Empty string from the only provider must not be stored as `best_seen`.
2689    /// When all providers fail or return empty, the caller should get an error,
2690    /// not a silent empty response.
2691    #[tokio::test]
2692    async fn cascade_empty_response_not_stored_as_best_seen() {
2693        use crate::mock::MockProvider;
2694
2695        // Single provider returns empty string (score=0.0, should_escalate may be true/false).
2696        // With quality_threshold=0.0 it won't escalate, so we can check the return value.
2697        let p = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2698        let cfg = CascadeRouterConfig {
2699            quality_threshold: 0.0,
2700            ..Default::default()
2701        };
2702        let r = RouterProvider::new(vec![p]).with_cascade(cfg);
2703        let msgs = vec![Message::from_legacy(Role::User, "hi")];
2704        // The provider returns "" — cascade must return it as-is (no best_seen involved
2705        // with a single provider), but this test confirms "" is not stored when escalating.
2706        let result = r.chat(&msgs).await;
2707        assert!(result.is_ok());
2708        assert_eq!(result.unwrap(), "");
2709    }
2710
2711    /// When provider 1 returns empty and provider 2 fails, `best_seen` must not hold
2712    /// the empty string — the caller must get an error, not a silent empty response.
2713    #[tokio::test]
2714    async fn cascade_empty_best_seen_not_returned_on_all_fail() {
2715        use crate::mock::MockProvider;
2716
2717        // p1: returns empty string (causes escalation with default threshold)
2718        // p2: hard error
2719        let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2720        let p2 = AnyProvider::Mock(MockProvider::failing());
2721
2722        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2723        let msgs = vec![Message::from_legacy(Role::User, "hi")];
2724        let result = r.chat(&msgs).await;
2725        // best_seen must NOT be the empty string; error must propagate.
2726        assert!(
2727            result.is_err(),
2728            "expected error, not silent empty string; got: {result:?}"
2729        );
2730    }
2731
2732    /// Stream variant: empty string from early provider must not be stored as `best_seen`.
2733    #[tokio::test]
2734    async fn cascade_stream_empty_response_not_stored_as_best_seen() {
2735        use crate::mock::MockProvider;
2736
2737        // p1 (early): returns "" — should NOT be stored as best_seen.
2738        // p2 (last): returns a real response.
2739        let p1 = AnyProvider::Mock(MockProvider::with_responses(vec![String::new()]));
2740        let p2 = AnyProvider::Mock(
2741            MockProvider::with_responses(vec!["real answer".to_owned()]).with_streaming(),
2742        );
2743
2744        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig::default());
2745        let msgs = vec![Message::from_legacy(Role::User, "hi")];
2746        let stream = r.chat_stream(&msgs).await.expect("should not error");
2747        let text = collect_stream(stream).await.expect("stream should succeed");
2748        assert_eq!(text, "real answer");
2749    }
2750
2751    // ── Arc<[AnyProvider]> + cost_tiers tests ──────────────────────────────────
2752
2753    #[test]
2754    fn arc_providers_clone_shares_allocation() {
2755        use crate::mock::MockProvider;
2756        let p = AnyProvider::Mock(MockProvider::default());
2757        let r = RouterProvider::new(vec![p]);
2758        let c = r.clone();
2759        // Both RouterProvider instances must share the same Arc allocation.
2760        assert!(Arc::ptr_eq(&r.providers, &c.providers));
2761    }
2762
2763    #[test]
2764    fn cost_tiers_reorders_providers_at_construction() {
2765        use crate::mock::MockProvider;
2766        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2767        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2768        let p3 = AnyProvider::Mock(MockProvider::default().with_name("openai"));
2769        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2770            cost_tiers: Some(vec!["ollama".into(), "claude".into()]),
2771            ..CascadeRouterConfig::default()
2772        });
2773        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2774        // ollama first (tier 0), claude second (tier 1), openai last (unlisted, original idx 2)
2775        assert_eq!(names, vec!["ollama", "claude", "openai"]);
2776    }
2777
2778    #[test]
2779    fn cost_tiers_none_preserves_chain_order() {
2780        use crate::mock::MockProvider;
2781        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2782        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2783        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2784            cost_tiers: None,
2785            ..CascadeRouterConfig::default()
2786        });
2787        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2788        assert_eq!(names, vec!["claude", "ollama"]);
2789    }
2790
2791    #[test]
2792    fn cost_tiers_empty_vec_preserves_chain_order() {
2793        use crate::mock::MockProvider;
2794        let p1 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2795        let p2 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2796        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2797            cost_tiers: Some(vec![]),
2798            ..CascadeRouterConfig::default()
2799        });
2800        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2801        assert_eq!(names, vec!["claude", "ollama"]);
2802    }
2803
2804    #[test]
2805    fn cost_tiers_unknown_name_ignored() {
2806        use crate::mock::MockProvider;
2807        let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2808        let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2809        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2810            cost_tiers: Some(vec!["nonexistent".into(), "ollama".into()]),
2811            ..CascadeRouterConfig::default()
2812        });
2813        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2814        // "nonexistent" ignored; "ollama" is tier 1 → first; "claude" unlisted → second
2815        assert_eq!(names, vec!["ollama", "claude"]);
2816    }
2817
2818    #[test]
2819    fn cost_tiers_all_providers_listed() {
2820        use crate::mock::MockProvider;
2821        let p1 = AnyProvider::Mock(MockProvider::default().with_name("c"));
2822        let p2 = AnyProvider::Mock(MockProvider::default().with_name("b"));
2823        let p3 = AnyProvider::Mock(MockProvider::default().with_name("a"));
2824        let r = RouterProvider::new(vec![p1, p2, p3]).with_cascade(CascadeRouterConfig {
2825            cost_tiers: Some(vec!["a".into(), "b".into(), "c".into()]),
2826            ..CascadeRouterConfig::default()
2827        });
2828        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2829        assert_eq!(names, vec!["a", "b", "c"]);
2830    }
2831
2832    #[test]
2833    fn cost_tiers_duplicate_name_uses_last_position() {
2834        use crate::mock::MockProvider;
2835        let p1 = AnyProvider::Mock(MockProvider::default().with_name("ollama"));
2836        let p2 = AnyProvider::Mock(MockProvider::default().with_name("claude"));
2837        // "ollama" appears twice in tiers: HashMap overwrites → position 2.
2838        // claude=tier 0, ollama=tier 2 → claude before ollama.
2839        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2840            cost_tiers: Some(vec!["claude".into(), "ollama".into(), "ollama".into()]),
2841            ..CascadeRouterConfig::default()
2842        });
2843        let names: Vec<&str> = r.providers.iter().map(LlmProvider::name).collect();
2844        assert_eq!(names, vec!["claude", "ollama"]);
2845    }
2846
2847    #[test]
2848    fn cost_tiers_empty_router_does_not_panic() {
2849        let r = RouterProvider::new(vec![]).with_cascade(CascadeRouterConfig {
2850            cost_tiers: Some(vec!["foo".into()]),
2851            ..CascadeRouterConfig::default()
2852        });
2853        assert_eq!(r.providers.len(), 0);
2854    }
2855
2856    #[test]
2857    fn set_status_tx_works_with_arc() {
2858        use crate::mock::MockProvider;
2859        let p = AnyProvider::Mock(MockProvider::default());
2860        let mut r = RouterProvider::new(vec![p]);
2861        let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
2862        r.set_status_tx(tx); // must not panic
2863    }
2864
2865    #[tokio::test]
2866    async fn cascade_chat_with_tools_unaffected_by_cost_tiers() {
2867        use crate::mock::MockProvider;
2868        // chat_with_tools skips cascade entirely (HIGH-04). Verify that cost_tiers
2869        // ordering does not accidentally affect the non-cascade tool fallback path.
2870        let p1 = AnyProvider::Mock(MockProvider::failing().with_name("cheap"));
2871        let p2 = AnyProvider::Mock(MockProvider::failing().with_name("expensive"));
2872        let r = RouterProvider::new(vec![p1, p2]).with_cascade(CascadeRouterConfig {
2873            cost_tiers: Some(vec!["cheap".into()]),
2874            ..CascadeRouterConfig::default()
2875        });
2876        let msgs = vec![Message::from_legacy(Role::User, "hi")];
2877        // Both providers fail → NoProviders, not a cascade-specific error.
2878        let err = r.chat_with_tools(&msgs, &[]).await.unwrap_err();
2879        assert!(matches!(err, LlmError::NoProviders));
2880    }
2881
2882    // ── Embed retry / rate-limit tests ────────────────────────────────────────
2883
2884    /// Provider returns `RateLimited` twice then succeeds on the third attempt.
2885    /// The router must retry and return the successful embedding.
2886    #[tokio::test]
2887    async fn embed_retries_on_rate_limited_then_succeeds() {
2888        use crate::mock::MockProvider;
2889
2890        let p = AnyProvider::Mock({
2891            let mut m = MockProvider::default()
2892                .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
2893                .with_name("p1");
2894            m.supports_embeddings = true;
2895            m.embedding = vec![0.1, 0.2];
2896            m
2897        });
2898        let r = RouterProvider::new(vec![p]);
2899        let result = r.embed("text").await.unwrap();
2900        assert_eq!(result, vec![0.1, 0.2]);
2901    }
2902
2903    /// When all retries (3) are exhausted on the first provider, the router falls
2904    /// back to the second provider and returns its embedding.
2905    #[tokio::test]
2906    async fn embed_falls_back_after_all_retries_exhausted() {
2907        use crate::mock::MockProvider;
2908
2909        // p1: 4 RateLimited errors (attempt 0..=3 all fail)
2910        let p1 = AnyProvider::Mock({
2911            let mut m = MockProvider::default()
2912                .with_errors(vec![
2913                    LlmError::RateLimited,
2914                    LlmError::RateLimited,
2915                    LlmError::RateLimited,
2916                    LlmError::RateLimited,
2917                ])
2918                .with_name("p1");
2919            m.supports_embeddings = true;
2920            m
2921        });
2922        let p2 = AnyProvider::Mock({
2923            let mut m = MockProvider::default().with_name("p2");
2924            m.supports_embeddings = true;
2925            m.embedding = vec![9.0, 8.0];
2926            m
2927        });
2928        let r = RouterProvider::new(vec![p1, p2]);
2929        let result = r.embed("text").await.unwrap();
2930        assert_eq!(result, vec![9.0, 8.0]);
2931    }
2932
2933    /// Provider returns `RateLimited` twice then succeeds via `embed_batch`.
2934    #[tokio::test]
2935    async fn embed_batch_retries_on_rate_limited_then_succeeds() {
2936        use crate::mock::MockProvider;
2937
2938        let p = AnyProvider::Mock({
2939            let mut m = MockProvider::default()
2940                .with_errors(vec![LlmError::RateLimited, LlmError::RateLimited])
2941                .with_name("p1");
2942            m.supports_embeddings = true;
2943            m.embedding = vec![0.5, 0.6];
2944            m
2945        });
2946        let r = RouterProvider::new(vec![p]);
2947        let result = r.embed_batch(&["a", "b"]).await.unwrap();
2948        assert_eq!(result, vec![vec![0.5, 0.6], vec![0.5, 0.6]]);
2949    }
2950
2951    /// When all `embed_batch` retries are exhausted on the first provider, falls back
2952    /// to the second provider.
2953    #[tokio::test]
2954    async fn embed_batch_falls_back_after_all_retries_exhausted() {
2955        use crate::mock::MockProvider;
2956
2957        // p1 needs 4 errors per embed call * 1 text = 4 total (attempt 0..=3)
2958        let p1 = AnyProvider::Mock({
2959            let mut m = MockProvider::default()
2960                .with_errors(vec![
2961                    LlmError::RateLimited,
2962                    LlmError::RateLimited,
2963                    LlmError::RateLimited,
2964                    LlmError::RateLimited,
2965                ])
2966                .with_name("p1");
2967            m.supports_embeddings = true;
2968            m
2969        });
2970        let p2 = AnyProvider::Mock({
2971            let mut m = MockProvider::default().with_name("p2");
2972            m.supports_embeddings = true;
2973            m.embedding = vec![7.0, 8.0];
2974            m
2975        });
2976        let r = RouterProvider::new(vec![p1, p2]);
2977        let result = r.embed_batch(&["x"]).await.unwrap();
2978        assert_eq!(result, vec![vec![7.0, 8.0]]);
2979    }
2980
2981    // ── InvalidInput embed break tests ────────────────────────────────────────
2982
2983    /// When a provider returns `InvalidInput` from `embed()`, the router must break
2984    /// the fallback loop immediately and return `InvalidInput` — not `NoProviders`.
2985    #[tokio::test]
2986    async fn embed_invalid_input_breaks_loop_and_returns_invalid_input() {
2987        use crate::mock::MockProvider;
2988
2989        let p = AnyProvider::Mock(MockProvider::default().with_embed_invalid_input());
2990        let r = RouterProvider::new(vec![p]).with_thompson(None);
2991        let err = r.embed("some text").await.unwrap_err();
2992        assert!(
2993            matches!(err, LlmError::InvalidInput { .. }),
2994            "expected InvalidInput, got {err:?}"
2995        );
2996    }
2997
2998    /// When a provider returns `InvalidInput`, the router must NOT fall through to
2999    /// the next provider — a second embed-capable provider must never be called.
3000    #[tokio::test]
3001    async fn embed_invalid_input_does_not_fall_through_to_second_provider() {
3002        use crate::mock::MockProvider;
3003
3004        // p1 returns InvalidInput; p2 is a functioning embed provider.
3005        // If the loop falls through, p2 returns Ok — which would mean the error was
3006        // swallowed instead of breaking immediately.
3007        let p1 = AnyProvider::Mock(
3008            MockProvider::default()
3009                .with_embed_invalid_input()
3010                .with_name("p1"),
3011        );
3012        let p2 = AnyProvider::Mock({
3013            let mut m = MockProvider::default();
3014            m.supports_embeddings = true;
3015            m.name_override = Some("p2".into());
3016            m
3017        });
3018
3019        let r = RouterProvider::new(vec![p1, p2]);
3020        let err = r.embed("test").await.unwrap_err();
3021
3022        // The error must carry p1's name, proving p2 was never reached.
3023        assert!(
3024            matches!(&err, LlmError::InvalidInput { provider, .. } if provider == "p1"),
3025            "expected InvalidInput from p1, got {err:?}"
3026        );
3027    }
3028
3029    /// The router skips providers that do not support embeddings and continues to
3030    /// the next one, returning a successful result from the first capable provider.
3031    #[tokio::test]
3032    async fn embed_skips_non_embedding_providers_and_falls_through() {
3033        use crate::mock::MockProvider;
3034
3035        // p1 does not support embeddings — skipped by the loop guard.
3036        // p2 supports embeddings and returns successfully.
3037        let p1 = AnyProvider::Mock({
3038            let mut m = MockProvider::default().with_name("p1");
3039            m.supports_embeddings = false;
3040            m
3041        });
3042        let p2 = AnyProvider::Mock({
3043            let mut m = MockProvider::default().with_name("p2");
3044            m.supports_embeddings = true;
3045            m.embedding = vec![1.0, 2.0, 3.0];
3046            m
3047        });
3048
3049        let r = RouterProvider::new(vec![p1, p2]);
3050        let result = r.embed("hello").await.unwrap();
3051        assert_eq!(result, vec![1.0, 2.0, 3.0]);
3052    }
3053
3054    /// `InvalidInput` from embed does not call `record_availability` (no reputation penalty).
3055    /// We verify this indirectly: `thompson_stats` must show no entry for the provider
3056    /// after an `InvalidInput` embed, whereas a normal embed failure increments it.
3057    #[tokio::test]
3058    async fn embed_invalid_input_does_not_record_availability() {
3059        use crate::mock::MockProvider;
3060
3061        let p = AnyProvider::Mock(
3062            MockProvider::default()
3063                .with_embed_invalid_input()
3064                .with_name("test-provider"),
3065        );
3066        let r = RouterProvider::new(vec![p]).with_thompson(None);
3067        let _ = r.embed("text").await;
3068
3069        // record_availability is only called on success or generic error,
3070        // not on InvalidInput. So thompson_stats must have no entry for "test-provider".
3071        let stats = r.thompson_stats();
3072        let provider_in_stats = stats.iter().any(|(name, ..)| name == "test-provider");
3073        assert!(
3074            !provider_in_stats,
3075            "InvalidInput must not update provider reputation; stats: {stats:?}"
3076        );
3077    }
3078
3079    // ── quality_gate tests ────────────────────────────────────────────────────
3080
3081    /// `with_quality_gate()` happy path: when cosine similarity >= threshold the
3082    /// response is returned directly without falling back.
3083    #[tokio::test]
3084    async fn quality_gate_passes_when_similarity_above_threshold() {
3085        use crate::mock::MockProvider;
3086
3087        // p1 returns a response; embed returns a unit vector so cosine similarity
3088        // with itself is 1.0 (>= any reasonable threshold).
3089        let p1 = AnyProvider::Mock({
3090            let mut m = MockProvider::with_responses(vec!["answer".to_owned()]).with_name("p1");
3091            m.supports_embeddings = true;
3092            m.embedding = vec![1.0, 0.0];
3093            m
3094        });
3095        let r = RouterProvider::new(vec![p1])
3096            .with_thompson(None)
3097            .with_quality_gate(0.5);
3098        let msgs = vec![Message::from_legacy(Role::User, "question")];
3099        let result = r.chat(&msgs).await.unwrap();
3100        assert_eq!(result, "answer");
3101    }
3102
3103    /// `with_quality_gate()` exhaustion: when all providers fail the gate the router
3104    /// returns the best-seen response (highest similarity) rather than an error.
3105    #[tokio::test]
3106    async fn quality_gate_exhaustion_returns_best_seen() {
3107        use crate::mock::MockProvider;
3108
3109        // p1 returns a response but embedding similarity is 0.0 (orthogonal vectors)
3110        // so it fails the gate (0.0 < 0.9). p2 fails entirely.
3111        // Expected: best_seen from p1 is returned.
3112        let p1 = AnyProvider::Mock({
3113            let mut m =
3114                MockProvider::with_responses(vec!["best_so_far".to_owned()]).with_name("p1");
3115            m.supports_embeddings = true;
3116            // query embed = [1,0], response embed = [0,1] → similarity = 0.0
3117            m.embedding = vec![0.0, 1.0];
3118            m
3119        });
3120        let p2 = AnyProvider::Mock(MockProvider::failing().with_name("p2"));
3121        let r = RouterProvider::new(vec![p1, p2])
3122            .with_thompson(None)
3123            .with_quality_gate(0.9);
3124        let msgs = vec![Message::from_legacy(Role::User, "question")];
3125        let result = r.chat(&msgs).await.unwrap();
3126        assert_eq!(result, "best_so_far");
3127    }
3128
3129    // ── apply_routing_signals guard logic tests ───────────────────────────────
3130
3131    /// `quality_gate = 5.0` (> 1.0) must be silently ignored — the field is left
3132    /// as `None` and no panic occurs.
3133    #[test]
3134    fn routing_signals_quality_gate_above_one_is_ignored() {
3135        // Build a RouterProvider directly and check that with_quality_gate is only
3136        // called for in-range values by replicating the guard from provider.rs.
3137        let threshold: f32 = 5.0;
3138        let mut router = RouterProvider::new(vec![]);
3139        if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3140            router = router.with_quality_gate(threshold);
3141        }
3142        assert!(
3143            router.quality_gate.is_none(),
3144            "out-of-range quality_gate must not be wired; got {:?}",
3145            router.quality_gate
3146        );
3147    }
3148
3149    /// `quality_gate = 0.8` (valid) must be wired into the router.
3150    #[test]
3151    fn routing_signals_quality_gate_valid_is_wired() {
3152        let threshold: f32 = 0.8;
3153        let mut router = RouterProvider::new(vec![]);
3154        if threshold.is_finite() && threshold > 0.0 && threshold <= 1.0 {
3155            router = router.with_quality_gate(threshold);
3156        }
3157        assert_eq!(
3158            router.quality_gate,
3159            Some(0.8),
3160            "valid quality_gate must be wired"
3161        );
3162    }
3163
3164    // --- ASI debounce tests ---
3165
3166    #[test]
3167    fn asi_debounce_same_turn_fires_once() {
3168        let router = RouterProvider::new(vec![]);
3169        let turn_id = 42u64;
3170
3171        // First call: prev == u64::MAX (initial) → not equal to turn_id → proceeds (returns false)
3172        let prev1 = router.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3173        let first_dropped = prev1 == turn_id;
3174
3175        // Second call same turn: prev == turn_id → dropped
3176        let prev2 = router.asi_last_turn.swap(turn_id, Ordering::AcqRel);
3177        let second_dropped = prev2 == turn_id;
3178
3179        assert!(!first_dropped, "first call in turn must not be dropped");
3180        assert!(second_dropped, "second call in same turn must be dropped");
3181    }
3182
3183    #[test]
3184    fn asi_debounce_next_turn_fires_again() {
3185        let router = RouterProvider::new(vec![]);
3186
3187        // Simulate turn 1
3188        let prev1 = router.asi_last_turn.swap(1u64, Ordering::AcqRel);
3189        assert_ne!(prev1, 1u64, "turn 1: initial value != 1, should proceed");
3190
3191        // Simulate turn 2 — different turn_id
3192        let prev2 = router.asi_last_turn.swap(2u64, Ordering::AcqRel);
3193        let dropped = prev2 == 2u64;
3194        assert!(!dropped, "turn 2 must not be dropped (different turn_id)");
3195    }
3196
3197    #[test]
3198    fn turn_counter_increments_across_clones() {
3199        let router = RouterProvider::new(vec![]);
3200        let clone = router.clone();
3201
3202        let t0 = router.turn_counter.fetch_add(1, Ordering::Relaxed);
3203        let t1 = clone.turn_counter.fetch_add(1, Ordering::Relaxed);
3204
3205        // Both clones share the same Arc<AtomicU64>
3206        assert_eq!(t1, t0 + 1, "cloned router shares turn_counter");
3207    }
3208
3209    #[test]
3210    fn with_embed_concurrency_zero_means_no_semaphore() {
3211        let r = RouterProvider::new(vec![]).with_embed_concurrency(0);
3212        assert!(r.embed_semaphore.is_none(), "0 should disable semaphore");
3213    }
3214
3215    #[test]
3216    fn with_embed_concurrency_positive_creates_semaphore() {
3217        let r = RouterProvider::new(vec![]).with_embed_concurrency(4);
3218        let sem = r.embed_semaphore.as_ref().expect("semaphore should exist");
3219        assert_eq!(sem.available_permits(), 4);
3220    }
3221
3222    #[tokio::test]
3223    async fn embed_semaphore_limits_concurrency() {
3224        use std::sync::Arc as StdArc;
3225        use std::sync::atomic::{AtomicUsize, Ordering as AO};
3226
3227        // Use a semaphore with 2 permits. Verify that at most 2 concurrent
3228        // tasks can hold the permit at the same time.
3229        let sem = Arc::new(tokio::sync::Semaphore::new(2));
3230        let concurrent_peak = StdArc::new(AtomicUsize::new(0));
3231        let active = StdArc::new(AtomicUsize::new(0));
3232
3233        let mut handles = vec![];
3234        for _ in 0..6 {
3235            let sem_clone = sem.clone();
3236            let peak = concurrent_peak.clone();
3237            let active = active.clone();
3238            handles.push(tokio::spawn(async move {
3239                let _permit = sem_clone.acquire().await.unwrap();
3240                let cur = active.fetch_add(1, AO::SeqCst) + 1;
3241                // Track peak concurrent usage.
3242                let mut p = peak.load(AO::SeqCst);
3243                while p < cur {
3244                    match peak.compare_exchange(p, cur, AO::SeqCst, AO::SeqCst) {
3245                        Ok(_) => break,
3246                        Err(new) => p = new,
3247                    }
3248                }
3249                tokio::time::sleep(std::time::Duration::from_millis(5)).await;
3250                active.fetch_sub(1, AO::SeqCst);
3251            }));
3252        }
3253        for h in handles {
3254            h.await.unwrap();
3255        }
3256        assert!(
3257            concurrent_peak.load(AO::SeqCst) <= 2,
3258            "peak concurrency should not exceed semaphore limit"
3259        );
3260    }
3261}