1use serde::{Deserialize, Serialize};
7
8use crate::errors::{UsageLimitExceeded, UsageLimitType};
9
10#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12pub struct RequestUsage {
13 #[serde(skip_serializing_if = "Option::is_none")]
15 pub request_tokens: Option<u64>,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 pub response_tokens: Option<u64>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub total_tokens: Option<u64>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub cache_creation_tokens: Option<u64>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub cache_read_tokens: Option<u64>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub details: Option<serde_json::Value>,
31}
32
33impl RequestUsage {
34 #[must_use]
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 #[must_use]
42 pub fn with_tokens(request_tokens: u64, response_tokens: u64) -> Self {
43 Self {
44 request_tokens: Some(request_tokens),
45 response_tokens: Some(response_tokens),
46 total_tokens: Some(request_tokens + response_tokens),
47 ..Self::default()
48 }
49 }
50
51 #[must_use]
53 pub fn request_tokens(mut self, tokens: u64) -> Self {
54 self.request_tokens = Some(tokens);
55 self.recalculate_total();
56 self
57 }
58
59 #[must_use]
61 pub fn response_tokens(mut self, tokens: u64) -> Self {
62 self.response_tokens = Some(tokens);
63 self.recalculate_total();
64 self
65 }
66
67 #[must_use]
69 pub fn cache_creation_tokens(mut self, tokens: u64) -> Self {
70 self.cache_creation_tokens = Some(tokens);
71 self
72 }
73
74 #[must_use]
76 pub fn cache_read_tokens(mut self, tokens: u64) -> Self {
77 self.cache_read_tokens = Some(tokens);
78 self
79 }
80
81 #[must_use]
83 pub fn details(mut self, details: serde_json::Value) -> Self {
84 self.details = Some(details);
85 self
86 }
87
88 pub fn merge(&mut self, other: &RequestUsage) {
90 self.request_tokens = match (self.request_tokens, other.request_tokens) {
91 (Some(a), Some(b)) => Some(a + b),
92 (Some(a), None) => Some(a),
93 (None, Some(b)) => Some(b),
94 (None, None) => None,
95 };
96 self.response_tokens = match (self.response_tokens, other.response_tokens) {
97 (Some(a), Some(b)) => Some(a + b),
98 (Some(a), None) => Some(a),
99 (None, Some(b)) => Some(b),
100 (None, None) => None,
101 };
102 self.cache_creation_tokens = match (self.cache_creation_tokens, other.cache_creation_tokens)
103 {
104 (Some(a), Some(b)) => Some(a + b),
105 (Some(a), None) => Some(a),
106 (None, Some(b)) => Some(b),
107 (None, None) => None,
108 };
109 self.cache_read_tokens = match (self.cache_read_tokens, other.cache_read_tokens) {
110 (Some(a), Some(b)) => Some(a + b),
111 (Some(a), None) => Some(a),
112 (None, Some(b)) => Some(b),
113 (None, None) => None,
114 };
115 self.recalculate_total();
116 }
117
118 fn recalculate_total(&mut self) {
120 self.total_tokens = match (self.request_tokens, self.response_tokens) {
121 (Some(a), Some(b)) => Some(a + b),
122 (Some(a), None) => Some(a),
123 (None, Some(b)) => Some(b),
124 (None, None) => None,
125 };
126 }
127
128 #[must_use]
130 pub fn total(&self) -> u64 {
131 self.total_tokens
132 .unwrap_or_else(|| self.request_tokens.unwrap_or(0) + self.response_tokens.unwrap_or(0))
133 }
134
135 #[must_use]
137 pub fn is_empty(&self) -> bool {
138 self.request_tokens.is_none()
139 && self.response_tokens.is_none()
140 && self.total_tokens.is_none()
141 }
142}
143
144impl std::ops::Add for RequestUsage {
145 type Output = Self;
146
147 fn add(mut self, rhs: Self) -> Self::Output {
148 self.merge(&rhs);
149 self
150 }
151}
152
153impl std::ops::AddAssign for RequestUsage {
154 fn add_assign(&mut self, rhs: Self) {
155 self.merge(&rhs);
156 }
157}
158
159#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
161pub struct RunUsage {
162 pub requests: Vec<RequestUsage>,
164 pub total_request_tokens: u64,
166 pub total_response_tokens: u64,
168 pub total_tokens: u64,
170}
171
172impl RunUsage {
173 #[must_use]
175 pub fn new() -> Self {
176 Self::default()
177 }
178
179 pub fn add_request(&mut self, usage: RequestUsage) {
181 self.total_request_tokens += usage.request_tokens.unwrap_or(0);
182 self.total_response_tokens += usage.response_tokens.unwrap_or(0);
183 self.total_tokens += usage.total();
184 self.requests.push(usage);
185 }
186
187 #[must_use]
189 pub fn request_count(&self) -> usize {
190 self.requests.len()
191 }
192
193 #[must_use]
195 pub fn is_empty(&self) -> bool {
196 self.requests.is_empty()
197 }
198
199 #[must_use]
201 pub fn avg_tokens_per_request(&self) -> f64 {
202 if self.requests.is_empty() {
203 0.0
204 } else {
205 self.total_tokens as f64 / self.requests.len() as f64
206 }
207 }
208}
209
210#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
212pub struct UsageLimits {
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub max_request_tokens: Option<u64>,
216 #[serde(skip_serializing_if = "Option::is_none")]
218 pub max_response_tokens: Option<u64>,
219 #[serde(skip_serializing_if = "Option::is_none")]
221 pub max_total_tokens: Option<u64>,
222 #[serde(skip_serializing_if = "Option::is_none")]
224 pub max_requests: Option<u64>,
225}
226
227impl UsageLimits {
228 #[must_use]
230 pub fn new() -> Self {
231 Self::default()
232 }
233
234 #[must_use]
236 pub fn max_request_tokens(mut self, tokens: u64) -> Self {
237 self.max_request_tokens = Some(tokens);
238 self
239 }
240
241 #[must_use]
243 pub fn max_response_tokens(mut self, tokens: u64) -> Self {
244 self.max_response_tokens = Some(tokens);
245 self
246 }
247
248 #[must_use]
250 pub fn max_total_tokens(mut self, tokens: u64) -> Self {
251 self.max_total_tokens = Some(tokens);
252 self
253 }
254
255 #[must_use]
257 pub fn max_requests(mut self, requests: u64) -> Self {
258 self.max_requests = Some(requests);
259 self
260 }
261
262 pub fn check(&self, usage: &RunUsage) -> Result<(), UsageLimitExceeded> {
266 if let Some(max) = self.max_request_tokens {
267 if usage.total_request_tokens > max {
268 return Err(UsageLimitExceeded::new(
269 UsageLimitType::RequestTokens,
270 usage.total_request_tokens,
271 max,
272 ));
273 }
274 }
275
276 if let Some(max) = self.max_response_tokens {
277 if usage.total_response_tokens > max {
278 return Err(UsageLimitExceeded::new(
279 UsageLimitType::ResponseTokens,
280 usage.total_response_tokens,
281 max,
282 ));
283 }
284 }
285
286 if let Some(max) = self.max_total_tokens {
287 if usage.total_tokens > max {
288 return Err(UsageLimitExceeded::new(
289 UsageLimitType::TotalTokens,
290 usage.total_tokens,
291 max,
292 ));
293 }
294 }
295
296 if let Some(max) = self.max_requests {
297 let count = usage.request_count() as u64;
298 if count > max {
299 return Err(UsageLimitExceeded::new(
300 UsageLimitType::Requests,
301 count,
302 max,
303 ));
304 }
305 }
306
307 Ok(())
308 }
309
310 #[must_use]
312 pub fn has_limits(&self) -> bool {
313 self.max_request_tokens.is_some()
314 || self.max_response_tokens.is_some()
315 || self.max_total_tokens.is_some()
316 || self.max_requests.is_some()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_request_usage_new() {
326 let usage = RequestUsage::new();
327 assert!(usage.is_empty());
328 }
329
330 #[test]
331 fn test_request_usage_with_tokens() {
332 let usage = RequestUsage::with_tokens(100, 50);
333 assert_eq!(usage.request_tokens, Some(100));
334 assert_eq!(usage.response_tokens, Some(50));
335 assert_eq!(usage.total_tokens, Some(150));
336 }
337
338 #[test]
339 fn test_request_usage_merge() {
340 let mut usage1 = RequestUsage::with_tokens(100, 50);
341 let usage2 = RequestUsage::with_tokens(200, 100);
342 usage1.merge(&usage2);
343 assert_eq!(usage1.request_tokens, Some(300));
344 assert_eq!(usage1.response_tokens, Some(150));
345 assert_eq!(usage1.total(), 450);
346 }
347
348 #[test]
349 fn test_run_usage() {
350 let mut run = RunUsage::new();
351 run.add_request(RequestUsage::with_tokens(100, 50));
352 run.add_request(RequestUsage::with_tokens(200, 100));
353
354 assert_eq!(run.request_count(), 2);
355 assert_eq!(run.total_request_tokens, 300);
356 assert_eq!(run.total_response_tokens, 150);
357 assert_eq!(run.total_tokens, 450);
358 }
359
360 #[test]
361 fn test_usage_limits_check_pass() {
362 let limits = UsageLimits::new().max_total_tokens(1000).max_requests(10);
363
364 let mut run = RunUsage::new();
365 run.add_request(RequestUsage::with_tokens(100, 50));
366
367 assert!(limits.check(&run).is_ok());
368 }
369
370 #[test]
371 fn test_usage_limits_check_fail() {
372 let limits = UsageLimits::new().max_total_tokens(100);
373
374 let mut run = RunUsage::new();
375 run.add_request(RequestUsage::with_tokens(100, 50));
376
377 let result = limits.check(&run);
378 assert!(result.is_err());
379 let err = result.unwrap_err();
380 assert_eq!(err.limit_type, UsageLimitType::TotalTokens);
381 }
382
383 #[test]
384 fn test_serde_roundtrip() {
385 let usage = RequestUsage::with_tokens(100, 50)
386 .cache_creation_tokens(10)
387 .cache_read_tokens(5);
388 let json = serde_json::to_string(&usage).unwrap();
389 let parsed: RequestUsage = serde_json::from_str(&json).unwrap();
390 assert_eq!(usage, parsed);
391 }
392}