Skip to main content

systemprompt_analytics/repository/
conversations.rs

1use anyhow::Result;
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6
7use crate::models::cli::{ConversationListRow, TimestampRow};
8
9#[derive(Debug)]
10pub struct ConversationAnalyticsRepository {
11    pool: Arc<PgPool>,
12}
13
14impl ConversationAnalyticsRepository {
15    pub fn new(db: &DbPool) -> Result<Self> {
16        let pool = db.pool_arc()?;
17        Ok(Self { pool })
18    }
19
20    pub async fn list_conversations(
21        &self,
22        start: DateTime<Utc>,
23        end: DateTime<Utc>,
24        limit: i64,
25    ) -> Result<Vec<ConversationListRow>> {
26        sqlx::query_as!(
27            ConversationListRow,
28            r#"
29            WITH agent_convs AS (
30                SELECT
31                    uc.context_id,
32                    uc.name,
33                    (SELECT COUNT(*) FROM agent_tasks at WHERE at.context_id = uc.context_id)::bigint as task_count,
34                    (SELECT COUNT(*) FROM task_messages tm
35                     JOIN agent_tasks at ON at.task_id = tm.task_id
36                     WHERE at.context_id = uc.context_id)::bigint as message_count,
37                    uc.created_at,
38                    uc.updated_at
39                FROM user_contexts uc
40                WHERE uc.created_at >= $1 AND uc.created_at < $2
41            ),
42            gateway_convs AS (
43                SELECT
44                    ar.session_id as context_id,
45                    NULL::text as name,
46                    0::bigint as task_count,
47                    COUNT(arm.id)::bigint as message_count,
48                    MIN(ar.created_at) as created_at,
49                    MAX(ar.created_at) as updated_at
50                FROM ai_requests ar
51                LEFT JOIN ai_request_messages arm ON arm.request_id = ar.id
52                WHERE ar.task_id IS NULL
53                  AND ar.session_id IS NOT NULL
54                  AND ar.created_at >= $1 AND ar.created_at < $2
55                  AND NOT EXISTS (
56                      SELECT 1 FROM user_contexts uc2 WHERE uc2.context_id = ar.session_id
57                  )
58                GROUP BY ar.session_id
59            )
60            SELECT
61                context_id as "context_id!",
62                name,
63                task_count as "task_count!",
64                message_count as "message_count!",
65                created_at as "created_at!",
66                updated_at as "updated_at!"
67            FROM (
68                SELECT * FROM agent_convs
69                UNION ALL
70                SELECT * FROM gateway_convs
71            ) combined
72            ORDER BY updated_at DESC
73            LIMIT $3
74            "#,
75            start,
76            end,
77            limit
78        )
79        .fetch_all(&*self.pool)
80        .await
81        .map_err(Into::into)
82    }
83
84    pub async fn get_context_count(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> Result<i64> {
85        let count = sqlx::query_scalar!(
86            r#"SELECT COUNT(*)::bigint as "count!" FROM user_contexts WHERE created_at >= $1 AND created_at < $2"#,
87            start,
88            end
89        )
90        .fetch_one(&*self.pool)
91        .await?;
92        Ok(count)
93    }
94
95    pub async fn get_task_stats(
96        &self,
97        start: DateTime<Utc>,
98        end: DateTime<Utc>,
99    ) -> Result<(i64, Option<f64>)> {
100        let row = sqlx::query!(
101            r#"
102            SELECT COUNT(*)::bigint as "count!", AVG(execution_time_ms)::float8 as avg_time
103            FROM agent_tasks
104            WHERE started_at >= $1 AND started_at < $2
105            "#,
106            start,
107            end
108        )
109        .fetch_one(&*self.pool)
110        .await?;
111        Ok((row.count, row.avg_time))
112    }
113
114    pub async fn get_message_count(&self, start: DateTime<Utc>, end: DateTime<Utc>) -> Result<i64> {
115        let count = sqlx::query_scalar!(
116            r#"SELECT COUNT(*)::bigint as "count!" FROM task_messages WHERE created_at >= $1 AND created_at < $2"#,
117            start,
118            end
119        )
120        .fetch_one(&*self.pool)
121        .await?;
122        Ok(count)
123    }
124
125    pub async fn get_context_timestamps(
126        &self,
127        start: DateTime<Utc>,
128        end: DateTime<Utc>,
129    ) -> Result<Vec<TimestampRow>> {
130        sqlx::query_as!(
131            TimestampRow,
132            r#"
133            SELECT created_at as "timestamp!"
134            FROM user_contexts
135            WHERE created_at >= $1 AND created_at < $2
136            "#,
137            start,
138            end
139        )
140        .fetch_all(&*self.pool)
141        .await
142        .map_err(Into::into)
143    }
144
145    pub async fn get_task_timestamps(
146        &self,
147        start: DateTime<Utc>,
148        end: DateTime<Utc>,
149    ) -> Result<Vec<TimestampRow>> {
150        sqlx::query_as!(
151            TimestampRow,
152            r#"
153            SELECT started_at as "timestamp!"
154            FROM agent_tasks
155            WHERE started_at >= $1 AND started_at < $2
156            "#,
157            start,
158            end
159        )
160        .fetch_all(&*self.pool)
161        .await
162        .map_err(Into::into)
163    }
164
165    pub async fn get_message_timestamps(
166        &self,
167        start: DateTime<Utc>,
168        end: DateTime<Utc>,
169    ) -> Result<Vec<TimestampRow>> {
170        sqlx::query_as!(
171            TimestampRow,
172            r#"
173            SELECT created_at as "timestamp!"
174            FROM task_messages
175            WHERE created_at >= $1 AND created_at < $2
176            "#,
177            start,
178            end
179        )
180        .fetch_all(&*self.pool)
181        .await
182        .map_err(Into::into)
183    }
184}