systemprompt_agent/repository/task/
mod.rs1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9 create_task, task_state_to_db_string, track_agent_in_context, update_task_failed_with_error,
10 update_task_state,
11};
12pub use queries::{
13 get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context, TaskContextInfo,
14};
15
16use crate::models::a2a::{Task, TaskState};
17use sqlx::PgPool;
18use std::sync::Arc;
19use systemprompt_database::DbPool;
20use systemprompt_traits::{DynFileUploadProvider, DynSessionAnalyticsProvider, RepositoryError};
21
22#[derive(Clone)]
23pub struct TaskRepository {
24 pool: Arc<PgPool>,
25 write_pool: Arc<PgPool>,
26 db_pool: DbPool,
27 pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
28 pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
29}
30
31impl std::fmt::Debug for TaskRepository {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("TaskRepository")
34 .field("db_pool", &"<DbPool>")
35 .field(
36 "session_analytics_provider",
37 &self.session_analytics_provider.is_some(),
38 )
39 .field("file_upload_provider", &self.file_upload_provider.is_some())
40 .finish()
41 }
42}
43
44impl TaskRepository {
45 pub fn new(db: &DbPool) -> anyhow::Result<Self> {
46 let pool = db.pool_arc()?;
47 let write_pool = db.write_pool_arc()?;
48 Ok(Self {
49 pool,
50 write_pool,
51 db_pool: db.clone(),
52 session_analytics_provider: None,
53 file_upload_provider: None,
54 })
55 }
56
57 #[must_use]
58 pub fn with_session_analytics_provider(
59 mut self,
60 provider: DynSessionAnalyticsProvider,
61 ) -> Self {
62 self.session_analytics_provider = Some(provider);
63 self
64 }
65
66 #[must_use]
67 pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
68 self.file_upload_provider = Some(provider);
69 self
70 }
71
72 pub(crate) fn db_pool(&self) -> &DbPool {
73 &self.db_pool
74 }
75
76 pub async fn create_task(
77 &self,
78 task: &Task,
79 user_id: &systemprompt_identifiers::UserId,
80 session_id: &systemprompt_identifiers::SessionId,
81 trace_id: &systemprompt_identifiers::TraceId,
82 agent_name: &str,
83 ) -> Result<String, RepositoryError> {
84 let result = create_task(
85 &self.write_pool,
86 task,
87 user_id,
88 session_id,
89 trace_id,
90 agent_name,
91 )
92 .await?;
93
94 if let Some(ref provider) = self.session_analytics_provider {
95 if let Err(e) = provider.increment_task_count(session_id).await {
96 tracing::warn!(error = %e, "Failed to increment analytics task count");
97 }
98 }
99
100 Ok(result)
101 }
102
103 pub async fn get_task(
104 &self,
105 task_id: &systemprompt_identifiers::TaskId,
106 ) -> Result<Option<Task>, RepositoryError> {
107 get_task(&self.pool, &self.db_pool, task_id).await
108 }
109
110 pub async fn get_task_by_str(&self, task_id: &str) -> Result<Option<Task>, RepositoryError> {
111 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
112 self.get_task(&task_id_typed).await
113 }
114
115 pub async fn list_tasks_by_context(
116 &self,
117 context_id: &systemprompt_identifiers::ContextId,
118 ) -> Result<Vec<Task>, RepositoryError> {
119 list_tasks_by_context(&self.pool, &self.db_pool, context_id).await
120 }
121
122 pub async fn list_tasks_by_context_str(
123 &self,
124 context_id: &str,
125 ) -> Result<Vec<Task>, RepositoryError> {
126 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
127 self.list_tasks_by_context(&context_id_typed).await
128 }
129
130 pub async fn get_tasks_by_user_id(
131 &self,
132 user_id: &systemprompt_identifiers::UserId,
133 limit: Option<i32>,
134 offset: Option<i32>,
135 ) -> Result<Vec<Task>, RepositoryError> {
136 get_tasks_by_user_id(&self.pool, &self.db_pool, user_id, limit, offset).await
137 }
138
139 pub async fn get_tasks_by_user_id_str(
140 &self,
141 user_id: &str,
142 limit: Option<i32>,
143 offset: Option<i32>,
144 ) -> Result<Vec<Task>, RepositoryError> {
145 let user_id_typed = systemprompt_identifiers::UserId::new(user_id);
146 self.get_tasks_by_user_id(&user_id_typed, limit, offset)
147 .await
148 }
149
150 pub async fn track_agent_in_context(
151 &self,
152 context_id: &systemprompt_identifiers::ContextId,
153 agent_name: &str,
154 ) -> Result<(), RepositoryError> {
155 track_agent_in_context(&self.write_pool, context_id, agent_name).await
156 }
157
158 pub async fn track_agent_in_context_str(
159 &self,
160 context_id: &str,
161 agent_name: &str,
162 ) -> Result<(), RepositoryError> {
163 let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
164 self.track_agent_in_context(&context_id_typed, agent_name)
165 .await
166 }
167
168 pub async fn update_task_state(
169 &self,
170 task_id: &systemprompt_identifiers::TaskId,
171 state: TaskState,
172 timestamp: &chrono::DateTime<chrono::Utc>,
173 ) -> Result<(), RepositoryError> {
174 update_task_state(&self.write_pool, task_id, state, timestamp).await
175 }
176
177 pub async fn update_task_state_str(
178 &self,
179 task_id: &str,
180 state: TaskState,
181 timestamp: &chrono::DateTime<chrono::Utc>,
182 ) -> Result<(), RepositoryError> {
183 let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
184 self.update_task_state(&task_id_typed, state, timestamp)
185 .await
186 }
187
188 pub async fn update_task_failed_with_error(
189 &self,
190 task_id: &systemprompt_identifiers::TaskId,
191 error_message: &str,
192 timestamp: &chrono::DateTime<chrono::Utc>,
193 ) -> Result<(), RepositoryError> {
194 update_task_failed_with_error(&self.write_pool, task_id, error_message, timestamp).await
195 }
196
197 pub async fn get_task_context_info(
198 &self,
199 task_id: &systemprompt_identifiers::TaskId,
200 ) -> Result<Option<TaskContextInfo>, RepositoryError> {
201 get_task_context_info(&self.pool, task_id).await
202 }
203}