systemprompt_agent/repository/task/
mutations.rs1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8 match state {
9 TaskState::Pending | TaskState::Submitted => "submitted",
10 TaskState::Working => "working",
11 TaskState::InputRequired => "input-required",
12 TaskState::Completed => "completed",
13 TaskState::Canceled => "canceled",
14 TaskState::Failed => "failed",
15 TaskState::Rejected => "rejected",
16 TaskState::AuthRequired => "auth-required",
17 TaskState::Unknown => "unknown",
18 }
19}
20
21#[allow(missing_debug_implementations)]
22pub struct CreateTaskParams<'a> {
23 pub pool: &'a Arc<PgPool>,
24 pub task: &'a Task,
25 pub user_id: &'a systemprompt_identifiers::UserId,
26 pub session_id: &'a systemprompt_identifiers::SessionId,
27 pub trace_id: &'a systemprompt_identifiers::TraceId,
28 pub agent_name: &'a str,
29}
30
31pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
32 let CreateTaskParams {
33 pool,
34 task,
35 user_id,
36 session_id,
37 trace_id,
38 agent_name,
39 } = params;
40 let metadata_json = task.metadata.as_ref().map_or_else(
41 || serde_json::json!({}),
42 |m| {
43 serde_json::to_value(m).unwrap_or_else(|e| {
44 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
45 serde_json::json!({})
46 })
47 },
48 );
49
50 let status = task_state_to_db_string(task.status.state);
51 let task_id_str = task.id.as_str();
52 let context_id_str = task.context_id.as_str();
53 let user_id_str = user_id.as_ref();
54 let session_id_str = session_id.as_ref();
55 let trace_id_str = trace_id.as_ref();
56
57 sqlx::query!(
58 r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
59 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
60 task_id_str,
61 context_id_str,
62 status,
63 task.status.timestamp,
64 user_id_str,
65 session_id_str,
66 trace_id_str,
67 metadata_json,
68 agent_name
69 )
70 .execute(pool.as_ref())
71 .await
72 .map_err(RepositoryError::database)?;
73
74 Ok(task.id.to_string())
75}
76
77pub async fn track_agent_in_context(
78 pool: &Arc<PgPool>,
79 context_id: &systemprompt_identifiers::ContextId,
80 agent_name: &str,
81) -> Result<(), RepositoryError> {
82 let context_id_str = context_id.as_str();
83 sqlx::query!(
84 r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
85 ON CONFLICT (context_id, agent_name) DO NOTHING"#,
86 context_id_str,
87 agent_name
88 )
89 .execute(pool.as_ref())
90 .await
91 .map_err(RepositoryError::database)?;
92
93 Ok(())
94}
95
96pub async fn update_task_state(
97 pool: &Arc<PgPool>,
98 task_id: &systemprompt_identifiers::TaskId,
99 state: TaskState,
100 timestamp: &chrono::DateTime<chrono::Utc>,
101) -> Result<(), RepositoryError> {
102 let status = task_state_to_db_string(state);
103 let task_id_str = task_id.as_str();
104
105 if state == TaskState::Completed {
106 sqlx::query!(
107 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
108 completed_at = CURRENT_TIMESTAMP,
109 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
110 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
111 WHERE task_id = $3"#,
112 status,
113 timestamp,
114 task_id_str
115 )
116 .execute(pool.as_ref())
117 .await
118 .map_err(RepositoryError::database)?;
119 } else if state == TaskState::Working {
120 sqlx::query!(
121 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
122 started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
123 WHERE task_id = $3"#,
124 status,
125 timestamp,
126 task_id_str
127 )
128 .execute(pool.as_ref())
129 .await
130 .map_err(RepositoryError::database)?;
131 } else {
132 sqlx::query!(
133 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
134 status,
135 timestamp,
136 task_id_str
137 )
138 .execute(pool.as_ref())
139 .await
140 .map_err(RepositoryError::database)?;
141 }
142
143 Ok(())
144}
145
146pub async fn update_task_failed_with_error(
147 pool: &Arc<PgPool>,
148 task_id: &systemprompt_identifiers::TaskId,
149 error_message: &str,
150 timestamp: &chrono::DateTime<chrono::Utc>,
151) -> Result<(), RepositoryError> {
152 let task_id_str = task_id.as_str();
153
154 sqlx::query!(
155 r#"UPDATE agent_tasks SET
156 status = 'failed',
157 status_timestamp = $1,
158 error_message = $2,
159 updated_at = CURRENT_TIMESTAMP,
160 completed_at = CURRENT_TIMESTAMP,
161 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
162 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
163 WHERE task_id = $3"#,
164 timestamp,
165 error_message,
166 task_id_str
167 )
168 .execute(pool.as_ref())
169 .await
170 .map_err(RepositoryError::database)?;
171
172 Ok(())
173}