systemprompt_agent/repository/context/
queries.rs1use chrono::{DateTime, Utc};
2
3use super::ContextRepository;
4use crate::models::context::{ContextStateEvent, UserContext, UserContextWithStats};
5use crate::repository::task::constructor::TaskConstructor;
6use systemprompt_identifiers::{ContextId, SessionId, TaskId, UserId};
7use systemprompt_traits::RepositoryError;
8
9impl ContextRepository {
10 pub async fn get_context(
11 &self,
12 context_id: &ContextId,
13 user_id: &UserId,
14 ) -> Result<UserContext, RepositoryError> {
15 let row = sqlx::query!(
16 r#"SELECT
17 context_id as "context_id!",
18 user_id as "user_id!",
19 name as "name!",
20 created_at as "created_at!",
21 updated_at as "updated_at!"
22 FROM user_contexts WHERE context_id = $1 AND user_id = $2"#,
23 context_id.as_str(),
24 user_id.as_str()
25 )
26 .fetch_one(&*self.pool)
27 .await
28 .map_err(|e| match e {
29 sqlx::Error::RowNotFound => RepositoryError::NotFound(format!(
30 "Context {} not found for user {}",
31 context_id, user_id
32 )),
33 _ => RepositoryError::database(e),
34 })?;
35
36 Ok(UserContext {
37 context_id: row.context_id.into(),
38 user_id: row.user_id.into(),
39 name: row.name,
40 created_at: row.created_at,
41 updated_at: row.updated_at,
42 })
43 }
44
45 pub async fn list_contexts_basic(
46 &self,
47 user_id: &UserId,
48 ) -> Result<Vec<UserContext>, RepositoryError> {
49 let rows = sqlx::query!(
50 r#"SELECT
51 context_id as "context_id!",
52 user_id as "user_id!",
53 name as "name!",
54 created_at as "created_at!",
55 updated_at as "updated_at!"
56 FROM user_contexts WHERE user_id = $1 ORDER BY updated_at DESC"#,
57 user_id.as_str()
58 )
59 .fetch_all(&*self.pool)
60 .await
61 .map_err(|e| RepositoryError::database(e))?;
62
63 Ok(rows
64 .into_iter()
65 .map(|r| UserContext {
66 context_id: r.context_id.into(),
67 user_id: r.user_id.into(),
68 name: r.name,
69 created_at: r.created_at,
70 updated_at: r.updated_at,
71 })
72 .collect())
73 }
74
75 pub async fn list_contexts_with_stats(
76 &self,
77 user_id: &UserId,
78 ) -> Result<Vec<UserContextWithStats>, RepositoryError> {
79 let rows = sqlx::query!(
80 r#"SELECT
81 c.context_id as "context_id!",
82 c.user_id as "user_id!",
83 c.name as "name!",
84 c.created_at as "created_at!",
85 c.updated_at as "updated_at!",
86 COALESCE(COUNT(DISTINCT t.task_id), 0)::bigint as "task_count!",
87 COALESCE(COUNT(DISTINCT m.id), 0)::bigint as "message_count!",
88 MAX(m.created_at) as last_message_at
89 FROM user_contexts c
90 LEFT JOIN agent_tasks t ON t.context_id = c.context_id
91 LEFT JOIN task_messages m ON m.task_id = t.task_id
92 WHERE c.user_id = $1
93 GROUP BY c.context_id
94 ORDER BY c.updated_at DESC"#,
95 user_id.as_str()
96 )
97 .fetch_all(&*self.pool)
98 .await
99 .map_err(|e| RepositoryError::database(e))?;
100
101 Ok(rows
102 .into_iter()
103 .map(|r| UserContextWithStats {
104 context_id: r.context_id.into(),
105 user_id: r.user_id.into(),
106 name: r.name,
107 created_at: r.created_at,
108 updated_at: r.updated_at,
109 task_count: r.task_count,
110 message_count: r.message_count,
111 last_message_at: r.last_message_at,
112 })
113 .collect())
114 }
115
116 pub async fn find_by_session_id(
117 &self,
118 session_id: &SessionId,
119 ) -> Result<Option<UserContext>, RepositoryError> {
120 let row = sqlx::query!(
121 r#"SELECT
122 context_id as "context_id!",
123 user_id as "user_id!",
124 name as "name!",
125 created_at as "created_at!",
126 updated_at as "updated_at!"
127 FROM user_contexts WHERE session_id = $1
128 ORDER BY created_at DESC LIMIT 1"#,
129 session_id.as_str()
130 )
131 .fetch_optional(&*self.pool)
132 .await
133 .map_err(RepositoryError::database)?;
134
135 Ok(row.map(|r| UserContext {
136 context_id: r.context_id.into(),
137 user_id: r.user_id.into(),
138 name: r.name,
139 created_at: r.created_at,
140 updated_at: r.updated_at,
141 }))
142 }
143
144 pub async fn get_context_events_since(
145 &self,
146 context_id: &ContextId,
147 last_seen: DateTime<Utc>,
148 ) -> Result<Vec<ContextStateEvent>, RepositoryError> {
149 let mut events = Vec::new();
150
151 let task_ids: Vec<String> = sqlx::query_scalar!(
152 r#"SELECT t.task_id as "task_id!" FROM agent_tasks t
153 WHERE t.context_id = $1 AND t.updated_at > $2
154 ORDER BY t.updated_at ASC"#,
155 context_id.as_str(),
156 last_seen
157 )
158 .fetch_all(&*self.pool)
159 .await
160 .map_err(|e| RepositoryError::database(e))?;
161
162 if !task_ids.is_empty() {
163 let constructor = TaskConstructor::new(&self.db_pool)?;
164 let task_ids_typed: Vec<TaskId> = task_ids.iter().map(|id| TaskId::new(id)).collect();
165 let tasks = constructor.construct_tasks_batch(&task_ids_typed).await?;
166
167 for task in tasks {
168 events.push(ContextStateEvent::TaskStatusChanged {
169 task,
170 context_id: context_id.clone(),
171 timestamp: Utc::now(),
172 });
173 }
174 }
175
176 let context_updates = sqlx::query!(
177 r#"SELECT
178 context_id as "context_id!",
179 name as "name!",
180 updated_at as "updated_at!"
181 FROM user_contexts
182 WHERE context_id = $1 AND updated_at > $2
183 ORDER BY updated_at ASC"#,
184 context_id.as_str(),
185 last_seen
186 )
187 .fetch_all(&*self.pool)
188 .await
189 .map_err(|e| RepositoryError::database(e))?;
190
191 for row in context_updates {
192 events.push(ContextStateEvent::ContextUpdated {
193 context_id: ContextId::new(row.context_id),
194 name: row.name,
195 timestamp: row.updated_at,
196 });
197 }
198
199 events.sort_by(|a, b| a.timestamp().cmp(&b.timestamp()));
200
201 Ok(events)
202 }
203}