sentinel_proxy/proxy/
model_routing.rs

1//! Model-based routing for inference requests.
2//!
3//! Routes inference requests to different upstreams based on the model name.
4//! Supports glob patterns for flexible model matching (e.g., `gpt-4*`, `claude-*`).
5
6use sentinel_config::{InferenceProvider, ModelRoutingConfig, ModelUpstreamMapping};
7
8/// Result of model-based routing lookup.
9#[derive(Debug, Clone)]
10pub struct ModelRoutingResult {
11    /// Target upstream for this model
12    pub upstream: String,
13    /// Provider override if specified (for cross-provider routing)
14    pub provider: Option<InferenceProvider>,
15    /// Whether this was a default routing (no specific mapping matched)
16    pub is_default: bool,
17}
18
19/// Find the upstream for a given model name.
20///
21/// Checks mappings in order (first match wins). If no mapping matches,
22/// returns the default upstream if configured, otherwise None.
23///
24/// # Arguments
25/// * `config` - Model routing configuration
26/// * `model` - Model name to route
27///
28/// # Returns
29/// `Some(ModelRoutingResult)` if a matching upstream was found, `None` otherwise.
30pub fn find_upstream_for_model(
31    config: &ModelRoutingConfig,
32    model: &str,
33) -> Option<ModelRoutingResult> {
34    // Check mappings in order (first match wins)
35    for mapping in &config.mappings {
36        if matches_model_pattern(&mapping.model_pattern, model) {
37            return Some(ModelRoutingResult {
38                upstream: mapping.upstream.clone(),
39                provider: mapping.provider,
40                is_default: false,
41            });
42        }
43    }
44
45    // No mapping matched - use default if configured
46    config.default_upstream.as_ref().map(|upstream| ModelRoutingResult {
47        upstream: upstream.clone(),
48        provider: None,
49        is_default: true,
50    })
51}
52
53/// Check if a model name matches a pattern.
54///
55/// Supports:
56/// - Exact match: `"gpt-4"` matches `"gpt-4"`
57/// - Glob patterns with `*` wildcard: `"gpt-4*"` matches `"gpt-4"`, `"gpt-4-turbo"`, `"gpt-4o"`
58fn matches_model_pattern(pattern: &str, model: &str) -> bool {
59    // Exact match (fast path)
60    if pattern == model {
61        return true;
62    }
63
64    // Glob pattern matching
65    glob_match(pattern, model)
66}
67
68/// Simple glob pattern matching for model names.
69///
70/// Supports:
71/// - `*` matches any sequence of characters (including empty)
72/// - All other characters match literally
73///
74/// # Examples
75/// - `gpt-4*` matches `gpt-4`, `gpt-4-turbo`, `gpt-4o`
76/// - `claude-*-sonnet` matches `claude-3-sonnet`, `claude-3.5-sonnet`
77/// - `*-turbo` matches `gpt-4-turbo`, `gpt-3.5-turbo`
78fn glob_match(pattern: &str, text: &str) -> bool {
79    let pattern_chars: Vec<char> = pattern.chars().collect();
80    let text_chars: Vec<char> = text.chars().collect();
81
82    glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
83}
84
85fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
86    // End of pattern
87    if p_idx >= pattern.len() {
88        return t_idx >= text.len();
89    }
90
91    // Wildcard match
92    if pattern[p_idx] == '*' {
93        // Try matching zero or more characters
94        for i in t_idx..=text.len() {
95            if glob_match_recursive(pattern, text, p_idx + 1, i) {
96                return true;
97            }
98        }
99        return false;
100    }
101
102    // Exact character match
103    if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
104        return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
105    }
106
107    false
108}
109
110/// Extract model name from request headers.
111///
112/// Checks common model headers in order of precedence:
113/// 1. `x-model` - Explicit model header
114/// 2. `x-model-id` - Alternative model header
115///
116/// # Returns
117/// `Some(model_name)` if found in headers, `None` otherwise.
118pub fn extract_model_from_headers(headers: &http::HeaderMap) -> Option<String> {
119    // Check common model headers
120    let header_names = ["x-model", "x-model-id"];
121
122    for name in header_names {
123        if let Some(value) = headers.get(name) {
124            if let Ok(model) = value.to_str() {
125                let model = model.trim();
126                if !model.is_empty() {
127                    return Some(model.to_string());
128                }
129            }
130        }
131    }
132
133    None
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    fn create_test_config() -> ModelRoutingConfig {
141        ModelRoutingConfig {
142            mappings: vec![
143                ModelUpstreamMapping {
144                    model_pattern: "gpt-4".to_string(),
145                    upstream: "openai-gpt4".to_string(),
146                    provider: Some(InferenceProvider::OpenAi),
147                },
148                ModelUpstreamMapping {
149                    model_pattern: "gpt-4*".to_string(),
150                    upstream: "openai-primary".to_string(),
151                    provider: Some(InferenceProvider::OpenAi),
152                },
153                ModelUpstreamMapping {
154                    model_pattern: "gpt-3.5*".to_string(),
155                    upstream: "openai-secondary".to_string(),
156                    provider: Some(InferenceProvider::OpenAi),
157                },
158                ModelUpstreamMapping {
159                    model_pattern: "claude-*".to_string(),
160                    upstream: "anthropic-backend".to_string(),
161                    provider: Some(InferenceProvider::Anthropic),
162                },
163                ModelUpstreamMapping {
164                    model_pattern: "llama-*".to_string(),
165                    upstream: "local-gpu".to_string(),
166                    provider: Some(InferenceProvider::Generic),
167                },
168            ],
169            default_upstream: Some("openai-primary".to_string()),
170        }
171    }
172
173    #[test]
174    fn test_exact_match() {
175        let config = create_test_config();
176
177        // Exact match for "gpt-4" should match first (more specific)
178        let result = find_upstream_for_model(&config, "gpt-4").unwrap();
179        assert_eq!(result.upstream, "openai-gpt4");
180        assert!(!result.is_default);
181    }
182
183    #[test]
184    fn test_glob_suffix_match() {
185        let config = create_test_config();
186
187        // "gpt-4-turbo" should match "gpt-4*" pattern
188        let result = find_upstream_for_model(&config, "gpt-4-turbo").unwrap();
189        assert_eq!(result.upstream, "openai-primary");
190        assert!(!result.is_default);
191
192        // "gpt-4o" should match "gpt-4*" pattern
193        let result = find_upstream_for_model(&config, "gpt-4o").unwrap();
194        assert_eq!(result.upstream, "openai-primary");
195    }
196
197    #[test]
198    fn test_claude_models() {
199        let config = create_test_config();
200
201        // All claude models should route to anthropic
202        let result = find_upstream_for_model(&config, "claude-3-opus").unwrap();
203        assert_eq!(result.upstream, "anthropic-backend");
204        assert_eq!(result.provider, Some(InferenceProvider::Anthropic));
205
206        let result = find_upstream_for_model(&config, "claude-3.5-sonnet").unwrap();
207        assert_eq!(result.upstream, "anthropic-backend");
208    }
209
210    #[test]
211    fn test_default_upstream() {
212        let config = create_test_config();
213
214        // Unknown model should fall back to default
215        let result = find_upstream_for_model(&config, "unknown-model").unwrap();
216        assert_eq!(result.upstream, "openai-primary");
217        assert!(result.is_default);
218        assert!(result.provider.is_none());
219    }
220
221    #[test]
222    fn test_no_match_no_default() {
223        let config = ModelRoutingConfig {
224            mappings: vec![ModelUpstreamMapping {
225                model_pattern: "gpt-4".to_string(),
226                upstream: "openai".to_string(),
227                provider: None,
228            }],
229            default_upstream: None,
230        };
231
232        // No match and no default should return None
233        let result = find_upstream_for_model(&config, "claude-3-opus");
234        assert!(result.is_none());
235    }
236
237    #[test]
238    fn test_first_match_wins() {
239        let config = create_test_config();
240
241        // "gpt-4" exact match should win over "gpt-4*" glob
242        let result = find_upstream_for_model(&config, "gpt-4").unwrap();
243        assert_eq!(result.upstream, "openai-gpt4");
244    }
245
246    #[test]
247    fn test_glob_match_patterns() {
248        // Test various glob patterns
249        assert!(glob_match("gpt-4*", "gpt-4"));
250        assert!(glob_match("gpt-4*", "gpt-4-turbo"));
251        assert!(glob_match("gpt-4*", "gpt-4o"));
252        assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
253
254        assert!(glob_match("*-turbo", "gpt-4-turbo"));
255        assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
256        assert!(!glob_match("*-turbo", "gpt-4"));
257
258        assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
259        assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
260        assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
261
262        assert!(glob_match("*", "anything"));
263        assert!(glob_match("*", ""));
264    }
265
266    #[test]
267    fn test_extract_model_from_headers() {
268        let mut headers = http::HeaderMap::new();
269
270        // No headers
271        assert!(extract_model_from_headers(&headers).is_none());
272
273        // x-model header
274        headers.insert("x-model", "gpt-4".parse().unwrap());
275        assert_eq!(
276            extract_model_from_headers(&headers),
277            Some("gpt-4".to_string())
278        );
279
280        // x-model-id header (lower precedence)
281        headers.clear();
282        headers.insert("x-model-id", "claude-3-opus".parse().unwrap());
283        assert_eq!(
284            extract_model_from_headers(&headers),
285            Some("claude-3-opus".to_string())
286        );
287
288        // Both headers - x-model takes precedence
289        headers.insert("x-model", "gpt-4".parse().unwrap());
290        assert_eq!(
291            extract_model_from_headers(&headers),
292            Some("gpt-4".to_string())
293        );
294
295        // Empty header value
296        headers.clear();
297        headers.insert("x-model", "".parse().unwrap());
298        assert!(extract_model_from_headers(&headers).is_none());
299
300        // Whitespace-only header value
301        headers.clear();
302        headers.insert("x-model", "   ".parse().unwrap());
303        assert!(extract_model_from_headers(&headers).is_none());
304    }
305}