Skip to main content

systemprompt_analytics/repository/
conversations.rs

1use crate::Result;
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6
7use crate::models::cli::{ConversationListRow, GatewaySessionListRow, TimestampRow};
8
9#[derive(Debug)]
10pub struct ConversationAnalyticsRepository {
11    pool: Arc<PgPool>,
12}
13
14impl ConversationAnalyticsRepository {
15    pub fn new(db: &DbPool) -> Result<Self> {
16        let pool = db.pool_arc()?;
17        Ok(Self { pool })
18    }
19
20    pub async fn list_agent_contexts(
21        &self,
22        start: DateTime<Utc>,
23        end: DateTime<Utc>,
24        limit: i64,
25    ) -> Result<Vec<ConversationListRow>> {
26        sqlx::query_as!(
27            ConversationListRow,
28            r#"
29            SELECT
30                uc.context_id as "context_id!: systemprompt_identifiers::ContextId",
31                uc.name as "name?",
32                (SELECT COUNT(*) FROM agent_tasks at WHERE at.context_id = uc.context_id)::bigint as "task_count!",
33                (SELECT COUNT(*) FROM task_messages tm
34                 JOIN agent_tasks at ON at.task_id = tm.task_id
35                 WHERE at.context_id = uc.context_id)::bigint as "message_count!",
36                uc.created_at as "created_at!",
37                uc.updated_at as "updated_at!"
38            FROM user_contexts uc
39            WHERE uc.created_at >= $1 AND uc.created_at < $2
40            ORDER BY uc.updated_at DESC
41            LIMIT $3
42            "#,
43            start,
44            end,
45            limit
46        )
47        .fetch_all(&*self.pool)
48        .await
49        .map_err(Into::into)
50    }
51
52    pub async fn list_gateway_sessions(
53        &self,
54        start: DateTime<Utc>,
55        end: DateTime<Utc>,
56        limit: i64,
57    ) -> Result<Vec<GatewaySessionListRow>> {
58        sqlx::query_as!(
59            GatewaySessionListRow,
60            r#"
61            SELECT
62                ar.session_id as "session_id!: systemprompt_identifiers::SessionId",
63                COUNT(arm.id)::bigint as "message_count!",
64                MIN(ar.created_at) as "created_at!",
65                MAX(ar.created_at) as "updated_at!"
66            FROM ai_requests ar
67            LEFT JOIN ai_request_messages arm ON arm.request_id = ar.id
68            WHERE ar.task_id IS NULL
69              AND ar.session_id IS NOT NULL
70              AND ar.created_at >= $1 AND ar.created_at < $2
71              AND NOT EXISTS (
72                  SELECT 1 FROM user_contexts uc2 WHERE uc2.context_id::text = ar.session_id
73              )
74            GROUP BY ar.session_id
75            ORDER BY MAX(ar.created_at) DESC
76            LIMIT $3
77            "#,
78            start,
79            end,
80            limit
81        )
82        .fetch_all(&*self.pool)
83        .await
84        .map_err(Into::into)
85    }
86
87    pub async fn get_context_count(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> Result<i64> {
88        let count = sqlx::query_scalar!(
89            r#"SELECT COUNT(*)::bigint as "count!" FROM user_contexts WHERE created_at >= $1 AND created_at < $2"#,
90            start,
91            end
92        )
93        .fetch_one(&*self.pool)
94        .await?;
95        Ok(count)
96    }
97
98    pub async fn get_task_stats(
99        &self,
100        start: DateTime<Utc>,
101        end: DateTime<Utc>,
102    ) -> Result<(i64, Option<f64>)> {
103        let row = sqlx::query!(
104            r#"
105            SELECT COUNT(*)::bigint as "count!", AVG(execution_time_ms)::float8 as avg_time
106            FROM agent_tasks
107            WHERE started_at >= $1 AND started_at < $2
108            "#,
109            start,
110            end
111        )
112        .fetch_one(&*self.pool)
113        .await?;
114        Ok((row.count, row.avg_time))
115    }
116
117    pub async fn get_message_count(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> Result<i64> {
118        let count = sqlx::query_scalar!(
119            r#"SELECT COUNT(*)::bigint as "count!" FROM task_messages WHERE created_at >= $1 AND created_at < $2"#,
120            start,
121            end
122        )
123        .fetch_one(&*self.pool)
124        .await?;
125        Ok(count)
126    }
127
128    pub async fn get_context_timestamps(
129        &self,
130        start: DateTime<Utc>,
131        end: DateTime<Utc>,
132    ) -> Result<Vec<TimestampRow>> {
133        sqlx::query_as!(
134            TimestampRow,
135            r#"
136            SELECT created_at as "timestamp!"
137            FROM user_contexts
138            WHERE created_at >= $1 AND created_at < $2
139            "#,
140            start,
141            end
142        )
143        .fetch_all(&*self.pool)
144        .await
145        .map_err(Into::into)
146    }
147
148    pub async fn get_task_timestamps(
149        &self,
150        start: DateTime<Utc>,
151        end: DateTime<Utc>,
152    ) -> Result<Vec<TimestampRow>> {
153        sqlx::query_as!(
154            TimestampRow,
155            r#"
156            SELECT started_at as "timestamp!"
157            FROM agent_tasks
158            WHERE started_at >= $1 AND started_at < $2
159            "#,
160            start,
161            end
162        )
163        .fetch_all(&*self.pool)
164        .await
165        .map_err(Into::into)
166    }
167
168    pub async fn get_message_timestamps(
169        &self,
170        start: DateTime<Utc>,
171        end: DateTime<Utc>,
172    ) -> Result<Vec<TimestampRow>> {
173        sqlx::query_as!(
174            TimestampRow,
175            r#"
176            SELECT created_at as "timestamp!"
177            FROM task_messages
178            WHERE created_at >= $1 AND created_at < $2
179            "#,
180            start,
181            end
182        )
183        .fetch_all(&*self.pool)
184        .await
185        .map_err(Into::into)
186    }
187}