systemprompt_agent/repository/task/
task_updates.rs1use super::{task_state_to_db_string, TaskRepository};
2use crate::models::a2a::{Message, Task, TaskState};
3use crate::repository::context::message::{
4 get_next_sequence_number_sqlx, persist_message_sqlx, FileUploadContext,
5};
6use systemprompt_traits::RepositoryError;
7
8impl TaskRepository {
9 pub async fn update_task_and_save_messages(
10 &self,
11 task: &Task,
12 user_message: &Message,
13 agent_message: &Message,
14 user_id: Option<&systemprompt_identifiers::UserId>,
15 session_id: &systemprompt_identifiers::SessionId,
16 trace_id: &systemprompt_identifiers::TraceId,
17 ) -> Result<Task, RepositoryError> {
18 let pool = self.get_pg_pool()?;
19 let mut tx = pool
20 .begin()
21 .await
22 .map_err(|e| RepositoryError::database(e))?;
23
24 let status = task_state_to_db_string(task.status.state.clone());
25 let metadata_json = task
26 .metadata
27 .as_ref()
28 .map(|m| {
29 serde_json::to_value(m).unwrap_or_else(|e| {
30 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
31 serde_json::json!({})
32 })
33 })
34 .unwrap_or_else(|| serde_json::json!({}));
35
36 let task_id_str = task.id.as_str();
37 let is_completed = task.status.state == TaskState::Completed;
38
39 let result = if is_completed {
40 sqlx::query!(
41 r#"UPDATE agent_tasks SET
42 status = $1,
43 status_timestamp = $2,
44 metadata = $3,
45 updated_at = CURRENT_TIMESTAMP,
46 completed_at = CURRENT_TIMESTAMP,
47 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
48 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
49 WHERE task_id = $4"#,
50 status,
51 task.status.timestamp,
52 metadata_json,
53 task_id_str
54 )
55 .execute(&mut *tx)
56 .await
57 .map_err(|e| RepositoryError::database(e))?
58 } else {
59 sqlx::query!(
60 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
61 status,
62 task.status.timestamp,
63 metadata_json,
64 task_id_str
65 )
66 .execute(&mut *tx)
67 .await
68 .map_err(|e| RepositoryError::database(e))?
69 };
70
71 if result.rows_affected() == 0 {
72 return Err(RepositoryError::NotFound(format!(
73 "Task not found for update: {}",
74 task.id
75 )));
76 }
77
78 let upload_ctx = self
79 .file_upload_provider
80 .as_ref()
81 .map(|svc| FileUploadContext {
82 upload_provider: svc,
83 context_id: &task.context_id,
84 user_id,
85 session_id: Some(session_id),
86 trace_id: Some(trace_id),
87 });
88
89 let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
90 persist_message_sqlx(
91 &mut tx,
92 user_message,
93 &task.id,
94 &task.context_id,
95 user_seq,
96 user_id,
97 session_id,
98 trace_id,
99 upload_ctx.as_ref(),
100 )
101 .await?;
102
103 let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
104 persist_message_sqlx(
105 &mut tx,
106 agent_message,
107 &task.id,
108 &task.context_id,
109 agent_seq,
110 user_id,
111 session_id,
112 trace_id,
113 upload_ctx.as_ref(),
114 )
115 .await?;
116
117 tx.commit()
118 .await
119 .map_err(|e| RepositoryError::database(e))?;
120
121 if let Some(ref analytics_provider) = self.session_analytics_provider {
122 for _ in 0..2 {
123 if let Err(e) = analytics_provider.increment_message_count(session_id).await {
124 tracing::warn!(error = %e, "Failed to increment analytics message count");
125 }
126 }
127 }
128
129 let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
130 RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
131 })?;
132
133 Ok(updated_task)
134 }
135
136 pub async fn delete_task(
137 &self,
138 task_id: &systemprompt_identifiers::TaskId,
139 ) -> Result<(), RepositoryError> {
140 let pool = self.get_pg_pool()?;
141 let task_id_str = task_id.as_str();
142
143 sqlx::query!(
144 "DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
145 WHERE task_id = $1)",
146 task_id_str
147 )
148 .execute(&*pool)
149 .await
150 .map_err(|e| RepositoryError::database(e))?;
151
152 sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
153 .execute(&*pool)
154 .await
155 .map_err(|e| RepositoryError::database(e))?;
156
157 sqlx::query!(
158 "DELETE FROM task_execution_steps WHERE task_id = $1",
159 task_id_str
160 )
161 .execute(&*pool)
162 .await
163 .map_err(|e| RepositoryError::database(e))?;
164
165 sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
166 .execute(&*pool)
167 .await
168 .map_err(|e| RepositoryError::database(e))?;
169
170 Ok(())
171 }
172}