systemprompt_agent/repository/task/
mutations.rs1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8 match state {
9 TaskState::Pending => "submitted",
10 TaskState::Submitted => "submitted",
11 TaskState::Working => "working",
12 TaskState::InputRequired => "input-required",
13 TaskState::Completed => "completed",
14 TaskState::Canceled => "canceled",
15 TaskState::Failed => "failed",
16 TaskState::Rejected => "rejected",
17 TaskState::AuthRequired => "auth-required",
18 TaskState::Unknown => "unknown",
19 }
20}
21
22pub async fn create_task(
23 pool: &Arc<PgPool>,
24 task: &Task,
25 user_id: &systemprompt_identifiers::UserId,
26 session_id: &systemprompt_identifiers::SessionId,
27 trace_id: &systemprompt_identifiers::TraceId,
28 agent_name: &str,
29) -> Result<String, RepositoryError> {
30 let metadata_json = task
31 .metadata
32 .as_ref()
33 .map(|m| {
34 serde_json::to_value(m).unwrap_or_else(|e| {
35 tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
36 serde_json::json!({})
37 })
38 })
39 .unwrap_or_else(|| serde_json::json!({}));
40
41 let status = task_state_to_db_string(task.status.state.clone());
42 let task_id_str = task.id.as_str();
43 let context_id_str = task.context_id.as_str();
44 let user_id_str = user_id.as_ref();
45 let session_id_str = session_id.as_ref();
46 let trace_id_str = trace_id.as_ref();
47
48 sqlx::query!(
49 r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
50 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
51 task_id_str,
52 context_id_str,
53 status,
54 task.status.timestamp,
55 user_id_str,
56 session_id_str,
57 trace_id_str,
58 metadata_json,
59 agent_name
60 )
61 .execute(pool.as_ref())
62 .await
63 .map_err(|e| RepositoryError::database(e))?;
64
65 Ok(task.id.to_string())
66}
67
68pub async fn track_agent_in_context(
69 pool: &Arc<PgPool>,
70 context_id: &systemprompt_identifiers::ContextId,
71 agent_name: &str,
72) -> Result<(), RepositoryError> {
73 let context_id_str = context_id.as_str();
74 sqlx::query!(
75 r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
76 ON CONFLICT (context_id, agent_name) DO NOTHING"#,
77 context_id_str,
78 agent_name
79 )
80 .execute(pool.as_ref())
81 .await
82 .map_err(|e| RepositoryError::database(e))?;
83
84 Ok(())
85}
86
87pub async fn update_task_state(
88 pool: &Arc<PgPool>,
89 task_id: &systemprompt_identifiers::TaskId,
90 state: TaskState,
91 timestamp: &chrono::DateTime<chrono::Utc>,
92) -> Result<(), RepositoryError> {
93 let status = task_state_to_db_string(state);
94 let task_id_str = task_id.as_str();
95
96 if state == TaskState::Completed {
97 sqlx::query!(
98 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
99 completed_at = CURRENT_TIMESTAMP,
100 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
101 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
102 WHERE task_id = $3"#,
103 status,
104 timestamp,
105 task_id_str
106 )
107 .execute(pool.as_ref())
108 .await
109 .map_err(|e| RepositoryError::database(e))?;
110 } else if state == TaskState::Working {
111 sqlx::query!(
112 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
113 started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
114 WHERE task_id = $3"#,
115 status,
116 timestamp,
117 task_id_str
118 )
119 .execute(pool.as_ref())
120 .await
121 .map_err(|e| RepositoryError::database(e))?;
122 } else {
123 sqlx::query!(
124 r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
125 status,
126 timestamp,
127 task_id_str
128 )
129 .execute(pool.as_ref())
130 .await
131 .map_err(|e| RepositoryError::database(e))?;
132 }
133
134 Ok(())
135}
136
137pub async fn update_task_failed_with_error(
138 pool: &Arc<PgPool>,
139 task_id: &systemprompt_identifiers::TaskId,
140 error_message: &str,
141 timestamp: &chrono::DateTime<chrono::Utc>,
142) -> Result<(), RepositoryError> {
143 let task_id_str = task_id.as_str();
144
145 sqlx::query!(
146 r#"UPDATE agent_tasks SET
147 status = 'failed',
148 status_timestamp = $1,
149 error_message = $2,
150 updated_at = CURRENT_TIMESTAMP,
151 completed_at = CURRENT_TIMESTAMP,
152 started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
153 execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
154 WHERE task_id = $3"#,
155 timestamp,
156 error_message,
157 task_id_str
158 )
159 .execute(pool.as_ref())
160 .await
161 .map_err(|e| RepositoryError::database(e))?;
162
163 Ok(())
164}