Skip to main content

systemprompt_agent/repository/context/
queries.rs

1use chrono::{DateTime, Utc};
2
3use super::ContextRepository;
4use crate::models::context::{ContextStateEvent, UserContext, UserContextWithStats};
5use crate::repository::task::constructor::TaskConstructor;
6use systemprompt_identifiers::{ContextId, SessionId, TaskId, UserId};
7use systemprompt_traits::RepositoryError;
8
9impl ContextRepository {
10    pub async fn get_context(
11        &self,
12        context_id: &ContextId,
13        user_id: &UserId,
14    ) -> Result<UserContext, RepositoryError> {
15        let pool = self.get_pg_pool()?;
16
17        let row = sqlx::query!(
18            r#"SELECT
19                context_id as "context_id!",
20                user_id as "user_id!",
21                name as "name!",
22                created_at as "created_at!",
23                updated_at as "updated_at!"
24            FROM user_contexts WHERE context_id = $1 AND user_id = $2"#,
25            context_id.as_str(),
26            user_id.as_str()
27        )
28        .fetch_one(pool.as_ref())
29        .await
30        .map_err(|e| match e {
31            sqlx::Error::RowNotFound => RepositoryError::NotFound(format!(
32                "Context {} not found for user {}",
33                context_id, user_id
34            )),
35            _ => RepositoryError::database(e),
36        })?;
37
38        Ok(UserContext {
39            context_id: row.context_id.into(),
40            user_id: row.user_id.into(),
41            name: row.name,
42            created_at: row.created_at,
43            updated_at: row.updated_at,
44        })
45    }
46
47    pub async fn list_contexts_basic(
48        &self,
49        user_id: &UserId,
50    ) -> Result<Vec<UserContext>, RepositoryError> {
51        let pool = self.get_pg_pool()?;
52
53        let rows = sqlx::query!(
54            r#"SELECT
55                context_id as "context_id!",
56                user_id as "user_id!",
57                name as "name!",
58                created_at as "created_at!",
59                updated_at as "updated_at!"
60            FROM user_contexts WHERE user_id = $1 ORDER BY updated_at DESC"#,
61            user_id.as_str()
62        )
63        .fetch_all(pool.as_ref())
64        .await
65        .map_err(|e| RepositoryError::database(e))?;
66
67        Ok(rows
68            .into_iter()
69            .map(|r| UserContext {
70                context_id: r.context_id.into(),
71                user_id: r.user_id.into(),
72                name: r.name,
73                created_at: r.created_at,
74                updated_at: r.updated_at,
75            })
76            .collect())
77    }
78
79    pub async fn list_contexts_with_stats(
80        &self,
81        user_id: &UserId,
82    ) -> Result<Vec<UserContextWithStats>, RepositoryError> {
83        let pool = self.get_pg_pool()?;
84
85        let rows = sqlx::query!(
86            r#"SELECT
87                c.context_id as "context_id!",
88                c.user_id as "user_id!",
89                c.name as "name!",
90                c.created_at as "created_at!",
91                c.updated_at as "updated_at!",
92                COALESCE(COUNT(DISTINCT t.task_id), 0)::bigint as "task_count!",
93                COALESCE(COUNT(DISTINCT m.id), 0)::bigint as "message_count!",
94                MAX(m.created_at) as last_message_at
95            FROM user_contexts c
96            LEFT JOIN agent_tasks t ON t.context_id = c.context_id
97            LEFT JOIN task_messages m ON m.task_id = t.task_id
98            WHERE c.user_id = $1
99            GROUP BY c.context_id
100            ORDER BY c.updated_at DESC"#,
101            user_id.as_str()
102        )
103        .fetch_all(pool.as_ref())
104        .await
105        .map_err(|e| RepositoryError::database(e))?;
106
107        Ok(rows
108            .into_iter()
109            .map(|r| UserContextWithStats {
110                context_id: r.context_id.into(),
111                user_id: r.user_id.into(),
112                name: r.name,
113                created_at: r.created_at,
114                updated_at: r.updated_at,
115                task_count: r.task_count,
116                message_count: r.message_count,
117                last_message_at: r.last_message_at,
118            })
119            .collect())
120    }
121
122    pub async fn find_by_session_id(
123        &self,
124        session_id: &SessionId,
125    ) -> Result<Option<UserContext>, RepositoryError> {
126        let pool = self.get_pg_pool()?;
127
128        let row = sqlx::query!(
129            r#"SELECT
130                context_id as "context_id!",
131                user_id as "user_id!",
132                name as "name!",
133                created_at as "created_at!",
134                updated_at as "updated_at!"
135            FROM user_contexts WHERE session_id = $1
136            ORDER BY created_at DESC LIMIT 1"#,
137            session_id.as_str()
138        )
139        .fetch_optional(pool.as_ref())
140        .await
141        .map_err(RepositoryError::database)?;
142
143        Ok(row.map(|r| UserContext {
144            context_id: r.context_id.into(),
145            user_id: r.user_id.into(),
146            name: r.name,
147            created_at: r.created_at,
148            updated_at: r.updated_at,
149        }))
150    }
151
152    pub async fn get_context_events_since(
153        &self,
154        context_id: &ContextId,
155        last_seen: DateTime<Utc>,
156    ) -> Result<Vec<ContextStateEvent>, RepositoryError> {
157        let mut events = Vec::new();
158        let pool = self.get_pg_pool()?;
159
160        let task_ids: Vec<String> = sqlx::query_scalar!(
161            r#"SELECT t.task_id as "task_id!" FROM agent_tasks t
162             WHERE t.context_id = $1 AND t.updated_at > $2
163             ORDER BY t.updated_at ASC"#,
164            context_id.as_str(),
165            last_seen
166        )
167        .fetch_all(pool.as_ref())
168        .await
169        .map_err(|e| RepositoryError::database(e))?;
170
171        if !task_ids.is_empty() {
172            let constructor = TaskConstructor::new(self.db_pool.clone());
173            let task_ids_typed: Vec<TaskId> = task_ids.iter().map(|id| TaskId::new(id)).collect();
174            let tasks = constructor.construct_tasks_batch(&task_ids_typed).await?;
175
176            for task in tasks {
177                events.push(ContextStateEvent::TaskStatusChanged {
178                    task,
179                    context_id: context_id.clone(),
180                    timestamp: Utc::now(),
181                });
182            }
183        }
184
185        let context_updates = sqlx::query!(
186            r#"SELECT
187                context_id as "context_id!",
188                name as "name!",
189                updated_at as "updated_at!"
190            FROM user_contexts
191            WHERE context_id = $1 AND updated_at > $2
192            ORDER BY updated_at ASC"#,
193            context_id.as_str(),
194            last_seen
195        )
196        .fetch_all(pool.as_ref())
197        .await
198        .map_err(|e| RepositoryError::database(e))?;
199
200        for row in context_updates {
201            events.push(ContextStateEvent::ContextUpdated {
202                context_id: ContextId::new(row.context_id),
203                name: row.name,
204                timestamp: row.updated_at,
205            });
206        }
207
208        events.sort_by(|a, b| a.timestamp().cmp(&b.timestamp()));
209
210        Ok(events)
211    }
212}