1use std::collections::BTreeMap;
4
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8#[cfg(feature = "pricing")]
9pub mod pricing;
10
11#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
13pub struct Usage {
14 pub requests: u64,
16 pub input_tokens: u64,
23 #[serde(default)]
25 pub cache_write_tokens: u64,
26 #[serde(default)]
28 pub cache_read_tokens: u64,
29 pub output_tokens: u64,
31 pub total_tokens: u64,
33 #[serde(default)]
35 pub tool_calls: u64,
36}
37
38impl Usage {
39 pub fn add_assign(&mut self, other: &Self) {
41 self.requests = self.requests.saturating_add(other.requests);
42 self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
43 self.cache_write_tokens = self
44 .cache_write_tokens
45 .saturating_add(other.cache_write_tokens);
46 self.cache_read_tokens = self
47 .cache_read_tokens
48 .saturating_add(other.cache_read_tokens);
49 self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
50 self.total_tokens = self.total_tokens.saturating_add(other.total_tokens);
51 self.tool_calls = self.tool_calls.saturating_add(other.tool_calls);
52 }
53
54 #[must_use]
56 pub const fn is_empty(&self) -> bool {
57 self.requests == 0
58 && self.input_tokens == 0
59 && self.cache_write_tokens == 0
60 && self.cache_read_tokens == 0
61 && self.output_tokens == 0
62 && self.total_tokens == 0
63 && self.tool_calls == 0
64 }
65
66 #[must_use]
68 pub const fn with_additional_tool_calls(mut self, tool_calls: u64) -> Self {
69 self.tool_calls = self.tool_calls.saturating_add(tool_calls);
70 self
71 }
72}
73
74#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Serialize)]
76pub struct PricingEstimate {
77 #[serde(default)]
79 pub amount_micros_usd: u64,
80}
81
82impl PricingEstimate {
83 #[must_use]
85 pub const fn from_micros_usd(amount_micros_usd: u64) -> Self {
86 Self { amount_micros_usd }
87 }
88
89 pub fn add_assign(&mut self, other: &Self) {
91 self.amount_micros_usd = self
92 .amount_micros_usd
93 .saturating_add(other.amount_micros_usd);
94 }
95
96 #[must_use]
98 pub const fn is_zero(&self) -> bool {
99 self.amount_micros_usd == 0
100 }
101}
102
103#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
105pub struct UsageSnapshotEntry {
106 pub agent_id: String,
108 pub agent_name: String,
110 pub model_id: String,
112 pub usage: Usage,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub estimate_pricing: Option<PricingEstimate>,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub usage_id: Option<String>,
120 #[serde(default = "default_usage_source")]
122 pub source: String,
123}
124
125#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
127pub struct UsageAgentTotal {
128 pub agent_name: String,
130 pub model_id: String,
132 pub usage: Usage,
134 #[serde(default, skip_serializing_if = "Option::is_none")]
136 pub estimate_pricing: Option<PricingEstimate>,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
139 pub usage_id: Option<String>,
140 #[serde(default = "default_usage_source")]
142 pub source: String,
143}
144
145#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
150pub struct UsageSnapshot {
151 pub run_id: String,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub latest_usage: Option<Usage>,
160 #[serde(default)]
162 pub total_usage: Usage,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub estimate_pricing: Option<PricingEstimate>,
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
168 pub entries: Vec<UsageSnapshotEntry>,
169 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
171 pub agent_usages: BTreeMap<String, UsageAgentTotal>,
172 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
174 pub model_usages: BTreeMap<String, Usage>,
175 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
177 pub model_estimate_pricing: BTreeMap<String, PricingEstimate>,
178}
179
180fn default_usage_source() -> String {
181 "model_request".to_string()
182}
183
184#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
186pub struct UsageLimits {
187 #[serde(default, skip_serializing_if = "Option::is_none")]
189 pub request_limit: Option<u64>,
190 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub input_tokens_limit: Option<u64>,
193 #[serde(default, skip_serializing_if = "Option::is_none")]
195 pub output_tokens_limit: Option<u64>,
196 #[serde(default, skip_serializing_if = "Option::is_none")]
198 pub total_tokens_limit: Option<u64>,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
201 pub tool_calls_limit: Option<u64>,
202 #[cfg(feature = "pricing")]
204 #[serde(default, skip_serializing_if = "Option::is_none")]
205 pub cost_budget: Option<pricing::CostBudget>,
206}
207
208impl UsageLimits {
209 #[must_use]
211 pub const fn new() -> Self {
212 Self {
213 request_limit: None,
214 input_tokens_limit: None,
215 output_tokens_limit: None,
216 total_tokens_limit: None,
217 tool_calls_limit: None,
218 #[cfg(feature = "pricing")]
219 cost_budget: None,
220 }
221 }
222
223 #[must_use]
225 pub const fn with_request_limit(mut self, limit: u64) -> Self {
226 self.request_limit = Some(limit);
227 self
228 }
229
230 #[must_use]
232 pub const fn with_input_tokens_limit(mut self, limit: u64) -> Self {
233 self.input_tokens_limit = Some(limit);
234 self
235 }
236
237 #[must_use]
239 pub const fn with_output_tokens_limit(mut self, limit: u64) -> Self {
240 self.output_tokens_limit = Some(limit);
241 self
242 }
243
244 #[must_use]
246 pub const fn with_total_tokens_limit(mut self, limit: u64) -> Self {
247 self.total_tokens_limit = Some(limit);
248 self
249 }
250
251 #[must_use]
253 pub const fn with_tool_calls_limit(mut self, limit: u64) -> Self {
254 self.tool_calls_limit = Some(limit);
255 self
256 }
257
258 #[cfg(feature = "pricing")]
260 #[must_use]
261 pub const fn with_cost_budget(mut self, budget: pricing::CostBudget) -> Self {
262 self.cost_budget = Some(budget);
263 self
264 }
265
266 #[cfg(feature = "pricing")]
268 #[must_use]
269 pub fn estimate_cost_micros(&self, usage: &Usage) -> Option<u64> {
270 self.cost_budget
271 .as_ref()
272 .map(|budget| budget.estimate_micros(usage))
273 }
274
275 #[cfg(feature = "pricing")]
277 #[must_use]
278 pub fn estimate_pricing(&self, usage: &Usage) -> Option<PricingEstimate> {
279 self.cost_budget
280 .as_ref()
281 .map(|budget| budget.estimate_pricing(usage))
282 }
283
284 pub const fn check_before_request(&self, current: &Usage) -> Result<(), UsageLimitError> {
290 if let Some(limit) = self.request_limit {
291 let next = current.requests.saturating_add(1);
292 if next > limit {
293 return Err(UsageLimitError::NextRequest {
294 limit,
295 next_requests: next,
296 });
297 }
298 }
299 Ok(())
300 }
301
302 pub const fn check_tool_calls(&self, projected: &Usage) -> Result<(), UsageLimitError> {
308 if let Some(limit) = self.tool_calls_limit {
309 if projected.tool_calls > limit {
310 return Err(UsageLimitError::ToolCalls {
311 limit,
312 tool_calls: projected.tool_calls,
313 });
314 }
315 }
316 Ok(())
317 }
318
319 pub fn check_usage(&self, usage: &Usage) -> Result<(), UsageLimitError> {
325 check_limit(
326 UsageTokenKind::InputTokens,
327 self.input_tokens_limit,
328 usage.input_tokens,
329 )?;
330 check_limit(
331 UsageTokenKind::OutputTokens,
332 self.output_tokens_limit,
333 usage.output_tokens,
334 )?;
335 check_limit(
336 UsageTokenKind::TotalTokens,
337 self.total_tokens_limit,
338 usage.total_tokens,
339 )?;
340 #[cfg(feature = "pricing")]
341 if let Some(budget) = &self.cost_budget {
342 budget.check_usage(usage)?;
343 }
344 Ok(())
345 }
346}
347
348#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
350#[serde(rename_all = "snake_case")]
351pub enum UsageTokenKind {
352 InputTokens,
354 OutputTokens,
356 TotalTokens,
358}
359
360impl std::fmt::Display for UsageTokenKind {
361 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 let value = match self {
363 Self::InputTokens => "input_tokens",
364 Self::OutputTokens => "output_tokens",
365 Self::TotalTokens => "total_tokens",
366 };
367 formatter.write_str(value)
368 }
369}
370
371#[derive(Clone, Debug, Error, Deserialize, Eq, PartialEq, Serialize)]
373pub enum UsageLimitError {
374 #[error("the next request would exceed the request_limit of {limit} (next_requests={next_requests})")]
376 NextRequest {
377 limit: u64,
379 next_requests: u64,
381 },
382 #[error("exceeded the {kind}_limit of {limit} ({kind}={actual})")]
384 Token {
385 kind: UsageTokenKind,
387 limit: u64,
389 actual: u64,
391 },
392 #[cfg(feature = "pricing")]
394 #[error(
395 "exceeded the total_cost_limit_micros of {limit_micros} (cost_micros={actual_micros})"
396 )]
397 Cost {
398 limit_micros: u64,
400 actual_micros: u64,
402 },
403 #[error("the next tool call(s) would exceed the tool_calls_limit of {limit} (tool_calls={tool_calls})")]
405 ToolCalls {
406 limit: u64,
408 tool_calls: u64,
410 },
411}
412
413const fn check_limit(
414 kind: UsageTokenKind,
415 limit: Option<u64>,
416 actual: u64,
417) -> Result<(), UsageLimitError> {
418 if let Some(limit) = limit {
419 if actual > limit {
420 return Err(UsageLimitError::Token {
421 kind,
422 limit,
423 actual,
424 });
425 }
426 }
427 Ok(())
428}
429
430pub fn add_optional_pricing(
432 total: &mut Option<PricingEstimate>,
433 estimate: Option<&PricingEstimate>,
434) {
435 if let Some(estimate) = estimate {
436 match total {
437 Some(total) => total.add_assign(estimate),
438 None => *total = Some(estimate.clone()),
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn usage_add_assign_and_empty_work() {
449 let mut usage = Usage {
450 requests: 1,
451 input_tokens: 2,
452 cache_write_tokens: 7,
453 cache_read_tokens: 11,
454 output_tokens: 3,
455 total_tokens: 5,
456 tool_calls: 1,
457 };
458 usage.add_assign(&Usage {
459 requests: 2,
460 input_tokens: 4,
461 cache_write_tokens: 13,
462 cache_read_tokens: 17,
463 output_tokens: 6,
464 total_tokens: 10,
465 tool_calls: 3,
466 });
467 assert_eq!(usage.requests, 3);
468 assert_eq!(usage.input_tokens, 6);
469 assert_eq!(usage.cache_write_tokens, 20);
470 assert_eq!(usage.cache_read_tokens, 28);
471 assert_eq!(usage.output_tokens, 9);
472 assert_eq!(usage.total_tokens, 15);
473 assert_eq!(usage.tool_calls, 4);
474 assert_eq!(usage.clone().with_additional_tool_calls(2).tool_calls, 6);
475 assert!(Usage::default().is_empty());
476 assert!(!usage.is_empty());
477 }
478
479 #[test]
480 fn usage_add_assign_saturates() {
481 let mut usage = Usage {
482 requests: u64::MAX,
483 input_tokens: u64::MAX,
484 cache_write_tokens: u64::MAX,
485 cache_read_tokens: u64::MAX,
486 output_tokens: u64::MAX,
487 total_tokens: u64::MAX,
488 tool_calls: u64::MAX,
489 };
490 usage.add_assign(&Usage {
491 requests: 1,
492 input_tokens: 1,
493 cache_write_tokens: 1,
494 cache_read_tokens: 1,
495 output_tokens: 1,
496 total_tokens: 1,
497 tool_calls: 1,
498 });
499 assert_eq!(usage.requests, u64::MAX);
500 assert_eq!(usage.input_tokens, u64::MAX);
501 assert_eq!(usage.cache_write_tokens, u64::MAX);
502 assert_eq!(usage.cache_read_tokens, u64::MAX);
503 assert_eq!(usage.output_tokens, u64::MAX);
504 assert_eq!(usage.total_tokens, u64::MAX);
505 assert_eq!(usage.tool_calls, u64::MAX);
506 }
507
508 #[test]
509 fn usage_limit_error_token_kind_is_owned_ser_de_contract() {
510 let error = UsageLimitError::Token {
511 kind: UsageTokenKind::TotalTokens,
512 limit: 5,
513 actual: 6,
514 };
515 let value = match serde_json::to_value(&error) {
516 Ok(value) => value,
517 Err(err) => panic!("usage limit error should serialize: {err}"),
518 };
519 let restored: UsageLimitError = match serde_json::from_value(value) {
520 Ok(restored) => restored,
521 Err(err) => panic!("usage limit error should deserialize: {err}"),
522 };
523 assert_eq!(restored, error);
524 }
525
526 #[test]
527 fn usage_snapshot_accepts_missing_pricing_fields() {
528 let snapshot: UsageSnapshot = match serde_json::from_value(serde_json::json!({
529 "run_id": "run_1",
530 "total_usage": {
531 "requests": 1,
532 "input_tokens": 2,
533 "output_tokens": 3,
534 "total_tokens": 5
535 }
536 })) {
537 Ok(snapshot) => snapshot,
538 Err(err) => panic!("usage snapshot should deserialize: {err}"),
539 };
540 assert_eq!(snapshot.run_id, "run_1");
541 assert!(snapshot.estimate_pricing.is_none());
542 assert!(snapshot.model_estimate_pricing.is_empty());
543 }
544}