Skip to main content

rustic_ai/
model_config.rs

1use std::collections::{HashMap, HashSet};
2
3use serde::Deserialize;
4use serde_json::{Map, Value};
5
6#[derive(Clone, Debug, PartialEq)]
7pub struct ResolvedModelConfig {
8    pub primary: String,
9    pub backup: Option<String>,
10    pub retry_limit: u32,
11    pub failover_on: HashSet<String>,
12    pub settings: Map<String, Value>,
13}
14
15#[derive(Clone, Debug, PartialEq, Deserialize)]
16pub struct CircuitBreakerConfig {
17    pub failure_threshold: u32,
18    pub recovery_timeout: u32,
19    pub window: u32,
20    pub trigger_on: Vec<String>,
21}
22
23impl Default for CircuitBreakerConfig {
24    fn default() -> Self {
25        Self {
26            failure_threshold: 3,
27            recovery_timeout: 60,
28            window: 300,
29            trigger_on: vec![
30                "http_403".to_string(),
31                "http_401".to_string(),
32                "connect_error".to_string(),
33                "http_5xx".to_string(),
34            ],
35        }
36    }
37}
38
39pub trait ModelConfigResolver: Send + Sync {
40    fn resolve_model_config(
41        &self,
42        agent_name: &str,
43        requested_model: Option<&str>,
44        environment: Option<&str>,
45    ) -> ResolvedModelConfig;
46
47    fn resolve_utility_config(
48        &self,
49        utility_name: &str,
50        environment: Option<&str>,
51    ) -> ResolvedModelConfig;
52
53    fn circuit_breaker_config(&self, _environment: Option<&str>) -> CircuitBreakerConfig {
54        CircuitBreakerConfig::default()
55    }
56}
57
58#[derive(Clone, Debug, Default, PartialEq)]
59pub struct ModelConfigEntry {
60    pub primary: Option<String>,
61    pub backup: Option<String>,
62    pub retry_limit: Option<u32>,
63    pub failover_on: Option<HashSet<String>>,
64    pub settings: Option<Map<String, Value>>,
65}
66
67impl ModelConfigEntry {
68    pub fn new(primary: impl Into<String>) -> Self {
69        Self {
70            primary: Some(primary.into()),
71            ..Default::default()
72        }
73    }
74
75    pub fn backup(mut self, backup: impl Into<String>) -> Self {
76        self.backup = Some(backup.into());
77        self
78    }
79
80    pub fn retry_limit(mut self, retry_limit: u32) -> Self {
81        self.retry_limit = Some(retry_limit);
82        self
83    }
84
85    pub fn failover_on<I, S>(mut self, values: I) -> Self
86    where
87        I: IntoIterator<Item = S>,
88        S: Into<String>,
89    {
90        let set = values.into_iter().map(Into::into).collect::<HashSet<_>>();
91        self.failover_on = Some(set);
92        self
93    }
94
95    pub fn setting(mut self, key: impl Into<String>, value: Value) -> Self {
96        self.settings
97            .get_or_insert_with(Map::new)
98            .insert(key.into(), value);
99        self
100    }
101}
102
103#[derive(Clone, Debug)]
104pub struct InMemoryResolver {
105    defaults: ModelConfigEntry,
106    agents: HashMap<String, ModelConfigEntry>,
107    environments: HashMap<String, ModelConfigEntry>,
108    utilities: HashMap<String, ModelConfigEntry>,
109    circuit_breaker: CircuitBreakerConfig,
110    fallback_model: Option<String>,
111}
112
113impl InMemoryResolver {
114    pub fn new(primary: impl Into<String>) -> Self {
115        Self {
116            defaults: ModelConfigEntry::new(primary),
117            agents: HashMap::new(),
118            environments: HashMap::new(),
119            utilities: HashMap::new(),
120            circuit_breaker: CircuitBreakerConfig::default(),
121            fallback_model: None,
122        }
123    }
124
125    pub fn with_defaults(mut self, defaults: ModelConfigEntry) -> Self {
126        self.defaults = defaults;
127        self
128    }
129
130    pub fn with_fallback_model(mut self, model: impl Into<String>) -> Self {
131        self.fallback_model = Some(model.into());
132        self
133    }
134
135    pub fn insert_agent(&mut self, name: impl Into<String>, entry: ModelConfigEntry) {
136        self.agents.insert(name.into(), entry);
137    }
138
139    pub fn insert_environment(&mut self, name: impl Into<String>, entry: ModelConfigEntry) {
140        self.environments.insert(name.into().to_lowercase(), entry);
141    }
142
143    pub fn insert_utility(&mut self, name: impl Into<String>, entry: ModelConfigEntry) {
144        self.utilities.insert(name.into(), entry);
145    }
146
147    pub fn set_circuit_breaker(&mut self, config: CircuitBreakerConfig) {
148        self.circuit_breaker = config;
149    }
150
151    fn resolve_entries(
152        &self,
153        name: &str,
154        environment: Option<&str>,
155        map: &HashMap<String, ModelConfigEntry>,
156    ) -> ModelConfigEntry {
157        let env_key = environment.map(|env| env.to_lowercase());
158        let env_entry = env_key
159            .as_ref()
160            .and_then(|key| self.environments.get(key))
161            .cloned()
162            .unwrap_or_default();
163        let name_entry = map.get(name).cloned().unwrap_or_default();
164        merge_entries(&[self.defaults.clone(), env_entry, name_entry])
165    }
166
167    fn build_resolved(
168        &self,
169        merged: ModelConfigEntry,
170        requested_model: Option<&str>,
171    ) -> ResolvedModelConfig {
172        let primary = requested_model
173            .map(str::to_string)
174            .or(merged.primary)
175            .or_else(|| self.fallback_model.clone())
176            .unwrap_or_default();
177        let retry_limit = merged.retry_limit.unwrap_or(0);
178        let failover_on = merged.failover_on.unwrap_or_default();
179        let settings = merged.settings.unwrap_or_default();
180
181        ResolvedModelConfig {
182            primary,
183            backup: merged.backup,
184            retry_limit,
185            failover_on,
186            settings,
187        }
188    }
189}
190
191impl ModelConfigResolver for InMemoryResolver {
192    fn resolve_model_config(
193        &self,
194        agent_name: &str,
195        requested_model: Option<&str>,
196        environment: Option<&str>,
197    ) -> ResolvedModelConfig {
198        let merged = self.resolve_entries(agent_name, environment, &self.agents);
199        self.build_resolved(merged, requested_model)
200    }
201
202    fn resolve_utility_config(
203        &self,
204        utility_name: &str,
205        environment: Option<&str>,
206    ) -> ResolvedModelConfig {
207        let merged = self.resolve_entries(utility_name, environment, &self.utilities);
208        self.build_resolved(merged, None)
209    }
210
211    fn circuit_breaker_config(&self, _environment: Option<&str>) -> CircuitBreakerConfig {
212        self.circuit_breaker.clone()
213    }
214}
215
216fn merge_entries(entries: &[ModelConfigEntry]) -> ModelConfigEntry {
217    let mut merged = ModelConfigEntry::default();
218    for entry in entries {
219        if let Some(primary) = &entry.primary {
220            merged.primary = Some(primary.clone());
221        }
222        if let Some(backup) = &entry.backup {
223            merged.backup = Some(backup.clone());
224        }
225        if let Some(retry_limit) = entry.retry_limit {
226            merged.retry_limit = Some(retry_limit);
227        }
228        if let Some(failover_on) = &entry.failover_on {
229            merged.failover_on = Some(failover_on.clone());
230        }
231        if let Some(settings) = &entry.settings {
232            let merged_settings = merged.settings.get_or_insert_with(Map::new);
233            for (key, value) in settings {
234                merged_settings.insert(key.clone(), value.clone());
235            }
236        }
237    }
238    merged
239}