turbomcp_client/llm/
routing.rs1use crate::llm::core::{LLMError, LLMRequest, LLMResult};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum RoutingStrategy {
9 Specific { provider: String },
11 RoundRobin { providers: Vec<String> },
13 RuleBased { rules: Vec<RouteRule> },
15 LoadBalanced { providers: Vec<String> },
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RouteRule {
22 pub condition: RouteCondition,
24 pub provider: String,
26 pub priority: i32,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum RouteCondition {
33 ModelEquals { model: String },
35 ModelContains { pattern: String },
37 MetadataEquals { key: String, value: String },
39 IsStreaming,
41 HasImages,
43 Always,
45}
46
47#[derive(Debug)]
49pub struct RequestRouter {
50 strategy: RoutingStrategy,
51 round_robin_index: std::sync::Mutex<usize>,
52}
53
54impl RequestRouter {
55 pub fn new(strategy: RoutingStrategy) -> Self {
57 Self {
58 strategy,
59 round_robin_index: std::sync::Mutex::new(0),
60 }
61 }
62
63 pub fn route_request(&self, request: &LLMRequest) -> LLMResult<String> {
65 match &self.strategy {
66 RoutingStrategy::Specific { provider } => Ok(provider.clone()),
67
68 RoutingStrategy::RoundRobin { providers } => {
69 if providers.is_empty() {
70 return Err(LLMError::configuration(
71 "No providers configured for round-robin",
72 ));
73 }
74
75 let mut index = self.round_robin_index.lock().unwrap();
76 let provider = providers[*index % providers.len()].clone();
77 *index = (*index + 1) % providers.len();
78 Ok(provider)
79 }
80
81 RoutingStrategy::RuleBased { rules } => {
82 let mut matching_rules: Vec<_> = rules
83 .iter()
84 .filter(|rule| self.matches_condition(&rule.condition, request))
85 .collect();
86
87 matching_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
89
90 matching_rules
91 .first()
92 .map(|rule| rule.provider.clone())
93 .ok_or_else(|| LLMError::configuration("No routing rules matched the request"))
94 }
95
96 RoutingStrategy::LoadBalanced { providers } => {
97 if providers.is_empty() {
98 return Err(LLMError::configuration(
99 "No providers configured for load balancing",
100 ));
101 }
102
103 let mut index = self.round_robin_index.lock().unwrap();
106 let provider = providers[*index % providers.len()].clone();
107 *index = (*index + 1) % providers.len();
108 Ok(provider)
109 }
110 }
111 }
112
113 fn matches_condition(&self, condition: &RouteCondition, request: &LLMRequest) -> bool {
114 match condition {
115 RouteCondition::ModelEquals { model } => request.model == *model,
116
117 RouteCondition::ModelContains { pattern } => request.model.contains(pattern),
118
119 RouteCondition::MetadataEquals { key, value } => request
120 .metadata
121 .get(key)
122 .and_then(|v| v.as_str())
123 .map(|v| v == value)
124 .unwrap_or(false),
125
126 RouteCondition::IsStreaming => request.stream,
127
128 RouteCondition::HasImages => request.messages.iter().any(|msg| msg.content.is_image()),
129
130 RouteCondition::Always => true,
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::llm::core::LLMMessage;
139
140 #[test]
141 fn test_specific_routing() {
142 let strategy = RoutingStrategy::Specific {
143 provider: "openai".to_string(),
144 };
145
146 let router = RequestRouter::new(strategy);
147 let request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
148
149 let result = router.route_request(&request).unwrap();
150 assert_eq!(result, "openai");
151 }
152
153 #[test]
154 fn test_round_robin_routing() {
155 let strategy = RoutingStrategy::RoundRobin {
156 providers: vec!["openai".to_string(), "anthropic".to_string()],
157 };
158
159 let router = RequestRouter::new(strategy);
160 let request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
161
162 let result1 = router.route_request(&request).unwrap();
163 let result2 = router.route_request(&request).unwrap();
164 let result3 = router.route_request(&request).unwrap();
165
166 assert_eq!(result1, "openai");
167 assert_eq!(result2, "anthropic");
168 assert_eq!(result3, "openai"); }
170
171 #[test]
172 fn test_rule_based_routing() {
173 let rules = vec![
174 RouteRule {
175 condition: RouteCondition::ModelContains {
176 pattern: "gpt".to_string(),
177 },
178 provider: "openai".to_string(),
179 priority: 10,
180 },
181 RouteRule {
182 condition: RouteCondition::ModelContains {
183 pattern: "claude".to_string(),
184 },
185 provider: "anthropic".to_string(),
186 priority: 10,
187 },
188 RouteRule {
189 condition: RouteCondition::Always,
190 provider: "ollama".to_string(),
191 priority: 1,
192 },
193 ];
194
195 let strategy = RoutingStrategy::RuleBased { rules };
196 let router = RequestRouter::new(strategy);
197
198 let gpt_request = LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]);
200 let result = router.route_request(&gpt_request).unwrap();
201 assert_eq!(result, "openai");
202
203 let claude_request = LLMRequest::new("claude-3-sonnet", vec![LLMMessage::user("Hello")]);
205 let result = router.route_request(&claude_request).unwrap();
206 assert_eq!(result, "anthropic");
207
208 let other_request = LLMRequest::new("llama2", vec![LLMMessage::user("Hello")]);
210 let result = router.route_request(&other_request).unwrap();
211 assert_eq!(result, "ollama");
212 }
213
214 #[test]
215 fn test_streaming_condition() {
216 let rules = vec![RouteRule {
217 condition: RouteCondition::IsStreaming,
218 provider: "streaming_provider".to_string(),
219 priority: 10,
220 }];
221
222 let strategy = RoutingStrategy::RuleBased { rules };
223 let router = RequestRouter::new(strategy);
224
225 let streaming_request =
226 LLMRequest::new("gpt-4", vec![LLMMessage::user("Hello")]).with_streaming(true);
227
228 let result = router.route_request(&streaming_request).unwrap();
229 assert_eq!(result, "streaming_provider");
230 }
231}