Skip to main content

systemprompt_agent/repository/context/
queries.rs

1use chrono::{DateTime, Utc};
2
3use super::ContextRepository;
4use crate::models::context::{ContextStateEvent, UserContext, UserContextWithStats};
5use crate::repository::task::constructor::TaskConstructor;
6use systemprompt_identifiers::{ContextId, SessionId, TaskId, UserId};
7use systemprompt_traits::RepositoryError;
8
9impl ContextRepository {
10    pub async fn get_context(
11        &self,
12        context_id: &ContextId,
13        user_id: &UserId,
14    ) -> Result<UserContext, RepositoryError> {
15        let row = sqlx::query!(
16            r#"SELECT
17                context_id as "context_id!",
18                user_id as "user_id!",
19                name as "name!",
20                created_at as "created_at!",
21                updated_at as "updated_at!"
22            FROM user_contexts WHERE context_id = $1 AND user_id = $2"#,
23            context_id.as_str(),
24            user_id.as_str()
25        )
26        .fetch_one(&*self.pool)
27        .await
28        .map_err(|e| match e {
29            sqlx::Error::RowNotFound => RepositoryError::NotFound(format!(
30                "Context {} not found for user {}",
31                context_id, user_id
32            )),
33            _ => RepositoryError::database(e),
34        })?;
35
36        Ok(UserContext {
37            context_id: row.context_id.into(),
38            user_id: row.user_id.into(),
39            name: row.name,
40            created_at: row.created_at,
41            updated_at: row.updated_at,
42        })
43    }
44
45    pub async fn list_contexts_basic(
46        &self,
47        user_id: &UserId,
48    ) -> Result<Vec<UserContext>, RepositoryError> {
49        let rows = sqlx::query!(
50            r#"SELECT
51                context_id as "context_id!",
52                user_id as "user_id!",
53                name as "name!",
54                created_at as "created_at!",
55                updated_at as "updated_at!"
56            FROM user_contexts WHERE user_id = $1 ORDER BY updated_at DESC"#,
57            user_id.as_str()
58        )
59        .fetch_all(&*self.pool)
60        .await
61        .map_err(|e| RepositoryError::database(e))?;
62
63        Ok(rows
64            .into_iter()
65            .map(|r| UserContext {
66                context_id: r.context_id.into(),
67                user_id: r.user_id.into(),
68                name: r.name,
69                created_at: r.created_at,
70                updated_at: r.updated_at,
71            })
72            .collect())
73    }
74
75    pub async fn list_contexts_with_stats(
76        &self,
77        user_id: &UserId,
78    ) -> Result<Vec<UserContextWithStats>, RepositoryError> {
79        let rows = sqlx::query!(
80            r#"SELECT
81                c.context_id as "context_id!",
82                c.user_id as "user_id!",
83                c.name as "name!",
84                c.created_at as "created_at!",
85                c.updated_at as "updated_at!",
86                COALESCE(COUNT(DISTINCT t.task_id), 0)::bigint as "task_count!",
87                COALESCE(COUNT(DISTINCT m.id), 0)::bigint as "message_count!",
88                MAX(m.created_at) as last_message_at
89            FROM user_contexts c
90            LEFT JOIN agent_tasks t ON t.context_id = c.context_id
91            LEFT JOIN task_messages m ON m.task_id = t.task_id
92            WHERE c.user_id = $1
93            GROUP BY c.context_id
94            ORDER BY c.updated_at DESC"#,
95            user_id.as_str()
96        )
97        .fetch_all(&*self.pool)
98        .await
99        .map_err(|e| RepositoryError::database(e))?;
100
101        Ok(rows
102            .into_iter()
103            .map(|r| UserContextWithStats {
104                context_id: r.context_id.into(),
105                user_id: r.user_id.into(),
106                name: r.name,
107                created_at: r.created_at,
108                updated_at: r.updated_at,
109                task_count: r.task_count,
110                message_count: r.message_count,
111                last_message_at: r.last_message_at,
112            })
113            .collect())
114    }
115
116    pub async fn find_by_session_id(
117        &self,
118        session_id: &SessionId,
119    ) -> Result<Option<UserContext>, RepositoryError> {
120        let row = sqlx::query!(
121            r#"SELECT
122                context_id as "context_id!",
123                user_id as "user_id!",
124                name as "name!",
125                created_at as "created_at!",
126                updated_at as "updated_at!"
127            FROM user_contexts WHERE session_id = $1
128            ORDER BY created_at DESC LIMIT 1"#,
129            session_id.as_str()
130        )
131        .fetch_optional(&*self.pool)
132        .await
133        .map_err(RepositoryError::database)?;
134
135        Ok(row.map(|r| UserContext {
136            context_id: r.context_id.into(),
137            user_id: r.user_id.into(),
138            name: r.name,
139            created_at: r.created_at,
140            updated_at: r.updated_at,
141        }))
142    }
143
144    pub async fn get_context_events_since(
145        &self,
146        context_id: &ContextId,
147        last_seen: DateTime<Utc>,
148    ) -> Result<Vec<ContextStateEvent>, RepositoryError> {
149        let mut events = Vec::new();
150
151        let task_ids: Vec<String> = sqlx::query_scalar!(
152            r#"SELECT t.task_id as "task_id!" FROM agent_tasks t
153             WHERE t.context_id = $1 AND t.updated_at > $2
154             ORDER BY t.updated_at ASC"#,
155            context_id.as_str(),
156            last_seen
157        )
158        .fetch_all(&*self.pool)
159        .await
160        .map_err(|e| RepositoryError::database(e))?;
161
162        if !task_ids.is_empty() {
163            let constructor = TaskConstructor::new(&self.db_pool)?;
164            let task_ids_typed: Vec<TaskId> = task_ids.iter().map(|id| TaskId::new(id)).collect();
165            let tasks = constructor.construct_tasks_batch(&task_ids_typed).await?;
166
167            for task in tasks {
168                events.push(ContextStateEvent::TaskStatusChanged {
169                    task,
170                    context_id: context_id.clone(),
171                    timestamp: Utc::now(),
172                });
173            }
174        }
175
176        let context_updates = sqlx::query!(
177            r#"SELECT
178                context_id as "context_id!",
179                name as "name!",
180                updated_at as "updated_at!"
181            FROM user_contexts
182            WHERE context_id = $1 AND updated_at > $2
183            ORDER BY updated_at ASC"#,
184            context_id.as_str(),
185            last_seen
186        )
187        .fetch_all(&*self.pool)
188        .await
189        .map_err(|e| RepositoryError::database(e))?;
190
191        for row in context_updates {
192            events.push(ContextStateEvent::ContextUpdated {
193                context_id: ContextId::new(row.context_id),
194                name: row.name,
195                timestamp: row.updated_at,
196            });
197        }
198
199        events.sort_by(|a, b| a.timestamp().cmp(&b.timestamp()));
200
201        Ok(events)
202    }
203}