Skip to main content

systemprompt_models/services/
ai.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4const fn default_true() -> bool {
5    true
6}
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct AiConfig {
10    #[serde(default)]
11    pub default_provider: String,
12
13    #[serde(default)]
14    pub default_max_output_tokens: Option<u32>,
15
16    #[serde(default)]
17    pub sampling: SamplingConfig,
18
19    #[serde(default)]
20    pub providers: HashMap<String, AiProviderConfig>,
21
22    #[serde(default)]
23    pub tool_models: HashMap<String, ToolModelSettings>,
24
25    #[serde(default)]
26    pub mcp: McpConfig,
27
28    #[serde(default)]
29    pub history: HistoryConfig,
30}
31
32#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
33pub struct SamplingConfig {
34    #[serde(default)]
35    pub enable_smart_routing: bool,
36
37    #[serde(default)]
38    pub fallback_enabled: bool,
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
42pub struct McpConfig {
43    #[serde(default)]
44    pub auto_discover: bool,
45
46    /// Resilience policy applied to outbound MCP tool RPCs (timeouts, retry,
47    /// circuit breaker, bulkhead).
48    #[serde(default = "default_mcp_resilience")]
49    pub resilience: ResilienceSettings,
50}
51
52impl Default for McpConfig {
53    fn default() -> Self {
54        Self {
55            auto_discover: false,
56            resilience: default_mcp_resilience(),
57        }
58    }
59}
60
61/// MCP defaults: tool RPCs are bounded at 30s rather than the 60s AI default.
62fn default_mcp_resilience() -> ResilienceSettings {
63    ResilienceSettings {
64        request_timeout_ms: 30_000,
65        connect_timeout_ms: 5_000,
66        ..ResilienceSettings::default()
67    }
68}
69
70/// Per-dependency resilience policy: timeouts, retry, circuit breaker,
71/// bulkhead.
72///
73/// Plain serde data loaded from profile config (all values in milliseconds or
74/// counts). Translated into the runtime form consumed by the resilience
75/// primitives in `systemprompt-database`.
76#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
77pub struct ResilienceSettings {
78    /// Timeout for a single (non-streaming) attempt.
79    #[serde(default = "default_request_timeout")]
80    pub request_timeout_ms: u64,
81
82    /// Connection-establishment timeout.
83    #[serde(default = "default_resilience_connect_timeout")]
84    pub connect_timeout_ms: u64,
85
86    /// Maximum gap between two chunks of a streaming response.
87    #[serde(default = "default_stream_idle_timeout")]
88    pub stream_idle_timeout_ms: u64,
89
90    /// Maximum attempts including the first. `1` disables retries.
91    #[serde(default = "default_retry_attempts")]
92    pub retry_attempts: u32,
93
94    /// Backoff before the first retry; doubles each subsequent attempt.
95    #[serde(default = "default_retry_base_delay")]
96    pub retry_base_delay_ms: u64,
97
98    /// Upper bound on a single backoff delay.
99    #[serde(default = "default_retry_max_delay")]
100    pub retry_max_delay_ms: u64,
101
102    /// Consecutive failures that trip the circuit breaker open.
103    #[serde(default = "default_breaker_threshold")]
104    pub breaker_failure_threshold: u32,
105
106    /// How long the breaker stays open before allowing a half-open probe.
107    #[serde(default = "default_breaker_cooldown")]
108    pub breaker_open_cooldown_ms: u64,
109
110    /// Concurrent probes admitted while the breaker is half-open.
111    #[serde(default = "default_half_open_probes")]
112    pub breaker_half_open_probes: u32,
113
114    /// Maximum in-flight calls to the dependency; further calls fast-fail.
115    #[serde(default = "default_max_concurrent")]
116    pub max_concurrent: usize,
117}
118
119impl Default for ResilienceSettings {
120    fn default() -> Self {
121        Self {
122            request_timeout_ms: default_request_timeout(),
123            connect_timeout_ms: default_resilience_connect_timeout(),
124            stream_idle_timeout_ms: default_stream_idle_timeout(),
125            retry_attempts: default_retry_attempts(),
126            retry_base_delay_ms: default_retry_base_delay(),
127            retry_max_delay_ms: default_retry_max_delay(),
128            breaker_failure_threshold: default_breaker_threshold(),
129            breaker_open_cooldown_ms: default_breaker_cooldown(),
130            breaker_half_open_probes: default_half_open_probes(),
131            max_concurrent: default_max_concurrent(),
132        }
133    }
134}
135
136const fn default_request_timeout() -> u64 {
137    60_000
138}
139
140const fn default_resilience_connect_timeout() -> u64 {
141    10_000
142}
143
144const fn default_stream_idle_timeout() -> u64 {
145    60_000
146}
147
148const fn default_retry_attempts() -> u32 {
149    3
150}
151
152const fn default_retry_base_delay() -> u64 {
153    200
154}
155
156const fn default_retry_max_delay() -> u64 {
157    10_000
158}
159
160const fn default_breaker_threshold() -> u32 {
161    5
162}
163
164const fn default_breaker_cooldown() -> u64 {
165    30_000
166}
167
168const fn default_half_open_probes() -> u32 {
169    1
170}
171
172const fn default_max_concurrent() -> usize {
173    16
174}
175
176#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
177pub struct HistoryConfig {
178    #[serde(default = "default_retention_days")]
179    pub retention_days: u32,
180
181    #[serde(default)]
182    pub log_tool_executions: bool,
183}
184
185impl Default for HistoryConfig {
186    fn default() -> Self {
187        Self {
188            retention_days: default_retention_days(),
189            log_tool_executions: false,
190        }
191    }
192}
193
194const fn default_retention_days() -> u32 {
195    30
196}
197
198#[derive(Debug, Clone, Default, Serialize, Deserialize)]
199pub struct ToolModelSettings {
200    pub model: String,
201
202    #[serde(default)]
203    pub max_output_tokens: Option<u32>,
204}
205
206#[allow(clippy::struct_excessive_bools)]
207#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
208pub struct ModelCapabilities {
209    #[serde(default)]
210    pub vision: bool,
211
212    #[serde(default)]
213    pub audio_input: bool,
214
215    #[serde(default)]
216    pub video_input: bool,
217
218    #[serde(default)]
219    pub image_generation: bool,
220
221    #[serde(default)]
222    pub audio_generation: bool,
223
224    #[serde(default)]
225    pub streaming: bool,
226
227    #[serde(default)]
228    pub tools: bool,
229
230    #[serde(default)]
231    pub structured_output: bool,
232
233    #[serde(default)]
234    pub system_prompts: bool,
235
236    #[serde(default)]
237    pub image_resolution_config: bool,
238}
239
240#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
241pub struct ModelLimits {
242    #[serde(default)]
243    pub context_window: u32,
244
245    #[serde(default)]
246    pub max_output_tokens: u32,
247}
248
249#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
250pub struct ModelPricing {
251    #[serde(default)]
252    pub input_per_million: f64,
253
254    #[serde(default)]
255    pub output_per_million: f64,
256
257    #[serde(default)]
258    pub per_image_cents: Option<f64>,
259}
260
261#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
262pub struct ModelDefinition {
263    #[serde(default)]
264    pub capabilities: ModelCapabilities,
265
266    #[serde(default)]
267    pub limits: ModelLimits,
268
269    #[serde(default)]
270    pub pricing: ModelPricing,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct AiProviderConfig {
275    #[serde(default = "default_true")]
276    pub enabled: bool,
277
278    #[serde(default)]
279    pub api_key: String,
280
281    #[serde(default)]
282    pub endpoint: Option<String>,
283
284    #[serde(default)]
285    pub default_model: String,
286
287    #[serde(default)]
288    pub default_image_model: String,
289
290    #[serde(default)]
291    pub default_image_resolution: String,
292
293    #[serde(default)]
294    pub google_search_enabled: bool,
295
296    #[serde(default)]
297    pub models: HashMap<String, ModelDefinition>,
298
299    /// Resilience policy applied to outbound AI provider calls (timeouts,
300    /// retry, circuit breaker, bulkhead).
301    #[serde(default)]
302    pub resilience: ResilienceSettings,
303}
304
305impl Default for AiProviderConfig {
306    fn default() -> Self {
307        Self {
308            enabled: true,
309            api_key: String::new(),
310            endpoint: None,
311            default_model: String::new(),
312            default_image_model: String::new(),
313            default_image_resolution: String::new(),
314            google_search_enabled: false,
315            models: HashMap::new(),
316            resilience: ResilienceSettings::default(),
317        }
318    }
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ToolModelConfig {
323    pub provider: String,
324    pub model: String,
325
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub max_output_tokens: Option<u32>,
328
329    #[serde(skip_serializing_if = "Option::is_none")]
330    pub thinking_level: Option<String>,
331}