systemprompt_agent/repository/task/
mod.rs1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9 create_task, task_state_to_db_string, track_agent_in_context, update_task_failed_with_error,
10 update_task_state,
11};
12pub use queries::{
13 get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context, TaskContextInfo,
14};
15
16use crate::models::a2a::{Task, TaskState};
17use sqlx::PgPool;
18use std::sync::Arc;
19use systemprompt_database::DbPool;
20use systemprompt_traits::{
21 DynFileUploadProvider, DynSessionAnalyticsProvider, Repository as RepositoryTrait,
22 RepositoryError,
23};
24
25#[derive(Clone)]
26pub struct TaskRepository {
27 db_pool: DbPool,
28 pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
29 pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
30}
31
32impl std::fmt::Debug for TaskRepository {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("TaskRepository")
35 .field("db_pool", &"<DbPool>")
36 .field(
37 "session_analytics_provider",
38 &self.session_analytics_provider.is_some(),
39 )
40 .field("file_upload_provider", &self.file_upload_provider.is_some())
41 .finish()
42 }
43}
44
45impl TaskRepository {
46 #[must_use]
47 pub fn new(db_pool: DbPool) -> Self {
48 Self {
49 db_pool,
50 session_analytics_provider: None,
51 file_upload_provider: None,
52 }
53 }
54
55 #[must_use]
56 pub fn with_session_analytics_provider(
57 mut self,
58 provider: DynSessionAnalyticsProvider,
59 ) -> Self {
60 self.session_analytics_provider = Some(provider);
61 self
62 }
63
64 #[must_use]
65 pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
66 self.file_upload_provider = Some(provider);
67 self
68 }
69
70 pub(crate) fn get_pg_pool(&self) -> Result<Arc<PgPool>, RepositoryError> {
71 self.db_pool.as_ref().get_postgres_pool().ok_or_else(|| {
72 RepositoryError::InvalidData("PostgreSQL pool not available".to_string())
73 })
74 }
75
76 pub async fn create_task(
77 &self,
78 task: &Task,
79 user_id: &systemprompt_identifiers::UserId,
80 session_id: &systemprompt_identifiers::SessionId,
81 trace_id: &systemprompt_identifiers::TraceId,
82 agent_name: &str,
83 ) -> Result<String, RepositoryError> {
84 let pool = self.get_pg_pool()?;
85 let result = create_task(&pool, task, user_id, session_id, trace_id, agent_name).await?;
86
87 if let Some(ref provider) = self.session_analytics_provider {
88 if let Err(e) = provider.increment_task_count(session_id).await {
89 tracing::warn!(error = %e, "Failed to increment analytics task count");
90 }
91 }
92
93 Ok(result)
94 }
95
96 pub async fn get_task(
97 &self,
98 task_id: &systemprompt_identifiers::TaskId,
99 ) -> Result<Option<Task>, RepositoryError> {
100 let pool = self.get_pg_pool()?;
101 get_task(&pool, &self.db_pool, task_id).await
102 }
103
104 pub async fn get_task_by_str(&self, task_id: &str) -> Result<Option<Task>, RepositoryError> {
105 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
106 self.get_task(&task_id_typed).await
107 }
108
109 pub async fn list_tasks_by_context(
110 &self,
111 context_id: &systemprompt_identifiers::ContextId,
112 ) -> Result<Vec<Task>, RepositoryError> {
113 let pool = self.get_pg_pool()?;
114 list_tasks_by_context(&pool, &self.db_pool, context_id).await
115 }
116
117 pub async fn list_tasks_by_context_str(
118 &self,
119 context_id: &str,
120 ) -> Result<Vec<Task>, RepositoryError> {
121 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
122 self.list_tasks_by_context(&context_id_typed).await
123 }
124
125 pub async fn get_tasks_by_user_id(
126 &self,
127 user_id: &systemprompt_identifiers::UserId,
128 limit: Option<i32>,
129 offset: Option<i32>,
130 ) -> Result<Vec<Task>, RepositoryError> {
131 let pool = self.get_pg_pool()?;
132 get_tasks_by_user_id(&pool, &self.db_pool, user_id, limit, offset).await
133 }
134
135 pub async fn get_tasks_by_user_id_str(
136 &self,
137 user_id: &str,
138 limit: Option<i32>,
139 offset: Option<i32>,
140 ) -> Result<Vec<Task>, RepositoryError> {
141 let user_id_typed = systemprompt_identifiers::UserId::new(user_id);
142 self.get_tasks_by_user_id(&user_id_typed, limit, offset)
143 .await
144 }
145
146 pub async fn track_agent_in_context(
147 &self,
148 context_id: &systemprompt_identifiers::ContextId,
149 agent_name: &str,
150 ) -> Result<(), RepositoryError> {
151 let pool = self.get_pg_pool()?;
152 track_agent_in_context(&pool, context_id, agent_name).await
153 }
154
155 pub async fn track_agent_in_context_str(
156 &self,
157 context_id: &str,
158 agent_name: &str,
159 ) -> Result<(), RepositoryError> {
160 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
161 self.track_agent_in_context(&context_id_typed, agent_name)
162 .await
163 }
164
165 pub async fn update_task_state(
166 &self,
167 task_id: &systemprompt_identifiers::TaskId,
168 state: TaskState,
169 timestamp: &chrono::DateTime<chrono::Utc>,
170 ) -> Result<(), RepositoryError> {
171 let pool = self.get_pg_pool()?;
172 update_task_state(&pool, task_id, state, timestamp).await
173 }
174
175 pub async fn update_task_state_str(
176 &self,
177 task_id: &str,
178 state: TaskState,
179 timestamp: &chrono::DateTime<chrono::Utc>,
180 ) -> Result<(), RepositoryError> {
181 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
182 self.update_task_state(&task_id_typed, state, timestamp)
183 .await
184 }
185
186 pub async fn update_task_failed_with_error(
187 &self,
188 task_id: &systemprompt_identifiers::TaskId,
189 error_message: &str,
190 timestamp: &chrono::DateTime<chrono::Utc>,
191 ) -> Result<(), RepositoryError> {
192 let pool = self.get_pg_pool()?;
193 update_task_failed_with_error(&pool, task_id, error_message, timestamp).await
194 }
195
196 pub async fn get_task_context_info(
197 &self,
198 task_id: &systemprompt_identifiers::TaskId,
199 ) -> Result<Option<TaskContextInfo>, RepositoryError> {
200 let pool = self.get_pg_pool()?;
201 get_task_context_info(&pool, task_id).await
202 }
203}
204
205impl RepositoryTrait for TaskRepository {
206 type Pool = DbPool;
207 type Error = RepositoryError;
208
209 fn pool(&self) -> &Self::Pool {
210 &self.db_pool
211 }
212}