Skip to main content

systemprompt_agent/repository/task/
task_updates.rs

1use super::{task_state_to_db_string, TaskRepository};
2use crate::models::a2a::{Message, Task, TaskState};
3use crate::repository::context::message::{
4    get_next_sequence_number_sqlx, persist_message_sqlx, FileUploadContext,
5};
6use systemprompt_traits::RepositoryError;
7
8impl TaskRepository {
9    pub async fn update_task_and_save_messages(
10        &self,
11        task: &Task,
12        user_message: &Message,
13        agent_message: &Message,
14        user_id: Option<&systemprompt_identifiers::UserId>,
15        session_id: &systemprompt_identifiers::SessionId,
16        trace_id: &systemprompt_identifiers::TraceId,
17    ) -> Result<Task, RepositoryError> {
18        let pool = self.get_pg_pool()?;
19        let mut tx = pool
20            .begin()
21            .await
22            .map_err(|e| RepositoryError::database(e))?;
23
24        let status = task_state_to_db_string(task.status.state.clone());
25        let metadata_json = task
26            .metadata
27            .as_ref()
28            .map(|m| {
29                serde_json::to_value(m).unwrap_or_else(|e| {
30                    tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
31                    serde_json::json!({})
32                })
33            })
34            .unwrap_or_else(|| serde_json::json!({}));
35
36        let task_id_str = task.id.as_str();
37        let is_completed = task.status.state == TaskState::Completed;
38
39        let result = if is_completed {
40            sqlx::query!(
41                r#"UPDATE agent_tasks SET
42                    status = $1,
43                    status_timestamp = $2,
44                    metadata = $3,
45                    updated_at = CURRENT_TIMESTAMP,
46                    completed_at = CURRENT_TIMESTAMP,
47                    started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
48                    execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
49                WHERE task_id = $4"#,
50                status,
51                task.status.timestamp,
52                metadata_json,
53                task_id_str
54            )
55            .execute(&mut *tx)
56            .await
57            .map_err(|e| RepositoryError::database(e))?
58        } else {
59            sqlx::query!(
60                r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
61                status,
62                task.status.timestamp,
63                metadata_json,
64                task_id_str
65            )
66            .execute(&mut *tx)
67            .await
68            .map_err(|e| RepositoryError::database(e))?
69        };
70
71        if result.rows_affected() == 0 {
72            return Err(RepositoryError::NotFound(format!(
73                "Task not found for update: {}",
74                task.id
75            )));
76        }
77
78        let upload_ctx = self
79            .file_upload_provider
80            .as_ref()
81            .map(|svc| FileUploadContext {
82                upload_provider: svc,
83                context_id: &task.context_id,
84                user_id,
85                session_id: Some(session_id),
86                trace_id: Some(trace_id),
87            });
88
89        let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
90        persist_message_sqlx(
91            &mut tx,
92            user_message,
93            &task.id,
94            &task.context_id,
95            user_seq,
96            user_id,
97            session_id,
98            trace_id,
99            upload_ctx.as_ref(),
100        )
101        .await?;
102
103        let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
104        persist_message_sqlx(
105            &mut tx,
106            agent_message,
107            &task.id,
108            &task.context_id,
109            agent_seq,
110            user_id,
111            session_id,
112            trace_id,
113            upload_ctx.as_ref(),
114        )
115        .await?;
116
117        tx.commit()
118            .await
119            .map_err(|e| RepositoryError::database(e))?;
120
121        if let Some(ref analytics_provider) = self.session_analytics_provider {
122            for _ in 0..2 {
123                if let Err(e) = analytics_provider.increment_message_count(session_id).await {
124                    tracing::warn!(error = %e, "Failed to increment analytics message count");
125                }
126            }
127        }
128
129        let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
130            RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
131        })?;
132
133        Ok(updated_task)
134    }
135
136    pub async fn delete_task(
137        &self,
138        task_id: &systemprompt_identifiers::TaskId,
139    ) -> Result<(), RepositoryError> {
140        let pool = self.get_pg_pool()?;
141        let task_id_str = task_id.as_str();
142
143        sqlx::query!(
144            "DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
145             WHERE task_id = $1)",
146            task_id_str
147        )
148        .execute(&*pool)
149        .await
150        .map_err(|e| RepositoryError::database(e))?;
151
152        sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
153            .execute(&*pool)
154            .await
155            .map_err(|e| RepositoryError::database(e))?;
156
157        sqlx::query!(
158            "DELETE FROM task_execution_steps WHERE task_id = $1",
159            task_id_str
160        )
161        .execute(&*pool)
162        .await
163        .map_err(|e| RepositoryError::database(e))?;
164
165        sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
166            .execute(&*pool)
167            .await
168            .map_err(|e| RepositoryError::database(e))?;
169
170        Ok(())
171    }
172}