Skip to main content

stakai/types/
model.rs

1//! Unified model types for all AI providers
2//!
3//! This module provides a single `Model` struct that replaces provider-specific
4//! model enums (AnthropicModel, OpenAIModel, GeminiModel) and related types.
5
6use serde::{Deserialize, Serialize};
7
8/// Unified model representation across all providers
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
10pub struct Model {
11    /// Model identifier sent to the API (e.g., "claude-sonnet-4-5-20250929")
12    pub id: String,
13    /// Human-readable name (e.g., "Claude Sonnet 4.5")
14    pub name: String,
15    /// Provider identifier (e.g., "anthropic", "openai", "google")
16    pub provider: String,
17    /// Extended thinking/reasoning support
18    pub reasoning: bool,
19    /// Pricing per 1M tokens (None for custom/unknown models)
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub cost: Option<ModelCost>,
22    /// Token limits
23    pub limit: ModelLimit,
24    /// Release date (YYYY-MM-DD format)
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub release_date: Option<String>,
27}
28
29impl Model {
30    /// Create a new model with all fields
31    pub fn new(
32        id: impl Into<String>,
33        name: impl Into<String>,
34        provider: impl Into<String>,
35        reasoning: bool,
36        cost: Option<ModelCost>,
37        limit: ModelLimit,
38    ) -> Self {
39        Self {
40            id: id.into(),
41            name: name.into(),
42            provider: provider.into(),
43            reasoning,
44            cost,
45            limit,
46            release_date: None,
47        }
48    }
49
50    /// Create a custom model with minimal info (no pricing)
51    pub fn custom(id: impl Into<String>, provider: impl Into<String>) -> Self {
52        let id = id.into();
53        Self {
54            name: id.clone(),
55            id,
56            provider: provider.into(),
57            reasoning: false,
58            cost: None,
59            limit: ModelLimit::default(),
60            release_date: None,
61        }
62    }
63
64    /// Check if this model has pricing information
65    pub fn has_pricing(&self) -> bool {
66        self.cost.is_some()
67    }
68
69    /// Get the display name (name field)
70    pub fn display_name(&self) -> &str {
71        &self.name
72    }
73
74    /// Get the model ID used for API calls
75    pub fn model_id(&self) -> &str {
76        &self.id
77    }
78
79    /// Get the provider name
80    pub fn provider_name(&self) -> &str {
81        &self.provider
82    }
83}
84
85impl std::fmt::Display for Model {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        write!(f, "{}", self.name)
88    }
89}
90
91/// Pricing information per 1M tokens
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ModelCost {
94    /// Cost per 1M input tokens
95    pub input: f64,
96    /// Cost per 1M output tokens
97    pub output: f64,
98    /// Cost per 1M cached input tokens (if supported)
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub cache_read: Option<f64>,
101    /// Cost per 1M tokens written to cache (if supported)
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub cache_write: Option<f64>,
104}
105
106impl ModelCost {
107    /// Create a new cost struct with basic input/output pricing
108    pub fn new(input: f64, output: f64) -> Self {
109        Self {
110            input,
111            output,
112            cache_read: None,
113            cache_write: None,
114        }
115    }
116
117    /// Create a cost struct with cache pricing
118    pub fn with_cache(input: f64, output: f64, cache_read: f64, cache_write: f64) -> Self {
119        Self {
120            input,
121            output,
122            cache_read: Some(cache_read),
123            cache_write: Some(cache_write),
124        }
125    }
126
127    /// Calculate cost for given token counts (in tokens, not millions)
128    pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 {
129        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input;
130        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output;
131        input_cost + output_cost
132    }
133
134    /// Calculate cost with cache tokens
135    pub fn calculate_with_cache(
136        &self,
137        input_tokens: u64,
138        output_tokens: u64,
139        cache_read_tokens: u64,
140        cache_write_tokens: u64,
141    ) -> f64 {
142        let base_cost = self.calculate(input_tokens, output_tokens);
143        let cache_read_cost = self
144            .cache_read
145            .map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate)
146            .unwrap_or(0.0);
147        let cache_write_cost = self
148            .cache_write
149            .map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate)
150            .unwrap_or(0.0);
151        base_cost + cache_read_cost + cache_write_cost
152    }
153}
154
155/// Token limits for the model
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct ModelLimit {
158    /// Maximum context window size in tokens
159    pub context: u64,
160    /// Maximum output tokens
161    pub output: u64,
162}
163
164impl ModelLimit {
165    /// Create a new limit struct
166    pub fn new(context: u64, output: u64) -> Self {
167        Self { context, output }
168    }
169}
170
171impl Default for ModelLimit {
172    fn default() -> Self {
173        Self {
174            context: 128_000,
175            output: 8_192,
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_model_creation() {
186        let model = Model::new(
187            "claude-sonnet-4-5-20250929",
188            "Claude Sonnet 4.5",
189            "anthropic",
190            true,
191            Some(ModelCost::with_cache(3.0, 15.0, 0.30, 3.75)),
192            ModelLimit::new(200_000, 16_384),
193        );
194
195        assert_eq!(model.id, "claude-sonnet-4-5-20250929");
196        assert_eq!(model.name, "Claude Sonnet 4.5");
197        assert_eq!(model.provider, "anthropic");
198        assert!(model.reasoning);
199        assert!(model.has_pricing());
200    }
201
202    #[test]
203    fn test_custom_model() {
204        let model = Model::custom("llama3", "ollama");
205
206        assert_eq!(model.id, "llama3");
207        assert_eq!(model.name, "llama3");
208        assert_eq!(model.provider, "ollama");
209        assert!(!model.reasoning);
210        assert!(!model.has_pricing());
211    }
212
213    #[test]
214    fn test_cost_calculation() {
215        let cost = ModelCost::new(3.0, 15.0);
216
217        // 1000 input tokens, 500 output tokens
218        let total = cost.calculate(1000, 500);
219        // (1000/1M) * 3.0 + (500/1M) * 15.0 = 0.003 + 0.0075 = 0.0105
220        assert!((total - 0.0105).abs() < 0.0001);
221    }
222
223    #[test]
224    fn test_cost_with_cache() {
225        let cost = ModelCost::with_cache(3.0, 15.0, 0.30, 3.75);
226
227        let total = cost.calculate_with_cache(1000, 500, 2000, 1000);
228        // base: 0.0105
229        // cache_read: (2000/1M) * 0.30 = 0.0006
230        // cache_write: (1000/1M) * 3.75 = 0.00375
231        // total: 0.0105 + 0.0006 + 0.00375 = 0.01485
232        assert!((total - 0.01485).abs() < 0.0001);
233    }
234
235    #[test]
236    fn test_model_display() {
237        let model = Model::new(
238            "gpt-5",
239            "GPT-5",
240            "openai",
241            false,
242            None,
243            ModelLimit::default(),
244        );
245
246        assert_eq!(format!("{}", model), "GPT-5");
247    }
248
249    #[test]
250    fn test_serialization() {
251        let model = Model::new(
252            "claude-sonnet-4-5-20250929",
253            "Claude Sonnet 4.5",
254            "anthropic",
255            true,
256            Some(ModelCost::new(3.0, 15.0)),
257            ModelLimit::new(200_000, 16_384),
258        );
259
260        let json = serde_json::to_string(&model).unwrap();
261        assert!(json.contains("\"id\":\"claude-sonnet-4-5-20250929\""));
262        assert!(json.contains("\"provider\":\"anthropic\""));
263
264        let deserialized: Model = serde_json::from_str(&json).unwrap();
265        assert_eq!(model, deserialized);
266    }
267}