Skip to main content

tuitbot_core/storage/
llm_usage.rs

1//! LLM usage tracking — stores per-call token counts and costs.
2
3use crate::error::StorageError;
4
5use super::accounts::DEFAULT_ACCOUNT_ID;
6use super::DbPool;
7
8/// Summary of costs across multiple time windows.
9#[derive(Debug, serde::Serialize)]
10pub struct CostSummary {
11    pub cost_today: f64,
12    pub cost_7d: f64,
13    pub cost_30d: f64,
14    pub cost_all_time: f64,
15    pub calls_today: i64,
16    pub calls_7d: i64,
17    pub calls_30d: i64,
18    pub calls_all_time: i64,
19}
20
21/// Daily cost aggregation for chart data.
22#[derive(Debug, serde::Serialize)]
23pub struct DailyCostSummary {
24    pub date: String,
25    pub cost: f64,
26    pub calls: i64,
27    pub input_tokens: i64,
28    pub output_tokens: i64,
29}
30
31/// Cost breakdown by provider + model.
32#[derive(Debug, serde::Serialize)]
33pub struct ModelCostBreakdown {
34    pub provider: String,
35    pub model: String,
36    pub cost: f64,
37    pub calls: i64,
38    pub input_tokens: i64,
39    pub output_tokens: i64,
40}
41
42/// Cost breakdown by generation type (reply/tweet/thread).
43#[derive(Debug, serde::Serialize)]
44pub struct TypeCostBreakdown {
45    pub generation_type: String,
46    pub cost: f64,
47    pub calls: i64,
48    pub avg_cost: f64,
49}
50
51/// Insert a new LLM usage record for a specific account.
52#[allow(clippy::too_many_arguments)]
53pub async fn insert_llm_usage_for(
54    pool: &DbPool,
55    account_id: &str,
56    generation_type: &str,
57    provider: &str,
58    model: &str,
59    input_tokens: u32,
60    output_tokens: u32,
61    cost_usd: f64,
62) -> Result<(), StorageError> {
63    sqlx::query(
64        "INSERT INTO llm_usage (account_id, generation_type, provider, model, input_tokens, output_tokens, cost_usd)
65         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
66    )
67    .bind(account_id)
68    .bind(generation_type)
69    .bind(provider)
70    .bind(model)
71    .bind(input_tokens)
72    .bind(output_tokens)
73    .bind(cost_usd)
74    .execute(pool)
75    .await
76    .map_err(|e| StorageError::Query { source: e })?;
77    Ok(())
78}
79
80/// Insert a new LLM usage record.
81pub async fn insert_llm_usage(
82    pool: &DbPool,
83    generation_type: &str,
84    provider: &str,
85    model: &str,
86    input_tokens: u32,
87    output_tokens: u32,
88    cost_usd: f64,
89) -> Result<(), StorageError> {
90    insert_llm_usage_for(
91        pool,
92        DEFAULT_ACCOUNT_ID,
93        generation_type,
94        provider,
95        model,
96        input_tokens,
97        output_tokens,
98        cost_usd,
99    )
100    .await
101}
102
103/// Get cost summary across time windows for a specific account.
104pub async fn get_cost_summary_for(
105    pool: &DbPool,
106    account_id: &str,
107) -> Result<CostSummary, StorageError> {
108    let row: (f64, i64, f64, i64, f64, i64, f64, i64) = sqlx::query_as(
109        "SELECT
110            COALESCE(SUM(CASE WHEN created_at >= date('now') THEN cost_usd ELSE 0.0 END), 0.0),
111            COALESCE(SUM(CASE WHEN created_at >= date('now') THEN 1 ELSE 0 END), 0),
112            COALESCE(SUM(CASE WHEN created_at >= date('now', '-7 days') THEN cost_usd ELSE 0.0 END), 0.0),
113            COALESCE(SUM(CASE WHEN created_at >= date('now', '-7 days') THEN 1 ELSE 0 END), 0),
114            COALESCE(SUM(CASE WHEN created_at >= date('now', '-30 days') THEN cost_usd ELSE 0.0 END), 0.0),
115            COALESCE(SUM(CASE WHEN created_at >= date('now', '-30 days') THEN 1 ELSE 0 END), 0),
116            COALESCE(SUM(cost_usd), 0.0),
117            COUNT(*)
118        FROM llm_usage
119        WHERE account_id = ?",
120    )
121    .bind(account_id)
122    .fetch_one(pool)
123    .await
124    .map_err(|e| StorageError::Query { source: e })?;
125
126    Ok(CostSummary {
127        cost_today: row.0,
128        calls_today: row.1,
129        cost_7d: row.2,
130        calls_7d: row.3,
131        cost_30d: row.4,
132        calls_30d: row.5,
133        cost_all_time: row.6,
134        calls_all_time: row.7,
135    })
136}
137
138/// Get cost summary across time windows.
139pub async fn get_cost_summary(pool: &DbPool) -> Result<CostSummary, StorageError> {
140    get_cost_summary_for(pool, DEFAULT_ACCOUNT_ID).await
141}
142
143/// Get daily cost aggregation for chart data for a specific account.
144pub async fn get_daily_costs_for(
145    pool: &DbPool,
146    account_id: &str,
147    days: u32,
148) -> Result<Vec<DailyCostSummary>, StorageError> {
149    let rows: Vec<(String, f64, i64, i64, i64)> = sqlx::query_as(
150        "SELECT
151            date(created_at) as day,
152            COALESCE(SUM(cost_usd), 0.0),
153            COUNT(*),
154            COALESCE(SUM(input_tokens), 0),
155            COALESCE(SUM(output_tokens), 0)
156        FROM llm_usage
157        WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
158        GROUP BY day
159        ORDER BY day",
160    )
161    .bind(account_id)
162    .bind(days)
163    .fetch_all(pool)
164    .await
165    .map_err(|e| StorageError::Query { source: e })?;
166
167    Ok(rows
168        .into_iter()
169        .map(
170            |(date, cost, calls, input_tokens, output_tokens)| DailyCostSummary {
171                date,
172                cost,
173                calls,
174                input_tokens,
175                output_tokens,
176            },
177        )
178        .collect())
179}
180
181/// Get daily cost aggregation for chart data.
182pub async fn get_daily_costs(
183    pool: &DbPool,
184    days: u32,
185) -> Result<Vec<DailyCostSummary>, StorageError> {
186    get_daily_costs_for(pool, DEFAULT_ACCOUNT_ID, days).await
187}
188
189/// Get cost breakdown by provider + model for a specific account.
190pub async fn get_model_breakdown_for(
191    pool: &DbPool,
192    account_id: &str,
193    days: u32,
194) -> Result<Vec<ModelCostBreakdown>, StorageError> {
195    let rows: Vec<(String, String, f64, i64, i64, i64)> = sqlx::query_as(
196        "SELECT
197            provider,
198            model,
199            COALESCE(SUM(cost_usd), 0.0),
200            COUNT(*),
201            COALESCE(SUM(input_tokens), 0),
202            COALESCE(SUM(output_tokens), 0)
203        FROM llm_usage
204        WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
205        GROUP BY provider, model
206        ORDER BY SUM(cost_usd) DESC",
207    )
208    .bind(account_id)
209    .bind(days)
210    .fetch_all(pool)
211    .await
212    .map_err(|e| StorageError::Query { source: e })?;
213
214    Ok(rows
215        .into_iter()
216        .map(
217            |(provider, model, cost, calls, input_tokens, output_tokens)| ModelCostBreakdown {
218                provider,
219                model,
220                cost,
221                calls,
222                input_tokens,
223                output_tokens,
224            },
225        )
226        .collect())
227}
228
229/// Get cost breakdown by provider + model.
230pub async fn get_model_breakdown(
231    pool: &DbPool,
232    days: u32,
233) -> Result<Vec<ModelCostBreakdown>, StorageError> {
234    get_model_breakdown_for(pool, DEFAULT_ACCOUNT_ID, days).await
235}
236
237/// Get cost breakdown by generation type for a specific account.
238pub async fn get_type_breakdown_for(
239    pool: &DbPool,
240    account_id: &str,
241    days: u32,
242) -> Result<Vec<TypeCostBreakdown>, StorageError> {
243    let rows: Vec<(String, f64, i64)> = sqlx::query_as(
244        "SELECT
245            generation_type,
246            COALESCE(SUM(cost_usd), 0.0),
247            COUNT(*)
248        FROM llm_usage
249        WHERE account_id = ? AND created_at >= date('now', '-' || ? || ' days')
250        GROUP BY generation_type
251        ORDER BY SUM(cost_usd) DESC",
252    )
253    .bind(account_id)
254    .bind(days)
255    .fetch_all(pool)
256    .await
257    .map_err(|e| StorageError::Query { source: e })?;
258
259    Ok(rows
260        .into_iter()
261        .map(|(generation_type, cost, calls)| {
262            let avg_cost = if calls > 0 { cost / calls as f64 } else { 0.0 };
263            TypeCostBreakdown {
264                generation_type,
265                cost,
266                calls,
267                avg_cost,
268            }
269        })
270        .collect())
271}
272
273/// Get cost breakdown by generation type.
274pub async fn get_type_breakdown(
275    pool: &DbPool,
276    days: u32,
277) -> Result<Vec<TypeCostBreakdown>, StorageError> {
278    get_type_breakdown_for(pool, DEFAULT_ACCOUNT_ID, days).await
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::storage::init_test_db;
285
286    #[tokio::test]
287    async fn insert_and_query_summary() {
288        let pool = init_test_db().await.expect("init db");
289
290        insert_llm_usage(&pool, "reply", "openai", "gpt-4o-mini", 100, 50, 0.000045)
291            .await
292            .expect("insert");
293
294        insert_llm_usage(&pool, "tweet", "openai", "gpt-4o-mini", 200, 80, 0.000063)
295            .await
296            .expect("insert");
297
298        let summary = get_cost_summary(&pool).await.expect("summary");
299        assert_eq!(summary.calls_all_time, 2);
300        assert!(summary.cost_all_time > 0.0);
301    }
302
303    #[tokio::test]
304    async fn model_breakdown_groups_correctly() {
305        let pool = init_test_db().await.expect("init db");
306
307        insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
308            .await
309            .expect("insert");
310        insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
311            .await
312            .expect("insert");
313        insert_llm_usage(&pool, "reply", "anthropic", "claude-sonnet", 100, 50, 0.002)
314            .await
315            .expect("insert");
316
317        let breakdown = get_model_breakdown(&pool, 30).await.expect("breakdown");
318        assert_eq!(breakdown.len(), 2);
319    }
320
321    #[tokio::test]
322    async fn type_breakdown_groups_correctly() {
323        let pool = init_test_db().await.expect("init db");
324
325        insert_llm_usage(&pool, "reply", "openai", "gpt-4o", 100, 50, 0.001)
326            .await
327            .expect("insert");
328        insert_llm_usage(&pool, "tweet", "openai", "gpt-4o", 100, 50, 0.001)
329            .await
330            .expect("insert");
331        insert_llm_usage(&pool, "thread", "openai", "gpt-4o", 100, 50, 0.001)
332            .await
333            .expect("insert");
334
335        let breakdown = get_type_breakdown(&pool, 30).await.expect("breakdown");
336        assert_eq!(breakdown.len(), 3);
337    }
338
339    #[tokio::test]
340    async fn empty_table_returns_zero_summary() {
341        let pool = init_test_db().await.expect("init db");
342
343        let summary = get_cost_summary(&pool).await.expect("summary");
344        assert_eq!(summary.calls_all_time, 0);
345        assert!((summary.cost_all_time).abs() < f64::EPSILON);
346    }
347}