Skip to main content

systemprompt_agent/repository/task/
mod.rs

1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9    create_task, task_state_to_db_string, track_agent_in_context, update_task_failed_with_error,
10    update_task_state,
11};
12pub use queries::{
13    get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context, TaskContextInfo,
14};
15
16use crate::models::a2a::{Task, TaskState};
17use sqlx::PgPool;
18use std::sync::Arc;
19use systemprompt_database::DbPool;
20use systemprompt_traits::{DynFileUploadProvider, DynSessionAnalyticsProvider, RepositoryError};
21
22#[derive(Clone)]
23pub struct TaskRepository {
24    pool: Arc<PgPool>,
25    write_pool: Arc<PgPool>,
26    db_pool: DbPool,
27    pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
28    pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
29}
30
31impl std::fmt::Debug for TaskRepository {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("TaskRepository")
34            .field("db_pool", &"<DbPool>")
35            .field(
36                "session_analytics_provider",
37                &self.session_analytics_provider.is_some(),
38            )
39            .field("file_upload_provider", &self.file_upload_provider.is_some())
40            .finish()
41    }
42}
43
44impl TaskRepository {
45    pub fn new(db: &DbPool) -> anyhow::Result<Self> {
46        let pool = db.pool_arc()?;
47        let write_pool = db.write_pool_arc()?;
48        Ok(Self {
49            pool,
50            write_pool,
51            db_pool: db.clone(),
52            session_analytics_provider: None,
53            file_upload_provider: None,
54        })
55    }
56
57    #[must_use]
58    pub fn with_session_analytics_provider(
59        mut self,
60        provider: DynSessionAnalyticsProvider,
61    ) -> Self {
62        self.session_analytics_provider = Some(provider);
63        self
64    }
65
66    #[must_use]
67    pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
68        self.file_upload_provider = Some(provider);
69        self
70    }
71
72    pub(crate) fn db_pool(&self) -> &DbPool {
73        &self.db_pool
74    }
75
76    pub async fn create_task(
77        &self,
78        task: &Task,
79        user_id: &systemprompt_identifiers::UserId,
80        session_id: &systemprompt_identifiers::SessionId,
81        trace_id: &systemprompt_identifiers::TraceId,
82        agent_name: &str,
83    ) -> Result<String, RepositoryError> {
84        let result = create_task(
85            &self.write_pool,
86            task,
87            user_id,
88            session_id,
89            trace_id,
90            agent_name,
91        )
92        .await?;
93
94        if let Some(ref provider) = self.session_analytics_provider {
95            if let Err(e) = provider.increment_task_count(session_id).await {
96                tracing::warn!(error = %e, "Failed to increment analytics task count");
97            }
98        }
99
100        Ok(result)
101    }
102
103    pub async fn get_task(
104        &self,
105        task_id: &systemprompt_identifiers::TaskId,
106    ) -> Result<Option<Task>, RepositoryError> {
107        get_task(&self.pool, &self.db_pool, task_id).await
108    }
109
110    pub async fn get_task_by_str(&self, task_id: &str) -> Result<Option<Task>, RepositoryError> {
111        let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
112        self.get_task(&task_id_typed).await
113    }
114
115    pub async fn list_tasks_by_context(
116        &self,
117        context_id: &systemprompt_identifiers::ContextId,
118    ) -> Result<Vec<Task>, RepositoryError> {
119        list_tasks_by_context(&self.pool, &self.db_pool, context_id).await
120    }
121
122    pub async fn list_tasks_by_context_str(
123        &self,
124        context_id: &str,
125    ) -> Result<Vec<Task>, RepositoryError> {
126        let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
127        self.list_tasks_by_context(&context_id_typed).await
128    }
129
130    pub async fn get_tasks_by_user_id(
131        &self,
132        user_id: &systemprompt_identifiers::UserId,
133        limit: Option<i32>,
134        offset: Option<i32>,
135    ) -> Result<Vec<Task>, RepositoryError> {
136        get_tasks_by_user_id(&self.pool, &self.db_pool, user_id, limit, offset).await
137    }
138
139    pub async fn get_tasks_by_user_id_str(
140        &self,
141        user_id: &str,
142        limit: Option<i32>,
143        offset: Option<i32>,
144    ) -> Result<Vec<Task>, RepositoryError> {
145        let user_id_typed = systemprompt_identifiers::UserId::new(user_id);
146        self.get_tasks_by_user_id(&user_id_typed, limit, offset)
147            .await
148    }
149
150    pub async fn track_agent_in_context(
151        &self,
152        context_id: &systemprompt_identifiers::ContextId,
153        agent_name: &str,
154    ) -> Result<(), RepositoryError> {
155        track_agent_in_context(&self.write_pool, context_id, agent_name).await
156    }
157
158    pub async fn track_agent_in_context_str(
159        &self,
160        context_id: &str,
161        agent_name: &str,
162    ) -> Result<(), RepositoryError> {
163        let context_id_typed = systemprompt_identifiers::ContextId::new(context_id);
164        self.track_agent_in_context(&context_id_typed, agent_name)
165            .await
166    }
167
168    pub async fn update_task_state(
169        &self,
170        task_id: &systemprompt_identifiers::TaskId,
171        state: TaskState,
172        timestamp: &chrono::DateTime<chrono::Utc>,
173    ) -> Result<(), RepositoryError> {
174        update_task_state(&self.write_pool, task_id, state, timestamp).await
175    }
176
177    pub async fn update_task_state_str(
178        &self,
179        task_id: &str,
180        state: TaskState,
181        timestamp: &chrono::DateTime<chrono::Utc>,
182    ) -> Result<(), RepositoryError> {
183        let task_id_typed = systemprompt_identifiers::TaskId::new(task_id);
184        self.update_task_state(&task_id_typed, state, timestamp)
185            .await
186    }
187
188    pub async fn update_task_failed_with_error(
189        &self,
190        task_id: &systemprompt_identifiers::TaskId,
191        error_message: &str,
192        timestamp: &chrono::DateTime<chrono::Utc>,
193    ) -> Result<(), RepositoryError> {
194        update_task_failed_with_error(&self.write_pool, task_id, error_message, timestamp).await
195    }
196
197    pub async fn get_task_context_info(
198        &self,
199        task_id: &systemprompt_identifiers::TaskId,
200    ) -> Result<Option<TaskContextInfo>, RepositoryError> {
201        get_task_context_info(&self.pool, task_id).await
202    }
203}