synaptic_callbacks/
cost_tracking.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use synaptic_core::{CallbackHandler, RunEvent, SynapticError, TokenUsage};
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone)]
10pub struct ModelPricing {
11 pub input_per_million: f64,
13 pub output_per_million: f64,
15}
16
17impl ModelPricing {
18 pub fn new(input_per_million: f64, output_per_million: f64) -> Self {
19 Self {
20 input_per_million,
21 output_per_million,
22 }
23 }
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct UsageSnapshot {
29 pub total_input_tokens: u64,
30 pub total_output_tokens: u64,
31 pub total_requests: u64,
32 pub estimated_cost_usd: f64,
33 pub per_model: HashMap<String, ModelUsage>,
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct ModelUsage {
40 pub input_tokens: u64,
41 pub output_tokens: u64,
42 pub requests: u64,
43 pub cost_usd: f64,
44}
45
46struct CostState {
48 usage: UsageSnapshot,
49 pricing: HashMap<String, ModelPricing>,
50 budget_limit: Option<f64>,
51 current_model: String,
52}
53
54pub struct CostTrackingCallback {
59 state: Arc<RwLock<CostState>>,
60}
61
62impl CostTrackingCallback {
63 pub fn new(pricing: HashMap<String, ModelPricing>) -> Self {
65 Self {
66 state: Arc::new(RwLock::new(CostState {
67 usage: UsageSnapshot::default(),
68 pricing,
69 budget_limit: None,
70 current_model: String::new(),
71 })),
72 }
73 }
74
75 pub fn with_budget(self, limit_usd: f64) -> Self {
77 let state = self.state.clone();
79 tokio::spawn(async move {
80 state.write().await.budget_limit = Some(limit_usd);
81 });
82 self
83 }
84
85 pub async fn set_model(&self, model_name: &str) {
87 self.state.write().await.current_model = model_name.to_string();
88 }
89
90 pub async fn record_usage(&self, usage: &TokenUsage) {
92 let mut state = self.state.write().await;
93 let model = state.current_model.clone();
94
95 let cost = state.pricing.get(&model).map(|pricing| {
97 (usage.input_tokens as f64 / 1_000_000.0) * pricing.input_per_million
98 + (usage.output_tokens as f64 / 1_000_000.0) * pricing.output_per_million
99 });
100
101 state.usage.total_input_tokens += usage.input_tokens as u64;
102 state.usage.total_output_tokens += usage.output_tokens as u64;
103 state.usage.total_requests += 1;
104
105 let entry = state.usage.per_model.entry(model).or_default();
106 entry.input_tokens += usage.input_tokens as u64;
107 entry.output_tokens += usage.output_tokens as u64;
108 entry.requests += 1;
109
110 if let Some(cost) = cost {
111 entry.cost_usd += cost;
112 state.usage.estimated_cost_usd += cost;
113 }
114 }
115
116 pub async fn snapshot(&self) -> UsageSnapshot {
118 self.state.read().await.usage.clone()
119 }
120
121 pub async fn is_over_budget(&self) -> bool {
123 let state = self.state.read().await;
124 if let Some(limit) = state.budget_limit {
125 state.usage.estimated_cost_usd > limit
126 } else {
127 false
128 }
129 }
130}
131
132pub fn default_pricing() -> HashMap<String, ModelPricing> {
134 let mut m = HashMap::new();
135 m.insert("gpt-4o".to_string(), ModelPricing::new(2.5, 10.0));
137 m.insert("gpt-4o-mini".to_string(), ModelPricing::new(0.15, 0.6));
138 m.insert("o1".to_string(), ModelPricing::new(15.0, 60.0));
139 m.insert("o3-mini".to_string(), ModelPricing::new(1.1, 4.4));
140 m.insert(
142 "claude-sonnet-4-20250514".to_string(),
143 ModelPricing::new(3.0, 15.0),
144 );
145 m.insert(
146 "claude-haiku-4-5-20251001".to_string(),
147 ModelPricing::new(0.8, 4.0),
148 );
149 m.insert(
150 "claude-opus-4-20250514".to_string(),
151 ModelPricing::new(15.0, 75.0),
152 );
153 m.insert("gemini-2.0-flash".to_string(), ModelPricing::new(0.1, 0.4));
155 m.insert("gemini-2.0-pro".to_string(), ModelPricing::new(1.25, 10.0));
156 m
157}
158
159#[async_trait]
160impl CallbackHandler for CostTrackingCallback {
161 async fn on_event(&self, event: RunEvent) -> Result<(), SynapticError> {
162 if let RunEvent::LlmCalled { .. } = event {
163 }
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[tokio::test]
175 async fn tracks_usage() {
176 let pricing = default_pricing();
177 let tracker = CostTrackingCallback::new(pricing);
178 tracker.set_model("gpt-4o").await;
179
180 let usage = TokenUsage {
181 input_tokens: 1000,
182 output_tokens: 500,
183 total_tokens: 1500,
184 input_details: None,
185 output_details: None,
186 };
187 tracker.record_usage(&usage).await;
188
189 let snap = tracker.snapshot().await;
190 assert_eq!(snap.total_input_tokens, 1000);
191 assert_eq!(snap.total_output_tokens, 500);
192 assert_eq!(snap.total_requests, 1);
193 assert!(snap.estimated_cost_usd > 0.0);
194 }
195
196 #[tokio::test]
197 async fn per_model_breakdown() {
198 let pricing = default_pricing();
199 let tracker = CostTrackingCallback::new(pricing);
200
201 tracker.set_model("gpt-4o").await;
202 tracker
203 .record_usage(&TokenUsage {
204 input_tokens: 100,
205 output_tokens: 50,
206 total_tokens: 0,
207 input_details: None,
208 output_details: None,
209 })
210 .await;
211
212 tracker.set_model("gpt-4o-mini").await;
213 tracker
214 .record_usage(&TokenUsage {
215 input_tokens: 200,
216 output_tokens: 100,
217 total_tokens: 0,
218 input_details: None,
219 output_details: None,
220 })
221 .await;
222
223 let snap = tracker.snapshot().await;
224 assert_eq!(snap.per_model.len(), 2);
225 assert_eq!(snap.total_requests, 2);
226 }
227
228 #[test]
229 fn default_pricing_has_models() {
230 let p = default_pricing();
231 assert!(p.contains_key("gpt-4o"));
232 assert!(p.contains_key("claude-sonnet-4-20250514"));
233 }
234}