systemprompt_agent/repository/task/
mutations.rs1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8 match state {
9 TaskState::Pending => "TASK_STATE_PENDING",
10 TaskState::Submitted => "TASK_STATE_SUBMITTED",
11 TaskState::Working => "TASK_STATE_WORKING",
12 TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
13 TaskState::Completed => "TASK_STATE_COMPLETED",
14 TaskState::Canceled => "TASK_STATE_CANCELED",
15 TaskState::Failed => "TASK_STATE_FAILED",
16 TaskState::Rejected => "TASK_STATE_REJECTED",
17 TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
18 TaskState::Unknown => "TASK_STATE_UNKNOWN",
19 }
20}
21
22#[allow(missing_debug_implementations)]
23pub struct CreateTaskParams<'a> {
24 pub pool: &'a Arc<PgPool>,
25 pub task: &'a Task,
26 pub user_id: &'a systemprompt_identifiers::UserId,
27 pub session_id: &'a systemprompt_identifiers::SessionId,
28 pub trace_id: &'a systemprompt_identifiers::TraceId,
29 pub agent_name: &'a str,
30}
31
32pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
33 let CreateTaskParams {
34 pool,
35 task,
36 user_id,
37 session_id,
38 trace_id,
39 agent_name,
40 } = params;
41 let metadata_json = task.metadata.as_ref().map_or_else(
42 || serde_json::json!({}),
43 |m| {
44 serde_json::to_value(m).unwrap_or_else(|e| {
45 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
46 serde_json::json!({})
47 })
48 },
49 );
50
51 let status = task_state_to_db_string(task.status.state);
52 let task_id_str = task.id.as_str();
53 let context_id_str = task.context_id.as_str();
54 let user_id_str = user_id.as_ref();
55 let session_id_str = session_id.as_ref();
56 let trace_id_str = trace_id.as_ref();
57
58 sqlx::query!(
59 r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
61 task_id_str,
62 context_id_str,
63 status,
64 task.status.timestamp,
65 user_id_str,
66 session_id_str,
67 trace_id_str,
68 metadata_json,
69 agent_name
70 )
71 .execute(pool.as_ref())
72 .await
73 .map_err(RepositoryError::database)?;
74
75 Ok(task.id.to_string())
76}
77
78pub async fn track_agent_in_context(
79 pool: &Arc<PgPool>,
80 context_id: &systemprompt_identifiers::ContextId,
81 agent_name: &str,
82) -> Result<(), RepositoryError> {
83 let context_id_str = context_id.as_str();
84 sqlx::query!(
85 r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
86 ON CONFLICT (context_id, agent_name) DO NOTHING"#,
87 context_id_str,
88 agent_name
89 )
90 .execute(pool.as_ref())
91 .await
92 .map_err(RepositoryError::database)?;
93
94 Ok(())
95}
96
97pub async fn update_task_state(
98 pool: &Arc<PgPool>,
99 task_id: &systemprompt_identifiers::TaskId,
100 state: TaskState,
101 timestamp: &chrono::DateTime<chrono::Utc>,
102) -> Result<(), RepositoryError> {
103 let status = task_state_to_db_string(state);
104 let task_id_str = task_id.as_str();
105
106 let mut tx = pool.begin().await.map_err(RepositoryError::database)?;
107
108 let current = sqlx::query!(
109 r#"SELECT status, version FROM agent_tasks WHERE task_id = $1 FOR UPDATE"#,
110 task_id_str
111 )
112 .fetch_optional(&mut *tx)
113 .await
114 .map_err(RepositoryError::database)?
115 .ok_or_else(|| RepositoryError::NotFound(format!("task {task_id_str}")))?;
116
117 let current_state: TaskState = current.status.parse().map_err(|e: String| {
118 RepositoryError::InvalidData(format!("unrecognised stored task state: {e}"))
119 })?;
120
121 if current_state == state {
122 tx.commit().await.map_err(RepositoryError::database)?;
123 return Ok(());
124 }
125
126 if !current_state.can_transition_to(&state) {
127 return Err(RepositoryError::ConstraintViolation(format!(
128 "invalid task state transition for {task_id_str}: {current_state:?} -> {state:?}"
129 )));
130 }
131
132 let expected_version = current.version;
133
134 let rows_affected = if state == TaskState::Completed {
135 sqlx::query!(
136 r#"UPDATE agent_tasks
137 SET status = $1,
138 status_timestamp = $2,
139 updated_at = CURRENT_TIMESTAMP,
140 completed_at = CURRENT_TIMESTAMP,
141 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
142 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000,
143 version = version + 1
144 WHERE task_id = $3 AND version = $4"#,
145 status,
146 timestamp,
147 task_id_str,
148 expected_version
149 )
150 .execute(&mut *tx)
151 .await
152 .map_err(RepositoryError::database)?
153 .rows_affected()
154 } else if state == TaskState::Working {
155 sqlx::query!(
156 r#"UPDATE agent_tasks
157 SET status = $1,
158 status_timestamp = $2,
159 updated_at = CURRENT_TIMESTAMP,
160 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
161 version = version + 1
162 WHERE task_id = $3 AND version = $4"#,
163 status,
164 timestamp,
165 task_id_str,
166 expected_version
167 )
168 .execute(&mut *tx)
169 .await
170 .map_err(RepositoryError::database)?
171 .rows_affected()
172 } else {
173 sqlx::query!(
174 r#"UPDATE agent_tasks
175 SET status = $1,
176 status_timestamp = $2,
177 updated_at = CURRENT_TIMESTAMP,
178 version = version + 1
179 WHERE task_id = $3 AND version = $4"#,
180 status,
181 timestamp,
182 task_id_str,
183 expected_version
184 )
185 .execute(&mut *tx)
186 .await
187 .map_err(RepositoryError::database)?
188 .rows_affected()
189 };
190
191 if rows_affected == 0 {
192 return Err(RepositoryError::ConstraintViolation(format!(
193 "stale task update for {task_id_str}: expected version {expected_version}"
194 )));
195 }
196
197 tx.commit().await.map_err(RepositoryError::database)?;
198 Ok(())
199}
200
201pub async fn update_task_failed_with_error(
202 pool: &Arc<PgPool>,
203 task_id: &systemprompt_identifiers::TaskId,
204 error_message: &str,
205 timestamp: &chrono::DateTime<chrono::Utc>,
206) -> Result<(), RepositoryError> {
207 let task_id_str = task_id.as_str();
208
209 let mut tx = pool.begin().await.map_err(RepositoryError::database)?;
210
211 let current = sqlx::query!(
212 r#"SELECT status, version FROM agent_tasks WHERE task_id = $1 FOR UPDATE"#,
213 task_id_str
214 )
215 .fetch_optional(&mut *tx)
216 .await
217 .map_err(RepositoryError::database)?
218 .ok_or_else(|| RepositoryError::NotFound(format!("task {task_id_str}")))?;
219
220 let current_state: TaskState = current.status.parse().map_err(|e: String| {
221 RepositoryError::InvalidData(format!("unrecognised stored task state: {e}"))
222 })?;
223
224 if current_state == TaskState::Failed {
225 tx.commit().await.map_err(RepositoryError::database)?;
226 return Ok(());
227 }
228
229 if !current_state.can_transition_to(&TaskState::Failed) {
230 return Err(RepositoryError::ConstraintViolation(format!(
231 "invalid task state transition for {task_id_str}: {current_state:?} -> Failed"
232 )));
233 }
234
235 let expected_version = current.version;
236
237 let rows_affected = sqlx::query!(
238 r#"UPDATE agent_tasks SET
239 status = 'TASK_STATE_FAILED',
240 status_timestamp = $1,
241 error_message = $2,
242 updated_at = CURRENT_TIMESTAMP,
243 completed_at = CURRENT_TIMESTAMP,
244 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
245 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000,
246 version = version + 1
247 WHERE task_id = $3 AND version = $4"#,
248 timestamp,
249 error_message,
250 task_id_str,
251 expected_version
252 )
253 .execute(&mut *tx)
254 .await
255 .map_err(RepositoryError::database)?
256 .rows_affected();
257
258 if rows_affected == 0 {
259 return Err(RepositoryError::ConstraintViolation(format!(
260 "stale task update for {task_id_str}: expected version {expected_version}"
261 )));
262 }
263
264 tx.commit().await.map_err(RepositoryError::database)?;
265 Ok(())
266}