Skip to main content

swink_agent/
model_catalog.rs

1use std::sync::OnceLock;
2
3use serde::Deserialize;
4
5use crate::ModelSpec;
6use crate::types::{Cost, ModelCapabilities, Usage};
7
8#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ProviderKind {
11    Remote,
12    Local,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum AuthMode {
18    Bearer,
19    ApiKeyHeader,
20    AwsSigv4,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum ApiVersion {
26    V1,
27    V1beta,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum PresetCapability {
33    Text,
34    Tools,
35    Thinking,
36    ImagesIn,
37    Streaming,
38    StructuredOutput,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum PresetStatus {
44    Ga,
45    Preview,
46}
47
48#[derive(Debug, Clone, PartialEq, Deserialize)]
49pub struct PresetCatalog {
50    pub id: String,
51    pub display_name: String,
52    pub group: Option<String>,
53    pub model_id: String,
54    pub api_version: Option<ApiVersion>,
55    #[serde(default)]
56    pub capabilities: Vec<PresetCapability>,
57    pub status: Option<PresetStatus>,
58    pub context_window_tokens: Option<u64>,
59    pub max_output_tokens: Option<u64>,
60    #[serde(default)]
61    pub include_by_default: bool,
62    pub repo_id: Option<String>,
63    pub filename: Option<String>,
64    #[serde(default)]
65    pub cost_per_million_input: Option<f64>,
66    #[serde(default)]
67    pub cost_per_million_output: Option<f64>,
68    #[serde(default)]
69    pub cost_per_million_cache_read: Option<f64>,
70    #[serde(default)]
71    pub cost_per_million_cache_write: Option<f64>,
72}
73
74#[derive(Debug, Clone, PartialEq, Deserialize)]
75pub struct ProviderCatalog {
76    pub key: String,
77    pub display_name: String,
78    pub kind: ProviderKind,
79    pub auth_mode: Option<AuthMode>,
80    pub credential_env_var: Option<String>,
81    pub base_url_env_var: Option<String>,
82    pub default_base_url: Option<String>,
83    #[serde(default)]
84    pub requires_base_url: bool,
85    pub region_env_var: Option<String>,
86    #[serde(default)]
87    pub presets: Vec<PresetCatalog>,
88}
89
90impl ProviderCatalog {
91    #[must_use]
92    pub fn preset(&self, preset_id: &str) -> Option<&PresetCatalog> {
93        self.presets.iter().find(|preset| preset.id == preset_id)
94    }
95}
96
97#[derive(Debug, Clone, PartialEq, Deserialize)]
98pub struct ModelCatalog {
99    #[serde(default)]
100    pub providers: Vec<ProviderCatalog>,
101}
102
103impl ModelCatalog {
104    #[must_use]
105    pub fn provider(&self, provider_key: &str) -> Option<&ProviderCatalog> {
106        self.providers
107            .iter()
108            .find(|provider| provider.key == provider_key)
109    }
110
111    /// Search across all providers for a preset matching the given `model_id`.
112    #[must_use]
113    pub fn find_preset_by_model_id(&self, model_id: &str) -> Option<CatalogPreset> {
114        for provider in &self.providers {
115            for preset in &provider.presets {
116                if preset.model_id == model_id {
117                    return self.preset(&provider.key, &preset.id);
118                }
119            }
120        }
121        None
122    }
123
124    #[must_use]
125    pub fn preset(&self, provider_key: &str, preset_id: &str) -> Option<CatalogPreset> {
126        let provider = self.provider(provider_key)?;
127        let preset = provider.preset(preset_id)?;
128        Some(CatalogPreset {
129            provider_key: provider.key.clone(),
130            provider_display_name: provider.display_name.clone(),
131            provider_kind: provider.kind.clone(),
132            preset_id: preset.id.clone(),
133            display_name: preset.display_name.clone(),
134            group: preset.group.clone(),
135            model_id: preset.model_id.clone(),
136            api_version: preset.api_version.clone(),
137            capabilities: preset.capabilities.clone(),
138            status: preset.status.clone(),
139            context_window_tokens: preset.context_window_tokens,
140            max_output_tokens: preset.max_output_tokens,
141            auth_mode: provider.auth_mode.clone(),
142            credential_env_var: provider.credential_env_var.clone(),
143            base_url_env_var: provider.base_url_env_var.clone(),
144            default_base_url: provider.default_base_url.clone(),
145            requires_base_url: provider.requires_base_url,
146            region_env_var: provider.region_env_var.clone(),
147            include_by_default: preset.include_by_default,
148            repo_id: preset.repo_id.clone(),
149            filename: preset.filename.clone(),
150            cost_per_million_input: preset.cost_per_million_input,
151            cost_per_million_output: preset.cost_per_million_output,
152            cost_per_million_cache_read: preset.cost_per_million_cache_read,
153            cost_per_million_cache_write: preset.cost_per_million_cache_write,
154        })
155    }
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct CatalogPreset {
160    pub provider_key: String,
161    pub provider_display_name: String,
162    pub provider_kind: ProviderKind,
163    pub preset_id: String,
164    pub display_name: String,
165    pub group: Option<String>,
166    pub model_id: String,
167    pub api_version: Option<ApiVersion>,
168    pub capabilities: Vec<PresetCapability>,
169    pub status: Option<PresetStatus>,
170    pub context_window_tokens: Option<u64>,
171    pub max_output_tokens: Option<u64>,
172    pub auth_mode: Option<AuthMode>,
173    pub credential_env_var: Option<String>,
174    pub base_url_env_var: Option<String>,
175    pub default_base_url: Option<String>,
176    pub requires_base_url: bool,
177    pub region_env_var: Option<String>,
178    pub include_by_default: bool,
179    pub repo_id: Option<String>,
180    pub filename: Option<String>,
181    pub cost_per_million_input: Option<f64>,
182    pub cost_per_million_output: Option<f64>,
183    pub cost_per_million_cache_read: Option<f64>,
184    pub cost_per_million_cache_write: Option<f64>,
185}
186
187impl CatalogPreset {
188    /// Build a [`ModelCapabilities`] from the catalog's capability list and
189    /// token limits.
190    #[must_use]
191    pub fn model_capabilities(&self) -> ModelCapabilities {
192        let has = |cap: &PresetCapability| self.capabilities.contains(cap);
193        ModelCapabilities {
194            supports_thinking: has(&PresetCapability::Thinking),
195            supports_vision: has(&PresetCapability::ImagesIn),
196            supports_tool_use: has(&PresetCapability::Tools),
197            supports_streaming: has(&PresetCapability::Streaming),
198            supports_structured_output: has(&PresetCapability::StructuredOutput),
199            max_context_window: self.context_window_tokens,
200            max_output_tokens: self.max_output_tokens,
201        }
202    }
203
204    /// Create a [`ModelSpec`] pre-populated with capabilities from the catalog.
205    #[must_use]
206    pub fn model_spec(&self) -> ModelSpec {
207        ModelSpec::new(&self.provider_key, &self.model_id)
208            .with_capabilities(self.model_capabilities())
209    }
210}
211
212#[must_use]
213pub fn model_catalog() -> &'static ModelCatalog {
214    static MODEL_CATALOG: OnceLock<ModelCatalog> = OnceLock::new();
215    MODEL_CATALOG.get_or_init(|| {
216        toml::from_str(include_str!("model_catalog.toml"))
217            .expect("src/model_catalog.toml must be valid TOML")
218    })
219}
220
221/// Compute monetary cost from token usage using catalog pricing data.
222///
223/// Looks up the model by `model_id` across all providers. Returns
224/// `Cost::default()` if the model is not found or has no pricing data.
225#[must_use]
226pub fn calculate_cost(model_id: &str, usage: &Usage) -> Cost {
227    let Some(preset) = model_catalog().find_preset_by_model_id(model_id) else {
228        return Cost::default();
229    };
230
231    #[allow(clippy::cast_precision_loss)] // token counts fit comfortably in f64
232    let per_m = |tokens: u64, rate: Option<f64>| -> f64 {
233        rate.map_or(0.0, |r| tokens as f64 * r / 1_000_000.0)
234    };
235
236    let input = per_m(usage.input, preset.cost_per_million_input);
237    let output = per_m(usage.output, preset.cost_per_million_output);
238    let cache_read = per_m(usage.cache_read, preset.cost_per_million_cache_read);
239    let cache_write = per_m(usage.cache_write, preset.cost_per_million_cache_write);
240
241    Cost {
242        input,
243        output,
244        cache_read,
245        cache_write,
246        total: input + output + cache_read + cache_write,
247        ..Cost::default()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn catalog_loads_grouped_presets() {
257        let catalog = model_catalog();
258        let anthropic = catalog.provider("anthropic").unwrap();
259        assert_eq!(anthropic.kind, ProviderKind::Remote);
260        assert!(anthropic.preset("sonnet_46").is_some());
261
262        let local = catalog.provider("local").unwrap();
263        assert_eq!(local.kind, ProviderKind::Local);
264        assert!(!local.preset("smollm3_3b").unwrap().include_by_default);
265        assert!(local.preset("gemma4_e2b").unwrap().include_by_default);
266        assert_eq!(
267            local.preset("gemma4_e2b").unwrap().context_window_tokens,
268            Some(128_000)
269        );
270
271        let google = catalog.provider("google").unwrap();
272        assert_eq!(google.kind, ProviderKind::Remote);
273        assert_eq!(google.presets.len(), 4);
274
275        let bedrock = catalog.provider("bedrock").unwrap();
276        assert_eq!(bedrock.auth_mode, Some(AuthMode::AwsSigv4));
277        assert_eq!(bedrock.region_env_var.as_deref(), Some("AWS_REGION"));
278    }
279
280    #[test]
281    fn preset_lookup_returns_provider_metadata() {
282        let preset = model_catalog().preset("openai", "gpt_5_4").unwrap();
283        assert_eq!(preset.display_name, "OpenAI GPT-5.4");
284        assert_eq!(preset.model_id, "gpt-5.4");
285        assert_eq!(preset.credential_env_var.as_deref(), Some("OPENAI_API_KEY"));
286        assert_eq!(preset.base_url_env_var.as_deref(), Some("OPENAI_BASE_URL"));
287        assert_eq!(preset.auth_mode, Some(AuthMode::Bearer));
288    }
289
290    #[test]
291    fn google_preset_lookup_returns_extended_metadata() {
292        let preset = model_catalog().preset("google", "gemini_3_flash").unwrap();
293        assert_eq!(preset.display_name, "Google Gemini 3 Flash");
294        assert_eq!(preset.model_id, "gemini-3-flash-preview");
295        assert_eq!(preset.api_version, Some(ApiVersion::V1beta));
296        assert_eq!(preset.status, Some(PresetStatus::Preview));
297        assert_eq!(
298            preset.capabilities,
299            vec![
300                PresetCapability::Text,
301                PresetCapability::Tools,
302                PresetCapability::Thinking,
303                PresetCapability::ImagesIn,
304                PresetCapability::Streaming,
305                PresetCapability::StructuredOutput,
306            ]
307        );
308        assert_eq!(preset.context_window_tokens, Some(1_000_000));
309        assert_eq!(preset.max_output_tokens, Some(65536));
310        assert_eq!(preset.credential_env_var.as_deref(), Some("GEMINI_API_KEY"));
311        assert_eq!(preset.base_url_env_var.as_deref(), Some("GEMINI_BASE_URL"));
312    }
313
314    #[test]
315    fn azure_and_bedrock_presets_expose_provider_specific_metadata() {
316        let azure = model_catalog().preset("azure", "gpt_4o").unwrap();
317        assert_eq!(azure.auth_mode, Some(AuthMode::ApiKeyHeader));
318        assert!(azure.requires_base_url);
319        assert_eq!(azure.base_url_env_var.as_deref(), Some("AZURE_BASE_URL"));
320
321        let bedrock = model_catalog()
322            .preset("bedrock", "anthropic_claude_sonnet_45")
323            .unwrap();
324        assert_eq!(bedrock.auth_mode, Some(AuthMode::AwsSigv4));
325        assert_eq!(bedrock.region_env_var.as_deref(), Some("AWS_REGION"));
326        assert_eq!(bedrock.group.as_deref(), Some("anthropic"));
327    }
328
329    #[test]
330    fn anthropic_preset_model_capabilities() {
331        let preset = model_catalog().preset("anthropic", "sonnet_46").unwrap();
332        let caps = preset.model_capabilities();
333        assert!(caps.supports_thinking);
334        assert!(caps.supports_vision);
335        assert!(caps.supports_tool_use);
336        assert!(caps.supports_streaming);
337        assert!(caps.supports_structured_output);
338        assert_eq!(caps.max_context_window, Some(200_000));
339        assert_eq!(caps.max_output_tokens, Some(16384));
340    }
341
342    #[test]
343    fn model_spec_carries_capabilities_from_preset() {
344        let preset = model_catalog().preset("anthropic", "opus_46").unwrap();
345        let spec = preset.model_spec();
346        let caps = spec.capabilities();
347        assert!(caps.supports_thinking);
348        assert!(caps.supports_vision);
349        assert!(caps.supports_tool_use);
350        assert_eq!(caps.max_context_window, Some(200_000));
351        assert_eq!(caps.max_output_tokens, Some(32768));
352    }
353
354    #[test]
355    fn openai_preset_no_thinking() {
356        let preset = model_catalog().preset("openai", "gpt_5_4_mini").unwrap();
357        let caps = preset.model_capabilities();
358        assert!(!caps.supports_thinking);
359        assert!(caps.supports_tool_use);
360        assert!(caps.supports_vision);
361        assert!(caps.supports_streaming);
362        assert!(caps.supports_structured_output);
363        assert_eq!(caps.max_context_window, Some(400_000));
364    }
365
366    #[test]
367    fn local_preset_minimal_capabilities() {
368        let preset = model_catalog().preset("local", "smollm3_3b").unwrap();
369        let caps = preset.model_capabilities();
370        assert!(!caps.supports_thinking);
371        assert!(!caps.supports_vision);
372        assert!(!caps.supports_tool_use);
373        assert!(caps.supports_streaming);
374        assert!(!caps.supports_structured_output);
375        assert_eq!(caps.max_context_window, Some(8192));
376        assert_eq!(caps.max_output_tokens, Some(2048));
377    }
378
379    #[test]
380    fn bedrock_preset_capabilities() {
381        let preset = model_catalog()
382            .preset("bedrock", "anthropic_claude_sonnet_45")
383            .unwrap();
384        let caps = preset.model_capabilities();
385        assert!(caps.supports_thinking);
386        assert!(caps.supports_vision);
387        assert!(caps.supports_tool_use);
388        assert!(caps.supports_streaming);
389        assert!(!caps.supports_structured_output);
390    }
391
392    #[test]
393    fn manual_model_spec_defaults_to_no_capabilities() {
394        let spec = crate::ModelSpec::new("custom", "my-model");
395        let caps = spec.capabilities();
396        assert!(!caps.supports_thinking);
397        assert!(!caps.supports_vision);
398        assert!(!caps.supports_tool_use);
399        assert!(!caps.supports_streaming);
400        assert!(!caps.supports_structured_output);
401        assert_eq!(caps.max_context_window, None);
402        assert_eq!(caps.max_output_tokens, None);
403    }
404
405    // --- US4: Cost calculation tests ---
406
407    fn usage(input: u64, output: u64, cache_read: u64, cache_write: u64) -> crate::types::Usage {
408        crate::types::Usage {
409            input,
410            output,
411            cache_read,
412            cache_write,
413            total: input + output + cache_read + cache_write,
414            ..Default::default()
415        }
416    }
417
418    #[test]
419    fn calculate_cost_known_model() {
420        // Sonnet 4.6: input=$3/M, output=$15/M
421        let cost = calculate_cost("claude-sonnet-4-6", &usage(1_000_000, 500_000, 0, 0));
422        assert!((cost.input - 3.0).abs() < 0.001);
423        assert!((cost.output - 7.5).abs() < 0.001);
424        assert!((cost.total - 10.5).abs() < 0.001);
425    }
426
427    #[test]
428    fn calculate_cost_unknown_model() {
429        let cost = calculate_cost("nonexistent-model-xyz", &usage(1_000_000, 1_000_000, 0, 0));
430        assert!((cost.input).abs() < 0.001);
431        assert!((cost.output).abs() < 0.001);
432        assert!((cost.total).abs() < 0.001);
433    }
434
435    #[test]
436    fn calculate_cost_zero_usage() {
437        let cost = calculate_cost("claude-sonnet-4-6", &usage(0, 0, 0, 0));
438        assert!((cost.total).abs() < 0.001);
439    }
440
441    #[test]
442    fn calculate_cost_cache_tokens() {
443        // Sonnet 4.6: cache_read=$0.30/M, cache_write=$3.75/M
444        let cost = calculate_cost("claude-sonnet-4-6", &usage(0, 0, 2_000_000, 1_000_000));
445        assert!((cost.cache_read - 0.60).abs() < 0.001);
446        assert!((cost.cache_write - 3.75).abs() < 0.001);
447        assert!((cost.total - 4.35).abs() < 0.001);
448    }
449
450    #[test]
451    fn calculate_cost_no_pricing_data() {
452        // Local model has no pricing fields
453        let cost = calculate_cost("SmolLM3-3B-Q4_K_M", &usage(1_000_000, 500_000, 0, 0));
454        assert!((cost.total).abs() < 0.001);
455    }
456
457    // --- US5: Capability introspection tests ---
458
459    #[test]
460    fn capabilities_from_catalog_preset() {
461        let preset = model_catalog().preset("anthropic", "sonnet_46").unwrap();
462        let caps = preset.model_capabilities();
463        assert!(caps.supports_thinking);
464        assert!(caps.supports_vision);
465        assert!(caps.supports_tool_use);
466        assert!(caps.supports_streaming);
467        assert!(caps.supports_structured_output);
468    }
469
470    #[test]
471    fn capabilities_context_window_and_output() {
472        let preset = model_catalog().preset("openai", "gpt_5_4").unwrap();
473        let caps = preset.model_capabilities();
474        assert_eq!(caps.max_context_window, Some(1_050_000));
475        assert_eq!(caps.max_output_tokens, Some(128_000));
476    }
477
478    #[test]
479    fn model_spec_carries_capabilities() {
480        let preset = model_catalog().preset("google", "gemini_3_flash").unwrap();
481        let spec = preset.model_spec();
482        let caps = spec.capabilities();
483        assert!(caps.supports_thinking);
484        assert!(caps.supports_vision);
485        assert!(caps.supports_tool_use);
486        assert_eq!(caps.max_context_window, Some(1_000_000));
487    }
488
489    #[test]
490    fn find_preset_by_model_id_works() {
491        let preset = model_catalog()
492            .find_preset_by_model_id("claude-sonnet-4-6")
493            .unwrap();
494        assert_eq!(preset.preset_id, "sonnet_46");
495        assert_eq!(preset.provider_key, "anthropic");
496    }
497
498    #[test]
499    fn find_preset_by_model_id_unknown_returns_none() {
500        assert!(
501            model_catalog()
502                .find_preset_by_model_id("nonexistent")
503                .is_none()
504        );
505    }
506}