1use crate::session::Session;
2use serde::{Deserialize, Serialize};
3
4const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
5const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
6const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
7const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct ModelPricing {
11 pub input_cost_per_million: f64,
12 pub output_cost_per_million: f64,
13 pub cache_creation_cost_per_million: f64,
14 pub cache_read_cost_per_million: f64,
15}
16
17impl ModelPricing {
18 #[must_use]
19 pub const fn default_sonnet_tier() -> Self {
20 Self {
21 input_cost_per_million: DEFAULT_INPUT_COST_PER_MILLION,
22 output_cost_per_million: DEFAULT_OUTPUT_COST_PER_MILLION,
23 cache_creation_cost_per_million: DEFAULT_CACHE_CREATION_COST_PER_MILLION,
24 cache_read_cost_per_million: DEFAULT_CACHE_READ_COST_PER_MILLION,
25 }
26 }
27}
28
29#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
30pub struct TokenUsage {
31 pub input_tokens: u32,
32 pub output_tokens: u32,
33 pub cache_creation_input_tokens: u32,
34 pub cache_read_input_tokens: u32,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub struct UsageCostEstimate {
39 pub input_cost_usd: f64,
40 pub output_cost_usd: f64,
41 pub cache_creation_cost_usd: f64,
42 pub cache_read_cost_usd: f64,
43}
44
45impl UsageCostEstimate {
46 #[must_use]
47 pub fn total_cost_usd(self) -> f64 {
48 self.input_cost_usd
49 + self.output_cost_usd
50 + self.cache_creation_cost_usd
51 + self.cache_read_cost_usd
52 }
53}
54
55#[must_use]
56pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
57 let normalized = model.to_ascii_lowercase();
58 if normalized.contains("haiku") {
59 return Some(ModelPricing {
60 input_cost_per_million: 1.0,
61 output_cost_per_million: 5.0,
62 cache_creation_cost_per_million: 1.25,
63 cache_read_cost_per_million: 0.1,
64 });
65 }
66 if normalized.contains("opus") {
67 return Some(ModelPricing {
68 input_cost_per_million: 15.0,
69 output_cost_per_million: 75.0,
70 cache_creation_cost_per_million: 18.75,
71 cache_read_cost_per_million: 1.5,
72 });
73 }
74 if normalized.contains("sonnet") {
75 return Some(ModelPricing::default_sonnet_tier());
76 }
77 None
78}
79
80impl TokenUsage {
81 #[must_use]
82 pub fn total_tokens(self) -> u32 {
83 self.input_tokens
84 + self.output_tokens
85 + self.cache_creation_input_tokens
86 + self.cache_read_input_tokens
87 }
88
89 #[must_use]
90 pub fn estimate_cost_usd(self) -> UsageCostEstimate {
91 self.estimate_cost_usd_with_pricing(ModelPricing::default_sonnet_tier())
92 }
93
94 #[must_use]
95 pub fn estimate_cost_usd_with_pricing(self, pricing: ModelPricing) -> UsageCostEstimate {
96 UsageCostEstimate {
97 input_cost_usd: cost_for_tokens(self.input_tokens, pricing.input_cost_per_million),
98 output_cost_usd: cost_for_tokens(self.output_tokens, pricing.output_cost_per_million),
99 cache_creation_cost_usd: cost_for_tokens(
100 self.cache_creation_input_tokens,
101 pricing.cache_creation_cost_per_million,
102 ),
103 cache_read_cost_usd: cost_for_tokens(
104 self.cache_read_input_tokens,
105 pricing.cache_read_cost_per_million,
106 ),
107 }
108 }
109
110 #[must_use]
111 pub fn summary_lines(self, label: &str) -> Vec<String> {
112 self.summary_lines_for_model(label, None)
113 }
114
115 #[must_use]
116 pub fn summary_lines_for_model(self, label: &str, model: Option<&str>) -> Vec<String> {
117 let pricing = model.and_then(pricing_for_model);
118 let cost = pricing.map_or_else(
119 || self.estimate_cost_usd(),
120 |pricing| self.estimate_cost_usd_with_pricing(pricing),
121 );
122 let model_suffix =
123 model.map_or_else(String::new, |model_name| format!(" model={model_name}"));
124 let pricing_suffix = if pricing.is_some() {
125 ""
126 } else if model.is_some() {
127 " pricing=estimated-default"
128 } else {
129 ""
130 };
131 vec![
132 format!(
133 "{label}: total_tokens={} input={} output={} cache_write={} cache_read={} estimated_cost={}{}{}",
134 self.total_tokens(),
135 self.input_tokens,
136 self.output_tokens,
137 self.cache_creation_input_tokens,
138 self.cache_read_input_tokens,
139 format_usd(cost.total_cost_usd()),
140 model_suffix,
141 pricing_suffix,
142 ),
143 format!(
144 " cost breakdown: input={} output={} cache_write={} cache_read={}",
145 format_usd(cost.input_cost_usd),
146 format_usd(cost.output_cost_usd),
147 format_usd(cost.cache_creation_cost_usd),
148 format_usd(cost.cache_read_cost_usd),
149 ),
150 ]
151 }
152}
153
154fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
155 f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens
156}
157
158#[must_use]
159pub fn format_usd(amount: f64) -> String {
160 format!("${amount:.4}")
161}
162
163#[derive(Debug, Clone, Default, PartialEq, Eq)]
164pub struct UsageTracker {
165 latest_turn: TokenUsage,
166 cumulative: TokenUsage,
167 turns: u32,
168}
169
170impl UsageTracker {
171 #[must_use]
172 pub fn new() -> Self {
173 Self::default()
174 }
175
176 #[must_use]
177 pub fn from_session(session: &Session) -> Self {
178 let mut tracker = Self::new();
179 for message in &session.messages {
180 if let Some(usage) = message.usage {
181 tracker.record(usage);
182 }
183 }
184 tracker
185 }
186
187 pub fn record(&mut self, usage: TokenUsage) {
188 self.latest_turn = usage;
189 self.cumulative.input_tokens += usage.input_tokens;
190 self.cumulative.output_tokens += usage.output_tokens;
191 self.cumulative.cache_creation_input_tokens += usage.cache_creation_input_tokens;
192 self.cumulative.cache_read_input_tokens += usage.cache_read_input_tokens;
193 self.turns += 1;
194 }
195
196 #[must_use]
197 pub fn current_turn_usage(&self) -> TokenUsage {
198 self.latest_turn
199 }
200
201 #[must_use]
202 pub fn cumulative_usage(&self) -> TokenUsage {
203 self.cumulative
204 }
205
206 #[must_use]
207 pub fn turns(&self) -> u32 {
208 self.turns
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::{format_usd, pricing_for_model, TokenUsage, UsageTracker};
215 use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
216
217 #[test]
218 fn tracks_true_cumulative_usage() {
219 let mut tracker = UsageTracker::new();
220 tracker.record(TokenUsage {
221 input_tokens: 10,
222 output_tokens: 4,
223 cache_creation_input_tokens: 2,
224 cache_read_input_tokens: 1,
225 });
226 tracker.record(TokenUsage {
227 input_tokens: 20,
228 output_tokens: 6,
229 cache_creation_input_tokens: 3,
230 cache_read_input_tokens: 2,
231 });
232
233 assert_eq!(tracker.turns(), 2);
234 assert_eq!(tracker.current_turn_usage().input_tokens, 20);
235 assert_eq!(tracker.current_turn_usage().output_tokens, 6);
236 assert_eq!(tracker.cumulative_usage().output_tokens, 10);
237 assert_eq!(tracker.cumulative_usage().input_tokens, 30);
238 assert_eq!(tracker.cumulative_usage().total_tokens(), 48);
239 }
240
241 #[test]
242 fn computes_cost_summary_lines() {
243 let usage = TokenUsage {
244 input_tokens: 1_000_000,
245 output_tokens: 500_000,
246 cache_creation_input_tokens: 100_000,
247 cache_read_input_tokens: 200_000,
248 };
249
250 let cost = usage.estimate_cost_usd();
251 assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
252 assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
253 let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
254 assert!(lines[0].contains("estimated_cost=$54.6750"));
255 assert!(lines[0].contains("model=claude-sonnet-4-6"));
256 assert!(lines[1].contains("cache_read=$0.3000"));
257 }
258
259 #[test]
260 fn supports_model_specific_pricing() {
261 let usage = TokenUsage {
262 input_tokens: 1_000_000,
263 output_tokens: 500_000,
264 cache_creation_input_tokens: 0,
265 cache_read_input_tokens: 0,
266 };
267
268 let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
269 let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
270 let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
271 let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
272 assert_eq!(format_usd(haiku_cost.total_cost_usd()), "$3.5000");
273 assert_eq!(format_usd(opus_cost.total_cost_usd()), "$52.5000");
274 }
275
276 #[test]
277 fn marks_unknown_model_pricing_as_fallback() {
278 let usage = TokenUsage {
279 input_tokens: 100,
280 output_tokens: 100,
281 cache_creation_input_tokens: 0,
282 cache_read_input_tokens: 0,
283 };
284 let lines = usage.summary_lines_for_model("usage", Some("custom-model"));
285 assert!(lines[0].contains("pricing=estimated-default"));
286 }
287
288 #[test]
289 fn reconstructs_usage_from_session_messages() {
290 let session = Session {
291 version: 1,
292 messages: vec![ConversationMessage {
293 role: MessageRole::Assistant,
294 blocks: vec![ContentBlock::Text {
295 text: "done".to_string(),
296 }],
297 usage: Some(TokenUsage {
298 input_tokens: 5,
299 output_tokens: 2,
300 cache_creation_input_tokens: 1,
301 cache_read_input_tokens: 0,
302 }),
303 }],
304 };
305
306 let tracker = UsageTracker::from_session(&session);
307 assert_eq!(tracker.turns(), 1);
308 assert_eq!(tracker.cumulative_usage().total_tokens(), 8);
309 }
310}