Skip to main content

sentinel_common/
budget.rs

1//! Token budget management and cost attribution types.
2//!
3//! This module provides configuration types for:
4//! - Per-tenant token budgets with period-based limits
5//! - Cost attribution with per-model pricing
6//!
7//! # Token Budgets
8//!
9//! Token budgets allow tracking cumulative token usage per tenant over
10//! configurable periods (hourly, daily, monthly). This enables:
11//! - Quota enforcement for API consumers
12//! - Usage alerts at configurable thresholds
13//! - Optional rollover of unused tokens
14//!
15//! # Cost Attribution
16//!
17//! Cost attribution tracks the monetary cost of inference requests based
18//! on model-specific pricing for input and output tokens.
19
20use serde::{Deserialize, Serialize};
21
22// ============================================================================
23// Budget Configuration
24// ============================================================================
25
26/// Token budget configuration for per-tenant usage tracking.
27///
28/// Budgets track cumulative token usage over a configurable period,
29/// with optional alerts and enforcement.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TokenBudgetConfig {
32    /// Budget period (when the budget resets)
33    #[serde(default)]
34    pub period: BudgetPeriod,
35
36    /// Total tokens allowed in the period
37    pub limit: u64,
38
39    /// Alert thresholds as percentages (e.g., [0.80, 0.90, 0.95])
40    /// Triggers alerts when usage crosses these thresholds
41    #[serde(default = "default_alert_thresholds")]
42    pub alert_thresholds: Vec<f64>,
43
44    /// Whether to enforce the limit (block requests when exhausted)
45    #[serde(default = "default_true")]
46    pub enforce: bool,
47
48    /// Allow unused tokens to roll over to the next period
49    #[serde(default)]
50    pub rollover: bool,
51
52    /// Allow burst usage above limit as a percentage (soft limit)
53    /// E.g., 0.10 allows 10% burst above the limit
54    #[serde(default)]
55    pub burst_allowance: Option<f64>,
56}
57
58fn default_alert_thresholds() -> Vec<f64> {
59    vec![0.80, 0.90, 0.95]
60}
61
62fn default_true() -> bool {
63    true
64}
65
66impl Default for TokenBudgetConfig {
67    fn default() -> Self {
68        Self {
69            period: BudgetPeriod::Daily,
70            limit: 1_000_000, // 1M tokens
71            alert_thresholds: default_alert_thresholds(),
72            enforce: true,
73            rollover: false,
74            burst_allowance: None,
75        }
76    }
77}
78
79/// Budget period defining when the budget resets.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum BudgetPeriod {
83    /// Resets every hour
84    Hourly,
85    /// Resets every day at midnight UTC
86    #[default]
87    Daily,
88    /// Resets on the first of each month at midnight UTC
89    Monthly,
90    /// Custom period in seconds
91    Custom {
92        /// Period duration in seconds
93        seconds: u64,
94    },
95}
96
97impl BudgetPeriod {
98    /// Get the period duration in seconds.
99    pub fn as_secs(&self) -> u64 {
100        match self {
101            BudgetPeriod::Hourly => 3600,
102            BudgetPeriod::Daily => 86400,
103            BudgetPeriod::Monthly => 2_592_000, // 30 days
104            BudgetPeriod::Custom { seconds } => *seconds,
105        }
106    }
107}
108
109// ============================================================================
110// Cost Attribution Configuration
111// ============================================================================
112
113/// Cost attribution configuration for tracking inference costs.
114///
115/// Allows per-model pricing with separate input/output token rates.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct CostAttributionConfig {
118    /// Whether cost attribution is enabled
119    #[serde(default)]
120    pub enabled: bool,
121
122    /// Per-model pricing rules (evaluated in order, first match wins)
123    #[serde(default)]
124    pub pricing: Vec<ModelPricing>,
125
126    /// Default cost per million input tokens (fallback)
127    #[serde(default = "default_input_cost")]
128    pub default_input_cost: f64,
129
130    /// Default cost per million output tokens (fallback)
131    #[serde(default = "default_output_cost")]
132    pub default_output_cost: f64,
133
134    /// Currency for cost values (default: USD)
135    #[serde(default = "default_currency")]
136    pub currency: String,
137}
138
139fn default_input_cost() -> f64 {
140    1.0
141}
142
143fn default_output_cost() -> f64 {
144    2.0
145}
146
147fn default_currency() -> String {
148    "USD".to_string()
149}
150
151impl Default for CostAttributionConfig {
152    fn default() -> Self {
153        Self {
154            enabled: false,
155            pricing: Vec::new(),
156            default_input_cost: default_input_cost(),
157            default_output_cost: default_output_cost(),
158            currency: default_currency(),
159        }
160    }
161}
162
163/// Per-model pricing configuration.
164///
165/// The `model_pattern` supports glob-style matching:
166/// - `gpt-4*` matches `gpt-4`, `gpt-4-turbo`, `gpt-4o`, etc.
167/// - `claude-3-*` matches all Claude 3 variants
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct ModelPricing {
170    /// Model name or pattern (glob-style matching with `*`)
171    pub model_pattern: String,
172
173    /// Cost per million input tokens
174    pub input_cost_per_million: f64,
175
176    /// Cost per million output tokens
177    pub output_cost_per_million: f64,
178
179    /// Optional currency override (defaults to parent config currency)
180    #[serde(default)]
181    pub currency: Option<String>,
182}
183
184impl ModelPricing {
185    /// Create new model pricing with the given pattern and costs.
186    pub fn new(pattern: impl Into<String>, input_cost: f64, output_cost: f64) -> Self {
187        Self {
188            model_pattern: pattern.into(),
189            input_cost_per_million: input_cost,
190            output_cost_per_million: output_cost,
191            currency: None,
192        }
193    }
194
195    /// Check if this pricing rule matches the given model name.
196    pub fn matches(&self, model: &str) -> bool {
197        if self.model_pattern.contains('*') {
198            // Glob-style matching
199            let pattern = &self.model_pattern;
200            if pattern.starts_with('*') && pattern.ends_with('*') {
201                // *pattern* - contains
202                let inner = &pattern[1..pattern.len() - 1];
203                model.contains(inner)
204            } else if pattern.starts_with('*') {
205                // *pattern - ends with
206                model.ends_with(&pattern[1..])
207            } else if pattern.ends_with('*') {
208                // pattern* - starts with
209                model.starts_with(&pattern[..pattern.len() - 1])
210            } else {
211                // Complex pattern - split and match parts
212                let parts: Vec<&str> = pattern.split('*').collect();
213                if parts.is_empty() {
214                    return true;
215                }
216
217                let mut remaining = model;
218                for (i, part) in parts.iter().enumerate() {
219                    if part.is_empty() {
220                        continue;
221                    }
222                    if i == 0 {
223                        // First part must be prefix
224                        if !remaining.starts_with(part) {
225                            return false;
226                        }
227                        remaining = &remaining[part.len()..];
228                    } else if i == parts.len() - 1 {
229                        // Last part must be suffix
230                        if !remaining.ends_with(part) {
231                            return false;
232                        }
233                    } else {
234                        // Middle parts must exist
235                        if let Some(idx) = remaining.find(part) {
236                            remaining = &remaining[idx + part.len()..];
237                        } else {
238                            return false;
239                        }
240                    }
241                }
242                true
243            }
244        } else {
245            // Exact match
246            self.model_pattern == model
247        }
248    }
249
250    /// Calculate cost for the given token counts.
251    pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
252        let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_cost_per_million;
253        let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_cost_per_million;
254        input_cost + output_cost
255    }
256}
257
258// ============================================================================
259// Result Types
260// ============================================================================
261
262/// Result of a budget check operation.
263#[derive(Debug, Clone, PartialEq)]
264pub enum BudgetCheckResult {
265    /// Request is allowed within budget
266    Allowed {
267        /// Tokens remaining after this request
268        remaining: u64,
269    },
270    /// Budget is exhausted
271    Exhausted {
272        /// Seconds until the period resets
273        retry_after_secs: u64,
274    },
275    /// Request allowed via burst allowance (soft limit)
276    Soft {
277        /// Tokens remaining (negative means over budget)
278        remaining: i64,
279        /// Amount over the base limit
280        over_by: u64,
281    },
282}
283
284impl BudgetCheckResult {
285    /// Returns true if the request should be allowed.
286    pub fn is_allowed(&self) -> bool {
287        matches!(self, Self::Allowed { .. } | Self::Soft { .. })
288    }
289
290    /// Returns the retry-after value in seconds, or 0 if allowed.
291    pub fn retry_after_secs(&self) -> u64 {
292        match self {
293            Self::Exhausted { retry_after_secs } => *retry_after_secs,
294            _ => 0,
295        }
296    }
297}
298
299/// Alert generated when budget threshold is crossed.
300#[derive(Debug, Clone)]
301pub struct BudgetAlert {
302    /// Tenant/client identifier
303    pub tenant: String,
304    /// Threshold that was crossed (e.g., 0.80 for 80%)
305    pub threshold: f64,
306    /// Current token usage
307    pub tokens_used: u64,
308    /// Budget limit
309    pub tokens_limit: u64,
310    /// Current period start time (Unix timestamp)
311    pub period_start: u64,
312}
313
314impl BudgetAlert {
315    /// Get the usage percentage.
316    pub fn usage_percent(&self) -> f64 {
317        if self.tokens_limit == 0 {
318            return 0.0;
319        }
320        (self.tokens_used as f64 / self.tokens_limit as f64) * 100.0
321    }
322}
323
324/// Current budget status for a tenant.
325#[derive(Debug, Clone)]
326pub struct TenantBudgetStatus {
327    /// Tokens used in current period
328    pub tokens_used: u64,
329    /// Budget limit
330    pub tokens_limit: u64,
331    /// Tokens remaining
332    pub tokens_remaining: u64,
333    /// Usage percentage
334    pub usage_percent: f64,
335    /// Period start time (Unix timestamp)
336    pub period_start: u64,
337    /// Period end time (Unix timestamp)
338    pub period_end: u64,
339    /// Whether budget is exhausted
340    pub exhausted: bool,
341}
342
343/// Result of a cost calculation.
344#[derive(Debug, Clone)]
345pub struct CostResult {
346    /// Cost for input tokens
347    pub input_cost: f64,
348    /// Cost for output tokens
349    pub output_cost: f64,
350    /// Total cost (input + output)
351    pub total_cost: f64,
352    /// Currency
353    pub currency: String,
354    /// Model that was used
355    pub model: String,
356    /// Number of input tokens
357    pub input_tokens: u64,
358    /// Number of output tokens
359    pub output_tokens: u64,
360}
361
362impl CostResult {
363    /// Create a new cost result.
364    pub fn new(
365        model: impl Into<String>,
366        input_tokens: u64,
367        output_tokens: u64,
368        input_cost: f64,
369        output_cost: f64,
370        currency: impl Into<String>,
371    ) -> Self {
372        Self {
373            input_cost,
374            output_cost,
375            total_cost: input_cost + output_cost,
376            currency: currency.into(),
377            model: model.into(),
378            input_tokens,
379            output_tokens,
380        }
381    }
382}
383
384// ============================================================================
385// Tests
386// ============================================================================
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_budget_period_as_secs() {
394        assert_eq!(BudgetPeriod::Hourly.as_secs(), 3600);
395        assert_eq!(BudgetPeriod::Daily.as_secs(), 86400);
396        assert_eq!(BudgetPeriod::Monthly.as_secs(), 2_592_000);
397        assert_eq!(BudgetPeriod::Custom { seconds: 7200 }.as_secs(), 7200);
398    }
399
400    #[test]
401    fn test_model_pricing_exact_match() {
402        let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
403        assert!(pricing.matches("gpt-4"));
404        assert!(!pricing.matches("gpt-4-turbo"));
405        assert!(!pricing.matches("gpt-3.5"));
406    }
407
408    #[test]
409    fn test_model_pricing_prefix_match() {
410        let pricing = ModelPricing::new("gpt-4*", 30.0, 60.0);
411        assert!(pricing.matches("gpt-4"));
412        assert!(pricing.matches("gpt-4-turbo"));
413        assert!(pricing.matches("gpt-4o"));
414        assert!(!pricing.matches("gpt-3.5"));
415    }
416
417    #[test]
418    fn test_model_pricing_suffix_match() {
419        let pricing = ModelPricing::new("*-turbo", 30.0, 60.0);
420        assert!(pricing.matches("gpt-4-turbo"));
421        assert!(pricing.matches("gpt-3.5-turbo"));
422        assert!(!pricing.matches("gpt-4"));
423    }
424
425    #[test]
426    fn test_model_pricing_contains_match() {
427        let pricing = ModelPricing::new("*claude*", 30.0, 60.0);
428        assert!(pricing.matches("claude-3"));
429        assert!(pricing.matches("anthropic-claude-3-opus"));
430        assert!(!pricing.matches("gpt-4"));
431    }
432
433    #[test]
434    fn test_model_pricing_calculate_cost() {
435        let pricing = ModelPricing::new("gpt-4", 30.0, 60.0);
436
437        // 1M input tokens = $30, 1M output tokens = $60
438        let cost = pricing.calculate_cost(1_000_000, 1_000_000);
439        assert!((cost - 90.0).abs() < 0.001);
440
441        // 1000 input tokens, 500 output tokens
442        let cost = pricing.calculate_cost(1000, 500);
443        let expected = (1000.0 / 1_000_000.0) * 30.0 + (500.0 / 1_000_000.0) * 60.0;
444        assert!((cost - expected).abs() < 0.0001);
445    }
446
447    #[test]
448    fn test_budget_check_result_is_allowed() {
449        assert!(BudgetCheckResult::Allowed { remaining: 1000 }.is_allowed());
450        assert!(BudgetCheckResult::Soft { remaining: -100, over_by: 100 }.is_allowed());
451        assert!(!BudgetCheckResult::Exhausted { retry_after_secs: 3600 }.is_allowed());
452    }
453
454    #[test]
455    fn test_budget_alert_usage_percent() {
456        let alert = BudgetAlert {
457            tenant: "test".to_string(),
458            threshold: 0.80,
459            tokens_used: 800_000,
460            tokens_limit: 1_000_000,
461            period_start: 0,
462        };
463        assert!((alert.usage_percent() - 80.0).abs() < 0.001);
464    }
465
466    #[test]
467    fn test_cost_result_new() {
468        let result = CostResult::new("gpt-4", 1000, 500, 0.03, 0.03, "USD");
469        assert_eq!(result.model, "gpt-4");
470        assert_eq!(result.input_tokens, 1000);
471        assert_eq!(result.output_tokens, 500);
472        assert!((result.total_cost - 0.06).abs() < 0.001);
473    }
474
475    #[test]
476    fn test_token_budget_config_default() {
477        let config = TokenBudgetConfig::default();
478        assert_eq!(config.period, BudgetPeriod::Daily);
479        assert_eq!(config.limit, 1_000_000);
480        assert!(config.enforce);
481        assert!(!config.rollover);
482        assert!(config.burst_allowance.is_none());
483        assert_eq!(config.alert_thresholds, vec![0.80, 0.90, 0.95]);
484    }
485
486    #[test]
487    fn test_cost_attribution_config_default() {
488        let config = CostAttributionConfig::default();
489        assert!(!config.enabled);
490        assert!(config.pricing.is_empty());
491        assert!((config.default_input_cost - 1.0).abs() < 0.001);
492        assert!((config.default_output_cost - 2.0).abs() < 0.001);
493        assert_eq!(config.currency, "USD");
494    }
495}