1use crate::error::StorageError;
4
5use super::accounts::DEFAULT_ACCOUNT_ID;
6use super::DbPool;
7
8#[derive(Debug, serde::Serialize)]
10pub struct CostSummary {
11 pub cost_today: f64,
12 pub cost_7d: f64,
13 pub cost_30d: f64,
14 pub cost_all_time: f64,
15 pub calls_today: i64,
16 pub calls_7d: i64,
17 pub calls_30d: i64,
18 pub calls_all_time: i64,
19}
20
21#[derive(Debug, serde::Serialize)]
23pub struct DailyCostSummary {
24 pub date: String,
25 pub cost: f64,
26 pub calls: i64,
27 pub input_tokens: i64,
28 pub output_tokens: i64,
29}
30
31#[derive(Debug, serde::Serialize)]
33pub struct ModelCostBreakdown {
34 pub provider: String,
35 pub model: String,
36 pub cost: f64,
37 pub calls: i64,
38 pub input_tokens: i64,
39 pub output_tokens: i64,
40}
41
42#[derive(Debug, serde::Serialize)]
44pub struct TypeCostBreakdown {
45 pub generation_type: String,
46 pub cost: f64,
47 pub calls: i64,
48 pub avg_cost: f64,
49}
50
51#[allow(clippy::too_many_arguments)]
53pub async fn insert_llm_usage_for(
54 pool: &DbPool,
55 account_id: &str,
56 generation_type: &str,
57 provider: &str,
58 model: &str,
59 input_tokens: u32,
60 output_tokens: u32,
61 cost_usd: f64,
62) -> Result<(), StorageError> {
63 sqlx::query(
64 "INSERT INTO llm_usage (account_id, generation_type, provider, model, input_tokens, output_tokens, cost_usd)
65 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
66 )
67 .bind(account_id)
68 .bind(generation_type)
69 .bind(provider)
70 .bind(model)
71 .bind(input_tokens)
72 .bind(output_tokens)
73 .bind(cost_usd)
74 .execute(pool)
75 .await
76 .map_err(|e| StorageError::Query { source: e })?;
77 Ok(())
78}
79
80pub async fn insert_llm_usage(
82 pool: &DbPool,
83 generation_type: &str,
84 provider: &str,
85 model: &str,
86 input_tokens: u32,
87 output_tokens: u32,
88 cost_usd: f64,
89) -> Result<(), StorageError> {
90 insert_llm_usage_for(
91 pool,
92 DEFAULT_ACCOUNT_ID,
93 generation_type,
94 provider,
95 model,
96 input_tokens,
97 output_tokens,
98 cost_usd,
99 )
100 .await
101}
102
103pub async fn get_cost_summary_for(
105 pool: &DbPool,
106 account_id: &str,
107) -> Result<CostSummary, StorageError> {
108 let row: (f64, i64, f64, i64, f64, i64, f64, i64) = sqlx::query_as(
109 "SELECT
110 COALESCE(SUM(CASE WHEN created_at >= date('now') THEN cost_usd ELSE 0.0 END), 0.0),
111 COALESCE(SUM(CASE WHEN created_at >= date('now') THEN 1 ELSE 0 END), 0),
112 COALESCE(SUM(CASE WHEN created_at >= date('now', '-7 days') THEN cost_usd ELSE 0.0 END), 0.0),
113 COALESCE(SUM(CASE WHEN created_at >= date('now', '-7 days') THEN 1 ELSE 0 END), 0),
114 COALESCE(SUM(CASE WHEN created_at >= date('now', '-30 days') THEN cost_usd ELSE 0.0 END), 0.0),
115 COALESCE(SUM(CASE WHEN created_at >= date('now', '-30 days') THEN 1 ELSE 0 END), 0),
116 COALESCE(SUM(cost_usd), 0.0),
117 COUNT(*)
118 FROM llm_usage
119 WHERE account_id = ?",
120 )
121 .bind(account_id)
122 .fetch_one(pool)
123 .await
124 .map_err(|e| StorageError::Query { source: e })?;
125
126 Ok(CostSummary {
127 cost_today: row.0,
128 calls_today: row.1,
129 cost_7d: row.2,
130 calls_7d: row.3,
131 cost_30d: row.4,
132 calls_30d: row.5,
133 cost_all_time: row.6,
134 calls_all_time: row.7,
135 })
136}
137
138pub async fn get_cost_summary(pool: &DbPool) -> Result<CostSummary, StorageError> {
140 get_cost_summary_for(pool, DEFAULT_ACCOUNT_ID).await
141}
142
143pub async fn get_daily_costs_for(
145 pool: &DbPool,
146 account_id: &str,
147 days: u32,
148) -> Result<Vec<DailyCostSummary>, StorageError> {
149 let rows: Vec<(String, f64, i64, i64, i64)> = sqlx::query_as(
150 "SELECT
151 date(created_at) as day,
152 COALESCE(SUM(cost_usd), 0.0),
153 COUNT(*),
154 COALESCE(SUM(input_tokens), 0),
155 COALESCE(SUM(output_tokens), 0)
156 FROM llm_usage
157 WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
158 GROUP BY day
159 ORDER BY day",
160 )
161 .bind(account_id)
162 .bind(days)
163 .fetch_all(pool)
164 .await
165 .map_err(|e| StorageError::Query { source: e })?;
166
167 Ok(rows
168 .into_iter()
169 .map(
170 |(date, cost, calls, input_tokens, output_tokens)| DailyCostSummary {
171 date,
172 cost,
173 calls,
174 input_tokens,
175 output_tokens,
176 },
177 )
178 .collect())
179}
180
181pub async fn get_daily_costs(
183 pool: &DbPool,
184 days: u32,
185) -> Result<Vec<DailyCostSummary>, StorageError> {
186 get_daily_costs_for(pool, DEFAULT_ACCOUNT_ID, days).await
187}
188
189pub async fn get_model_breakdown_for(
191 pool: &DbPool,
192 account_id: &str,
193 days: u32,
194) -> Result<Vec<ModelCostBreakdown>, StorageError> {
195 let rows: Vec<(String, String, f64, i64, i64, i64)> = sqlx::query_as(
196 "SELECT
197 provider,
198 model,
199 COALESCE(SUM(cost_usd), 0.0),
200 COUNT(*),
201 COALESCE(SUM(input_tokens), 0),
202 COALESCE(SUM(output_tokens), 0)
203 FROM llm_usage
204 WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
205 GROUP BY provider, model
206 ORDER BY SUM(cost_usd) DESC",
207 )
208 .bind(account_id)
209 .bind(days)
210 .fetch_all(pool)
211 .await
212 .map_err(|e| StorageError::Query { source: e })?;
213
214 Ok(rows
215 .into_iter()
216 .map(
217 |(provider, model, cost, calls, input_tokens, output_tokens)| ModelCostBreakdown {
218 provider,
219 model,
220 cost,
221 calls,
222 input_tokens,
223 output_tokens,
224 },
225 )
226 .collect())
227}
228
229pub async fn get_model_breakdown(
231 pool: &DbPool,
232 days: u32,
233) -> Result<Vec<ModelCostBreakdown>, StorageError> {
234 get_model_breakdown_for(pool, DEFAULT_ACCOUNT_ID, days).await
235}
236
237pub async fn get_type_breakdown_for(
239 pool: &DbPool,
240 account_id: &str,
241 days: u32,
242) -> Result<Vec<TypeCostBreakdown>, StorageError> {
243 let rows: Vec<(String, f64, i64)> = sqlx::query_as(
244 "SELECT
245 generation_type,
246 COALESCE(SUM(cost_usd), 0.0),
247 COUNT(*)
248 FROM llm_usage
249 WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
250 GROUP BY generation_type
251 ORDER BY SUM(cost_usd) DESC",
252 )
253 .bind(account_id)
254 .bind(days)
255 .fetch_all(pool)
256 .await
257 .map_err(|e| StorageError::Query { source: e })?;
258
259 Ok(rows
260 .into_iter()
261 .map(|(generation_type, cost, calls)| {
262 let avg_cost = if calls > 0 { cost / calls as f64 } else { 0.0 };
263 TypeCostBreakdown {
264 generation_type,
265 cost,
266 calls,
267 avg_cost,
268 }
269 })
270 .collect())
271}
272
273pub async fn get_type_breakdown(
275 pool: &DbPool,
276 days: u32,
277) -> Result<Vec<TypeCostBreakdown>, StorageError> {
278 get_type_breakdown_for(pool, DEFAULT_ACCOUNT_ID, days).await
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::storage::init_test_db;
285
286 #[tokio::test]
287 async fn insert_and_query_summary() {
288 let pool = init_test_db().await.expect("init db");
289
290 insert_llm_usage(&pool, "reply", "openai", "gpt-4o-mini", 100, 50, 0.000045)
291 .await
292 .expect("insert");
293
294 insert_llm_usage(&pool, "tweet", "openai", "gpt-4o-mini", 200, 80, 0.000063)
295 .await
296 .expect("insert");
297
298 let summary = get_cost_summary(&pool).await.expect("summary");
299 assert_eq!(summary.calls_all_time, 2);
300 assert!(summary.cost_all_time > 0.0);
301 }
302
303 #[tokio::test]
304 async fn model_breakdown_groups_correctly() {
305 let pool = init_test_db().await.expect("init db");
306
307 insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
308 .await
309 .expect("insert");
310 insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
311 .await
312 .expect("insert");
313 insert_llm_usage(&pool, "reply", "anthropic", "claude-sonnet", 100, 50, 0.002)
314 .await
315 .expect("insert");
316
317 let breakdown = get_model_breakdown(&pool, 30).await.expect("breakdown");
318 assert_eq!(breakdown.len(), 2);
319 }
320
321 #[tokio::test]
322 async fn type_breakdown_groups_correctly() {
323 let pool = init_test_db().await.expect("init db");
324
325 insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
326 .await
327 .expect("insert");
328 insert_llm_usage(&pool, "tweet", "openai", "gpt-4o", 100, 50, 0.001)
329 .await
330 .expect("insert");
331 insert_llm_usage(&pool, "thread", "openai", "gpt-4o", 100, 50, 0.001)
332 .await
333 .expect("insert");
334
335 let breakdown = get_type_breakdown(&pool, 30).await.expect("breakdown");
336 assert_eq!(breakdown.len(), 3);
337 }
338
339 #[tokio::test]
340 async fn empty_table_returns_zero_summary() {
341 let pool = init_test_db().await.expect("init db");
342
343 let summary = get_cost_summary(&pool).await.expect("summary");
344 assert_eq!(summary.calls_all_time, 0);
345 assert!((summary.cost_all_time).abs() < f64::EPSILON);
346 }
347}