1use sentinel_config::{InferenceProvider, ModelRoutingConfig, ModelUpstreamMapping};
7
8#[derive(Debug, Clone)]
10pub struct ModelRoutingResult {
11 pub upstream: String,
13 pub provider: Option<InferenceProvider>,
15 pub is_default: bool,
17}
18
19pub fn find_upstream_for_model(
31 config: &ModelRoutingConfig,
32 model: &str,
33) -> Option<ModelRoutingResult> {
34 for mapping in &config.mappings {
36 if matches_model_pattern(&mapping.model_pattern, model) {
37 return Some(ModelRoutingResult {
38 upstream: mapping.upstream.clone(),
39 provider: mapping.provider,
40 is_default: false,
41 });
42 }
43 }
44
45 config.default_upstream.as_ref().map(|upstream| ModelRoutingResult {
47 upstream: upstream.clone(),
48 provider: None,
49 is_default: true,
50 })
51}
52
53fn matches_model_pattern(pattern: &str, model: &str) -> bool {
59 if pattern == model {
61 return true;
62 }
63
64 glob_match(pattern, model)
66}
67
68fn glob_match(pattern: &str, text: &str) -> bool {
79 let pattern_chars: Vec<char> = pattern.chars().collect();
80 let text_chars: Vec<char> = text.chars().collect();
81
82 glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
83}
84
85fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
86 if p_idx >= pattern.len() {
88 return t_idx >= text.len();
89 }
90
91 if pattern[p_idx] == '*' {
93 for i in t_idx..=text.len() {
95 if glob_match_recursive(pattern, text, p_idx + 1, i) {
96 return true;
97 }
98 }
99 return false;
100 }
101
102 if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
104 return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
105 }
106
107 false
108}
109
110pub fn extract_model_from_headers(headers: &http::HeaderMap) -> Option<String> {
119 let header_names = ["x-model", "x-model-id"];
121
122 for name in header_names {
123 if let Some(value) = headers.get(name) {
124 if let Ok(model) = value.to_str() {
125 let model = model.trim();
126 if !model.is_empty() {
127 return Some(model.to_string());
128 }
129 }
130 }
131 }
132
133 None
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn create_test_config() -> ModelRoutingConfig {
141 ModelRoutingConfig {
142 mappings: vec![
143 ModelUpstreamMapping {
144 model_pattern: "gpt-4".to_string(),
145 upstream: "openai-gpt4".to_string(),
146 provider: Some(InferenceProvider::OpenAi),
147 },
148 ModelUpstreamMapping {
149 model_pattern: "gpt-4*".to_string(),
150 upstream: "openai-primary".to_string(),
151 provider: Some(InferenceProvider::OpenAi),
152 },
153 ModelUpstreamMapping {
154 model_pattern: "gpt-3.5*".to_string(),
155 upstream: "openai-secondary".to_string(),
156 provider: Some(InferenceProvider::OpenAi),
157 },
158 ModelUpstreamMapping {
159 model_pattern: "claude-*".to_string(),
160 upstream: "anthropic-backend".to_string(),
161 provider: Some(InferenceProvider::Anthropic),
162 },
163 ModelUpstreamMapping {
164 model_pattern: "llama-*".to_string(),
165 upstream: "local-gpu".to_string(),
166 provider: Some(InferenceProvider::Generic),
167 },
168 ],
169 default_upstream: Some("openai-primary".to_string()),
170 }
171 }
172
173 #[test]
174 fn test_exact_match() {
175 let config = create_test_config();
176
177 let result = find_upstream_for_model(&config, "gpt-4").unwrap();
179 assert_eq!(result.upstream, "openai-gpt4");
180 assert!(!result.is_default);
181 }
182
183 #[test]
184 fn test_glob_suffix_match() {
185 let config = create_test_config();
186
187 let result = find_upstream_for_model(&config, "gpt-4-turbo").unwrap();
189 assert_eq!(result.upstream, "openai-primary");
190 assert!(!result.is_default);
191
192 let result = find_upstream_for_model(&config, "gpt-4o").unwrap();
194 assert_eq!(result.upstream, "openai-primary");
195 }
196
197 #[test]
198 fn test_claude_models() {
199 let config = create_test_config();
200
201 let result = find_upstream_for_model(&config, "claude-3-opus").unwrap();
203 assert_eq!(result.upstream, "anthropic-backend");
204 assert_eq!(result.provider, Some(InferenceProvider::Anthropic));
205
206 let result = find_upstream_for_model(&config, "claude-3.5-sonnet").unwrap();
207 assert_eq!(result.upstream, "anthropic-backend");
208 }
209
210 #[test]
211 fn test_default_upstream() {
212 let config = create_test_config();
213
214 let result = find_upstream_for_model(&config, "unknown-model").unwrap();
216 assert_eq!(result.upstream, "openai-primary");
217 assert!(result.is_default);
218 assert!(result.provider.is_none());
219 }
220
221 #[test]
222 fn test_no_match_no_default() {
223 let config = ModelRoutingConfig {
224 mappings: vec![ModelUpstreamMapping {
225 model_pattern: "gpt-4".to_string(),
226 upstream: "openai".to_string(),
227 provider: None,
228 }],
229 default_upstream: None,
230 };
231
232 let result = find_upstream_for_model(&config, "claude-3-opus");
234 assert!(result.is_none());
235 }
236
237 #[test]
238 fn test_first_match_wins() {
239 let config = create_test_config();
240
241 let result = find_upstream_for_model(&config, "gpt-4").unwrap();
243 assert_eq!(result.upstream, "openai-gpt4");
244 }
245
246 #[test]
247 fn test_glob_match_patterns() {
248 assert!(glob_match("gpt-4*", "gpt-4"));
250 assert!(glob_match("gpt-4*", "gpt-4-turbo"));
251 assert!(glob_match("gpt-4*", "gpt-4o"));
252 assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
253
254 assert!(glob_match("*-turbo", "gpt-4-turbo"));
255 assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
256 assert!(!glob_match("*-turbo", "gpt-4"));
257
258 assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
259 assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
260 assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
261
262 assert!(glob_match("*", "anything"));
263 assert!(glob_match("*", ""));
264 }
265
266 #[test]
267 fn test_extract_model_from_headers() {
268 let mut headers = http::HeaderMap::new();
269
270 assert!(extract_model_from_headers(&headers).is_none());
272
273 headers.insert("x-model", "gpt-4".parse().unwrap());
275 assert_eq!(
276 extract_model_from_headers(&headers),
277 Some("gpt-4".to_string())
278 );
279
280 headers.clear();
282 headers.insert("x-model-id", "claude-3-opus".parse().unwrap());
283 assert_eq!(
284 extract_model_from_headers(&headers),
285 Some("claude-3-opus".to_string())
286 );
287
288 headers.insert("x-model", "gpt-4".parse().unwrap());
290 assert_eq!(
291 extract_model_from_headers(&headers),
292 Some("gpt-4".to_string())
293 );
294
295 headers.clear();
297 headers.insert("x-model", "".parse().unwrap());
298 assert!(extract_model_from_headers(&headers).is_none());
299
300 headers.clear();
302 headers.insert("x-model", " ".parse().unwrap());
303 assert!(extract_model_from_headers(&headers).is_none());
304 }
305}