Skip to main content

systemprompt_analytics/repository/
requests.rs

1use crate::Result;
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::UserId;
7
8use crate::models::cli::{ModelUsageRow, RequestListRow, RequestStatsRow, RequestTrendRow};
9
10#[derive(Debug)]
11pub struct RequestAnalyticsRepository {
12    pool: Arc<PgPool>,
13}
14
15impl RequestAnalyticsRepository {
16    pub fn new(db: &DbPool) -> Result<Self> {
17        let pool = db.pool_arc()?;
18        Ok(Self { pool })
19    }
20
21    pub async fn get_stats(
22        &self,
23        start: DateTime<Utc>,
24        end: DateTime<Utc>,
25        model_filter: Option<&str>,
26    ) -> Result<RequestStatsRow> {
27        if let Some(model) = model_filter {
28            let pattern = format!("%{}%", model);
29            sqlx::query_as!(
30                RequestStatsRow,
31                r#"
32                SELECT
33                    COUNT(*)::bigint as "total!",
34                    SUM(tokens_used)::bigint as "total_tokens",
35                    SUM(input_tokens)::bigint as "input_tokens",
36                    SUM(output_tokens)::bigint as "output_tokens",
37                    SUM(cost_microdollars)::bigint as "cost",
38                    AVG(latency_ms)::float8 as "avg_latency",
39                    COUNT(*) FILTER (WHERE cache_hit = true)::bigint as "cache_hits!"
40                FROM ai_requests
41                WHERE created_at >= $1 AND created_at < $2
42                  AND model ILIKE $3
43                "#,
44                start,
45                end,
46                pattern
47            )
48            .fetch_one(&*self.pool)
49            .await
50            .map_err(Into::into)
51        } else {
52            sqlx::query_as!(
53                RequestStatsRow,
54                r#"
55                SELECT
56                    COUNT(*)::bigint as "total!",
57                    SUM(tokens_used)::bigint as "total_tokens",
58                    SUM(input_tokens)::bigint as "input_tokens",
59                    SUM(output_tokens)::bigint as "output_tokens",
60                    SUM(cost_microdollars)::bigint as "cost",
61                    AVG(latency_ms)::float8 as "avg_latency",
62                    COUNT(*) FILTER (WHERE cache_hit = true)::bigint as "cache_hits!"
63                FROM ai_requests
64                WHERE created_at >= $1 AND created_at < $2
65                "#,
66                start,
67                end
68            )
69            .fetch_one(&*self.pool)
70            .await
71            .map_err(Into::into)
72        }
73    }
74
75    pub async fn list_models(
76        &self,
77        start: DateTime<Utc>,
78        end: DateTime<Utc>,
79        limit: i64,
80    ) -> Result<Vec<ModelUsageRow>> {
81        sqlx::query_as!(
82            ModelUsageRow,
83            r#"
84            SELECT
85                provider as "provider!",
86                model as "model!",
87                COUNT(*)::bigint as "request_count!",
88                SUM(tokens_used)::bigint as "total_tokens",
89                SUM(cost_microdollars)::bigint as "total_cost",
90                AVG(latency_ms)::float8 as "avg_latency"
91            FROM ai_requests
92            WHERE created_at >= $1 AND created_at < $2
93            GROUP BY provider, model
94            ORDER BY COUNT(*) DESC
95            LIMIT $3
96            "#,
97            start,
98            end,
99            limit
100        )
101        .fetch_all(&*self.pool)
102        .await
103        .map_err(Into::into)
104    }
105
106    pub async fn get_requests_for_trends(
107        &self,
108        start: DateTime<Utc>,
109        end: DateTime<Utc>,
110    ) -> Result<Vec<RequestTrendRow>> {
111        sqlx::query_as!(
112            RequestTrendRow,
113            r#"
114            SELECT
115                created_at as "created_at!",
116                tokens_used,
117                cost_microdollars,
118                latency_ms
119            FROM ai_requests
120            WHERE created_at >= $1 AND created_at < $2
121            ORDER BY created_at
122            "#,
123            start,
124            end
125        )
126        .fetch_all(&*self.pool)
127        .await
128        .map_err(Into::into)
129    }
130
131    pub async fn list_requests(
132        &self,
133        start: DateTime<Utc>,
134        end: DateTime<Utc>,
135        limit: i64,
136        model_filter: Option<&str>,
137    ) -> Result<Vec<RequestListRow>> {
138        if let Some(model) = model_filter {
139            let pattern = format!("%{}%", model);
140            sqlx::query_as!(
141                RequestListRow,
142                r#"
143                SELECT
144                    id as "id!",
145                    provider as "provider!",
146                    model as "model!",
147                    input_tokens,
148                    output_tokens,
149                    cost_microdollars,
150                    latency_ms,
151                    cache_hit,
152                    created_at as "created_at!",
153                    status as "status!",
154                    error_message,
155                    user_id as "user_id!: UserId"
156                FROM ai_requests
157                WHERE created_at >= $1 AND created_at < $2
158                  AND model ILIKE $3
159                ORDER BY created_at DESC
160                LIMIT $4
161                "#,
162                start,
163                end,
164                pattern,
165                limit
166            )
167            .fetch_all(&*self.pool)
168            .await
169            .map_err(Into::into)
170        } else {
171            sqlx::query_as!(
172                RequestListRow,
173                r#"
174                SELECT
175                    id as "id!",
176                    provider as "provider!",
177                    model as "model!",
178                    input_tokens,
179                    output_tokens,
180                    cost_microdollars,
181                    latency_ms,
182                    cache_hit,
183                    created_at as "created_at!",
184                    status as "status!",
185                    error_message,
186                    user_id as "user_id!: UserId"
187                FROM ai_requests
188                WHERE created_at >= $1 AND created_at < $2
189                ORDER BY created_at DESC
190                LIMIT $3
191                "#,
192                start,
193                end,
194                limit
195            )
196            .fetch_all(&*self.pool)
197            .await
198            .map_err(Into::into)
199        }
200    }
201}