Skip to main content

systemprompt_agent/repository/task/
task_updates.rs

1use super::{TaskRepository, task_state_to_db_string};
2use crate::models::a2a::{Message, Task, TaskState};
3use crate::repository::context::message::{
4    FileUploadContext, PersistMessageSqlxParams, get_next_sequence_number_sqlx,
5    persist_message_sqlx,
6};
7use systemprompt_traits::RepositoryError;
8
9#[allow(missing_debug_implementations)]
10pub struct UpdateTaskAndSaveMessagesParams<'a> {
11    pub task: &'a Task,
12    pub user_message: &'a Message,
13    pub agent_message: &'a Message,
14    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
15    pub session_id: &'a systemprompt_identifiers::SessionId,
16    pub trace_id: &'a systemprompt_identifiers::TraceId,
17}
18
19impl TaskRepository {
20    pub async fn update_task_and_save_messages(
21        &self,
22        params: UpdateTaskAndSaveMessagesParams<'_>,
23    ) -> Result<Task, RepositoryError> {
24        let UpdateTaskAndSaveMessagesParams {
25            task,
26            user_message,
27            agent_message,
28            user_id,
29            session_id,
30            trace_id,
31        } = params;
32        let mut tx = self
33            .write_pool
34            .begin()
35            .await
36            .map_err(RepositoryError::database)?;
37
38        let status = task_state_to_db_string(task.status.state);
39        let metadata_json = task
40            .metadata
41            .as_ref().map_or_else(|| serde_json::json!({}), |m| {
42                serde_json::to_value(m).unwrap_or_else(|e| {
43                    tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
44                    serde_json::json!({})
45                })
46            });
47
48        let task_id_str = task.id.as_str();
49        let is_completed = task.status.state == TaskState::Completed;
50
51        let result = if is_completed {
52            sqlx::query!(
53                r#"UPDATE agent_tasks SET
54                    status = $1,
55                    status_timestamp = $2,
56                    metadata = $3,
57                    updated_at = CURRENT_TIMESTAMP,
58                    completed_at = CURRENT_TIMESTAMP,
59                    started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
60                    execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
61                WHERE task_id = $4"#,
62                status,
63                task.status.timestamp,
64                metadata_json,
65                task_id_str
66            )
67            .execute(&mut *tx)
68            .await
69            .map_err(RepositoryError::database)?
70        } else {
71            sqlx::query!(
72                r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
73                status,
74                task.status.timestamp,
75                metadata_json,
76                task_id_str
77            )
78            .execute(&mut *tx)
79            .await
80            .map_err(RepositoryError::database)?
81        };
82
83        if result.rows_affected() == 0 {
84            return Err(RepositoryError::NotFound(format!(
85                "Task not found for update: {}",
86                task.id
87            )));
88        }
89
90        let upload_ctx = self
91            .file_upload_provider
92            .as_ref()
93            .map(|svc| FileUploadContext {
94                upload_provider: svc,
95                context_id: &task.context_id,
96                user_id,
97                session_id: Some(session_id),
98                trace_id: Some(trace_id),
99            });
100
101        let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
102        persist_message_sqlx(PersistMessageSqlxParams {
103            tx: &mut tx,
104            message: user_message,
105            task_id: &task.id,
106            context_id: &task.context_id,
107            sequence_number: user_seq,
108            user_id,
109            session_id,
110            trace_id,
111            upload_ctx: upload_ctx.as_ref(),
112        })
113        .await?;
114
115        let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
116        persist_message_sqlx(PersistMessageSqlxParams {
117            tx: &mut tx,
118            message: agent_message,
119            task_id: &task.id,
120            context_id: &task.context_id,
121            sequence_number: agent_seq,
122            user_id,
123            session_id,
124            trace_id,
125            upload_ctx: upload_ctx.as_ref(),
126        })
127        .await?;
128
129        tx.commit().await.map_err(RepositoryError::database)?;
130
131        if let Some(ref analytics_provider) = self.session_analytics_provider {
132            for _ in 0..2 {
133                if let Err(e) = analytics_provider.increment_message_count(session_id).await {
134                    tracing::warn!(error = %e, "Failed to increment analytics message count");
135                }
136            }
137        }
138
139        let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
140            RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
141        })?;
142
143        Ok(updated_task)
144    }
145
146    pub async fn delete_task(
147        &self,
148        task_id: &systemprompt_identifiers::TaskId,
149    ) -> Result<(), RepositoryError> {
150        let task_id_str = task_id.as_str();
151
152        sqlx::query!(
153            "DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
154             WHERE task_id = $1)",
155            task_id_str
156        )
157        .execute(&*self.write_pool)
158        .await
159        .map_err(RepositoryError::database)?;
160
161        sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
162            .execute(&*self.write_pool)
163            .await
164            .map_err(RepositoryError::database)?;
165
166        sqlx::query!(
167            "DELETE FROM task_execution_steps WHERE task_id = $1",
168            task_id_str
169        )
170        .execute(&*self.write_pool)
171        .await
172        .map_err(RepositoryError::database)?;
173
174        sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
175            .execute(&*self.write_pool)
176            .await
177            .map_err(RepositoryError::database)?;
178
179        Ok(())
180    }
181}