1pub mod cache;
17pub mod store;
18pub mod validate;
19
20pub use cache::CachingRoutingStore;
21#[cfg(feature = "postgres")]
22pub use store::PostgresRoutingStore;
23pub use store::{InMemoryRoutingStore, NewRoute, RoutingStore, RoutingStoreError};
24pub use validate::{validate_capability, ValidationError};
25
26use serde::{Deserialize, Serialize};
27use uuid::Uuid;
28
29use tt_shared::{ChatCompletionRequest, RequestContext};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Route {
36 pub id: Uuid,
38 pub name: String,
40 pub priority: u32,
42 pub enabled: bool,
44 pub when: RouteConditions,
46 pub then: RouteAction,
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct RouteConditions {
55 #[serde(default)]
57 pub model_in: Vec<String>,
58 #[serde(default)]
60 pub input_tokens_lt: Option<u32>,
61 #[serde(default)]
63 pub input_tokens_gt: Option<u32>,
64 #[serde(default)]
66 pub tag_equals: Option<String>,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub has_images: Option<bool>,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub has_audio: Option<bool>,
75 #[serde(default, skip_serializing_if = "Vec::is_empty")]
78 pub prompt_contains_any_of: Vec<String>,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub estimated_cost_gt: Option<f64>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub estimated_cost_lt: Option<f64>,
86}
87
88#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct RouteAction {
91 pub target_model: String,
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
101 pub fallbacks: Vec<String>,
102 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
106 pub disable_cache: bool,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub max_cost_usd: Option<f64>,
112}
113
114#[derive(Debug, Clone, Default)]
117pub struct RoutingEngine {
118 routes: Vec<Route>,
119}
120
121impl RoutingEngine {
122 pub fn new() -> Self {
125 Self::default()
126 }
127
128 pub fn with_routes(routes: impl IntoIterator<Item = Route>) -> Self {
131 let mut v: Vec<Route> = routes.into_iter().collect();
132 v.sort_by_key(|r| std::cmp::Reverse(r.priority));
133 Self { routes: v }
134 }
135
136 pub fn add(&mut self, route: Route) {
139 self.routes.push(route);
140 self.routes.sort_by_key(|r| std::cmp::Reverse(r.priority));
141 }
142
143 pub fn routes(&self) -> &[Route] {
145 &self.routes
146 }
147
148 pub fn evaluate(
155 &self,
156 req: &ChatCompletionRequest,
157 ctx: &RequestContext,
158 input_tokens_estimate: u32,
159 ) -> Option<&Route> {
160 self.evaluate_with_cost(req, ctx, input_tokens_estimate, None)
162 }
163
164 pub fn evaluate_with_cost(
170 &self,
171 req: &ChatCompletionRequest,
172 ctx: &RequestContext,
173 input_tokens_estimate: u32,
174 estimated_cost_usd: Option<f64>,
175 ) -> Option<&Route> {
176 self.routes
177 .iter()
178 .find(|r| r.enabled && matches(r, req, ctx, input_tokens_estimate, estimated_cost_usd))
179 }
180
181 pub fn find_by_name(&self, name: &str) -> Option<&Route> {
184 self.routes.iter().find(|r| r.enabled && r.name == name)
185 }
186}
187
188fn matches(
189 r: &Route,
190 req: &ChatCompletionRequest,
191 ctx: &RequestContext,
192 input_tokens: u32,
193 estimated_cost_usd: Option<f64>,
194) -> bool {
195 let c = &r.when;
196 if !c.model_in.is_empty() && !c.model_in.iter().any(|m| m == &req.model) {
197 return false;
198 }
199 if let Some(t) = c.input_tokens_lt {
200 if input_tokens >= t {
201 return false;
202 }
203 }
204 if let Some(t) = c.input_tokens_gt {
205 if input_tokens <= t {
206 return false;
207 }
208 }
209 if let Some(t) = c.estimated_cost_gt {
210 if !matches!(estimated_cost_usd, Some(cost) if cost > t) {
212 return false;
213 }
214 }
215 if let Some(t) = c.estimated_cost_lt {
216 if !matches!(estimated_cost_usd, Some(cost) if cost < t) {
217 return false;
218 }
219 }
220 if let Some(tag) = &c.tag_equals {
221 if ctx.tag.as_deref() != Some(tag.as_str()) {
222 return false;
223 }
224 }
225 if let Some(want) = c.has_images {
226 if tt_shared::capability_check::request_has_images(req) != want {
227 return false;
228 }
229 }
230 if let Some(want) = c.has_audio {
231 if tt_shared::capability_check::request_has_audio(req) != want {
232 return false;
233 }
234 }
235 if !c.prompt_contains_any_of.is_empty() {
236 let text = tt_shared::capability_check::request_input_text(req).to_lowercase();
237 if !c
238 .prompt_contains_any_of
239 .iter()
240 .any(|kw| text.contains(&kw.to_lowercase()))
241 {
242 return false;
243 }
244 }
245 true
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use tt_shared::{
252 context::{ProviderCredentials, SecretString},
253 messages::{ContentPart, ImageUrl, InputAudio},
254 ChatCompletionRequest, Message, MessageContent,
255 };
256
257 fn make_route(name: &str, priority: u32, model_in: Vec<&str>, target: &str) -> Route {
258 Route {
259 id: Uuid::now_v7(),
260 name: name.into(),
261 priority,
262 enabled: true,
263 when: RouteConditions {
264 model_in: model_in.into_iter().map(String::from).collect(),
265 ..Default::default()
266 },
267 then: RouteAction {
268 target_model: target.into(),
269 fallbacks: Vec::new(),
270 disable_cache: false,
271 max_cost_usd: None,
272 },
273 }
274 }
275
276 fn make_req(model: &str) -> ChatCompletionRequest {
277 ChatCompletionRequest {
278 model: model.into(),
279 messages: vec![Message::User {
280 content: MessageContent::Text("hi".into()),
281 name: None,
282 }],
283 ..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
284 }
285 }
286
287 fn make_req_with_part(model: &str, part: ContentPart) -> ChatCompletionRequest {
288 ChatCompletionRequest {
289 model: model.into(),
290 messages: vec![Message::User {
291 content: MessageContent::Parts(vec![part]),
292 name: None,
293 }],
294 ..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
295 }
296 }
297
298 fn image_part() -> ContentPart {
299 ContentPart::ImageUrl {
300 image_url: ImageUrl {
301 url: "data:image/png;base64,abc".into(),
302 detail: None,
303 },
304 }
305 }
306
307 fn audio_part() -> ContentPart {
308 ContentPart::InputAudio {
309 input_audio: InputAudio {
310 data: "abc".into(),
311 format: "wav".into(),
312 },
313 }
314 }
315
316 #[test]
317 fn find_by_name_matches_enabled_route_by_exact_name() {
318 let enabled = make_route("alpha", 10, vec!["gpt-4o"], "gpt-4o-mini");
319 let mut disabled = make_route("beta", 10, vec!["gpt-4o"], "gpt-4o-mini");
320 disabled.enabled = false;
321 let eng = RoutingEngine::with_routes(vec![enabled, disabled]);
322 assert!(eng.find_by_name("alpha").is_some());
323 assert_eq!(eng.find_by_name("alpha").unwrap().name, "alpha");
324 assert!(
325 eng.find_by_name("beta").is_none(),
326 "disabled route not found"
327 );
328 assert!(eng.find_by_name("missing").is_none());
329 }
330
331 #[test]
332 fn has_images_true_matches_only_image_requests() {
333 let route = Route {
334 when: RouteConditions {
335 has_images: Some(true),
336 ..Default::default()
337 },
338 ..make_route("vision", 10, vec![], "vision-mini")
339 };
340 let eng = RoutingEngine::with_routes(vec![route]);
341 assert!(eng
342 .evaluate(
343 &make_req_with_part("gpt-4o", image_part()),
344 &make_ctx(None),
345 100
346 )
347 .is_some());
348 assert!(eng
349 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
350 .is_none());
351 }
352
353 #[test]
354 fn has_images_false_matches_only_non_image_requests() {
355 let route = Route {
356 when: RouteConditions {
357 has_images: Some(false),
358 ..Default::default()
359 },
360 ..make_route("text", 10, vec![], "cheap")
361 };
362 let eng = RoutingEngine::with_routes(vec![route]);
363 assert!(eng
364 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
365 .is_some());
366 assert!(eng
367 .evaluate(
368 &make_req_with_part("gpt-4o", image_part()),
369 &make_ctx(None),
370 100
371 )
372 .is_none());
373 }
374
375 #[test]
376 fn has_audio_true_matches_only_audio_requests() {
377 let route = Route {
378 when: RouteConditions {
379 has_audio: Some(true),
380 ..Default::default()
381 },
382 ..make_route("audio", 10, vec![], "audio-model")
383 };
384 let eng = RoutingEngine::with_routes(vec![route]);
385 assert!(eng
386 .evaluate(
387 &make_req_with_part("gpt-4o", audio_part()),
388 &make_ctx(None),
389 100
390 )
391 .is_some());
392 assert!(eng
393 .evaluate(
394 &make_req_with_part("gpt-4o", image_part()),
395 &make_ctx(None),
396 100
397 )
398 .is_none());
399 }
400
401 #[test]
402 fn modality_anded_with_model_in() {
403 let route = Route {
404 when: RouteConditions {
405 model_in: vec!["gpt-4o".into()],
406 has_images: Some(true),
407 ..Default::default()
408 },
409 ..make_route("both", 10, vec!["gpt-4o"], "vision-mini")
410 };
411 let eng = RoutingEngine::with_routes(vec![route]);
412 assert!(eng
413 .evaluate(
414 &make_req_with_part("gpt-4o", image_part()),
415 &make_ctx(None),
416 100
417 )
418 .is_some());
419 assert!(eng
420 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
421 .is_none());
422 assert!(eng
423 .evaluate(
424 &make_req_with_part("other", image_part()),
425 &make_ctx(None),
426 100
427 )
428 .is_none());
429 }
430
431 fn make_ctx(tag: Option<&str>) -> RequestContext {
432 RequestContext {
433 trace_id: Uuid::now_v7(),
434 org_id: Uuid::nil(),
435 api_key_id: Uuid::nil(),
436 credentials: ProviderCredentials {
437 api_key: SecretString::new(""),
438 base_url: None,
439 extra_headers: Vec::new(),
440 },
441 tag: tag.map(String::from),
442 deadline: None,
443 }
444 }
445
446 #[test]
447 fn empty_engine_matches_nothing() {
448 let eng = RoutingEngine::new();
449 assert!(eng
450 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
451 .is_none());
452 }
453
454 #[test]
455 fn model_in_matches() {
456 let eng = RoutingEngine::with_routes(vec![make_route(
457 "to-mini",
458 10,
459 vec!["gpt-4o"],
460 "gpt-4o-mini",
461 )]);
462 let m = eng
463 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
464 .expect("should match");
465 assert_eq!(m.then.target_model, "gpt-4o-mini");
466 }
467
468 #[test]
469 fn priority_descending_first_match_wins() {
470 let eng = RoutingEngine::with_routes(vec![
471 make_route("low", 1, vec!["gpt-4o"], "low-target"),
472 make_route("high", 100, vec!["gpt-4o"], "high-target"),
473 make_route("mid", 50, vec!["gpt-4o"], "mid-target"),
474 ]);
475 let m = eng
476 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
477 .unwrap();
478 assert_eq!(m.then.target_model, "high-target");
479 }
480
481 #[test]
482 fn disabled_route_skipped() {
483 let mut route = make_route("disabled", 100, vec!["gpt-4o"], "never");
484 route.enabled = false;
485 let eng = RoutingEngine::with_routes(vec![
486 route,
487 make_route("enabled", 10, vec!["gpt-4o"], "winner"),
488 ]);
489 let m = eng
490 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
491 .unwrap();
492 assert_eq!(m.then.target_model, "winner");
493 }
494
495 #[test]
496 fn token_lt_filters() {
497 let route = Route {
498 when: RouteConditions {
499 model_in: vec!["gpt-4o".into()],
500 input_tokens_lt: Some(500),
501 ..Default::default()
502 },
503 ..make_route("short-only", 10, vec!["gpt-4o"], "gpt-4o-mini")
504 };
505 let eng = RoutingEngine::with_routes(vec![route]);
506 assert!(eng
507 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
508 .is_some());
509 assert!(eng
510 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 600)
511 .is_none());
512 }
513
514 #[test]
515 fn token_gt_filters() {
516 let route = Route {
517 when: RouteConditions {
518 model_in: vec!["gpt-4o".into()],
519 input_tokens_gt: Some(1000),
520 ..Default::default()
521 },
522 ..make_route("long-only", 10, vec!["gpt-4o"], "claude-opus-4-7")
523 };
524 let eng = RoutingEngine::with_routes(vec![route]);
525 assert!(eng
526 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 500)
527 .is_none());
528 assert!(eng
529 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 1500)
530 .is_some());
531 }
532
533 #[test]
534 fn tag_equals_filters() {
535 let route = Route {
536 when: RouteConditions {
537 tag_equals: Some("background".into()),
538 ..Default::default()
539 },
540 ..make_route("bg-only", 10, vec![], "cheap-model")
541 };
542 let eng = RoutingEngine::with_routes(vec![route]);
543 assert!(eng
544 .evaluate(&make_req("gpt-4o"), &make_ctx(None), 100)
545 .is_none());
546 assert!(eng
547 .evaluate(&make_req("gpt-4o"), &make_ctx(Some("background")), 100)
548 .is_some());
549 assert!(eng
550 .evaluate(&make_req("gpt-4o"), &make_ctx(Some("foreground")), 100)
551 .is_none());
552 }
553
554 #[test]
555 fn empty_model_in_matches_any_model() {
556 let route = make_route("any", 10, vec![], "target");
557 let eng = RoutingEngine::with_routes(vec![route]);
558 assert!(eng
559 .evaluate(&make_req("claude-sonnet-4-6"), &make_ctx(None), 100)
560 .is_some());
561 }
562
563 #[test]
568 fn route_action_minimal_serializes_without_new_fields() {
569 let a = RouteAction {
570 target_model: "x".into(),
571 fallbacks: Vec::new(),
572 disable_cache: false,
573 max_cost_usd: None,
574 };
575 let json = serde_json::to_string(&a).unwrap();
576 assert_eq!(
577 json, r#"{"target_model":"x"}"#,
578 "empty fallbacks must be omitted from JSON"
579 );
580 }
581
582 #[test]
585 fn route_action_backward_compat_deserialize() {
586 let json = r#"{"target_model":"gpt-4o-mini"}"#;
587 let a: RouteAction = serde_json::from_str(json).unwrap();
588 assert_eq!(a.target_model, "gpt-4o-mini");
589 assert!(a.fallbacks.is_empty(), "fallbacks must default to empty");
590 }
591
592 #[test]
595 fn route_action_full_round_trip() {
596 let original = RouteAction {
597 target_model: "claude-haiku-4-5".into(),
598 fallbacks: vec!["gpt-4o-mini".into(), "gemini-flash".into()],
599 disable_cache: false,
600 max_cost_usd: None,
601 };
602 let json = serde_json::to_string(&original).unwrap();
603 assert!(
604 json.contains("\"fallbacks\""),
605 "fallbacks must be present: {json}"
606 );
607 let roundtripped: RouteAction = serde_json::from_str(&json).unwrap();
608 assert_eq!(roundtripped.target_model, original.target_model);
609 assert_eq!(roundtripped.fallbacks, original.fallbacks);
610 }
611
612 #[test]
613 fn route_action_disable_cache_defaults_false_and_omits() {
614 let a = RouteAction {
616 target_model: "x".into(),
617 fallbacks: Vec::new(),
618 disable_cache: false,
619 max_cost_usd: None,
620 };
621 assert_eq!(
622 serde_json::to_string(&a).unwrap(),
623 r#"{"target_model":"x"}"#
624 );
625 let parsed: RouteAction = serde_json::from_str(r#"{"target_model":"m"}"#).unwrap();
627 assert!(!parsed.disable_cache);
628 let b = RouteAction {
630 disable_cache: true,
631 ..a
632 };
633 assert!(serde_json::to_string(&b)
634 .unwrap()
635 .contains("\"disable_cache\":true"));
636 }
637
638 #[test]
644 fn route_action_cross_type_wire_compat() {
645 let plan_side_json = r#"{"target_model":"claude-3-5-haiku","fallbacks":["gpt-4o-mini"]}"#;
646 let gateway_action: RouteAction = serde_json::from_str(plan_side_json).unwrap();
647 assert_eq!(gateway_action.target_model, "claude-3-5-haiku");
648 assert_eq!(gateway_action.fallbacks, vec!["gpt-4o-mini"]);
649 let reemitted = serde_json::to_string(&gateway_action).unwrap();
650 assert_eq!(reemitted, plan_side_json);
651 }
652
653 #[test]
657 fn route_action_legacy_force_cache_layer_is_ignored() {
658 let legacy =
659 r#"{"target_model":"claude-3-5-haiku","fallbacks":["x"],"force_cache_layer":"l1"}"#;
660 let a: RouteAction = serde_json::from_str(legacy).unwrap();
661 assert_eq!(a.target_model, "claude-3-5-haiku");
662 assert_eq!(a.fallbacks, vec!["x"]);
663 let j = serde_json::to_string(&a).unwrap();
664 assert!(
665 !j.contains("force_cache_layer"),
666 "obsolete key must not be re-emitted: {j}"
667 );
668 }
669
670 fn make_req_text(model: &str, text: &str) -> ChatCompletionRequest {
671 ChatCompletionRequest {
672 model: model.into(),
673 messages: vec![Message::User {
674 content: MessageContent::Text(text.into()),
675 name: None,
676 }],
677 ..serde_json::from_str(r#"{"model":"placeholder","messages":[]}"#).unwrap()
678 }
679 }
680
681 #[test]
682 fn prompt_contains_matches_case_insensitive_any() {
683 let route = Route {
684 when: RouteConditions {
685 prompt_contains_any_of: vec!["confidential".into(), "salary".into()],
686 ..Default::default()
687 },
688 ..make_route("topic", 10, vec![], "local")
689 };
690 let eng = RoutingEngine::with_routes(vec![route]);
691 assert!(eng
692 .evaluate(
693 &make_req_text("gpt-4o", "This is a Confidential memo"),
694 &make_ctx(None),
695 100
696 )
697 .is_some());
698 assert!(eng
699 .evaluate(
700 &make_req_text("gpt-4o", "my SALARY is"),
701 &make_ctx(None),
702 100
703 )
704 .is_some());
705 assert!(eng
706 .evaluate(
707 &make_req_text("gpt-4o", "the weather today"),
708 &make_ctx(None),
709 100
710 )
711 .is_none());
712 }
713
714 #[test]
715 fn prompt_contains_anded_with_model_in() {
716 let route = Route {
717 when: RouteConditions {
718 model_in: vec!["gpt-4o".into()],
719 prompt_contains_any_of: vec!["confidential".into()],
720 ..Default::default()
721 },
722 ..make_route("both", 10, vec!["gpt-4o"], "local")
723 };
724 let eng = RoutingEngine::with_routes(vec![route]);
725 assert!(eng
726 .evaluate(
727 &make_req_text("gpt-4o", "confidential"),
728 &make_ctx(None),
729 100
730 )
731 .is_some());
732 assert!(eng
733 .evaluate(&make_req_text("gpt-4o", "hello"), &make_ctx(None), 100)
734 .is_none());
735 }
736
737 #[test]
738 fn max_cost_usd_round_trips_and_omits_when_none() {
739 let mut a = make_route("x", 10, vec![], "gpt-4o-mini").then;
740 assert!(!serde_json::to_string(&a).unwrap().contains("max_cost_usd"));
741 a.max_cost_usd = Some(0.1);
742 let j = serde_json::to_string(&a).unwrap();
743 assert!(j.contains("\"max_cost_usd\":0.1"));
744 let back: RouteAction = serde_json::from_str(&j).unwrap();
745 assert_eq!(back.max_cost_usd, Some(0.1));
746 }
747
748 #[test]
749 fn cost_gt_matches_above_threshold_only() {
750 let route = Route {
751 when: RouteConditions {
752 estimated_cost_gt: Some(0.02),
753 ..Default::default()
754 },
755 ..make_route("expensive", 10, vec![], "cheaper")
756 };
757 let eng = RoutingEngine::with_routes(vec![route]);
758 assert!(eng
760 .evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.03))
761 .is_some());
762 assert!(eng
763 .evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.01))
764 .is_none());
765 assert!(eng
766 .evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, None)
767 .is_none());
768 }
769
770 #[test]
771 fn cost_lt_anded_with_model_in() {
772 let route = Route {
773 when: RouteConditions {
774 model_in: vec!["gpt-4o".into()],
775 estimated_cost_lt: Some(0.05),
776 ..Default::default()
777 },
778 ..make_route("cheap-small", 10, vec![], "target")
779 };
780 let eng = RoutingEngine::with_routes(vec![route]);
781 assert!(eng
782 .evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.01))
783 .is_some());
784 assert!(eng
786 .evaluate_with_cost(&make_req("gpt-4o"), &make_ctx(None), 100, Some(0.09))
787 .is_none());
788 assert!(eng
790 .evaluate_with_cost(&make_req("claude-x"), &make_ctx(None), 100, Some(0.01))
791 .is_none());
792 }
793}