Skip to main content

serdes_ai_core/
usage.rs

1//! Token usage tracking for model requests.
2//!
3//! This module provides types for tracking token usage across requests and runs,
4//! as well as usage limit checking.
5
6use serde::{Deserialize, Serialize};
7
8use crate::errors::{UsageLimitExceeded, UsageLimitType};
9
10/// Token usage for a single request.
11#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
12pub struct RequestUsage {
13    /// Number of tokens in the request/prompt.
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub request_tokens: Option<u64>,
16    /// Number of tokens in the response/completion.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub response_tokens: Option<u64>,
19    /// Total tokens (request + response).
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub total_tokens: Option<u64>,
22    /// Tokens used to create cache entries.
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub cache_creation_tokens: Option<u64>,
25    /// Tokens read from cache.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub cache_read_tokens: Option<u64>,
28    /// Provider-specific usage details.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub details: Option<serde_json::Value>,
31}
32
33impl RequestUsage {
34    /// Create a new empty usage record.
35    #[must_use]
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Create usage with request and response tokens.
41    #[must_use]
42    pub fn with_tokens(request_tokens: u64, response_tokens: u64) -> Self {
43        Self {
44            request_tokens: Some(request_tokens),
45            response_tokens: Some(response_tokens),
46            total_tokens: Some(request_tokens + response_tokens),
47            ..Self::default()
48        }
49    }
50
51    /// Set request tokens.
52    #[must_use]
53    pub fn request_tokens(mut self, tokens: u64) -> Self {
54        self.request_tokens = Some(tokens);
55        self.recalculate_total();
56        self
57    }
58
59    /// Set response tokens.
60    #[must_use]
61    pub fn response_tokens(mut self, tokens: u64) -> Self {
62        self.response_tokens = Some(tokens);
63        self.recalculate_total();
64        self
65    }
66
67    /// Set cache creation tokens.
68    #[must_use]
69    pub fn cache_creation_tokens(mut self, tokens: u64) -> Self {
70        self.cache_creation_tokens = Some(tokens);
71        self
72    }
73
74    /// Set cache read tokens.
75    #[must_use]
76    pub fn cache_read_tokens(mut self, tokens: u64) -> Self {
77        self.cache_read_tokens = Some(tokens);
78        self
79    }
80
81    /// Set details.
82    #[must_use]
83    pub fn details(mut self, details: serde_json::Value) -> Self {
84        self.details = Some(details);
85        self
86    }
87
88    /// Merge another usage record into this one.
89    pub fn merge(&mut self, other: &RequestUsage) {
90        self.request_tokens = match (self.request_tokens, other.request_tokens) {
91            (Some(a), Some(b)) => Some(a + b),
92            (Some(a), None) => Some(a),
93            (None, Some(b)) => Some(b),
94            (None, None) => None,
95        };
96        self.response_tokens = match (self.response_tokens, other.response_tokens) {
97            (Some(a), Some(b)) => Some(a + b),
98            (Some(a), None) => Some(a),
99            (None, Some(b)) => Some(b),
100            (None, None) => None,
101        };
102        self.cache_creation_tokens = match (self.cache_creation_tokens, other.cache_creation_tokens)
103        {
104            (Some(a), Some(b)) => Some(a + b),
105            (Some(a), None) => Some(a),
106            (None, Some(b)) => Some(b),
107            (None, None) => None,
108        };
109        self.cache_read_tokens = match (self.cache_read_tokens, other.cache_read_tokens) {
110            (Some(a), Some(b)) => Some(a + b),
111            (Some(a), None) => Some(a),
112            (None, Some(b)) => Some(b),
113            (None, None) => None,
114        };
115        self.recalculate_total();
116    }
117
118    /// Recalculate total from request and response.
119    fn recalculate_total(&mut self) {
120        self.total_tokens = match (self.request_tokens, self.response_tokens) {
121            (Some(a), Some(b)) => Some(a + b),
122            (Some(a), None) => Some(a),
123            (None, Some(b)) => Some(b),
124            (None, None) => None,
125        };
126    }
127
128    /// Get total tokens, calculating if not set.
129    #[must_use]
130    pub fn total(&self) -> u64 {
131        self.total_tokens
132            .unwrap_or_else(|| self.request_tokens.unwrap_or(0) + self.response_tokens.unwrap_or(0))
133    }
134
135    /// Check if this usage record has any data.
136    #[must_use]
137    pub fn is_empty(&self) -> bool {
138        self.request_tokens.is_none()
139            && self.response_tokens.is_none()
140            && self.total_tokens.is_none()
141    }
142}
143
144impl std::ops::Add for RequestUsage {
145    type Output = Self;
146
147    fn add(mut self, rhs: Self) -> Self::Output {
148        self.merge(&rhs);
149        self
150    }
151}
152
153impl std::ops::AddAssign for RequestUsage {
154    fn add_assign(&mut self, rhs: Self) {
155        self.merge(&rhs);
156    }
157}
158
159/// Accumulated usage for an entire run.
160#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
161pub struct RunUsage {
162    /// Individual request usages.
163    pub requests: Vec<RequestUsage>,
164    /// Total request tokens across all requests.
165    pub total_request_tokens: u64,
166    /// Total response tokens across all requests.
167    pub total_response_tokens: u64,
168    /// Total tokens across all requests.
169    pub total_tokens: u64,
170}
171
172impl RunUsage {
173    /// Create a new empty run usage.
174    #[must_use]
175    pub fn new() -> Self {
176        Self::default()
177    }
178
179    /// Add a request's usage.
180    pub fn add_request(&mut self, usage: RequestUsage) {
181        self.total_request_tokens += usage.request_tokens.unwrap_or(0);
182        self.total_response_tokens += usage.response_tokens.unwrap_or(0);
183        self.total_tokens += usage.total();
184        self.requests.push(usage);
185    }
186
187    /// Get the number of requests.
188    #[must_use]
189    pub fn request_count(&self) -> usize {
190        self.requests.len()
191    }
192
193    /// Check if there's no usage data.
194    #[must_use]
195    pub fn is_empty(&self) -> bool {
196        self.requests.is_empty()
197    }
198
199    /// Get average tokens per request.
200    #[must_use]
201    pub fn avg_tokens_per_request(&self) -> f64 {
202        if self.requests.is_empty() {
203            0.0
204        } else {
205            self.total_tokens as f64 / self.requests.len() as f64
206        }
207    }
208}
209
210/// Usage limits for a run.
211#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
212pub struct UsageLimits {
213    /// Maximum request tokens per request.
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub max_request_tokens: Option<u64>,
216    /// Maximum response tokens per request.
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub max_response_tokens: Option<u64>,
219    /// Maximum total tokens for the run.
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub max_total_tokens: Option<u64>,
222    /// Maximum number of requests.
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub max_requests: Option<u64>,
225}
226
227impl UsageLimits {
228    /// Create new empty limits.
229    #[must_use]
230    pub fn new() -> Self {
231        Self::default()
232    }
233
234    /// Set max request tokens.
235    #[must_use]
236    pub fn max_request_tokens(mut self, tokens: u64) -> Self {
237        self.max_request_tokens = Some(tokens);
238        self
239    }
240
241    /// Set max response tokens.
242    #[must_use]
243    pub fn max_response_tokens(mut self, tokens: u64) -> Self {
244        self.max_response_tokens = Some(tokens);
245        self
246    }
247
248    /// Set max total tokens.
249    #[must_use]
250    pub fn max_total_tokens(mut self, tokens: u64) -> Self {
251        self.max_total_tokens = Some(tokens);
252        self
253    }
254
255    /// Set max requests.
256    #[must_use]
257    pub fn max_requests(mut self, requests: u64) -> Self {
258        self.max_requests = Some(requests);
259        self
260    }
261
262    /// Check if usage exceeds limits.
263    ///
264    /// Returns `Ok(())` if within limits, or an error describing which limit was exceeded.
265    pub fn check(&self, usage: &RunUsage) -> Result<(), UsageLimitExceeded> {
266        if let Some(max) = self.max_request_tokens {
267            if usage.total_request_tokens > max {
268                return Err(UsageLimitExceeded::new(
269                    UsageLimitType::RequestTokens,
270                    usage.total_request_tokens,
271                    max,
272                ));
273            }
274        }
275
276        if let Some(max) = self.max_response_tokens {
277            if usage.total_response_tokens > max {
278                return Err(UsageLimitExceeded::new(
279                    UsageLimitType::ResponseTokens,
280                    usage.total_response_tokens,
281                    max,
282                ));
283            }
284        }
285
286        if let Some(max) = self.max_total_tokens {
287            if usage.total_tokens > max {
288                return Err(UsageLimitExceeded::new(
289                    UsageLimitType::TotalTokens,
290                    usage.total_tokens,
291                    max,
292                ));
293            }
294        }
295
296        if let Some(max) = self.max_requests {
297            let count = usage.request_count() as u64;
298            if count > max {
299                return Err(UsageLimitExceeded::new(
300                    UsageLimitType::Requests,
301                    count,
302                    max,
303                ));
304            }
305        }
306
307        Ok(())
308    }
309
310    /// Check if any limits are set.
311    #[must_use]
312    pub fn has_limits(&self) -> bool {
313        self.max_request_tokens.is_some()
314            || self.max_response_tokens.is_some()
315            || self.max_total_tokens.is_some()
316            || self.max_requests.is_some()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_request_usage_new() {
326        let usage = RequestUsage::new();
327        assert!(usage.is_empty());
328    }
329
330    #[test]
331    fn test_request_usage_with_tokens() {
332        let usage = RequestUsage::with_tokens(100, 50);
333        assert_eq!(usage.request_tokens, Some(100));
334        assert_eq!(usage.response_tokens, Some(50));
335        assert_eq!(usage.total_tokens, Some(150));
336    }
337
338    #[test]
339    fn test_request_usage_merge() {
340        let mut usage1 = RequestUsage::with_tokens(100, 50);
341        let usage2 = RequestUsage::with_tokens(200, 100);
342        usage1.merge(&usage2);
343        assert_eq!(usage1.request_tokens, Some(300));
344        assert_eq!(usage1.response_tokens, Some(150));
345        assert_eq!(usage1.total(), 450);
346    }
347
348    #[test]
349    fn test_run_usage() {
350        let mut run = RunUsage::new();
351        run.add_request(RequestUsage::with_tokens(100, 50));
352        run.add_request(RequestUsage::with_tokens(200, 100));
353
354        assert_eq!(run.request_count(), 2);
355        assert_eq!(run.total_request_tokens, 300);
356        assert_eq!(run.total_response_tokens, 150);
357        assert_eq!(run.total_tokens, 450);
358    }
359
360    #[test]
361    fn test_usage_limits_check_pass() {
362        let limits = UsageLimits::new().max_total_tokens(1000).max_requests(10);
363
364        let mut run = RunUsage::new();
365        run.add_request(RequestUsage::with_tokens(100, 50));
366
367        assert!(limits.check(&run).is_ok());
368    }
369
370    #[test]
371    fn test_usage_limits_check_fail() {
372        let limits = UsageLimits::new().max_total_tokens(100);
373
374        let mut run = RunUsage::new();
375        run.add_request(RequestUsage::with_tokens(100, 50));
376
377        let result = limits.check(&run);
378        assert!(result.is_err());
379        let err = result.unwrap_err();
380        assert_eq!(err.limit_type, UsageLimitType::TotalTokens);
381    }
382
383    #[test]
384    fn test_serde_roundtrip() {
385        let usage = RequestUsage::with_tokens(100, 50)
386            .cache_creation_tokens(10)
387            .cache_read_tokens(5);
388        let json = serde_json::to_string(&usage).unwrap();
389        let parsed: RequestUsage = serde_json::from_str(&json).unwrap();
390        assert_eq!(usage, parsed);
391    }
392}