systemprompt_agent/repository/task/
mod.rs1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9 CreateTaskParams, create_task, task_state_to_db_string, track_agent_in_context,
10 update_task_failed_with_error, update_task_state,
11};
12pub use queries::{
13 TaskContextInfo, get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context,
14};
15pub use task_updates::UpdateTaskAndSaveMessagesParams;
16
17use crate::models::a2a::{Task, TaskState};
18use sqlx::PgPool;
19use std::sync::Arc;
20use systemprompt_database::DbPool;
21use systemprompt_identifiers::{SessionId, TraceId, UserId};
22use systemprompt_traits::{DynFileUploadProvider, DynSessionAnalyticsProvider, RepositoryError};
23
24#[allow(missing_debug_implementations)]
25pub struct RepoCreateTaskParams<'a> {
26 pub task: &'a Task,
27 pub user_id: &'a UserId,
28 pub session_id: &'a SessionId,
29 pub trace_id: &'a TraceId,
30 pub agent_name: &'a str,
31}
32
33#[derive(Clone)]
34pub struct TaskRepository {
35 pool: Arc<PgPool>,
36 write_pool: Arc<PgPool>,
37 db_pool: DbPool,
38 pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
39 pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
40}
41
42impl std::fmt::Debug for TaskRepository {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("TaskRepository")
45 .field("pool", &"<PgPool>")
46 .field("write_pool", &"<PgPool>")
47 .field("db_pool", &"<DbPool>")
48 .field(
49 "session_analytics_provider",
50 &self.session_analytics_provider.is_some(),
51 )
52 .field("file_upload_provider", &self.file_upload_provider.is_some())
53 .finish()
54 }
55}
56
57impl TaskRepository {
58 pub fn new(db: &DbPool) -> anyhow::Result<Self> {
59 let pool = db.pool_arc()?;
60 let write_pool = db.write_pool_arc()?;
61 Ok(Self {
62 pool,
63 write_pool,
64 db_pool: Arc::clone(db),
65 session_analytics_provider: None,
66 file_upload_provider: None,
67 })
68 }
69
70 #[must_use]
71 pub fn with_session_analytics_provider(
72 mut self,
73 provider: DynSessionAnalyticsProvider,
74 ) -> Self {
75 self.session_analytics_provider = Some(provider);
76 self
77 }
78
79 #[must_use]
80 pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
81 self.file_upload_provider = Some(provider);
82 self
83 }
84
85 pub(crate) const fn db_pool(&self) -> &DbPool {
86 &self.db_pool
87 }
88
89 pub async fn create_task(
90 &self,
91 params: RepoCreateTaskParams<'_>,
92 ) -> Result<String, RepositoryError> {
93 let result = create_task(CreateTaskParams {
94 pool: &self.write_pool,
95 task: params.task,
96 user_id: params.user_id,
97 session_id: params.session_id,
98 trace_id: params.trace_id,
99 agent_name: params.agent_name,
100 })
101 .await?;
102
103 if let Some(ref provider) = self.session_analytics_provider {
104 if let Err(e) = provider.increment_task_count(params.session_id).await {
105 tracing::warn!(error = %e, "Failed to increment analytics task count");
106 }
107 }
108
109 Ok(result)
110 }
111
112 pub async fn get_task(
113 &self,
114 task_id: &systemprompt_identifiers::TaskId,
115 ) -> Result<Option<Task>, RepositoryError> {
116 get_task(&self.pool, &self.db_pool, task_id).await
117 }
118
119 pub async fn get_task_by_str(&self, task_id: &str) -> Result<Option<Task>, RepositoryError> {
120 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
121 self.get_task(&task_id_typed).await
122 }
123
124 pub async fn list_tasks_by_context(
125 &self,
126 context_id: &systemprompt_identifiers::ContextId,
127 ) -> Result<Vec<Task>, RepositoryError> {
128 list_tasks_by_context(&self.pool, &self.db_pool, context_id).await
129 }
130
131 pub async fn list_tasks_by_context_str(
132 &self,
133 context_id: &str,
134 ) -> Result<Vec<Task>, RepositoryError> {
135 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
136 self.list_tasks_by_context(&context_id_typed).await
137 }
138
139 pub async fn get_tasks_by_user_id(
140 &self,
141 user_id: &UserId,
142 limit: Option<i32>,
143 offset: Option<i32>,
144 ) -> Result<Vec<Task>, RepositoryError> {
145 get_tasks_by_user_id(&self.pool, &self.db_pool, user_id, limit, offset).await
146 }
147
148 pub async fn get_tasks_by_user_id_str(
149 &self,
150 user_id: &str,
151 limit: Option<i32>,
152 offset: Option<i32>,
153 ) -> Result<Vec<Task>, RepositoryError> {
154 let user_id_typed = UserId::new(user_id);
155 self.get_tasks_by_user_id(&user_id_typed, limit, offset)
156 .await
157 }
158
159 pub async fn track_agent_in_context(
160 &self,
161 context_id: &systemprompt_identifiers::ContextId,
162 agent_name: &str,
163 ) -> Result<(), RepositoryError> {
164 track_agent_in_context(&self.write_pool, context_id, agent_name).await
165 }
166
167 pub async fn track_agent_in_context_str(
168 &self,
169 context_id: &str,
170 agent_name: &str,
171 ) -> Result<(), RepositoryError> {
172 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
173 self.track_agent_in_context(&context_id_typed, agent_name)
174 .await
175 }
176
177 pub async fn update_task_state(
178 &self,
179 task_id: &systemprompt_identifiers::TaskId,
180 state: TaskState,
181 timestamp: &chrono::DateTime<chrono::Utc>,
182 ) -> Result<(), RepositoryError> {
183 update_task_state(&self.write_pool, task_id, state, timestamp).await
184 }
185
186 pub async fn update_task_state_str(
187 &self,
188 task_id: &str,
189 state: TaskState,
190 timestamp: &chrono::DateTime<chrono::Utc>,
191 ) -> Result<(), RepositoryError> {
192 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
193 self.update_task_state(&task_id_typed, state, timestamp)
194 .await
195 }
196
197 pub async fn update_task_failed_with_error(
198 &self,
199 task_id: &systemprompt_identifiers::TaskId,
200 error_message: &str,
201 timestamp: &chrono::DateTime<chrono::Utc>,
202 ) -> Result<(), RepositoryError> {
203 update_task_failed_with_error(&self.write_pool, task_id, error_message, timestamp).await
204 }
205
206 pub async fn get_task_context_info(
207 &self,
208 task_id: &systemprompt_identifiers::TaskId,
209 ) -> Result<Option<TaskContextInfo>, RepositoryError> {
210 get_task_context_info(&self.pool, task_id).await
211 }
212}