Skip to main content

vellaveto_http_proxy/proxy/
gateway.rs

1// Copyright 2026 Paolo Vella
2// SPDX-License-Identifier: BUSL-1.1
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE-BSL-1.1 file at the root of this repository.
6//
7// Change Date: Three years from the date of publication of this version.
8// Change License: MPL-2.0
9
10//! MCP Gateway Router — multi-backend tool routing with health tracking (Phase 20).
11//!
12//! The gateway routes tool calls to different upstream MCP servers based on
13//! tool name prefix matching. It maintains per-backend health state with
14//! configurable failure/success thresholds and supports session affinity.
15//!
16//! # Design Decisions
17//!
18//! - **Fail-closed**: When all matching backends are unhealthy, `route()` returns
19//!   `None` and the caller must deny the request.
20//! - **Longest prefix match**: Tool names are matched against prefixes sorted by
21//!   length descending, so `file_system_` beats `file_`.
22//! - **No rewrite of forwarding**: The router only resolves "which URL" — the
23//!   existing `forward_to_upstream()` function handles the actual HTTP request.
24
25use super::fallback;
26use std::collections::HashMap;
27use std::sync::RwLock;
28use vellaveto_config::{BackendConfig, GatewayConfig};
29use vellaveto_types::{BackendHealth, RoutingDecision, ToolConflict};
30
31/// Maximum tool name length considered for routing.
32/// Tool names longer than this are truncated before prefix matching.
33const MAX_TOOL_NAME_LEN: usize = 256;
34
35/// Internal health state for a single backend.
36#[derive(Debug)]
37struct BackendState {
38    url: String,
39    health: BackendHealth,
40    consecutive_failures: u32,
41    consecutive_successes: u32,
42}
43
44/// MCP Gateway Router.
45///
46/// Routes tool calls to upstream backends based on tool name prefix matching,
47/// with health-aware failover and session affinity support.
48impl std::fmt::Debug for GatewayRouter {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("GatewayRouter")
51            .field("prefix_table_len", &self.prefix_table.len())
52            .field("default_backend_id", &self.default_backend_id)
53            .field("unhealthy_threshold", &self.unhealthy_threshold)
54            .field("healthy_threshold", &self.healthy_threshold)
55            .finish()
56    }
57}
58
59pub struct GatewayRouter {
60    /// Per-backend mutable health state, keyed by backend ID.
61    states: RwLock<HashMap<String, BackendState>>,
62    /// Prefix→backend_id routing table, sorted longest-first for greedy matching.
63    prefix_table: Vec<(String, String)>,
64    /// Backend ID of the default (catch-all) backend, if any.
65    default_backend_id: Option<String>,
66    /// Number of consecutive failures before marking a backend unhealthy.
67    unhealthy_threshold: u32,
68    /// Number of consecutive successes before restoring a backend to healthy.
69    healthy_threshold: u32,
70    /// Original backend configs, keyed by backend ID (Phase 29).
71    /// Used to retrieve `transport_urls` for cross-transport fallback.
72    backend_configs: HashMap<String, BackendConfig>,
73}
74
75impl GatewayRouter {
76    /// Build a router from gateway configuration.
77    ///
78    /// Returns an error if the configuration is structurally invalid
79    /// (duplicate IDs, empty backends, etc.).
80    pub fn from_config(config: &GatewayConfig) -> Result<Self, String> {
81        if config.backends.is_empty() {
82            return Err("gateway requires at least one backend".to_string());
83        }
84
85        let mut states = HashMap::new();
86        let mut prefix_table = Vec::new();
87        let mut default_backend_id = None;
88        let mut seen_ids = std::collections::HashSet::new();
89
90        for backend in &config.backends {
91            if !seen_ids.insert(&backend.id) {
92                return Err(format!("duplicate backend id '{}'", backend.id));
93            }
94
95            states.insert(
96                backend.id.clone(),
97                BackendState {
98                    url: backend.url.clone(),
99                    health: BackendHealth::Healthy,
100                    consecutive_failures: 0,
101                    consecutive_successes: 0,
102                },
103            );
104
105            if backend.tool_prefixes.is_empty() {
106                if default_backend_id.is_some() {
107                    return Err(
108                        "multiple default backends (empty tool_prefixes); at most one allowed"
109                            .to_string(),
110                    );
111                }
112                default_backend_id = Some(backend.id.clone());
113            } else {
114                for prefix in &backend.tool_prefixes {
115                    prefix_table.push((prefix.clone(), backend.id.clone()));
116                }
117            }
118        }
119
120        // Sort by prefix length descending for longest-prefix-first matching
121        prefix_table.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
122
123        let backend_configs: HashMap<String, BackendConfig> = config
124            .backends
125            .iter()
126            .map(|b| (b.id.clone(), b.clone()))
127            .collect();
128
129        Ok(Self {
130            states: RwLock::new(states),
131            prefix_table,
132            default_backend_id,
133            unhealthy_threshold: config.unhealthy_threshold,
134            healthy_threshold: config.healthy_threshold,
135            backend_configs,
136        })
137    }
138
139    /// Route a tool call to the appropriate backend.
140    ///
141    /// Returns `None` when no healthy backend matches (fail-closed).
142    pub fn route(&self, tool_name: &str) -> Option<RoutingDecision> {
143        // SECURITY (FIND-R43-002): Truncate excessively long tool names using
144        // floor_char_boundary to avoid panicking on multi-byte UTF-8 boundaries.
145        let tool_name = if tool_name.len() > MAX_TOOL_NAME_LEN {
146            // Find the last valid UTF-8 char boundary at or before MAX_TOOL_NAME_LEN.
147            let mut end = MAX_TOOL_NAME_LEN;
148            while end > 0 && !tool_name.is_char_boundary(end) {
149                end -= 1;
150            }
151            &tool_name[..end]
152        } else {
153            tool_name
154        };
155
156        // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
157        let states = match self.states.read() {
158            Ok(guard) => guard,
159            Err(e) => {
160                tracing::error!("Gateway states RwLock poisoned in route(): {}", e);
161                return None; // fail-closed: no healthy backend
162            }
163        };
164
165        // Try longest-prefix match first
166        for (prefix, backend_id) in &self.prefix_table {
167            if tool_name.starts_with(prefix.as_str()) {
168                if let Some(state) = states.get(backend_id) {
169                    if state.health != BackendHealth::Unhealthy {
170                        return Some(RoutingDecision {
171                            backend_id: backend_id.clone(),
172                            upstream_url: state.url.clone(),
173                        });
174                    }
175                }
176            }
177        }
178
179        // Fall back to default backend
180        if let Some(ref default_id) = self.default_backend_id {
181            if let Some(state) = states.get(default_id) {
182                if state.health != BackendHealth::Unhealthy {
183                    return Some(RoutingDecision {
184                        backend_id: default_id.clone(),
185                        upstream_url: state.url.clone(),
186                    });
187                }
188            }
189        }
190
191        None // fail-closed
192    }
193
194    /// Route with session affinity — prefer a previously used backend if healthy.
195    ///
196    /// `session_affinities` maps tool_name → backend_id from prior routing decisions.
197    pub fn route_with_affinity(
198        &self,
199        tool_name: &str,
200        session_affinities: &HashMap<String, String>,
201    ) -> Option<RoutingDecision> {
202        // Check if there's a session-affine backend for this tool
203        if let Some(affine_id) = session_affinities.get(tool_name) {
204            // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
205            let states = match self.states.read() {
206                Ok(guard) => guard,
207                Err(e) => {
208                    tracing::error!(
209                        "Gateway states RwLock poisoned in route_with_affinity(): {}",
210                        e
211                    );
212                    return None; // fail-closed: no healthy backend
213                }
214            };
215            if let Some(state) = states.get(affine_id.as_str()) {
216                if state.health != BackendHealth::Unhealthy {
217                    return Some(RoutingDecision {
218                        backend_id: affine_id.clone(),
219                        upstream_url: state.url.clone(),
220                    });
221                }
222            }
223            // Affine backend is unhealthy — fall through to normal routing
224        }
225
226        self.route(tool_name)
227    }
228
229    /// Record a successful response from a backend.
230    ///
231    /// Transitions: Unhealthy→Degraded (after 1 success), Degraded→Healthy
232    /// (after `healthy_threshold` consecutive successes).
233    pub fn record_success(&self, backend_id: &str) {
234        // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
235        let mut states = match self.states.write() {
236            Ok(guard) => guard,
237            Err(e) => {
238                tracing::error!("Gateway states RwLock poisoned in record_success(): {}", e);
239                return;
240            }
241        };
242        if let Some(state) = states.get_mut(backend_id) {
243            state.consecutive_failures = 0;
244
245            match state.health {
246                BackendHealth::Unhealthy => {
247                    // IMP-R118-011: Reset to 1 directly (Unhealthy resets counter).
248                    state.health = BackendHealth::Degraded;
249                    state.consecutive_successes = 1;
250                    tracing::info!(
251                        backend = %backend_id,
252                        "Gateway backend transitioning: unhealthy → degraded"
253                    );
254                }
255                BackendHealth::Degraded => {
256                    // SECURITY (FIND-R43-009): Use saturating_add to prevent integer overflow.
257                    state.consecutive_successes = state.consecutive_successes.saturating_add(1);
258                    if state.consecutive_successes >= self.healthy_threshold {
259                        state.health = BackendHealth::Healthy;
260                        state.consecutive_successes = 0;
261                        tracing::info!(
262                            backend = %backend_id,
263                            "Gateway backend transitioning: degraded → healthy"
264                        );
265                    }
266                }
267                BackendHealth::Healthy => {
268                    state.consecutive_successes = state.consecutive_successes.saturating_add(1);
269                }
270            }
271        }
272    }
273
274    /// Record a failed response from a backend.
275    ///
276    /// Marks the backend as Unhealthy after `unhealthy_threshold` consecutive failures.
277    pub fn record_failure(&self, backend_id: &str) {
278        // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
279        let mut states = match self.states.write() {
280            Ok(guard) => guard,
281            Err(e) => {
282                tracing::error!("Gateway states RwLock poisoned in record_failure(): {}", e);
283                return;
284            }
285        };
286        if let Some(state) = states.get_mut(backend_id) {
287            state.consecutive_successes = 0;
288            // SECURITY (FIND-R43-009): Use saturating_add to prevent integer overflow.
289            state.consecutive_failures = state.consecutive_failures.saturating_add(1);
290
291            if state.consecutive_failures >= self.unhealthy_threshold
292                && state.health != BackendHealth::Unhealthy
293            {
294                tracing::warn!(
295                    backend = %backend_id,
296                    failures = state.consecutive_failures,
297                    "Gateway backend transitioning: {} → unhealthy",
298                    match state.health {
299                        BackendHealth::Healthy => "healthy",
300                        BackendHealth::Degraded => "degraded",
301                        BackendHealth::Unhealthy => "unhealthy",
302                    }
303                );
304                state.health = BackendHealth::Unhealthy;
305            }
306        }
307    }
308
309    /// Look up the original backend configuration by ID (Phase 29).
310    ///
311    /// Used by cross-transport fallback to retrieve `transport_urls`.
312    pub fn backend_config(&self, backend_id: &str) -> Option<&BackendConfig> {
313        self.backend_configs.get(backend_id)
314    }
315
316    /// Return a snapshot of all backend health states.
317    pub fn backend_health(&self) -> Vec<(String, String, BackendHealth)> {
318        // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
319        let states = match self.states.read() {
320            Ok(guard) => guard,
321            Err(e) => {
322                tracing::error!("Gateway states RwLock poisoned in backend_health(): {}", e);
323                return Vec::new();
324            }
325        };
326        states
327            .iter()
328            .map(|(id, state)| (id.clone(), state.url.clone(), state.health))
329            .collect()
330    }
331
332    /// Number of configured backends.
333    pub fn backend_count(&self) -> usize {
334        // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
335        let states = match self.states.read() {
336            Ok(guard) => guard,
337            Err(e) => {
338                tracing::error!("Gateway states RwLock poisoned in backend_count(): {}", e);
339                return 0;
340            }
341        };
342        states.len()
343    }
344}
345
346/// Maximum number of tool names per backend to prevent OOM from malicious backends.
347const MAX_TOOL_NAMES_PER_BACKEND: usize = 10_000;
348
349/// Data about tools discovered from a single backend.
350pub struct DiscoveredTools {
351    pub backend_id: String,
352    pub tool_names: Vec<String>,
353}
354
355impl DiscoveredTools {
356    /// SECURITY (IMP-R118-005): Validate bounds on tool_names.
357    pub fn validate(&self) -> Result<(), String> {
358        if self.tool_names.len() > MAX_TOOL_NAMES_PER_BACKEND {
359            return Err(format!(
360                "DiscoveredTools.tool_names count {} exceeds max {}",
361                self.tool_names.len(),
362                MAX_TOOL_NAMES_PER_BACKEND,
363            ));
364        }
365        for name in &self.tool_names {
366            if name.len() > MAX_TOOL_NAME_LEN {
367                return Err(format!(
368                    "tool name length {} exceeds max {}",
369                    name.len(),
370                    MAX_TOOL_NAME_LEN,
371                ));
372            }
373        }
374        Ok(())
375    }
376}
377
378/// Detect tool name conflicts across multiple backends.
379///
380/// Returns a list of tool names that are advertised by more than one backend.
381pub fn detect_conflicts(discovered: &[DiscoveredTools]) -> Vec<ToolConflict> {
382    let mut tool_map: HashMap<&str, Vec<&str>> = HashMap::new();
383    for dt in discovered {
384        for name in &dt.tool_names {
385            tool_map
386                .entry(name.as_str())
387                .or_default()
388                .push(dt.backend_id.as_str());
389        }
390    }
391    tool_map
392        .into_iter()
393        .filter(|(_, backends)| backends.len() > 1)
394        .map(|(tool_name, backends)| ToolConflict {
395            tool_name: tool_name.to_string(),
396            backends: backends.into_iter().map(String::from).collect(),
397        })
398        .collect()
399}
400
401/// Spawn a background health checker that periodically pings backends.
402///
403/// Sends a JSON-RPC `ping` request to each backend URL and records
404/// success/failure based on the response status.
405pub fn spawn_health_checker(
406    gateway: std::sync::Arc<GatewayRouter>,
407    client: reqwest::Client,
408    interval_secs: u64,
409) -> tokio::task::JoinHandle<()> {
410    tokio::spawn(async move {
411        let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
412        // Skip the first immediate tick
413        interval.tick().await;
414
415        loop {
416            interval.tick().await;
417
418            let backends: Vec<(String, String)> = {
419                // SECURITY (FIND-041-011): Fail-closed on RwLock poisoning.
420                let states = match gateway.states.read() {
421                    Ok(guard) => guard,
422                    Err(e) => {
423                        tracing::error!("Gateway states RwLock poisoned in health_checker: {}", e);
424                        continue; // skip this health check cycle
425                    }
426                };
427                states
428                    .iter()
429                    .map(|(id, state)| (id.clone(), state.url.clone()))
430                    .collect()
431            };
432
433            for (backend_id, url) in &backends {
434                let ping_payload = serde_json::json!({
435                    "jsonrpc": "2.0",
436                    "id": "health",
437                    "method": "ping"
438                });
439                let ping_body = match serde_json::to_vec(&ping_payload) {
440                    Ok(body) => body,
441                    Err(e) => {
442                        tracing::error!(
443                            backend = %backend_id,
444                            error = %e,
445                            "Gateway health check: failed to serialize ping payload"
446                        );
447                        gateway.record_failure(backend_id);
448                        continue;
449                    }
450                };
451                let mut headers = reqwest::header::HeaderMap::new();
452                headers.insert(
453                    reqwest::header::CONTENT_TYPE,
454                    reqwest::header::HeaderValue::from_static("application/json"),
455                );
456
457                let result = fallback::forward_with_fallback(
458                    &client,
459                    url,
460                    bytes::Bytes::from(ping_body),
461                    &headers,
462                    0,
463                    std::time::Duration::from_secs(5),
464                )
465                .await;
466
467                match result {
468                    Ok(resp)
469                        if (200..300).contains(&resp.status)
470                            || (400..500).contains(&resp.status) =>
471                    {
472                        tracing::debug!(
473                            backend = %backend_id,
474                            status = resp.status,
475                            transport = ?resp.transport_used,
476                            fallback_attempts = resp.fallback_attempts,
477                            response_bytes = resp.response.len(),
478                            has_negotiation_history = resp.negotiation_history.is_some(),
479                            "Gateway health check: backend reachable"
480                        );
481                        // 2xx or 4xx = server is alive (even if it rejects the ping method)
482                        gateway.record_success(backend_id);
483                    }
484                    Ok(resp) => {
485                        tracing::debug!(
486                            backend = %backend_id,
487                            status = resp.status,
488                            "Gateway health check: server error"
489                        );
490                        gateway.record_failure(backend_id);
491                    }
492                    Err(e) => {
493                        tracing::debug!(
494                            backend = %backend_id,
495                            error = %e,
496                            "Gateway health check: connection failed"
497                        );
498                        gateway.record_failure(backend_id);
499                    }
500                }
501            }
502
503            // Update metrics
504            let health = gateway.backend_health();
505            let total = health.len();
506            let healthy_count = health
507                .iter()
508                .filter(|(_, _, h)| *h == BackendHealth::Healthy)
509                .count();
510
511            metrics::gauge!("vellaveto_gateway_backends_total").set(total as f64);
512            metrics::gauge!("vellaveto_gateway_backends_healthy").set(healthy_count as f64);
513        }
514    })
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use vellaveto_config::{BackendConfig, GatewayConfig};
521
522    fn test_config(backends: Vec<BackendConfig>) -> GatewayConfig {
523        GatewayConfig {
524            enabled: true,
525            backends,
526            health_check_interval_secs: 15,
527            unhealthy_threshold: 3,
528            healthy_threshold: 2,
529        }
530    }
531
532    fn backend(id: &str, url: &str, prefixes: &[&str]) -> BackendConfig {
533        BackendConfig {
534            id: id.to_string(),
535            url: url.to_string(),
536            tool_prefixes: prefixes.iter().map(|s| s.to_string()).collect(),
537            weight: 100,
538            transport_urls: std::collections::HashMap::new(),
539        }
540    }
541
542    fn default_backend(id: &str, url: &str) -> BackendConfig {
543        BackendConfig {
544            id: id.to_string(),
545            url: url.to_string(),
546            tool_prefixes: vec![],
547            weight: 100,
548            transport_urls: std::collections::HashMap::new(),
549        }
550    }
551
552    #[test]
553    fn test_router_from_config_valid() {
554        let config = test_config(vec![
555            backend("fs", "http://fs:8000", &["fs_", "file_"]),
556            default_backend("default", "http://default:8000"),
557        ]);
558        let router = GatewayRouter::from_config(&config).unwrap();
559        assert_eq!(router.backend_count(), 2);
560        assert_eq!(router.prefix_table.len(), 2);
561        assert_eq!(router.default_backend_id, Some("default".to_string()));
562    }
563
564    #[test]
565    fn test_router_from_config_empty_backends() {
566        let config = test_config(vec![]);
567        let err = GatewayRouter::from_config(&config).unwrap_err();
568        assert!(err.contains("at least one backend"), "got: {err}");
569    }
570
571    #[test]
572    fn test_router_from_config_duplicate_ids() {
573        let config = test_config(vec![
574            backend("dup", "http://a:8000", &["a_"]),
575            backend("dup", "http://b:8000", &["b_"]),
576        ]);
577        let err = GatewayRouter::from_config(&config).unwrap_err();
578        assert!(err.contains("duplicate backend id"), "got: {err}");
579    }
580
581    #[test]
582    fn test_router_from_config_multiple_defaults() {
583        let config = test_config(vec![
584            default_backend("d1", "http://a:8000"),
585            default_backend("d2", "http://b:8000"),
586        ]);
587        let err = GatewayRouter::from_config(&config).unwrap_err();
588        assert!(err.contains("multiple default backends"), "got: {err}");
589    }
590
591    #[test]
592    fn test_route_prefix_match() {
593        let config = test_config(vec![
594            backend("fs", "http://fs:8000", &["fs_"]),
595            default_backend("default", "http://default:8000"),
596        ]);
597        let router = GatewayRouter::from_config(&config).unwrap();
598
599        let decision = router.route("fs_read_file").unwrap();
600        assert_eq!(decision.backend_id, "fs");
601        assert_eq!(decision.upstream_url, "http://fs:8000");
602    }
603
604    #[test]
605    fn test_route_longest_prefix_wins() {
606        let config = test_config(vec![
607            backend("short", "http://short:8000", &["fs_"]),
608            backend("long", "http://long:8000", &["fs_read_"]),
609        ]);
610        let router = GatewayRouter::from_config(&config).unwrap();
611
612        // "fs_read_file" matches both "fs_" and "fs_read_", longest wins
613        let decision = router.route("fs_read_file").unwrap();
614        assert_eq!(decision.backend_id, "long");
615
616        // "fs_write_file" only matches "fs_"
617        let decision = router.route("fs_write_file").unwrap();
618        assert_eq!(decision.backend_id, "short");
619    }
620
621    #[test]
622    fn test_route_default_fallback() {
623        let config = test_config(vec![
624            backend("fs", "http://fs:8000", &["fs_"]),
625            default_backend("default", "http://default:8000"),
626        ]);
627        let router = GatewayRouter::from_config(&config).unwrap();
628
629        let decision = router.route("unknown_tool").unwrap();
630        assert_eq!(decision.backend_id, "default");
631    }
632
633    #[test]
634    fn test_route_no_match_no_default_returns_none() {
635        let config = test_config(vec![backend("fs", "http://fs:8000", &["fs_"])]);
636        let router = GatewayRouter::from_config(&config).unwrap();
637
638        assert!(router.route("unknown_tool").is_none());
639    }
640
641    #[test]
642    fn test_route_unhealthy_backend_skipped() {
643        let config = test_config(vec![
644            backend("fs", "http://fs:8000", &["fs_"]),
645            default_backend("default", "http://default:8000"),
646        ]);
647        let router = GatewayRouter::from_config(&config).unwrap();
648
649        // Mark fs as unhealthy
650        for _ in 0..3 {
651            router.record_failure("fs");
652        }
653
654        // Should fall through to default
655        let decision = router.route("fs_read_file").unwrap();
656        assert_eq!(decision.backend_id, "default");
657    }
658
659    #[test]
660    fn test_route_degraded_backend_included() {
661        let config = GatewayConfig {
662            unhealthy_threshold: 5, // higher so we don't cross it
663            ..test_config(vec![backend("fs", "http://fs:8000", &["fs_"])])
664        };
665        let router = GatewayRouter::from_config(&config).unwrap();
666
667        // Record some failures but not enough to be unhealthy
668        router.record_failure("fs");
669        router.record_failure("fs");
670
671        // Should still route (degraded but not unhealthy since threshold is 5)
672        let decision = router.route("fs_read_file").unwrap();
673        assert_eq!(decision.backend_id, "fs");
674    }
675
676    #[test]
677    fn test_record_failure_marks_unhealthy() {
678        let config = test_config(vec![backend("fs", "http://fs:8000", &["fs_"])]);
679        let router = GatewayRouter::from_config(&config).unwrap();
680
681        // threshold is 3
682        router.record_failure("fs");
683        router.record_failure("fs");
684        assert!(router.route("fs_tool").is_some()); // still healthy
685
686        router.record_failure("fs");
687        assert!(router.route("fs_tool").is_none()); // now unhealthy
688    }
689
690    #[test]
691    fn test_record_success_restores_from_degraded() {
692        let config = test_config(vec![backend("fs", "http://fs:8000", &["fs_"])]);
693        let router = GatewayRouter::from_config(&config).unwrap();
694
695        // Make unhealthy
696        for _ in 0..3 {
697            router.record_failure("fs");
698        }
699        assert!(router.route("fs_tool").is_none());
700
701        // One success transitions to degraded
702        router.record_success("fs");
703        assert!(router.route("fs_tool").is_some()); // degraded is routable
704
705        // Another success restores to healthy (healthy_threshold = 2)
706        router.record_success("fs");
707        let health = router.backend_health();
708        let fs_health = health.iter().find(|(id, _, _)| id == "fs").unwrap();
709        assert_eq!(fs_health.2, BackendHealth::Healthy);
710    }
711
712    #[test]
713    fn test_record_failure_resets_success_count() {
714        let config = test_config(vec![backend("fs", "http://fs:8000", &["fs_"])]);
715        let router = GatewayRouter::from_config(&config).unwrap();
716
717        // Make unhealthy, then start recovering
718        for _ in 0..3 {
719            router.record_failure("fs");
720        }
721        router.record_success("fs"); // degraded now, success_count = 1
722
723        // Failure resets success counter
724        router.record_failure("fs");
725
726        // Need healthy_threshold (2) more successes to restore
727        router.record_success("fs");
728        let health = router.backend_health();
729        let fs_health = health.iter().find(|(id, _, _)| id == "fs").unwrap();
730        assert_eq!(fs_health.2, BackendHealth::Degraded); // still degraded
731    }
732
733    #[test]
734    fn test_health_transition_unhealthy_to_degraded() {
735        let config = test_config(vec![backend("b", "http://b:8000", &["b_"])]);
736        let router = GatewayRouter::from_config(&config).unwrap();
737
738        for _ in 0..3 {
739            router.record_failure("b");
740        }
741
742        let health = router.backend_health();
743        assert_eq!(health[0].2, BackendHealth::Unhealthy);
744
745        router.record_success("b");
746        let health = router.backend_health();
747        assert_eq!(health[0].2, BackendHealth::Degraded);
748    }
749
750    #[test]
751    fn test_health_transition_degraded_to_healthy() {
752        let config = test_config(vec![backend("b", "http://b:8000", &["b_"])]);
753        let router = GatewayRouter::from_config(&config).unwrap();
754
755        for _ in 0..3 {
756            router.record_failure("b");
757        }
758        router.record_success("b"); // degraded
759        router.record_success("b"); // healthy (threshold = 2)
760
761        let health = router.backend_health();
762        assert_eq!(health[0].2, BackendHealth::Healthy);
763    }
764
765    #[test]
766    fn test_backend_health_returns_all() {
767        let config = test_config(vec![
768            backend("a", "http://a:8000", &["a_"]),
769            backend("b", "http://b:8000", &["b_"]),
770        ]);
771        let router = GatewayRouter::from_config(&config).unwrap();
772
773        let health = router.backend_health();
774        assert_eq!(health.len(), 2);
775    }
776
777    #[test]
778    fn test_backend_count() {
779        let config = test_config(vec![
780            backend("a", "http://a:8000", &["a_"]),
781            backend("b", "http://b:8000", &["b_"]),
782            backend("c", "http://c:8000", &["c_"]),
783        ]);
784        let router = GatewayRouter::from_config(&config).unwrap();
785        assert_eq!(router.backend_count(), 3);
786    }
787
788    #[test]
789    fn test_detect_conflicts_none() {
790        let discovered = vec![
791            DiscoveredTools {
792                backend_id: "a".to_string(),
793                tool_names: vec!["tool_a".to_string()],
794            },
795            DiscoveredTools {
796                backend_id: "b".to_string(),
797                tool_names: vec!["tool_b".to_string()],
798            },
799        ];
800        let conflicts = detect_conflicts(&discovered);
801        assert!(conflicts.is_empty());
802    }
803
804    #[test]
805    fn test_detect_conflicts_found() {
806        let discovered = vec![
807            DiscoveredTools {
808                backend_id: "a".to_string(),
809                tool_names: vec!["read_file".to_string(), "unique_a".to_string()],
810            },
811            DiscoveredTools {
812                backend_id: "b".to_string(),
813                tool_names: vec!["read_file".to_string(), "unique_b".to_string()],
814            },
815        ];
816        let conflicts = detect_conflicts(&discovered);
817        assert_eq!(conflicts.len(), 1);
818        assert_eq!(conflicts[0].tool_name, "read_file");
819        assert_eq!(conflicts[0].backends.len(), 2);
820    }
821
822    #[test]
823    fn test_route_with_affinity_prefers_known() {
824        let config = test_config(vec![
825            backend("a", "http://a:8000", &["fs_"]),
826            backend("b", "http://b:8000", &["fs_"]),
827            default_backend("default", "http://default:8000"),
828        ]);
829        let router = GatewayRouter::from_config(&config).unwrap();
830
831        let mut affinities = HashMap::new();
832        affinities.insert("fs_read".to_string(), "b".to_string());
833
834        let decision = router.route_with_affinity("fs_read", &affinities).unwrap();
835        assert_eq!(decision.backend_id, "b");
836    }
837
838    #[test]
839    fn test_route_with_affinity_falls_back_on_unhealthy() {
840        let config = test_config(vec![
841            backend("a", "http://a:8000", &["fs_"]),
842            default_backend("default", "http://default:8000"),
843        ]);
844        let router = GatewayRouter::from_config(&config).unwrap();
845
846        // Make "a" unhealthy
847        for _ in 0..3 {
848            router.record_failure("a");
849        }
850
851        let mut affinities = HashMap::new();
852        affinities.insert("fs_read".to_string(), "a".to_string());
853
854        // Should fall back to default since "a" is unhealthy
855        let decision = router.route_with_affinity("fs_read", &affinities).unwrap();
856        assert_eq!(decision.backend_id, "default");
857    }
858
859    #[test]
860    fn test_route_with_affinity_empty_affinities() {
861        let config = test_config(vec![
862            backend("fs", "http://fs:8000", &["fs_"]),
863            default_backend("default", "http://default:8000"),
864        ]);
865        let router = GatewayRouter::from_config(&config).unwrap();
866
867        let affinities = HashMap::new();
868        let decision = router.route_with_affinity("fs_read", &affinities).unwrap();
869        assert_eq!(decision.backend_id, "fs");
870    }
871
872    #[test]
873    fn test_route_truncates_long_tool_name() {
874        let config = test_config(vec![default_backend("default", "http://default:8000")]);
875        let router = GatewayRouter::from_config(&config).unwrap();
876
877        let long_name = "x".repeat(1000);
878        let decision = router.route(&long_name).unwrap();
879        assert_eq!(decision.backend_id, "default");
880    }
881
882    #[test]
883    fn test_route_empty_tool_name_uses_default() {
884        let config = test_config(vec![
885            backend("fs", "http://fs:8000", &["fs_"]),
886            default_backend("default", "http://default:8000"),
887        ]);
888        let router = GatewayRouter::from_config(&config).unwrap();
889
890        let decision = router.route("").unwrap();
891        assert_eq!(decision.backend_id, "default");
892    }
893
894    #[test]
895    fn test_gateway_config_validate_valid() {
896        let config = test_config(vec![
897            backend("a", "http://a:8000", &["a_"]),
898            default_backend("default", "http://default:8000"),
899        ]);
900        assert!(config.validate().is_ok());
901    }
902
903    #[test]
904    fn test_gateway_config_validate_empty_id() {
905        let config = test_config(vec![BackendConfig {
906            id: String::new(),
907            url: "http://a:8000".to_string(),
908            tool_prefixes: vec![],
909            weight: 100,
910            transport_urls: std::collections::HashMap::new(),
911        }]);
912        let err = config.validate().unwrap_err();
913        assert!(err.contains("id must not be empty"), "got: {err}");
914    }
915
916    #[test]
917    fn test_gateway_config_validate_zero_weight() {
918        let config = test_config(vec![BackendConfig {
919            id: "b".to_string(),
920            url: "http://a:8000".to_string(),
921            tool_prefixes: vec![],
922            weight: 0,
923            transport_urls: std::collections::HashMap::new(),
924        }]);
925        let err = config.validate().unwrap_err();
926        assert!(err.contains("weight must be >= 1"), "got: {err}");
927    }
928
929    #[test]
930    fn test_gateway_config_validate_interval_bounds() {
931        let mut config = test_config(vec![default_backend("d", "http://d:8000")]);
932        config.health_check_interval_secs = 1;
933        assert!(config.validate().unwrap_err().contains("[5, 300]"));
934
935        config.health_check_interval_secs = 500;
936        assert!(config.validate().unwrap_err().contains("[5, 300]"));
937    }
938
939    #[test]
940    fn test_gateway_config_disabled_skips_validation() {
941        let config = GatewayConfig {
942            enabled: false,
943            backends: vec![],
944            health_check_interval_secs: 0,
945            unhealthy_threshold: 0,
946            healthy_threshold: 0,
947        };
948        assert!(config.validate().is_ok());
949    }
950
951    #[test]
952    fn test_gateway_config_serde_roundtrip() {
953        let config = test_config(vec![
954            backend("a", "http://a:8000", &["prefix_"]),
955            default_backend("default", "http://d:8000"),
956        ]);
957        let json_str = serde_json::to_string(&config).unwrap();
958        let deserialized: GatewayConfig = serde_json::from_str(&json_str).unwrap();
959        assert_eq!(config, deserialized);
960    }
961
962    // ═══════════════════════════════════════════════════
963    // ADVERSARIAL AUDIT ROUND 43 TESTS
964    // ═══════════════════════════════════════════════════
965
966    /// FIND-R43-002: Tool names with multi-byte UTF-8 chars at the truncation boundary
967    /// must not panic. Before fix, `&tool_name[..256]` would panic if byte 256 fell
968    /// in the middle of a multi-byte character.
969    #[test]
970    fn test_r43_002_route_multibyte_utf8_no_panic() {
971        let config = test_config(vec![default_backend("default", "http://default:8000")]);
972        let router = GatewayRouter::from_config(&config).unwrap();
973
974        // Create a string that is >256 bytes with multi-byte chars.
975        // Each CJK character is 3 bytes in UTF-8. 90 chars = 270 bytes.
976        let multibyte_name = "\u{4e16}".repeat(90);
977        assert!(multibyte_name.len() > 256); // 270 bytes > 256
978        assert!(multibyte_name.len() < 1000);
979
980        // Should not panic — uses char-boundary-safe truncation now.
981        let decision = router.route(&multibyte_name);
982        assert_eq!(decision.unwrap().backend_id, "default");
983    }
984
985    /// FIND-R43-002: 4-byte UTF-8 chars (emoji) at boundary.
986    #[test]
987    fn test_r43_002_route_4byte_utf8_at_boundary() {
988        let config = test_config(vec![default_backend("default", "http://default:8000")]);
989        let router = GatewayRouter::from_config(&config).unwrap();
990
991        // 64 emoji (each 4 bytes) = 256 bytes exactly, then add one more to exceed.
992        let mut name = "\u{1F600}".repeat(65);
993        assert!(name.len() > 256); // 260 bytes
994                                   // This should not panic.
995        let decision = router.route(&name);
996        assert!(decision.is_some());
997
998        // Test with exactly 255 ASCII + one 2-byte char to land at byte 256 mid-char.
999        name = "x".repeat(255) + "\u{00E9}"; // e-acute is 2 bytes in UTF-8
1000        assert_eq!(name.len(), 257);
1001        let decision = router.route(&name);
1002        assert!(decision.is_some());
1003    }
1004}