systemprompt_agent/repository/context/
queries.rs1use chrono::{DateTime, Utc};
2
3use super::ContextRepository;
4use crate::models::context::{ContextStateEvent, UserContext, UserContextWithStats};
5use crate::repository::task::constructor::TaskConstructor;
6use systemprompt_identifiers::{ContextId, SessionId, TaskId, UserId};
7use systemprompt_traits::RepositoryError;
8
9impl ContextRepository {
10 pub async fn get_context(
11 &self,
12 context_id: &ContextId,
13 user_id: &UserId,
14 ) -> Result<UserContext, RepositoryError> {
15 let pool = self.get_pg_pool()?;
16
17 let row = sqlx::query!(
18 r#"SELECT
19 context_id as "context_id!",
20 user_id as "user_id!",
21 name as "name!",
22 created_at as "created_at!",
23 updated_at as "updated_at!"
24 FROM user_contexts WHERE context_id = $1 AND user_id = $2"#,
25 context_id.as_str(),
26 user_id.as_str()
27 )
28 .fetch_one(pool.as_ref())
29 .await
30 .map_err(|e| match e {
31 sqlx::Error::RowNotFound => RepositoryError::NotFound(format!(
32 "Context {} not found for user {}",
33 context_id, user_id
34 )),
35 _ => RepositoryError::database(e),
36 })?;
37
38 Ok(UserContext {
39 context_id: row.context_id.into(),
40 user_id: row.user_id.into(),
41 name: row.name,
42 created_at: row.created_at,
43 updated_at: row.updated_at,
44 })
45 }
46
47 pub async fn list_contexts_basic(
48 &self,
49 user_id: &UserId,
50 ) -> Result<Vec<UserContext>, RepositoryError> {
51 let pool = self.get_pg_pool()?;
52
53 let rows = sqlx::query!(
54 r#"SELECT
55 context_id as "context_id!",
56 user_id as "user_id!",
57 name as "name!",
58 created_at as "created_at!",
59 updated_at as "updated_at!"
60 FROM user_contexts WHERE user_id = $1 ORDER BY updated_at DESC"#,
61 user_id.as_str()
62 )
63 .fetch_all(pool.as_ref())
64 .await
65 .map_err(|e| RepositoryError::database(e))?;
66
67 Ok(rows
68 .into_iter()
69 .map(|r| UserContext {
70 context_id: r.context_id.into(),
71 user_id: r.user_id.into(),
72 name: r.name,
73 created_at: r.created_at,
74 updated_at: r.updated_at,
75 })
76 .collect())
77 }
78
79 pub async fn list_contexts_with_stats(
80 &self,
81 user_id: &UserId,
82 ) -> Result<Vec<UserContextWithStats>, RepositoryError> {
83 let pool = self.get_pg_pool()?;
84
85 let rows = sqlx::query!(
86 r#"SELECT
87 c.context_id as "context_id!",
88 c.user_id as "user_id!",
89 c.name as "name!",
90 c.created_at as "created_at!",
91 c.updated_at as "updated_at!",
92 COALESCE(COUNT(DISTINCT t.task_id), 0)::bigint as "task_count!",
93 COALESCE(COUNT(DISTINCT m.id), 0)::bigint as "message_count!",
94 MAX(m.created_at) as last_message_at
95 FROM user_contexts c
96 LEFT JOIN agent_tasks t ON t.context_id = c.context_id
97 LEFT JOIN task_messages m ON m.task_id = t.task_id
98 WHERE c.user_id = $1
99 GROUP BY c.context_id
100 ORDER BY c.updated_at DESC"#,
101 user_id.as_str()
102 )
103 .fetch_all(pool.as_ref())
104 .await
105 .map_err(|e| RepositoryError::database(e))?;
106
107 Ok(rows
108 .into_iter()
109 .map(|r| UserContextWithStats {
110 context_id: r.context_id.into(),
111 user_id: r.user_id.into(),
112 name: r.name,
113 created_at: r.created_at,
114 updated_at: r.updated_at,
115 task_count: r.task_count,
116 message_count: r.message_count,
117 last_message_at: r.last_message_at,
118 })
119 .collect())
120 }
121
122 pub async fn find_by_session_id(
123 &self,
124 session_id: &SessionId,
125 ) -> Result<Option<UserContext>, RepositoryError> {
126 let pool = self.get_pg_pool()?;
127
128 let row = sqlx::query!(
129 r#"SELECT
130 context_id as "context_id!",
131 user_id as "user_id!",
132 name as "name!",
133 created_at as "created_at!",
134 updated_at as "updated_at!"
135 FROM user_contexts WHERE session_id = $1
136 ORDER BY created_at DESC LIMIT 1"#,
137 session_id.as_str()
138 )
139 .fetch_optional(pool.as_ref())
140 .await
141 .map_err(RepositoryError::database)?;
142
143 Ok(row.map(|r| UserContext {
144 context_id: r.context_id.into(),
145 user_id: r.user_id.into(),
146 name: r.name,
147 created_at: r.created_at,
148 updated_at: r.updated_at,
149 }))
150 }
151
152 pub async fn get_context_events_since(
153 &self,
154 context_id: &ContextId,
155 last_seen: DateTime<Utc>,
156 ) -> Result<Vec<ContextStateEvent>, RepositoryError> {
157 let mut events = Vec::new();
158 let pool = self.get_pg_pool()?;
159
160 let task_ids: Vec<String> = sqlx::query_scalar!(
161 r#"SELECT t.task_id as "task_id!" FROM agent_tasks t
162 WHERE t.context_id = $1 AND t.updated_at > $2
163 ORDER BY t.updated_at ASC"#,
164 context_id.as_str(),
165 last_seen
166 )
167 .fetch_all(pool.as_ref())
168 .await
169 .map_err(|e| RepositoryError::database(e))?;
170
171 if !task_ids.is_empty() {
172 let constructor = TaskConstructor::new(self.db_pool.clone());
173 let task_ids_typed: Vec<TaskId> = task_ids.iter().map(|id| TaskId::new(id)).collect();
174 let tasks = constructor.construct_tasks_batch(&task_ids_typed).await?;
175
176 for task in tasks {
177 events.push(ContextStateEvent::TaskStatusChanged {
178 task,
179 context_id: context_id.clone(),
180 timestamp: Utc::now(),
181 });
182 }
183 }
184
185 let context_updates = sqlx::query!(
186 r#"SELECT
187 context_id as "context_id!",
188 name as "name!",
189 updated_at as "updated_at!"
190 FROM user_contexts
191 WHERE context_id = $1 AND updated_at > $2
192 ORDER BY updated_at ASC"#,
193 context_id.as_str(),
194 last_seen
195 )
196 .fetch_all(pool.as_ref())
197 .await
198 .map_err(|e| RepositoryError::database(e))?;
199
200 for row in context_updates {
201 events.push(ContextStateEvent::ContextUpdated {
202 context_id: ContextId::new(row.context_id),
203 name: row.name,
204 timestamp: row.updated_at,
205 });
206 }
207
208 events.sort_by(|a, b| a.timestamp().cmp(&b.timestamp()));
209
210 Ok(events)
211 }
212}