1use std::sync::Arc;
2
3use crate::{ModelSpec, StreamFn};
4
5type ExtraModelConnections = Vec<(ModelSpec, Arc<dyn StreamFn>)>;
6
7#[derive(Clone)]
8pub struct ModelConnection {
9 model: ModelSpec,
10 stream_fn: Arc<dyn StreamFn>,
11}
12
13impl ModelConnection {
14 #[must_use]
15 pub fn new(model: ModelSpec, stream_fn: Arc<dyn StreamFn>) -> Self {
16 Self { model, stream_fn }
17 }
18
19 #[must_use]
20 pub const fn model_spec(&self) -> &ModelSpec {
21 &self.model
22 }
23
24 #[must_use]
25 pub fn stream_fn(&self) -> Arc<dyn StreamFn> {
26 Arc::clone(&self.stream_fn)
27 }
28}
29
30pub struct ModelConnections {
31 primary_model: ModelSpec,
32 primary_stream_fn: Arc<dyn StreamFn>,
33 extra_models: ExtraModelConnections,
34}
35
36impl ModelConnections {
37 #[must_use]
38 pub fn new(primary: ModelConnection, extras: Vec<ModelConnection>) -> Self {
39 let ModelConnection {
40 model: primary_model,
41 stream_fn: primary_stream_fn,
42 } = primary;
43 let mut extra_models = Vec::new();
44
45 for connection in extras {
46 let model = connection.model.clone();
47 if model == primary_model || extra_models.iter().any(|(existing, _)| *existing == model)
48 {
49 continue;
50 }
51 extra_models.push((model, connection.stream_fn()));
52 }
53
54 Self {
55 primary_model,
56 primary_stream_fn,
57 extra_models,
58 }
59 }
60
61 #[must_use]
62 pub const fn primary_model(&self) -> &ModelSpec {
63 &self.primary_model
64 }
65
66 #[must_use]
67 pub fn primary_stream_fn(&self) -> Arc<dyn StreamFn> {
68 Arc::clone(&self.primary_stream_fn)
69 }
70
71 #[must_use]
72 pub fn extra_models(&self) -> &[(ModelSpec, Arc<dyn StreamFn>)] {
73 &self.extra_models
74 }
75
76 #[must_use]
77 pub fn into_parts(self) -> (ModelSpec, Arc<dyn StreamFn>, ExtraModelConnections) {
78 (
79 self.primary_model,
80 self.primary_stream_fn,
81 self.extra_models,
82 )
83 }
84
85 #[must_use]
87 pub const fn builder() -> ModelConnectionsBuilder {
88 ModelConnectionsBuilder::new()
89 }
90}
91
92pub struct ModelConnectionsBuilder {
98 primary: Option<ModelConnection>,
99 fallbacks: Vec<ModelConnection>,
100}
101
102impl Default for ModelConnectionsBuilder {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl ModelConnectionsBuilder {
109 #[must_use]
111 pub const fn new() -> Self {
112 Self {
113 primary: None,
114 fallbacks: Vec::new(),
115 }
116 }
117
118 #[must_use]
120 pub fn primary(mut self, connection: ModelConnection) -> Self {
121 self.primary = Some(connection);
122 self
123 }
124
125 #[must_use]
127 pub fn fallback(mut self, connection: ModelConnection) -> Self {
128 self.fallbacks.push(connection);
129 self
130 }
131
132 #[must_use]
138 pub fn build(self) -> ModelConnections {
139 let primary = self
140 .primary
141 .expect("ModelConnectionsBuilder: primary connection is required");
142 ModelConnections::new(primary, self.fallbacks)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::pin::Pin;
149
150 use futures::Stream;
151 use tokio_util::sync::CancellationToken;
152
153 use super::*;
154 use crate::{AgentContext, AssistantMessageEvent, StreamOptions};
155
156 struct DummyStreamFn;
157
158 impl StreamFn for DummyStreamFn {
159 fn stream<'a>(
160 &'a self,
161 _model: &'a ModelSpec,
162 _context: &'a AgentContext,
163 _options: &'a StreamOptions,
164 _cancellation_token: CancellationToken,
165 ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
166 Box::pin(futures::stream::empty())
167 }
168 }
169
170 fn dummy_stream() -> Arc<dyn StreamFn> {
171 Arc::new(DummyStreamFn)
172 }
173
174 #[test]
175 fn into_parts_returns_correct_values() {
176 let primary_model = ModelSpec::new("anthropic", "claude-sonnet-4-6");
177 let extra_model = ModelSpec::new("openai", "gpt-5.2");
178
179 let connections = ModelConnections::new(
180 ModelConnection::new(primary_model.clone(), dummy_stream()),
181 vec![ModelConnection::new(extra_model.clone(), dummy_stream())],
182 );
183
184 let (model, _stream_fn, extras) = connections.into_parts();
185 assert_eq!(model, primary_model);
186 assert_eq!(extras.len(), 1);
187 assert_eq!(extras[0].0, extra_model);
188 }
189
190 #[test]
191 fn model_connection_getters() {
192 let model = ModelSpec::new("test", "test-model");
193 let stream = dummy_stream();
194 let conn = ModelConnection::new(model.clone(), Arc::clone(&stream));
195
196 assert_eq!(conn.model_spec(), &model);
197 let sf = conn.stream_fn();
199 assert!(Arc::ptr_eq(&sf, &stream));
200 }
201
202 #[test]
203 fn empty_extras() {
204 let connections = ModelConnections::new(
205 ModelConnection::new(
206 ModelSpec::new("anthropic", "claude-sonnet-4-6"),
207 dummy_stream(),
208 ),
209 vec![],
210 );
211
212 assert_eq!(connections.extra_models().len(), 0);
213 assert_eq!(
214 connections.primary_model(),
215 &ModelSpec::new("anthropic", "claude-sonnet-4-6")
216 );
217 }
218
219 #[test]
220 fn all_extras_are_duplicates_of_primary() {
221 let primary = ModelSpec::new("anthropic", "claude-sonnet-4-6");
222 let connections = ModelConnections::new(
223 ModelConnection::new(primary.clone(), dummy_stream()),
224 vec![
225 ModelConnection::new(primary.clone(), dummy_stream()),
226 ModelConnection::new(primary, dummy_stream()),
227 ],
228 );
229
230 assert_eq!(connections.extra_models().len(), 0);
232 }
233
234 #[test]
235 fn model_connections_keep_primary_first_and_deduplicate_extras() {
236 let connections = ModelConnections::new(
237 ModelConnection::new(
238 ModelSpec::new("anthropic", "claude-sonnet-4-6"),
239 dummy_stream(),
240 ),
241 vec![
242 ModelConnection::new(
243 ModelSpec::new("anthropic", "claude-sonnet-4-6"),
244 dummy_stream(),
245 ),
246 ModelConnection::new(ModelSpec::new("openai", "gpt-5.2"), dummy_stream()),
247 ModelConnection::new(ModelSpec::new("openai", "gpt-5.2"), dummy_stream()),
248 ModelConnection::new(ModelSpec::new("local", "SmolLM3-3B-Q4_K_M"), dummy_stream()),
249 ],
250 );
251
252 assert_eq!(
253 connections.primary_model(),
254 &ModelSpec::new("anthropic", "claude-sonnet-4-6")
255 );
256 assert_eq!(connections.extra_models().len(), 2);
257 assert_eq!(
258 connections.extra_models()[0].0,
259 ModelSpec::new("openai", "gpt-5.2")
260 );
261 assert_eq!(
262 connections.extra_models()[1].0,
263 ModelSpec::new("local", "SmolLM3-3B-Q4_K_M")
264 );
265 }
266
267 #[test]
268 fn builder_primary_only() {
269 let connections = ModelConnections::builder()
270 .primary(ModelConnection::new(
271 ModelSpec::new("anthropic", "claude-sonnet-4-6"),
272 dummy_stream(),
273 ))
274 .build();
275
276 assert_eq!(
277 connections.primary_model(),
278 &ModelSpec::new("anthropic", "claude-sonnet-4-6")
279 );
280 assert_eq!(connections.extra_models().len(), 0);
281 }
282
283 #[test]
284 fn builder_with_fallbacks() {
285 let connections = ModelConnections::builder()
286 .primary(ModelConnection::new(
287 ModelSpec::new("anthropic", "claude-sonnet-4-6"),
288 dummy_stream(),
289 ))
290 .fallback(ModelConnection::new(
291 ModelSpec::new("openai", "gpt-5.2"),
292 dummy_stream(),
293 ))
294 .fallback(ModelConnection::new(
295 ModelSpec::new("local", "SmolLM3-3B-Q4_K_M"),
296 dummy_stream(),
297 ))
298 .build();
299
300 assert_eq!(connections.extra_models().len(), 2);
301 assert_eq!(
302 connections.extra_models()[0].0,
303 ModelSpec::new("openai", "gpt-5.2")
304 );
305 }
306
307 #[test]
308 #[should_panic(expected = "primary connection is required")]
309 fn builder_panics_without_primary() {
310 let _ = ModelConnections::builder().build();
311 }
312}