Skip to main content

synaptic_callbacks/
cost_tracking.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{CallbackHandler, RunEvent, SynapticError, TokenUsage};
6use tokio::sync::RwLock;
7
8/// Per-model cost rates (USD per 1M tokens).
9#[derive(Debug, Clone)]
10pub struct ModelPricing {
11    /// Cost per 1M input/prompt tokens.
12    pub input_per_million: f64,
13    /// Cost per 1M output/completion tokens.
14    pub output_per_million: f64,
15}
16
17impl ModelPricing {
18    pub fn new(input_per_million: f64, output_per_million: f64) -> Self {
19        Self {
20            input_per_million,
21            output_per_million,
22        }
23    }
24}
25
26/// Accumulated usage stats for cost tracking.
27#[derive(Debug, Clone, Default)]
28pub struct UsageSnapshot {
29    pub total_input_tokens: u64,
30    pub total_output_tokens: u64,
31    pub total_requests: u64,
32    pub estimated_cost_usd: f64,
33    /// Per-model breakdown.
34    pub per_model: HashMap<String, ModelUsage>,
35}
36
37/// Per-model usage breakdown.
38#[derive(Debug, Clone, Default)]
39pub struct ModelUsage {
40    pub input_tokens: u64,
41    pub output_tokens: u64,
42    pub requests: u64,
43    pub cost_usd: f64,
44}
45
46/// Internal state.
47struct CostState {
48    usage: UsageSnapshot,
49    pricing: HashMap<String, ModelPricing>,
50    budget_limit: Option<f64>,
51    current_model: String,
52}
53
54/// Callback handler that tracks token usage and estimated cost across model calls.
55///
56/// Supports per-model pricing tables and optional budget limits. Query the
57/// accumulated snapshot via [`snapshot()`](CostTrackingCallback::snapshot).
58pub struct CostTrackingCallback {
59    state: Arc<RwLock<CostState>>,
60}
61
62impl CostTrackingCallback {
63    /// Create a new cost tracker with the given pricing table.
64    pub fn new(pricing: HashMap<String, ModelPricing>) -> Self {
65        Self {
66            state: Arc::new(RwLock::new(CostState {
67                usage: UsageSnapshot::default(),
68                pricing,
69                budget_limit: None,
70                current_model: String::new(),
71            })),
72        }
73    }
74
75    /// Set a budget limit in USD. Returns error via callback when exceeded.
76    pub fn with_budget(self, limit_usd: f64) -> Self {
77        // We'll set it after creation since we can't await in a non-async fn
78        let state = self.state.clone();
79        tokio::spawn(async move {
80            state.write().await.budget_limit = Some(limit_usd);
81        });
82        self
83    }
84
85    /// Set the current model name for cost attribution.
86    pub async fn set_model(&self, model_name: &str) {
87        self.state.write().await.current_model = model_name.to_string();
88    }
89
90    /// Record token usage from a model response.
91    pub async fn record_usage(&self, usage: &TokenUsage) {
92        let mut state = self.state.write().await;
93        let model = state.current_model.clone();
94
95        // Look up pricing before mutating per_model
96        let cost = state.pricing.get(&model).map(|pricing| {
97            (usage.input_tokens as f64 / 1_000_000.0) * pricing.input_per_million
98                + (usage.output_tokens as f64 / 1_000_000.0) * pricing.output_per_million
99        });
100
101        state.usage.total_input_tokens += usage.input_tokens as u64;
102        state.usage.total_output_tokens += usage.output_tokens as u64;
103        state.usage.total_requests += 1;
104
105        let entry = state.usage.per_model.entry(model).or_default();
106        entry.input_tokens += usage.input_tokens as u64;
107        entry.output_tokens += usage.output_tokens as u64;
108        entry.requests += 1;
109
110        if let Some(cost) = cost {
111            entry.cost_usd += cost;
112            state.usage.estimated_cost_usd += cost;
113        }
114    }
115
116    /// Get a snapshot of accumulated usage and costs.
117    pub async fn snapshot(&self) -> UsageSnapshot {
118        self.state.read().await.usage.clone()
119    }
120
121    /// Check if the budget has been exceeded.
122    pub async fn is_over_budget(&self) -> bool {
123        let state = self.state.read().await;
124        if let Some(limit) = state.budget_limit {
125            state.usage.estimated_cost_usd > limit
126        } else {
127            false
128        }
129    }
130}
131
132/// Build a default pricing table for common models (approximate, Feb 2026).
133pub fn default_pricing() -> HashMap<String, ModelPricing> {
134    let mut m = HashMap::new();
135    // OpenAI
136    m.insert("gpt-4o".to_string(), ModelPricing::new(2.5, 10.0));
137    m.insert("gpt-4o-mini".to_string(), ModelPricing::new(0.15, 0.6));
138    m.insert("o1".to_string(), ModelPricing::new(15.0, 60.0));
139    m.insert("o3-mini".to_string(), ModelPricing::new(1.1, 4.4));
140    // Anthropic
141    m.insert(
142        "claude-sonnet-4-20250514".to_string(),
143        ModelPricing::new(3.0, 15.0),
144    );
145    m.insert(
146        "claude-haiku-4-5-20251001".to_string(),
147        ModelPricing::new(0.8, 4.0),
148    );
149    m.insert(
150        "claude-opus-4-20250514".to_string(),
151        ModelPricing::new(15.0, 75.0),
152    );
153    // Gemini
154    m.insert("gemini-2.0-flash".to_string(), ModelPricing::new(0.1, 0.4));
155    m.insert("gemini-2.0-pro".to_string(), ModelPricing::new(1.25, 10.0));
156    m
157}
158
159#[async_trait]
160impl CallbackHandler for CostTrackingCallback {
161    async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
162        if let RunEvent::LlmCalled { .. } = event {
163            // Cost is tracked via record_usage() which is called externally
164            // when the actual TokenUsage is available from the response.
165        }
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[tokio::test]
175    async fn tracks_usage() {
176        let pricing = default_pricing();
177        let tracker = CostTrackingCallback::new(pricing);
178        tracker.set_model("gpt-4o").await;
179
180        let usage = TokenUsage {
181            input_tokens: 1000,
182            output_tokens: 500,
183            total_tokens: 1500,
184            input_details: None,
185            output_details: None,
186        };
187        tracker.record_usage(&usage).await;
188
189        let snap = tracker.snapshot().await;
190        assert_eq!(snap.total_input_tokens, 1000);
191        assert_eq!(snap.total_output_tokens, 500);
192        assert_eq!(snap.total_requests, 1);
193        assert!(snap.estimated_cost_usd > 0.0);
194    }
195
196    #[tokio::test]
197    async fn per_model_breakdown() {
198        let pricing = default_pricing();
199        let tracker = CostTrackingCallback::new(pricing);
200
201        tracker.set_model("gpt-4o").await;
202        tracker
203            .record_usage(&TokenUsage {
204                input_tokens: 100,
205                output_tokens: 50,
206                total_tokens: 0,
207                input_details: None,
208                output_details: None,
209            })
210            .await;
211
212        tracker.set_model("gpt-4o-mini").await;
213        tracker
214            .record_usage(&TokenUsage {
215                input_tokens: 200,
216                output_tokens: 100,
217                total_tokens: 0,
218                input_details: None,
219                output_details: None,
220            })
221            .await;
222
223        let snap = tracker.snapshot().await;
224        assert_eq!(snap.per_model.len(), 2);
225        assert_eq!(snap.total_requests, 2);
226    }
227
228    #[test]
229    fn default_pricing_has_models() {
230        let p = default_pricing();
231        assert!(p.contains_key("gpt-4o"));
232        assert!(p.contains_key("claude-sonnet-4-20250514"));
233    }
234}