Skip to main content

systemprompt_agent/repository/task/
mod.rs

1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9    create_task, task_state_to_db_string, track_agent_in_context, update_task_failed_with_error,
10    update_task_state,
11};
12pub use queries::{
13    get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context, TaskContextInfo,
14};
15
16use crate::models::a2a::{Task, TaskState};
17use sqlx::PgPool;
18use std::sync::Arc;
19use systemprompt_database::DbPool;
20use systemprompt_traits::{
21    DynFileUploadProvider, DynSessionAnalyticsProvider, Repository as RepositoryTrait,
22    RepositoryError,
23};
24
25#[derive(Clone)]
26pub struct TaskRepository {
27    db_pool: DbPool,
28    pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
29    pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
30}
31
32impl std::fmt::Debug for TaskRepository {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("TaskRepository")
35            .field("db_pool", &"<DbPool>")
36            .field(
37                "session_analytics_provider",
38                &self.session_analytics_provider.is_some(),
39            )
40            .field("file_upload_provider", &self.file_upload_provider.is_some())
41            .finish()
42    }
43}
44
45impl TaskRepository {
46    #[must_use]
47    pub fn new(db_pool: DbPool) -> Self {
48        Self {
49            db_pool,
50            session_analytics_provider: None,
51            file_upload_provider: None,
52        }
53    }
54
55    #[must_use]
56    pub fn with_session_analytics_provider(
57        mut self,
58        provider: DynSessionAnalyticsProvider,
59    ) -> Self {
60        self.session_analytics_provider = Some(provider);
61        self
62    }
63
64    #[must_use]
65    pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
66        self.file_upload_provider = Some(provider);
67        self
68    }
69
70    pub(crate) fn get_pg_pool(&self) -> Result<Arc<PgPool>, RepositoryError> {
71        self.db_pool.as_ref().get_postgres_pool().ok_or_else(|| {
72            RepositoryError::InvalidData("PostgreSQL pool not available".to_string())
73        })
74    }
75
76    pub async fn create_task(
77        &self,
78        task: &Task,
79        user_id: &systemprompt_identifiers::UserId,
80        session_id: &systemprompt_identifiers::SessionId,
81        trace_id: &systemprompt_identifiers::TraceId,
82        agent_name: &str,
83    ) -> Result<String, RepositoryError> {
84        let pool = self.get_pg_pool()?;
85        let result = create_task(&pool, task, user_id, session_id, trace_id, agent_name).await?;
86
87        if let Some(ref provider) = self.session_analytics_provider {
88            if let Err(e) = provider.increment_task_count(session_id).await {
89                tracing::warn!(error = %e, "Failed to increment analytics task count");
90            }
91        }
92
93        Ok(result)
94    }
95
96    pub async fn get_task(
97        &self,
98        task_id: &systemprompt_identifiers::TaskId,
99    ) -> Result<Option<Task>, RepositoryError> {
100        let pool = self.get_pg_pool()?;
101        get_task(&pool, &self.db_pool, task_id).await
102    }
103
104    pub async fn get_task_by_str(&self, task_id: &str) -> Result<Option<Task>, RepositoryError> {
105        let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
106        self.get_task(&task_id_typed).await
107    }
108
109    pub async fn list_tasks_by_context(
110        &self,
111        context_id: &systemprompt_identifiers::ContextId,
112    ) -> Result<Vec<Task>, RepositoryError> {
113        let pool = self.get_pg_pool()?;
114        list_tasks_by_context(&pool, &self.db_pool, context_id).await
115    }
116
117    pub async fn list_tasks_by_context_str(
118        &self,
119        context_id: &str,
120    ) -> Result<Vec<Task>, RepositoryError> {
121        let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
122        self.list_tasks_by_context(&context_id_typed).await
123    }
124
125    pub async fn get_tasks_by_user_id(
126        &self,
127        user_id: &systemprompt_identifiers::UserId,
128        limit: Option<i32>,
129        offset: Option<i32>,
130    ) -> Result<Vec<Task>, RepositoryError> {
131        let pool = self.get_pg_pool()?;
132        get_tasks_by_user_id(&pool, &self.db_pool, user_id, limit, offset).await
133    }
134
135    pub async fn get_tasks_by_user_id_str(
136        &self,
137        user_id: &str,
138        limit: Option<i32>,
139        offset: Option<i32>,
140    ) -> Result<Vec<Task>, RepositoryError> {
141        let user_id_typed = systemprompt_identifiers::UserId::new(user_id);
142        self.get_tasks_by_user_id(&user_id_typed, limit, offset)
143            .await
144    }
145
146    pub async fn track_agent_in_context(
147        &self,
148        context_id: &systemprompt_identifiers::ContextId,
149        agent_name: &str,
150    ) -> Result<(), RepositoryError> {
151        let pool = self.get_pg_pool()?;
152        track_agent_in_context(&pool, context_id, agent_name).await
153    }
154
155    pub async fn track_agent_in_context_str(
156        &self,
157        context_id: &str,
158        agent_name: &str,
159    ) -> Result<(), RepositoryError> {
160        let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
161        self.track_agent_in_context(&context_id_typed, agent_name)
162            .await
163    }
164
165    pub async fn update_task_state(
166        &self,
167        task_id: &systemprompt_identifiers::TaskId,
168        state: TaskState,
169        timestamp: &chrono::DateTime<chrono::Utc>,
170    ) -> Result<(), RepositoryError> {
171        let pool = self.get_pg_pool()?;
172        update_task_state(&pool, task_id, state, timestamp).await
173    }
174
175    pub async fn update_task_state_str(
176        &self,
177        task_id: &str,
178        state: TaskState,
179        timestamp: &chrono::DateTime<chrono::Utc>,
180    ) -> Result<(), RepositoryError> {
181        let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
182        self.update_task_state(&task_id_typed, state, timestamp)
183            .await
184    }
185
186    pub async fn update_task_failed_with_error(
187        &self,
188        task_id: &systemprompt_identifiers::TaskId,
189        error_message: &str,
190        timestamp: &chrono::DateTime<chrono::Utc>,
191    ) -> Result<(), RepositoryError> {
192        let pool = self.get_pg_pool()?;
193        update_task_failed_with_error(&pool, task_id, error_message, timestamp).await
194    }
195
196    pub async fn get_task_context_info(
197        &self,
198        task_id: &systemprompt_identifiers::TaskId,
199    ) -> Result<Option<TaskContextInfo>, RepositoryError> {
200        let pool = self.get_pg_pool()?;
201        get_task_context_info(&pool, task_id).await
202    }
203}
204
205impl RepositoryTrait for TaskRepository {
206    type Pool = DbPool;
207    type Error = RepositoryError;
208
209    fn pool(&self) -> &Self::Pool {
210        &self.db_pool
211    }
212}