Skip to main content

swink_agent/
model_presets.rs

1use std::sync::Arc;
2
3use crate::{ModelSpec, StreamFn};
4
5type ExtraModelConnections = Vec<(ModelSpec, Arc<dyn StreamFn>)>;
6
7#[derive(Clone)]
8pub struct ModelConnection {
9    model: ModelSpec,
10    stream_fn: Arc<dyn StreamFn>,
11}
12
13impl ModelConnection {
14    #[must_use]
15    pub fn new(model: ModelSpec, stream_fn: Arc<dyn StreamFn>) -> Self {
16        Self { model, stream_fn }
17    }
18
19    #[must_use]
20    pub const fn model_spec(&self) -> &ModelSpec {
21        &self.model
22    }
23
24    #[must_use]
25    pub fn stream_fn(&self) -> Arc<dyn StreamFn> {
26        Arc::clone(&self.stream_fn)
27    }
28}
29
30pub struct ModelConnections {
31    primary_model: ModelSpec,
32    primary_stream_fn: Arc<dyn StreamFn>,
33    extra_models: ExtraModelConnections,
34}
35
36impl ModelConnections {
37    #[must_use]
38    pub fn new(primary: ModelConnection, extras: Vec<ModelConnection>) -> Self {
39        let ModelConnection {
40            model: primary_model,
41            stream_fn: primary_stream_fn,
42        } = primary;
43        let mut extra_models = Vec::new();
44
45        for connection in extras {
46            let model = connection.model.clone();
47            if model == primary_model || extra_models.iter().any(|(existing, _)| *existing == model)
48            {
49                continue;
50            }
51            extra_models.push((model, connection.stream_fn()));
52        }
53
54        Self {
55            primary_model,
56            primary_stream_fn,
57            extra_models,
58        }
59    }
60
61    #[must_use]
62    pub const fn primary_model(&self) -> &ModelSpec {
63        &self.primary_model
64    }
65
66    #[must_use]
67    pub fn primary_stream_fn(&self) -> Arc<dyn StreamFn> {
68        Arc::clone(&self.primary_stream_fn)
69    }
70
71    #[must_use]
72    pub fn extra_models(&self) -> &[(ModelSpec, Arc<dyn StreamFn>)] {
73        &self.extra_models
74    }
75
76    #[must_use]
77    pub fn into_parts(self) -> (ModelSpec, Arc<dyn StreamFn>, ExtraModelConnections) {
78        (
79            self.primary_model,
80            self.primary_stream_fn,
81            self.extra_models,
82        )
83    }
84
85    /// Create a builder for constructing `ModelConnections` incrementally.
86    #[must_use]
87    pub const fn builder() -> ModelConnectionsBuilder {
88        ModelConnectionsBuilder::new()
89    }
90}
91
92/// Incrementally builds a [`ModelConnections`] value.
93///
94/// # Panics
95///
96/// [`build`](Self::build) panics if no primary connection has been set.
97pub struct ModelConnectionsBuilder {
98    primary: Option<ModelConnection>,
99    fallbacks: Vec<ModelConnection>,
100}
101
102impl Default for ModelConnectionsBuilder {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl ModelConnectionsBuilder {
109    /// Create a new empty builder.
110    #[must_use]
111    pub const fn new() -> Self {
112        Self {
113            primary: None,
114            fallbacks: Vec::new(),
115        }
116    }
117
118    /// Set the primary model connection.
119    #[must_use]
120    pub fn primary(mut self, connection: ModelConnection) -> Self {
121        self.primary = Some(connection);
122        self
123    }
124
125    /// Add a fallback model connection.
126    #[must_use]
127    pub fn fallback(mut self, connection: ModelConnection) -> Self {
128        self.fallbacks.push(connection);
129        self
130    }
131
132    /// Build the final [`ModelConnections`].
133    ///
134    /// # Panics
135    ///
136    /// Panics if no primary connection was set via [`primary`](Self::primary).
137    #[must_use]
138    pub fn build(self) -> ModelConnections {
139        let primary = self
140            .primary
141            .expect("ModelConnectionsBuilder: primary connection is required");
142        ModelConnections::new(primary, self.fallbacks)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use std::pin::Pin;
149
150    use futures::Stream;
151    use tokio_util::sync::CancellationToken;
152
153    use super::*;
154    use crate::{AgentContext, AssistantMessageEvent, StreamOptions};
155
156    struct DummyStreamFn;
157
158    impl StreamFn for DummyStreamFn {
159        fn stream<'a>(
160            &'a self,
161            _model: &'a ModelSpec,
162            _context: &'a AgentContext,
163            _options: &'a StreamOptions,
164            _cancellation_token: CancellationToken,
165        ) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
166            Box::pin(futures::stream::empty())
167        }
168    }
169
170    fn dummy_stream() -> Arc<dyn StreamFn> {
171        Arc::new(DummyStreamFn)
172    }
173
174    #[test]
175    fn into_parts_returns_correct_values() {
176        let primary_model = ModelSpec::new("anthropic", "claude-sonnet-4-6");
177        let extra_model = ModelSpec::new("openai", "gpt-5.2");
178
179        let connections = ModelConnections::new(
180            ModelConnection::new(primary_model.clone(), dummy_stream()),
181            vec![ModelConnection::new(extra_model.clone(), dummy_stream())],
182        );
183
184        let (model, _stream_fn, extras) = connections.into_parts();
185        assert_eq!(model, primary_model);
186        assert_eq!(extras.len(), 1);
187        assert_eq!(extras[0].0, extra_model);
188    }
189
190    #[test]
191    fn model_connection_getters() {
192        let model = ModelSpec::new("test", "test-model");
193        let stream = dummy_stream();
194        let conn = ModelConnection::new(model.clone(), Arc::clone(&stream));
195
196        assert_eq!(conn.model_spec(), &model);
197        // stream_fn() returns a clone of the Arc
198        let sf = conn.stream_fn();
199        assert!(Arc::ptr_eq(&sf, &stream));
200    }
201
202    #[test]
203    fn empty_extras() {
204        let connections = ModelConnections::new(
205            ModelConnection::new(
206                ModelSpec::new("anthropic", "claude-sonnet-4-6"),
207                dummy_stream(),
208            ),
209            vec![],
210        );
211
212        assert_eq!(connections.extra_models().len(), 0);
213        assert_eq!(
214            connections.primary_model(),
215            &ModelSpec::new("anthropic", "claude-sonnet-4-6")
216        );
217    }
218
219    #[test]
220    fn all_extras_are_duplicates_of_primary() {
221        let primary = ModelSpec::new("anthropic", "claude-sonnet-4-6");
222        let connections = ModelConnections::new(
223            ModelConnection::new(primary.clone(), dummy_stream()),
224            vec![
225                ModelConnection::new(primary.clone(), dummy_stream()),
226                ModelConnection::new(primary, dummy_stream()),
227            ],
228        );
229
230        // All extras match primary, so they should be filtered out
231        assert_eq!(connections.extra_models().len(), 0);
232    }
233
234    #[test]
235    fn model_connections_keep_primary_first_and_deduplicate_extras() {
236        let connections = ModelConnections::new(
237            ModelConnection::new(
238                ModelSpec::new("anthropic", "claude-sonnet-4-6"),
239                dummy_stream(),
240            ),
241            vec![
242                ModelConnection::new(
243                    ModelSpec::new("anthropic", "claude-sonnet-4-6"),
244                    dummy_stream(),
245                ),
246                ModelConnection::new(ModelSpec::new("openai", "gpt-5.2"), dummy_stream()),
247                ModelConnection::new(ModelSpec::new("openai", "gpt-5.2"), dummy_stream()),
248                ModelConnection::new(ModelSpec::new("local", "SmolLM3-3B-Q4_K_M"), dummy_stream()),
249            ],
250        );
251
252        assert_eq!(
253            connections.primary_model(),
254            &ModelSpec::new("anthropic", "claude-sonnet-4-6")
255        );
256        assert_eq!(connections.extra_models().len(), 2);
257        assert_eq!(
258            connections.extra_models()[0].0,
259            ModelSpec::new("openai", "gpt-5.2")
260        );
261        assert_eq!(
262            connections.extra_models()[1].0,
263            ModelSpec::new("local", "SmolLM3-3B-Q4_K_M")
264        );
265    }
266
267    #[test]
268    fn builder_primary_only() {
269        let connections = ModelConnections::builder()
270            .primary(ModelConnection::new(
271                ModelSpec::new("anthropic", "claude-sonnet-4-6"),
272                dummy_stream(),
273            ))
274            .build();
275
276        assert_eq!(
277            connections.primary_model(),
278            &ModelSpec::new("anthropic", "claude-sonnet-4-6")
279        );
280        assert_eq!(connections.extra_models().len(), 0);
281    }
282
283    #[test]
284    fn builder_with_fallbacks() {
285        let connections = ModelConnections::builder()
286            .primary(ModelConnection::new(
287                ModelSpec::new("anthropic", "claude-sonnet-4-6"),
288                dummy_stream(),
289            ))
290            .fallback(ModelConnection::new(
291                ModelSpec::new("openai", "gpt-5.2"),
292                dummy_stream(),
293            ))
294            .fallback(ModelConnection::new(
295                ModelSpec::new("local", "SmolLM3-3B-Q4_K_M"),
296                dummy_stream(),
297            ))
298            .build();
299
300        assert_eq!(connections.extra_models().len(), 2);
301        assert_eq!(
302            connections.extra_models()[0].0,
303            ModelSpec::new("openai", "gpt-5.2")
304        );
305    }
306
307    #[test]
308    #[should_panic(expected = "primary connection is required")]
309    fn builder_panics_without_primary() {
310        let _ = ModelConnections::builder().build();
311    }
312}