Skip to main content

systemprompt_agent/repository/task/
mod.rs

1pub mod constructor;
2mod mutations;
3mod queries;
4mod task_messages;
5mod task_updates;
6
7pub use constructor::TaskConstructor;
8pub use mutations::{
9    CreateTaskParams, create_task, task_state_to_db_string, track_agent_in_context,
10    update_task_failed_with_error, update_task_state,
11};
12pub use queries::{
13    TaskContextInfo, get_task, get_task_context_info, get_tasks_by_user_id, list_tasks_by_context,
14};
15pub use task_updates::UpdateTaskAndSaveMessagesParams;
16
17use crate::models::a2a::{Task, TaskState};
18use sqlx::PgPool;
19use std::sync::Arc;
20use systemprompt_database::DbPool;
21use systemprompt_identifiers::{SessionId, TraceId, UserId};
22use systemprompt_traits::{DynFileUploadProvider, DynSessionAnalyticsProvider, RepositoryError};
23
24#[allow(missing_debug_implementations)]
25pub struct RepoCreateTaskParams<'a> {
26    pub task: &'a Task,
27    pub user_id: &'a UserId,
28    pub session_id: &'a SessionId,
29    pub trace_id: &'a TraceId,
30    pub agent_name: &'a str,
31}
32
33#[derive(Clone)]
34pub struct TaskRepository {
35    pool: Arc<PgPool>,
36    write_pool: Arc<PgPool>,
37    db_pool: DbPool,
38    pub(crate) session_analytics_provider: Option<DynSessionAnalyticsProvider>,
39    pub(crate) file_upload_provider: Option<DynFileUploadProvider>,
40}
41
42impl std::fmt::Debug for TaskRepository {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("TaskRepository")
45            .field("pool", &"<PgPool>")
46            .field("write_pool", &"<PgPool>")
47            .field("db_pool", &"<DbPool>")
48            .field(
49                "session_analytics_provider",
50                &self.session_analytics_provider.is_some(),
51            )
52            .field("file_upload_provider", &self.file_upload_provider.is_some())
53            .finish()
54    }
55}
56
57impl TaskRepository {
58    pub fn new(db: &DbPool) -> anyhow::Result<Self> {
59        let pool = db.pool_arc()?;
60        let write_pool = db.write_pool_arc()?;
61        Ok(Self {
62            pool,
63            write_pool,
64            db_pool: Arc::clone(db),
65            session_analytics_provider: None,
66            file_upload_provider: None,
67        })
68    }
69
70    #[must_use]
71    pub fn with_session_analytics_provider(
72        mut self,
73        provider: DynSessionAnalyticsProvider,
74    ) -> Self {
75        self.session_analytics_provider = Some(provider);
76        self
77    }
78
79    #[must_use]
80    pub fn with_file_upload_provider(mut self, provider: DynFileUploadProvider) -> Self {
81        self.file_upload_provider = Some(provider);
82        self
83    }
84
85    pub(crate) const fn db_pool(&self) -> &DbPool {
86        &self.db_pool
87    }
88
89    pub async fn create_task(
90        &self,
91        params: RepoCreateTaskParams<'_>,
92    ) -> Result<String, RepositoryError> {
93        let result = create_task(CreateTaskParams {
94            pool: &self.write_pool,
95            task: params.task,
96            user_id: params.user_id,
97            session_id: params.session_id,
98            trace_id: params.trace_id,
99            agent_name: params.agent_name,
100        })
101        .await?;
102
103        if let Some(ref provider) = self.session_analytics_provider {
104            if let Err(e) = provider.increment_task_count(params.session_id).await {
105                tracing::warn!(error = %e, "Failed to increment analytics task count");
106            }
107        }
108
109        Ok(result)
110    }
111
112    pub async fn get_task(
113        &self,
114        task_id: &systemprompt_identifiers::TaskId,
115    ) -> Result<Option<Task>, RepositoryError> {
116        get_task(&self.pool, &self.db_pool, task_id).await
117    }
118
119    pub async fn list_tasks_by_context(
120        &self,
121        context_id: &systemprompt_identifiers::ContextId,
122    ) -> Result<Vec<Task>, RepositoryError> {
123        list_tasks_by_context(&self.pool, &self.db_pool, context_id).await
124    }
125
126    pub async fn get_tasks_by_user_id(
127        &self,
128        user_id: &UserId,
129        limit: Option<i32>,
130        offset: Option<i32>,
131    ) -> Result<Vec<Task>, RepositoryError> {
132        get_tasks_by_user_id(&self.pool, &self.db_pool, user_id, limit, offset).await
133    }
134
135    pub async fn track_agent_in_context(
136        &self,
137        context_id: &systemprompt_identifiers::ContextId,
138        agent_name: &str,
139    ) -> Result<(), RepositoryError> {
140        track_agent_in_context(&self.write_pool, context_id, agent_name).await
141    }
142
143    pub async fn update_task_state(
144        &self,
145        task_id: &systemprompt_identifiers::TaskId,
146        state: TaskState,
147        timestamp: &chrono::DateTime<chrono::Utc>,
148    ) -> Result<(), RepositoryError> {
149        update_task_state(&self.write_pool, task_id, state, timestamp).await
150    }
151
152    pub async fn update_task_failed_with_error(
153        &self,
154        task_id: &systemprompt_identifiers::TaskId,
155        error_message: &str,
156        timestamp: &chrono::DateTime<chrono::Utc>,
157    ) -> Result<(), RepositoryError> {
158        update_task_failed_with_error(&self.write_pool, task_id, error_message, timestamp).await
159    }
160
161    pub async fn get_task_context_info(
162        &self,
163        task_id: &systemprompt_identifiers::TaskId,
164    ) -> Result<Option<TaskContextInfo>, RepositoryError> {
165        get_task_context_info(&self.pool, task_id).await
166    }
167}