Skip to main content

systemprompt_agent/repository/task/
mutations.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8    match state {
9        TaskState::Pending => "TASK_STATE_PENDING",
10        TaskState::Submitted => "TASK_STATE_SUBMITTED",
11        TaskState::Working => "TASK_STATE_WORKING",
12        TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
13        TaskState::Completed => "TASK_STATE_COMPLETED",
14        TaskState::Canceled => "TASK_STATE_CANCELED",
15        TaskState::Failed => "TASK_STATE_FAILED",
16        TaskState::Rejected => "TASK_STATE_REJECTED",
17        TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
18        TaskState::Unknown => "TASK_STATE_UNKNOWN",
19    }
20}
21
22#[allow(missing_debug_implementations)]
23pub struct CreateTaskParams<'a> {
24    pub pool: &'a Arc<PgPool>,
25    pub task: &'a Task,
26    pub user_id: &'a systemprompt_identifiers::UserId,
27    pub session_id: &'a systemprompt_identifiers::SessionId,
28    pub trace_id: &'a systemprompt_identifiers::TraceId,
29    pub agent_name: &'a str,
30}
31
32pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
33    let CreateTaskParams {
34        pool,
35        task,
36        user_id,
37        session_id,
38        trace_id,
39        agent_name,
40    } = params;
41    let metadata_json = task.metadata.as_ref().map_or_else(
42        || serde_json::json!({}),
43        |m| {
44            serde_json::to_value(m).unwrap_or_else(|e| {
45                tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
46                serde_json::json!({})
47            })
48        },
49    );
50
51    let status = task_state_to_db_string(task.status.state);
52    let task_id_str = task.id.as_str();
53    let context_id_str = task.context_id.as_str();
54    let user_id_str = user_id.as_ref();
55    let session_id_str = session_id.as_ref();
56    let trace_id_str = trace_id.as_ref();
57
58    sqlx::query!(
59        r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
60        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
61        task_id_str,
62        context_id_str,
63        status,
64        task.status.timestamp,
65        user_id_str,
66        session_id_str,
67        trace_id_str,
68        metadata_json,
69        agent_name
70    )
71    .execute(pool.as_ref())
72    .await
73    .map_err(RepositoryError::database)?;
74
75    Ok(task.id.to_string())
76}
77
78pub async fn track_agent_in_context(
79    pool: &Arc<PgPool>,
80    context_id: &systemprompt_identifiers::ContextId,
81    agent_name: &str,
82) -> Result<(), RepositoryError> {
83    let context_id_str = context_id.as_str();
84    sqlx::query!(
85        r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
86        ON CONFLICT (context_id, agent_name) DO NOTHING"#,
87        context_id_str,
88        agent_name
89    )
90    .execute(pool.as_ref())
91    .await
92    .map_err(RepositoryError::database)?;
93
94    Ok(())
95}
96
97pub async fn update_task_state(
98    pool: &Arc<PgPool>,
99    task_id: &systemprompt_identifiers::TaskId,
100    state: TaskState,
101    timestamp: &chrono::DateTime<chrono::Utc>,
102) -> Result<(), RepositoryError> {
103    let status = task_state_to_db_string(state);
104    let task_id_str = task_id.as_str();
105
106    let mut tx = pool.begin().await.map_err(RepositoryError::database)?;
107
108    let current = sqlx::query!(
109        r#"SELECT status, version FROM agent_tasks WHERE task_id = $1 FOR UPDATE"#,
110        task_id_str
111    )
112    .fetch_optional(&mut *tx)
113    .await
114    .map_err(RepositoryError::database)?
115    .ok_or_else(|| RepositoryError::NotFound(format!("task {task_id_str}")))?;
116
117    let current_state: TaskState = current.status.parse().map_err(|e: String| {
118        RepositoryError::InvalidData(format!("unrecognised stored task state: {e}"))
119    })?;
120
121    if current_state == state {
122        tx.commit().await.map_err(RepositoryError::database)?;
123        return Ok(());
124    }
125
126    if !current_state.can_transition_to(&state) {
127        return Err(RepositoryError::ConstraintViolation(format!(
128            "invalid task state transition for {task_id_str}: {current_state:?} -> {state:?}"
129        )));
130    }
131
132    let expected_version = current.version;
133
134    let rows_affected = if state == TaskState::Completed {
135        sqlx::query!(
136            r#"UPDATE agent_tasks
137               SET status = $1,
138                   status_timestamp = $2,
139                   updated_at = CURRENT_TIMESTAMP,
140                   completed_at = CURRENT_TIMESTAMP,
141                   started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
142                   execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000,
143                   version = version + 1
144               WHERE task_id = $3 AND version = $4"#,
145            status,
146            timestamp,
147            task_id_str,
148            expected_version
149        )
150        .execute(&mut *tx)
151        .await
152        .map_err(RepositoryError::database)?
153        .rows_affected()
154    } else if state == TaskState::Working {
155        sqlx::query!(
156            r#"UPDATE agent_tasks
157               SET status = $1,
158                   status_timestamp = $2,
159                   updated_at = CURRENT_TIMESTAMP,
160                   started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
161                   version = version + 1
162               WHERE task_id = $3 AND version = $4"#,
163            status,
164            timestamp,
165            task_id_str,
166            expected_version
167        )
168        .execute(&mut *tx)
169        .await
170        .map_err(RepositoryError::database)?
171        .rows_affected()
172    } else {
173        sqlx::query!(
174            r#"UPDATE agent_tasks
175               SET status = $1,
176                   status_timestamp = $2,
177                   updated_at = CURRENT_TIMESTAMP,
178                   version = version + 1
179               WHERE task_id = $3 AND version = $4"#,
180            status,
181            timestamp,
182            task_id_str,
183            expected_version
184        )
185        .execute(&mut *tx)
186        .await
187        .map_err(RepositoryError::database)?
188        .rows_affected()
189    };
190
191    if rows_affected == 0 {
192        return Err(RepositoryError::ConstraintViolation(format!(
193            "stale task update for {task_id_str}: expected version {expected_version}"
194        )));
195    }
196
197    tx.commit().await.map_err(RepositoryError::database)?;
198    Ok(())
199}
200
201pub async fn update_task_failed_with_error(
202    pool: &Arc<PgPool>,
203    task_id: &systemprompt_identifiers::TaskId,
204    error_message: &str,
205    timestamp: &chrono::DateTime<chrono::Utc>,
206) -> Result<(), RepositoryError> {
207    let task_id_str = task_id.as_str();
208
209    let mut tx = pool.begin().await.map_err(RepositoryError::database)?;
210
211    let current = sqlx::query!(
212        r#"SELECT status, version FROM agent_tasks WHERE task_id = $1 FOR UPDATE"#,
213        task_id_str
214    )
215    .fetch_optional(&mut *tx)
216    .await
217    .map_err(RepositoryError::database)?
218    .ok_or_else(|| RepositoryError::NotFound(format!("task {task_id_str}")))?;
219
220    let current_state: TaskState = current.status.parse().map_err(|e: String| {
221        RepositoryError::InvalidData(format!("unrecognised stored task state: {e}"))
222    })?;
223
224    if current_state == TaskState::Failed {
225        tx.commit().await.map_err(RepositoryError::database)?;
226        return Ok(());
227    }
228
229    if !current_state.can_transition_to(&TaskState::Failed) {
230        return Err(RepositoryError::ConstraintViolation(format!(
231            "invalid task state transition for {task_id_str}: {current_state:?} -> Failed"
232        )));
233    }
234
235    let expected_version = current.version;
236
237    let rows_affected = sqlx::query!(
238        r#"UPDATE agent_tasks SET
239            status = 'TASK_STATE_FAILED',
240            status_timestamp = $1,
241            error_message = $2,
242            updated_at = CURRENT_TIMESTAMP,
243            completed_at = CURRENT_TIMESTAMP,
244            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
245            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000,
246            version = version + 1
247        WHERE task_id = $3 AND version = $4"#,
248        timestamp,
249        error_message,
250        task_id_str,
251        expected_version
252    )
253    .execute(&mut *tx)
254    .await
255    .map_err(RepositoryError::database)?
256    .rows_affected();
257
258    if rows_affected == 0 {
259        return Err(RepositoryError::ConstraintViolation(format!(
260            "stale task update for {task_id_str}: expected version {expected_version}"
261        )));
262    }
263
264    tx.commit().await.map_err(RepositoryError::database)?;
265    Ok(())
266}