1use crate::client::LlmClient;
14use crate::tool::ToolDef;
15use crate::types::{Message, SgrError, ToolCall};
16use serde_json::Value;
17
18#[derive(Debug, Clone)]
20pub struct RouterConfig {
21 pub message_threshold: usize,
23 pub tool_threshold: usize,
25 pub always_smart: bool,
27}
28
29impl Default for RouterConfig {
30 fn default() -> Self {
31 Self {
32 message_threshold: 10,
33 tool_threshold: 8,
34 always_smart: false,
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum ModelChoice {
42 Smart,
43 Fast,
44}
45
46pub struct ModelRouter<S: LlmClient, F: LlmClient> {
48 smart: S,
49 fast: F,
50 config: RouterConfig,
51}
52
53impl<S: LlmClient, F: LlmClient> ModelRouter<S, F> {
54 pub fn new(smart: S, fast: F) -> Self {
55 Self {
56 smart,
57 fast,
58 config: RouterConfig::default(),
59 }
60 }
61
62 pub fn with_config(mut self, config: RouterConfig) -> Self {
63 self.config = config;
64 self
65 }
66
67 pub fn route_messages(&self, messages: &[Message]) -> ModelChoice {
69 if self.config.always_smart {
70 return ModelChoice::Smart;
71 }
72 if messages.len() > self.config.message_threshold {
73 return ModelChoice::Smart;
74 }
75 ModelChoice::Fast
76 }
77
78 pub fn route_tools(&self, messages: &[Message], tools: &[ToolDef]) -> ModelChoice {
80 if self.config.always_smart {
81 return ModelChoice::Smart;
82 }
83 if messages.len() > self.config.message_threshold {
84 return ModelChoice::Smart;
85 }
86 if tools.len() > self.config.tool_threshold {
87 return ModelChoice::Smart;
88 }
89 ModelChoice::Fast
90 }
91
92 pub fn route_structured(&self, messages: &[Message], _schema: &Value) -> ModelChoice {
94 if self.config.always_smart {
95 return ModelChoice::Smart;
96 }
97 if messages.len() > self.config.message_threshold {
99 return ModelChoice::Smart;
100 }
101 ModelChoice::Smart
103 }
104}
105
106#[async_trait::async_trait]
107impl<S: LlmClient, F: LlmClient> LlmClient for ModelRouter<S, F> {
108 async fn structured_call(
109 &self,
110 messages: &[Message],
111 schema: &Value,
112 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
113 match self.route_structured(messages, schema) {
114 ModelChoice::Smart => self.smart.structured_call(messages, schema).await,
115 ModelChoice::Fast => self.fast.structured_call(messages, schema).await,
116 }
117 }
118
119 async fn tools_call(
120 &self,
121 messages: &[Message],
122 tools: &[ToolDef],
123 ) -> Result<Vec<ToolCall>, SgrError> {
124 match self.route_tools(messages, tools) {
125 ModelChoice::Smart => self.smart.tools_call(messages, tools).await,
126 ModelChoice::Fast => self.fast.tools_call(messages, tools).await,
127 }
128 }
129
130 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
131 match self.route_messages(messages) {
132 ModelChoice::Smart => self.smart.complete(messages).await,
133 ModelChoice::Fast => self.fast.complete(messages).await,
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn router_config_defaults() {
144 let config = RouterConfig::default();
145 assert!(!config.always_smart);
146 assert_eq!(config.message_threshold, 10);
147 assert_eq!(config.tool_threshold, 8);
148 }
149
150 #[test]
151 fn route_messages_logic() {
152 struct DummyClient;
153 #[async_trait::async_trait]
154 impl LlmClient for DummyClient {
155 async fn structured_call(
156 &self,
157 _: &[Message],
158 _: &Value,
159 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
160 Ok((None, vec![], String::new()))
161 }
162 async fn tools_call(
163 &self,
164 _: &[Message],
165 _: &[ToolDef],
166 ) -> Result<Vec<ToolCall>, SgrError> {
167 Ok(vec![])
168 }
169 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
170 Ok(String::new())
171 }
172 }
173
174 let router = ModelRouter::new(DummyClient, DummyClient);
175
176 let short: Vec<Message> = (0..3).map(|_| Message::user("hi")).collect();
178 assert_eq!(router.route_messages(&short), ModelChoice::Fast);
179
180 let long: Vec<Message> = (0..15).map(|_| Message::user("hi")).collect();
182 assert_eq!(router.route_messages(&long), ModelChoice::Smart);
183 }
184
185 #[test]
186 fn route_tools_logic() {
187 struct DummyClient;
188 #[async_trait::async_trait]
189 impl LlmClient for DummyClient {
190 async fn structured_call(
191 &self,
192 _: &[Message],
193 _: &Value,
194 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
195 Ok((None, vec![], String::new()))
196 }
197 async fn tools_call(
198 &self,
199 _: &[Message],
200 _: &[ToolDef],
201 ) -> Result<Vec<ToolCall>, SgrError> {
202 Ok(vec![])
203 }
204 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
205 Ok(String::new())
206 }
207 }
208
209 let router = ModelRouter::new(DummyClient, DummyClient);
210 let msgs = vec![Message::user("hi")];
211
212 let few_tools: Vec<ToolDef> = (0..3)
214 .map(|i| ToolDef {
215 name: format!("tool_{}", i),
216 description: "test".into(),
217 parameters: serde_json::json!({}),
218 })
219 .collect();
220 assert_eq!(router.route_tools(&msgs, &few_tools), ModelChoice::Fast);
221
222 let many_tools: Vec<ToolDef> = (0..12)
224 .map(|i| ToolDef {
225 name: format!("tool_{}", i),
226 description: "test".into(),
227 parameters: serde_json::json!({}),
228 })
229 .collect();
230 assert_eq!(router.route_tools(&msgs, &many_tools), ModelChoice::Smart);
231 }
232
233 #[test]
234 fn always_smart_overrides() {
235 struct DummyClient;
236 #[async_trait::async_trait]
237 impl LlmClient for DummyClient {
238 async fn structured_call(
239 &self,
240 _: &[Message],
241 _: &Value,
242 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
243 Ok((None, vec![], String::new()))
244 }
245 async fn tools_call(
246 &self,
247 _: &[Message],
248 _: &[ToolDef],
249 ) -> Result<Vec<ToolCall>, SgrError> {
250 Ok(vec![])
251 }
252 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
253 Ok(String::new())
254 }
255 }
256
257 let router = ModelRouter::new(DummyClient, DummyClient).with_config(RouterConfig {
258 always_smart: true,
259 ..Default::default()
260 });
261
262 let msgs = vec![Message::user("hi")];
263 assert_eq!(router.route_messages(&msgs), ModelChoice::Smart);
264 assert_eq!(router.route_tools(&msgs, &[]), ModelChoice::Smart);
265 }
266
267 #[test]
268 fn structured_defaults_to_smart() {
269 struct DummyClient;
270 #[async_trait::async_trait]
271 impl LlmClient for DummyClient {
272 async fn structured_call(
273 &self,
274 _: &[Message],
275 _: &Value,
276 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
277 Ok((None, vec![], String::new()))
278 }
279 async fn tools_call(
280 &self,
281 _: &[Message],
282 _: &[ToolDef],
283 ) -> Result<Vec<ToolCall>, SgrError> {
284 Ok(vec![])
285 }
286 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
287 Ok(String::new())
288 }
289 }
290
291 let router = ModelRouter::new(DummyClient, DummyClient);
292 let msgs = vec![Message::user("hi")];
293 assert_eq!(
295 router.route_structured(&msgs, &serde_json::json!({})),
296 ModelChoice::Smart
297 );
298 }
299}