Skip to main content

sgr_agent/
router.rs

1//! Model router — routes requests to smart or fast model based on complexity.
2//!
3//! Wraps two `LlmClient` instances and selects which to use per call:
4//! - **Smart model** (e.g. gemini-3.1-pro): for complex reasoning, many tools, long context
5//! - **Fast model** (e.g. gemini-3.1-flash): for simple tool calls, short context
6//!
7//! Selection heuristics:
8//! - Message count > threshold → smart
9//! - Tool count > threshold → smart
10//! - Schema complexity (deep nesting) → smart
11//! - Otherwise → fast
12
13use crate::client::LlmClient;
14use crate::tool::ToolDef;
15use crate::types::{Message, SgrError, ToolCall};
16use serde_json::Value;
17
18/// Configuration for the model router.
19#[derive(Debug, Clone)]
20pub struct RouterConfig {
21    /// Messages above this count route to smart model.
22    pub message_threshold: usize,
23    /// Tools above this count route to smart model.
24    pub tool_threshold: usize,
25    /// Always use smart model (bypass routing).
26    pub always_smart: bool,
27}
28
29impl Default for RouterConfig {
30    fn default() -> Self {
31        Self {
32            message_threshold: 10,
33            tool_threshold: 8,
34            always_smart: false,
35        }
36    }
37}
38
39/// Which model was selected for a request.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum ModelChoice {
42    Smart,
43    Fast,
44}
45
46/// Dual-model router that selects smart or fast model per request.
47pub struct ModelRouter<S: LlmClient, F: LlmClient> {
48    smart: S,
49    fast: F,
50    config: RouterConfig,
51}
52
53impl<S: LlmClient, F: LlmClient> ModelRouter<S, F> {
54    pub fn new(smart: S, fast: F) -> Self {
55        Self {
56            smart,
57            fast,
58            config: RouterConfig::default(),
59        }
60    }
61
62    pub fn with_config(mut self, config: RouterConfig) -> Self {
63        self.config = config;
64        self
65    }
66
67    /// Decide which model to use based on request characteristics.
68    pub fn route_messages(&self, messages: &[Message]) -> ModelChoice {
69        if self.config.always_smart {
70            return ModelChoice::Smart;
71        }
72        if messages.len() > self.config.message_threshold {
73            return ModelChoice::Smart;
74        }
75        ModelChoice::Fast
76    }
77
78    /// Decide which model for tool calls based on tool count.
79    pub fn route_tools(&self, messages: &[Message], tools: &[ToolDef]) -> ModelChoice {
80        if self.config.always_smart {
81            return ModelChoice::Smart;
82        }
83        if messages.len() > self.config.message_threshold {
84            return ModelChoice::Smart;
85        }
86        if tools.len() > self.config.tool_threshold {
87            return ModelChoice::Smart;
88        }
89        ModelChoice::Fast
90    }
91
92    /// Decide which model for structured calls.
93    pub fn route_structured(&self, messages: &[Message], _schema: &Value) -> ModelChoice {
94        if self.config.always_smart {
95            return ModelChoice::Smart;
96        }
97        // Structured output with many messages → smart
98        if messages.len() > self.config.message_threshold {
99            return ModelChoice::Smart;
100        }
101        // Structured calls are generally harder → smart for safety
102        ModelChoice::Smart
103    }
104}
105
106#[async_trait::async_trait]
107impl<S: LlmClient, F: LlmClient> LlmClient for ModelRouter<S, F> {
108    async fn structured_call(
109        &self,
110        messages: &[Message],
111        schema: &Value,
112    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
113        match self.route_structured(messages, schema) {
114            ModelChoice::Smart => self.smart.structured_call(messages, schema).await,
115            ModelChoice::Fast => self.fast.structured_call(messages, schema).await,
116        }
117    }
118
119    async fn tools_call(
120        &self,
121        messages: &[Message],
122        tools: &[ToolDef],
123    ) -> Result<Vec<ToolCall>, SgrError> {
124        match self.route_tools(messages, tools) {
125            ModelChoice::Smart => self.smart.tools_call(messages, tools).await,
126            ModelChoice::Fast => self.fast.tools_call(messages, tools).await,
127        }
128    }
129
130    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
131        match self.route_messages(messages) {
132            ModelChoice::Smart => self.smart.complete(messages).await,
133            ModelChoice::Fast => self.fast.complete(messages).await,
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn router_config_defaults() {
144        let config = RouterConfig::default();
145        assert!(!config.always_smart);
146        assert_eq!(config.message_threshold, 10);
147        assert_eq!(config.tool_threshold, 8);
148    }
149
150    #[test]
151    fn route_messages_logic() {
152        struct DummyClient;
153        #[async_trait::async_trait]
154        impl LlmClient for DummyClient {
155            async fn structured_call(
156                &self,
157                _: &[Message],
158                _: &Value,
159            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
160                Ok((None, vec![], String::new()))
161            }
162            async fn tools_call(
163                &self,
164                _: &[Message],
165                _: &[ToolDef],
166            ) -> Result<Vec<ToolCall>, SgrError> {
167                Ok(vec![])
168            }
169            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
170                Ok(String::new())
171            }
172        }
173
174        let router = ModelRouter::new(DummyClient, DummyClient);
175
176        // Short conversation → fast
177        let short: Vec<Message> = (0..3).map(|_| Message::user("hi")).collect();
178        assert_eq!(router.route_messages(&short), ModelChoice::Fast);
179
180        // Long conversation → smart
181        let long: Vec<Message> = (0..15).map(|_| Message::user("hi")).collect();
182        assert_eq!(router.route_messages(&long), ModelChoice::Smart);
183    }
184
185    #[test]
186    fn route_tools_logic() {
187        struct DummyClient;
188        #[async_trait::async_trait]
189        impl LlmClient for DummyClient {
190            async fn structured_call(
191                &self,
192                _: &[Message],
193                _: &Value,
194            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
195                Ok((None, vec![], String::new()))
196            }
197            async fn tools_call(
198                &self,
199                _: &[Message],
200                _: &[ToolDef],
201            ) -> Result<Vec<ToolCall>, SgrError> {
202                Ok(vec![])
203            }
204            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
205                Ok(String::new())
206            }
207        }
208
209        let router = ModelRouter::new(DummyClient, DummyClient);
210        let msgs = vec![Message::user("hi")];
211
212        // Few tools → fast
213        let few_tools: Vec<ToolDef> = (0..3)
214            .map(|i| ToolDef {
215                name: format!("tool_{}", i),
216                description: "test".into(),
217                parameters: serde_json::json!({}),
218            })
219            .collect();
220        assert_eq!(router.route_tools(&msgs, &few_tools), ModelChoice::Fast);
221
222        // Many tools → smart
223        let many_tools: Vec<ToolDef> = (0..12)
224            .map(|i| ToolDef {
225                name: format!("tool_{}", i),
226                description: "test".into(),
227                parameters: serde_json::json!({}),
228            })
229            .collect();
230        assert_eq!(router.route_tools(&msgs, &many_tools), ModelChoice::Smart);
231    }
232
233    #[test]
234    fn always_smart_overrides() {
235        struct DummyClient;
236        #[async_trait::async_trait]
237        impl LlmClient for DummyClient {
238            async fn structured_call(
239                &self,
240                _: &[Message],
241                _: &Value,
242            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
243                Ok((None, vec![], String::new()))
244            }
245            async fn tools_call(
246                &self,
247                _: &[Message],
248                _: &[ToolDef],
249            ) -> Result<Vec<ToolCall>, SgrError> {
250                Ok(vec![])
251            }
252            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
253                Ok(String::new())
254            }
255        }
256
257        let router = ModelRouter::new(DummyClient, DummyClient).with_config(RouterConfig {
258            always_smart: true,
259            ..Default::default()
260        });
261
262        let msgs = vec![Message::user("hi")];
263        assert_eq!(router.route_messages(&msgs), ModelChoice::Smart);
264        assert_eq!(router.route_tools(&msgs, &[]), ModelChoice::Smart);
265    }
266
267    #[test]
268    fn structured_defaults_to_smart() {
269        struct DummyClient;
270        #[async_trait::async_trait]
271        impl LlmClient for DummyClient {
272            async fn structured_call(
273                &self,
274                _: &[Message],
275                _: &Value,
276            ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
277                Ok((None, vec![], String::new()))
278            }
279            async fn tools_call(
280                &self,
281                _: &[Message],
282                _: &[ToolDef],
283            ) -> Result<Vec<ToolCall>, SgrError> {
284                Ok(vec![])
285            }
286            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
287                Ok(String::new())
288            }
289        }
290
291        let router = ModelRouter::new(DummyClient, DummyClient);
292        let msgs = vec![Message::user("hi")];
293        // Structured calls always prefer smart
294        assert_eq!(
295            router.route_structured(&msgs, &serde_json::json!({})),
296            ModelChoice::Smart
297        );
298    }
299}