systemprompt_agent/repository/task/
mutations.rs1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8 match state {
9 TaskState::Pending => "TASK_STATE_PENDING",
10 TaskState::Submitted => "TASK_STATE_SUBMITTED",
11 TaskState::Working => "TASK_STATE_WORKING",
12 TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
13 TaskState::Completed => "TASK_STATE_COMPLETED",
14 TaskState::Canceled => "TASK_STATE_CANCELED",
15 TaskState::Failed => "TASK_STATE_FAILED",
16 TaskState::Rejected => "TASK_STATE_REJECTED",
17 TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
18 TaskState::Unknown => "TASK_STATE_UNKNOWN",
19 }
20}
21
22#[allow(missing_debug_implementations)]
23pub struct CreateTaskParams<'a> {
24 pub pool: &'a Arc<PgPool>,
25 pub task: &'a Task,
26 pub user_id: &'a systemprompt_identifiers::UserId,
27 pub session_id: &'a systemprompt_identifiers::SessionId,
28 pub trace_id: &'a systemprompt_identifiers::TraceId,
29 pub agent_name: &'a str,
30}
31
32pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
33 let CreateTaskParams {
34 pool,
35 task,
36 user_id,
37 session_id,
38 trace_id,
39 agent_name,
40 } = params;
41 let metadata_json = task.metadata.as_ref().map_or_else(
42 || serde_json::json!({}),
43 |m| {
44 serde_json::to_value(m).unwrap_or_else(|e| {
45 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
46 serde_json::json!({})
47 })
48 },
49 );
50
51 let status = task_state_to_db_string(task.status.state);
52 let task_id_str = task.id.as_str();
53 let context_id_str = task.context_id.as_str();
54 let user_id_str = user_id.as_ref();
55 let session_id_str = session_id.as_ref();
56 let trace_id_str = trace_id.as_ref();
57
58 sqlx::query!(
59 r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
61 task_id_str,
62 context_id_str,
63 status,
64 task.status.timestamp,
65 user_id_str,
66 session_id_str,
67 trace_id_str,
68 metadata_json,
69 agent_name
70 )
71 .execute(pool.as_ref())
72 .await
73 .map_err(RepositoryError::database)?;
74
75 Ok(task.id.to_string())
76}
77
78pub async fn track_agent_in_context(
79 pool: &Arc<PgPool>,
80 context_id: &systemprompt_identifiers::ContextId,
81 agent_name: &str,
82) -> Result<(), RepositoryError> {
83 let context_id_str = context_id.as_str();
84 sqlx::query!(
85 r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
86 ON CONFLICT (context_id, agent_name) DO NOTHING"#,
87 context_id_str,
88 agent_name
89 )
90 .execute(pool.as_ref())
91 .await
92 .map_err(RepositoryError::database)?;
93
94 Ok(())
95}
96
97pub async fn update_task_state(
98 pool: &Arc<PgPool>,
99 task_id: &systemprompt_identifiers::TaskId,
100 state: TaskState,
101 timestamp: &chrono::DateTime<chrono::Utc>,
102) -> Result<(), RepositoryError> {
103 let status = task_state_to_db_string(state);
104 let task_id_str = task_id.as_str();
105
106 if state == TaskState::Completed {
107 sqlx::query!(
108 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
109 completed_at = CURRENT_TIMESTAMP,
110 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
111 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
112 WHERE task_id = $3"#,
113 status,
114 timestamp,
115 task_id_str
116 )
117 .execute(pool.as_ref())
118 .await
119 .map_err(RepositoryError::database)?;
120 } else if state == TaskState::Working {
121 sqlx::query!(
122 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
123 started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
124 WHERE task_id = $3"#,
125 status,
126 timestamp,
127 task_id_str
128 )
129 .execute(pool.as_ref())
130 .await
131 .map_err(RepositoryError::database)?;
132 } else {
133 sqlx::query!(
134 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
135 status,
136 timestamp,
137 task_id_str
138 )
139 .execute(pool.as_ref())
140 .await
141 .map_err(RepositoryError::database)?;
142 }
143
144 Ok(())
145}
146
147pub async fn update_task_failed_with_error(
148 pool: &Arc<PgPool>,
149 task_id: &systemprompt_identifiers::TaskId,
150 error_message: &str,
151 timestamp: &chrono::DateTime<chrono::Utc>,
152) -> Result<(), RepositoryError> {
153 let task_id_str = task_id.as_str();
154
155 sqlx::query!(
156 r#"UPDATE agent_tasks SET
157 status = 'TASK_STATE_FAILED',
158 status_timestamp = $1,
159 error_message = $2,
160 updated_at = CURRENT_TIMESTAMP,
161 completed_at = CURRENT_TIMESTAMP,
162 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
163 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
164 WHERE task_id = $3"#,
165 timestamp,
166 error_message,
167 task_id_str
168 )
169 .execute(pool.as_ref())
170 .await
171 .map_err(RepositoryError::database)?;
172
173 Ok(())
174}