systemprompt_agent/repository/task/
task_updates.rs1use super::{TaskRepository, task_state_to_db_string};
2use crate::models::a2a::{Message, Task, TaskState};
3use crate::repository::context::message::{
4 FileUploadContext, PersistMessageSqlxParams, get_next_sequence_number_sqlx,
5 persist_message_sqlx,
6};
7use systemprompt_traits::RepositoryError;
8
9#[allow(missing_debug_implementations)]
10pub struct UpdateTaskAndSaveMessagesParams<'a> {
11 pub task: &'a Task,
12 pub user_message: &'a Message,
13 pub agent_message: &'a Message,
14 pub user_id: Option<&'a systemprompt_identifiers::UserId>,
15 pub session_id: &'a systemprompt_identifiers::SessionId,
16 pub trace_id: &'a systemprompt_identifiers::TraceId,
17}
18
19impl TaskRepository {
20 pub async fn update_task_and_save_messages(
21 &self,
22 params: UpdateTaskAndSaveMessagesParams<'_>,
23 ) -> Result<Task, RepositoryError> {
24 let UpdateTaskAndSaveMessagesParams {
25 task,
26 user_message,
27 agent_message,
28 user_id,
29 session_id,
30 trace_id,
31 } = params;
32 let mut tx = self
33 .write_pool
34 .begin()
35 .await
36 .map_err(RepositoryError::database)?;
37
38 let status = task_state_to_db_string(task.status.state);
39 let metadata_json = task
40 .metadata
41 .as_ref().map_or_else(|| serde_json::json!({}), |m| {
42 serde_json::to_value(m).unwrap_or_else(|e| {
43 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
44 serde_json::json!({})
45 })
46 });
47
48 let task_id_str = task.id.as_str();
49 let is_completed = task.status.state == TaskState::Completed;
50
51 let result = if is_completed {
52 sqlx::query!(
53 r#"UPDATE agent_tasks SET
54 status = $1,
55 status_timestamp = $2,
56 metadata = $3,
57 updated_at = CURRENT_TIMESTAMP,
58 completed_at = CURRENT_TIMESTAMP,
59 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
60 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
61 WHERE task_id = $4"#,
62 status,
63 task.status.timestamp,
64 metadata_json,
65 task_id_str
66 )
67 .execute(&mut *tx)
68 .await
69 .map_err(RepositoryError::database)?
70 } else {
71 sqlx::query!(
72 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
73 status,
74 task.status.timestamp,
75 metadata_json,
76 task_id_str
77 )
78 .execute(&mut *tx)
79 .await
80 .map_err(RepositoryError::database)?
81 };
82
83 if result.rows_affected() == 0 {
84 return Err(RepositoryError::NotFound(format!(
85 "Task not found for update: {}",
86 task.id
87 )));
88 }
89
90 let upload_ctx = self
91 .file_upload_provider
92 .as_ref()
93 .map(|svc| FileUploadContext {
94 upload_provider: svc,
95 context_id: &task.context_id,
96 user_id,
97 session_id: Some(session_id),
98 trace_id: Some(trace_id),
99 });
100
101 let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
102 persist_message_sqlx(PersistMessageSqlxParams {
103 tx: &mut tx,
104 message: user_message,
105 task_id: &task.id,
106 context_id: &task.context_id,
107 sequence_number: user_seq,
108 user_id,
109 session_id,
110 trace_id,
111 upload_ctx: upload_ctx.as_ref(),
112 })
113 .await?;
114
115 let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
116 persist_message_sqlx(PersistMessageSqlxParams {
117 tx: &mut tx,
118 message: agent_message,
119 task_id: &task.id,
120 context_id: &task.context_id,
121 sequence_number: agent_seq,
122 user_id,
123 session_id,
124 trace_id,
125 upload_ctx: upload_ctx.as_ref(),
126 })
127 .await?;
128
129 tx.commit().await.map_err(RepositoryError::database)?;
130
131 if let Some(ref analytics_provider) = self.session_analytics_provider {
132 for _ in 0..2 {
133 if let Err(e) = analytics_provider.increment_message_count(session_id).await {
134 tracing::warn!(error = %e, "Failed to increment analytics message count");
135 }
136 }
137 }
138
139 let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
140 RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
141 })?;
142
143 Ok(updated_task)
144 }
145
146 pub async fn delete_task(
147 &self,
148 task_id: &systemprompt_identifiers::TaskId,
149 ) -> Result<(), RepositoryError> {
150 let task_id_str = task_id.as_str();
151
152 sqlx::query!(
153 "DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
154 WHERE task_id = $1)",
155 task_id_str
156 )
157 .execute(&*self.write_pool)
158 .await
159 .map_err(RepositoryError::database)?;
160
161 sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
162 .execute(&*self.write_pool)
163 .await
164 .map_err(RepositoryError::database)?;
165
166 sqlx::query!(
167 "DELETE FROM task_execution_steps WHERE task_id = $1",
168 task_id_str
169 )
170 .execute(&*self.write_pool)
171 .await
172 .map_err(RepositoryError::database)?;
173
174 sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
175 .execute(&*self.write_pool)
176 .await
177 .map_err(RepositoryError::database)?;
178
179 Ok(())
180 }
181}