swink_agent/fallback.rs
1//! Model fallback configuration.
2//!
3//! [`ModelFallback`] defines an ordered list of fallback models to try when
4//! the primary model exhausts its retry budget. Each entry pairs a
5//! [`ModelSpec`] with its corresponding [`StreamFn`], allowing fallback
6//! across providers.
7
8use std::sync::Arc;
9
10use crate::stream::StreamFn;
11use crate::types::ModelSpec;
12
13/// An ordered sequence of fallback models to attempt when the primary model
14/// (and its retries) are exhausted.
15///
16/// The agent tries each model in order, applying the configured
17/// [`RetryStrategy`](crate::RetryStrategy) independently for each model.
18/// When all fallback models are also exhausted the error propagates normally.
19///
20/// # Example
21///
22/// ```rust,no_run
23/// use swink_agent::{ModelFallback, ModelSpec};
24/// # use std::sync::Arc;
25/// # fn make_stream_fn() -> Arc<dyn swink_agent::StreamFn> { todo!() }
26///
27/// let fallback = ModelFallback::new(vec![
28/// (ModelSpec::new("openai", "gpt-4o-mini"), make_stream_fn()),
29/// (ModelSpec::new("anthropic", "claude-3-haiku-20240307"), make_stream_fn()),
30/// ]);
31/// ```
32#[derive(Clone)]
33pub struct ModelFallback {
34 models: Vec<(ModelSpec, Arc<dyn StreamFn>)>,
35}
36
37impl ModelFallback {
38 /// Create a new fallback chain from an ordered list of model/stream pairs.
39 #[must_use]
40 pub fn new(models: Vec<(ModelSpec, Arc<dyn StreamFn>)>) -> Self {
41 Self { models }
42 }
43
44 /// Returns the fallback models in order.
45 #[must_use]
46 pub fn models(&self) -> &[(ModelSpec, Arc<dyn StreamFn>)] {
47 &self.models
48 }
49
50 /// Returns `true` if the fallback chain is empty.
51 #[must_use]
52 pub fn is_empty(&self) -> bool {
53 self.models.is_empty()
54 }
55
56 /// Returns the number of fallback models.
57 #[must_use]
58 pub fn len(&self) -> usize {
59 self.models.len()
60 }
61}
62
63impl std::fmt::Debug for ModelFallback {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("ModelFallback")
66 .field(
67 "models",
68 &self
69 .models
70 .iter()
71 .map(|(m, _)| format!("{}:{}", m.provider, m.model_id))
72 .collect::<Vec<_>>(),
73 )
74 .finish()
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn empty_fallback() {
84 let fb = ModelFallback::new(vec![]);
85 assert!(fb.is_empty());
86 assert_eq!(fb.len(), 0);
87 assert!(fb.models().is_empty());
88 }
89
90 #[test]
91 fn debug_format() {
92 let fb = ModelFallback::new(vec![]);
93 let dbg = format!("{fb:?}");
94 assert!(dbg.contains("ModelFallback"));
95 }
96}