turbomcp_client/llm/
routing.rs

1//! Request routing and provider selection
2
3use crate::llm::core::{LLMError, LLMRequest, LLMResult};
4use serde::{Deserialize, Serialize};
5
6/// Strategies for routing requests to providers
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum RoutingStrategy {
9    /// Use specific provider
10    Specific { provider: String },
11    /// Round-robin between providers
12    RoundRobin { providers: Vec<String> },
13    /// Route based on request properties
14    RuleBased { rules: Vec<RouteRule> },
15    /// Route to least loaded provider
16    LoadBalanced { providers: Vec<String> },
17}
18
19/// Rule for routing requests
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RouteRule {
22    /// Condition to match
23    pub condition: RouteCondition,
24    /// Provider to route to
25    pub provider: String,
26    /// Rule priority (higher = more priority)
27    pub priority: i32,
28}
29
30/// Condition for route matching
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum RouteCondition {
33    /// Match specific model
34    ModelEquals { model: String },
35    /// Match model pattern
36    ModelContains { pattern: String },
37    /// Match request metadata
38    MetadataEquals { key: String, value: String },
39    /// Match streaming requests
40    IsStreaming,
41    /// Match requests with images
42    HasImages,
43    /// Always match
44    Always,
45}
46
47/// Request router for intelligent provider selection
48#[derive(Debug)]
49pub struct RequestRouter {
50    strategy: RoutingStrategy,
51    round_robin_index: std::sync::Mutex<usize>,
52}
53
54impl RequestRouter {
55    /// Create a new request router
56    pub fn new(strategy: RoutingStrategy) -> Self {
57        Self {
58            strategy,
59            round_robin_index: std::sync::Mutex::new(0),
60        }
61    }
62
63    /// Route a request to determine which provider to use
64    pub fn route_request(&self, request: &LLMRequest) -> LLMResult<String> {
65        match &self.strategy {
66            RoutingStrategy::Specific { provider } => Ok(provider.clone()),
67
68            RoutingStrategy::RoundRobin { providers } => {
69                if providers.is_empty() {
70                    return Err(LLMError::configuration(
71                        "No providers configured for round-robin",
72                    ));
73                }
74
75                let mut index = self.round_robin_index.lock().unwrap();
76                let provider = providers[*index % providers.len()].clone();
77                *index = (*index + 1) % providers.len();
78                Ok(provider)
79            }
80
81            RoutingStrategy::RuleBased { rules } => {
82                let mut matching_rules: Vec<_> = rules
83                    .iter()
84                    .filter(|rule| self.matches_condition(&rule.condition, request))
85                    .collect();
86
87                // Sort by priority (descending)
88                matching_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
89
90                matching_rules
91                    .first()
92                    .map(|rule| rule.provider.clone())
93                    .ok_or_else(|| LLMError::configuration("No routing rules matched the request"))
94            }
95
96            RoutingStrategy::LoadBalanced { providers } => {
97                if providers.is_empty() {
98                    return Err(LLMError::configuration(
99                        "No providers configured for load balancing",
100                    ));
101                }
102
103                // TODO: Implement actual load balancing
104                // For now, just use round-robin
105                let mut index = self.round_robin_index.lock().unwrap();
106                let provider = providers[*index % providers.len()].clone();
107                *index = (*index + 1) % providers.len();
108                Ok(provider)
109            }
110        }
111    }
112
113    fn matches_condition(&self, condition: &RouteCondition, request: &LLMRequest) -> bool {
114        match condition {
115            RouteCondition::ModelEquals { model } => request.model == *model,
116
117            RouteCondition::ModelContains { pattern } => request.model.contains(pattern),
118
119            RouteCondition::MetadataEquals { key, value } => request
120                .metadata
121                .get(key)
122                .and_then(|v| v.as_str())
123                .map(|v| v == value)
124                .unwrap_or(false),
125
126            RouteCondition::IsStreaming => request.stream,
127
128            RouteCondition::HasImages => request.messages.iter().any(|msg| msg.content.is_image()),
129
130            RouteCondition::Always => true,
131        }
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::llm::core::LLMMessage;
139
140    #[test]
141    fn test_specific_routing() {
142        let strategy = RoutingStrategy::Specific {
143            provider: "openai".to_string(),
144        };
145
146        let router = RequestRouter::new(strategy);
147        let request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
148
149        let result = router.route_request(&request).unwrap();
150        assert_eq!(result, "openai");
151    }
152
153    #[test]
154    fn test_round_robin_routing() {
155        let strategy = RoutingStrategy::RoundRobin {
156            providers: vec!["openai".to_string(), "anthropic".to_string()],
157        };
158
159        let router = RequestRouter::new(strategy);
160        let request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
161
162        let result1 = router.route_request(&request).unwrap();
163        let result2 = router.route_request(&request).unwrap();
164        let result3 = router.route_request(&request).unwrap();
165
166        assert_eq!(result1, "openai");
167        assert_eq!(result2, "anthropic");
168        assert_eq!(result3, "openai"); // Back to first
169    }
170
171    #[test]
172    fn test_rule_based_routing() {
173        let rules = vec![
174            RouteRule {
175                condition: RouteCondition::ModelContains {
176                    pattern: "gpt".to_string(),
177                },
178                provider: "openai".to_string(),
179                priority: 10,
180            },
181            RouteRule {
182                condition: RouteCondition::ModelContains {
183                    pattern: "claude".to_string(),
184                },
185                provider: "anthropic".to_string(),
186                priority: 10,
187            },
188            RouteRule {
189                condition: RouteCondition::Always,
190                provider: "ollama".to_string(),
191                priority: 1,
192            },
193        ];
194
195        let strategy = RoutingStrategy::RuleBased { rules };
196        let router = RequestRouter::new(strategy);
197
198        // Test GPT model routing
199        let gpt_request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
200        let result = router.route_request(&gpt_request).unwrap();
201        assert_eq!(result, "openai");
202
203        // Test Claude model routing
204        let claude_request = LLMRequest::new("claude-3-sonnet", vec![LLMMessage::user("Hello")]);
205        let result = router.route_request(&claude_request).unwrap();
206        assert_eq!(result, "anthropic");
207
208        // Test fallback routing
209        let other_request = LLMRequest::new("llama2", vec![LLMMessage::user("Hello")]);
210        let result = router.route_request(&other_request).unwrap();
211        assert_eq!(result, "ollama");
212    }
213
214    #[test]
215    fn test_streaming_condition() {
216        let rules = vec![RouteRule {
217            condition: RouteCondition::IsStreaming,
218            provider: "streaming_provider".to_string(),
219            priority: 10,
220        }];
221
222        let strategy = RoutingStrategy::RuleBased { rules };
223        let router = RequestRouter::new(strategy);
224
225        let streaming_request =
226            LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]).with_streaming(true);
227
228        let result = router.route_request(&streaming_request).unwrap();
229        assert_eq!(result, "streaming_provider");
230    }
231}