1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
10pub struct Model {
11 pub id: String,
13 pub name: String,
15 pub provider: String,
17 pub reasoning: bool,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub cost: Option<ModelCost>,
22 pub limit: ModelLimit,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub release_date: Option<String>,
27}
28
29impl Model {
30 pub fn new(
32 id: impl Into<String>,
33 name: impl Into<String>,
34 provider: impl Into<String>,
35 reasoning: bool,
36 cost: Option<ModelCost>,
37 limit: ModelLimit,
38 ) -> Self {
39 Self {
40 id: id.into(),
41 name: name.into(),
42 provider: provider.into(),
43 reasoning,
44 cost,
45 limit,
46 release_date: None,
47 }
48 }
49
50 pub fn custom(id: impl Into<String>, provider: impl Into<String>) -> Self {
52 let id = id.into();
53 Self {
54 name: id.clone(),
55 id,
56 provider: provider.into(),
57 reasoning: false,
58 cost: None,
59 limit: ModelLimit::default(),
60 release_date: None,
61 }
62 }
63
64 pub fn has_pricing(&self) -> bool {
66 self.cost.is_some()
67 }
68
69 pub fn display_name(&self) -> &str {
71 &self.name
72 }
73
74 pub fn model_id(&self) -> &str {
76 &self.id
77 }
78
79 pub fn provider_name(&self) -> &str {
81 &self.provider
82 }
83}
84
85impl std::fmt::Display for Model {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 write!(f, "{}", self.name)
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ModelCost {
94 pub input: f64,
96 pub output: f64,
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub cache_read: Option<f64>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub cache_write: Option<f64>,
104}
105
106impl ModelCost {
107 pub fn new(input: f64, output: f64) -> Self {
109 Self {
110 input,
111 output,
112 cache_read: None,
113 cache_write: None,
114 }
115 }
116
117 pub fn with_cache(input: f64, output: f64, cache_read: f64, cache_write: f64) -> Self {
119 Self {
120 input,
121 output,
122 cache_read: Some(cache_read),
123 cache_write: Some(cache_write),
124 }
125 }
126
127 pub fn calculate(&self, input_tokens: u64, output_tokens: u64) -> f64 {
129 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input;
130 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output;
131 input_cost + output_cost
132 }
133
134 pub fn calculate_with_cache(
136 &self,
137 input_tokens: u64,
138 output_tokens: u64,
139 cache_read_tokens: u64,
140 cache_write_tokens: u64,
141 ) -> f64 {
142 let base_cost = self.calculate(input_tokens, output_tokens);
143 let cache_read_cost = self
144 .cache_read
145 .map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate)
146 .unwrap_or(0.0);
147 let cache_write_cost = self
148 .cache_write
149 .map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate)
150 .unwrap_or(0.0);
151 base_cost + cache_read_cost + cache_write_cost
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
157pub struct ModelLimit {
158 pub context: u64,
160 pub output: u64,
162}
163
164impl ModelLimit {
165 pub fn new(context: u64, output: u64) -> Self {
167 Self { context, output }
168 }
169}
170
171impl Default for ModelLimit {
172 fn default() -> Self {
173 Self {
174 context: 128_000,
175 output: 8_192,
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_model_creation() {
186 let model = Model::new(
187 "claude-sonnet-4-5-20250929",
188 "Claude Sonnet 4.5",
189 "anthropic",
190 true,
191 Some(ModelCost::with_cache(3.0, 15.0, 0.30, 3.75)),
192 ModelLimit::new(200_000, 16_384),
193 );
194
195 assert_eq!(model.id, "claude-sonnet-4-5-20250929");
196 assert_eq!(model.name, "Claude Sonnet 4.5");
197 assert_eq!(model.provider, "anthropic");
198 assert!(model.reasoning);
199 assert!(model.has_pricing());
200 }
201
202 #[test]
203 fn test_custom_model() {
204 let model = Model::custom("llama3", "ollama");
205
206 assert_eq!(model.id, "llama3");
207 assert_eq!(model.name, "llama3");
208 assert_eq!(model.provider, "ollama");
209 assert!(!model.reasoning);
210 assert!(!model.has_pricing());
211 }
212
213 #[test]
214 fn test_cost_calculation() {
215 let cost = ModelCost::new(3.0, 15.0);
216
217 let total = cost.calculate(1000, 500);
219 assert!((total - 0.0105).abs() < 0.0001);
221 }
222
223 #[test]
224 fn test_cost_with_cache() {
225 let cost = ModelCost::with_cache(3.0, 15.0, 0.30, 3.75);
226
227 let total = cost.calculate_with_cache(1000, 500, 2000, 1000);
228 assert!((total - 0.01485).abs() < 0.0001);
233 }
234
235 #[test]
236 fn test_model_display() {
237 let model = Model::new(
238 "gpt-5",
239 "GPT-5",
240 "openai",
241 false,
242 None,
243 ModelLimit::default(),
244 );
245
246 assert_eq!(format!("{}", model), "GPT-5");
247 }
248
249 #[test]
250 fn test_serialization() {
251 let model = Model::new(
252 "claude-sonnet-4-5-20250929",
253 "Claude Sonnet 4.5",
254 "anthropic",
255 true,
256 Some(ModelCost::new(3.0, 15.0)),
257 ModelLimit::new(200_000, 16_384),
258 );
259
260 let json = serde_json::to_string(&model).unwrap();
261 assert!(json.contains("\"id\":\"claude-sonnet-4-5-20250929\""));
262 assert!(json.contains("\"provider\":\"anthropic\""));
263
264 let deserialized: Model = serde_json::from_str(&json).unwrap();
265 assert_eq!(model, deserialized);
266 }
267}