Skip to main content

systemprompt_analytics/repository/
requests.rs

1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6
7use crate::models::cli::{ModelUsageRow, RequestListRow, RequestStatsRow, RequestTrendRow};
8
9#[derive(Debug)]
10pub struct RequestAnalyticsRepository {
11    pool: Arc<PgPool>,
12}
13
14impl RequestAnalyticsRepository {
15    pub fn new(db: &DbPool) -> Result<Self> {
16        let pool = db.pool_arc()?;
17        Ok(Self { pool })
18    }
19
20    pub async fn get_stats(
21        &self,
22        start: DateTime<Utc>,
23        end: DateTime<Utc>,
24        model_filter: Option<&str>,
25    ) -> Result<RequestStatsRow> {
26        if let Some(model) = model_filter {
27            let pattern = format!("%{}%", model);
28            sqlx::query_as!(
29                RequestStatsRow,
30                r#"
31                SELECT
32                    COUNT(*)::bigint as "total!",
33                    SUM(tokens_used)::bigint as "total_tokens",
34                    SUM(input_tokens)::bigint as "input_tokens",
35                    SUM(output_tokens)::bigint as "output_tokens",
36                    SUM(cost_cents)::bigint as "cost",
37                    AVG(latency_ms)::float8 as "avg_latency",
38                    COUNT(*) FILTER (WHERE cache_hit = true)::bigint as "cache_hits!"
39                FROM ai_requests
40                WHERE created_at >= $1 AND created_at < $2
41                  AND model ILIKE $3
42                "#,
43                start,
44                end,
45                pattern
46            )
47            .fetch_one(&*self.pool)
48            .await
49            .map_err(Into::into)
50        } else {
51            sqlx::query_as!(
52                RequestStatsRow,
53                r#"
54                SELECT
55                    COUNT(*)::bigint as "total!",
56                    SUM(tokens_used)::bigint as "total_tokens",
57                    SUM(input_tokens)::bigint as "input_tokens",
58                    SUM(output_tokens)::bigint as "output_tokens",
59                    SUM(cost_cents)::bigint as "cost",
60                    AVG(latency_ms)::float8 as "avg_latency",
61                    COUNT(*) FILTER (WHERE cache_hit = true)::bigint as "cache_hits!"
62                FROM ai_requests
63                WHERE created_at >= $1 AND created_at < $2
64                "#,
65                start,
66                end
67            )
68            .fetch_one(&*self.pool)
69            .await
70            .map_err(Into::into)
71        }
72    }
73
74    pub async fn list_models(
75        &self,
76        start: DateTime<Utc>,
77        end: DateTime<Utc>,
78        limit: i64,
79    ) -> Result<Vec<ModelUsageRow>> {
80        sqlx::query_as!(
81            ModelUsageRow,
82            r#"
83            SELECT
84                provider as "provider!",
85                model as "model!",
86                COUNT(*)::bigint as "request_count!",
87                SUM(tokens_used)::bigint as "total_tokens",
88                SUM(cost_cents)::bigint as "total_cost",
89                AVG(latency_ms)::float8 as "avg_latency"
90            FROM ai_requests
91            WHERE created_at >= $1 AND created_at < $2
92            GROUP BY provider, model
93            ORDER BY COUNT(*) DESC
94            LIMIT $3
95            "#,
96            start,
97            end,
98            limit
99        )
100        .fetch_all(&*self.pool)
101        .await
102        .map_err(Into::into)
103    }
104
105    pub async fn get_requests_for_trends(
106        &self,
107        start: DateTime<Utc>,
108        end: DateTime<Utc>,
109    ) -> Result<Vec<RequestTrendRow>> {
110        sqlx::query_as!(
111            RequestTrendRow,
112            r#"
113            SELECT
114                created_at as "created_at!",
115                tokens_used,
116                cost_cents,
117                latency_ms
118            FROM ai_requests
119            WHERE created_at >= $1 AND created_at < $2
120            ORDER BY created_at
121            "#,
122            start,
123            end
124        )
125        .fetch_all(&*self.pool)
126        .await
127        .map_err(Into::into)
128    }
129
130    pub async fn list_requests(
131        &self,
132        start: DateTime<Utc>,
133        end: DateTime<Utc>,
134        limit: i64,
135        model_filter: Option<&str>,
136    ) -> Result<Vec<RequestListRow>> {
137        if let Some(model) = model_filter {
138            let pattern = format!("%{}%", model);
139            sqlx::query_as!(
140                RequestListRow,
141                r#"
142                SELECT
143                    id as "id!",
144                    provider as "provider!",
145                    model as "model!",
146                    input_tokens,
147                    output_tokens,
148                    cost_cents,
149                    latency_ms,
150                    cache_hit,
151                    created_at as "created_at!"
152                FROM ai_requests
153                WHERE created_at >= $1 AND created_at < $2
154                  AND model ILIKE $3
155                ORDER BY created_at DESC
156                LIMIT $4
157                "#,
158                start,
159                end,
160                pattern,
161                limit
162            )
163            .fetch_all(&*self.pool)
164            .await
165            .map_err(Into::into)
166        } else {
167            sqlx::query_as!(
168                RequestListRow,
169                r#"
170                SELECT
171                    id as "id!",
172                    provider as "provider!",
173                    model as "model!",
174                    input_tokens,
175                    output_tokens,
176                    cost_cents,
177                    latency_ms,
178                    cache_hit,
179                    created_at as "created_at!"
180                FROM ai_requests
181                WHERE created_at >= $1 AND created_at < $2
182                ORDER BY created_at DESC
183                LIMIT $3
184                "#,
185                start,
186                end,
187                limit
188            )
189            .fetch_all(&*self.pool)
190            .await
191            .map_err(Into::into)
192        }
193    }
194}