1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7pub struct RequestUsage {
8 pub input_tokens: u64,
9 pub output_tokens: u64,
10 pub cache_write_tokens: u64,
11 pub cache_read_tokens: u64,
12 pub input_audio_tokens: u64,
13 pub output_audio_tokens: u64,
14 pub details: HashMap<String, u64>,
15}
16
17impl RequestUsage {
18 pub fn total_tokens(&self) -> u64 {
19 self.input_tokens + self.output_tokens
20 }
21}
22
23#[derive(Debug, Clone, Default, Serialize, Deserialize)]
24pub struct RunUsage {
25 pub requests: u64,
26 pub tool_calls: u64,
27 pub input_tokens: u64,
28 pub output_tokens: u64,
29 pub cache_write_tokens: u64,
30 pub cache_read_tokens: u64,
31 pub input_audio_tokens: u64,
32 pub output_audio_tokens: u64,
33 pub details: HashMap<String, u64>,
34}
35
36impl RunUsage {
37 pub fn total_tokens(&self) -> u64 {
38 self.input_tokens + self.output_tokens
39 }
40
41 pub fn incr_request(&mut self, request: &RequestUsage) {
42 self.requests += 1;
43 self.input_tokens += request.input_tokens;
44 self.output_tokens += request.output_tokens;
45 self.cache_write_tokens += request.cache_write_tokens;
46 self.cache_read_tokens += request.cache_read_tokens;
47 self.input_audio_tokens += request.input_audio_tokens;
48 self.output_audio_tokens += request.output_audio_tokens;
49 for (k, v) in &request.details {
50 *self.details.entry(k.clone()).or_insert(0) += v;
51 }
52 }
53
54 pub fn incr_tool_call(&mut self) {
55 self.tool_calls += 1;
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct UsageLimits {
61 pub request_limit: Option<u64>,
62 pub tool_calls_limit: Option<u64>,
63 pub input_tokens_limit: Option<u64>,
64 pub output_tokens_limit: Option<u64>,
65 pub total_tokens_limit: Option<u64>,
66}
67
68impl Default for UsageLimits {
69 fn default() -> Self {
70 Self {
71 request_limit: Some(50),
72 tool_calls_limit: None,
73 input_tokens_limit: None,
74 output_tokens_limit: None,
75 total_tokens_limit: None,
76 }
77 }
78}
79
80impl UsageLimits {
81 pub fn check_request(&self, current_requests: u64) -> Result<(), UsageError> {
82 if let Some(limit) = self.request_limit
83 && current_requests >= limit
84 {
85 return Err(UsageError::RequestLimitExceeded { limit });
86 }
87 Ok(())
88 }
89
90 pub fn check_tool_call(&self, current_calls: u64) -> Result<(), UsageError> {
91 if let Some(limit) = self.tool_calls_limit
92 && current_calls >= limit
93 {
94 return Err(UsageError::ToolCallsLimitExceeded { limit });
95 }
96 Ok(())
97 }
98
99 pub fn check_after_response(&self, usage: &RunUsage) -> Result<(), UsageError> {
100 if let Some(limit) = self.input_tokens_limit
101 && usage.input_tokens > limit
102 {
103 return Err(UsageError::InputTokensLimitExceeded { limit });
104 }
105 if let Some(limit) = self.output_tokens_limit
106 && usage.output_tokens > limit
107 {
108 return Err(UsageError::OutputTokensLimitExceeded { limit });
109 }
110 if let Some(limit) = self.total_tokens_limit
111 && usage.total_tokens() > limit
112 {
113 return Err(UsageError::TotalTokensLimitExceeded { limit });
114 }
115 Ok(())
116 }
117}
118
119#[derive(Debug, Error)]
120pub enum UsageError {
121 #[error("request limit exceeded (limit {limit})")]
122 RequestLimitExceeded { limit: u64 },
123 #[error("tool call limit exceeded (limit {limit})")]
124 ToolCallsLimitExceeded { limit: u64 },
125 #[error("input token limit exceeded (limit {limit})")]
126 InputTokensLimitExceeded { limit: u64 },
127 #[error("output token limit exceeded (limit {limit})")]
128 OutputTokensLimitExceeded { limit: u64 },
129 #[error("total token limit exceeded (limit {limit})")]
130 TotalTokensLimitExceeded { limit: u64 },
131}