Skip to main content

systemprompt_agent/repository/task/
mutations.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8    match state {
9        TaskState::Pending => "submitted",
10        TaskState::Submitted => "submitted",
11        TaskState::Working => "working",
12        TaskState::InputRequired => "input-required",
13        TaskState::Completed => "completed",
14        TaskState::Canceled => "canceled",
15        TaskState::Failed => "failed",
16        TaskState::Rejected => "rejected",
17        TaskState::AuthRequired => "auth-required",
18        TaskState::Unknown => "unknown",
19    }
20}
21
22pub async fn create_task(
23    pool: &Arc<PgPool>,
24    task: &Task,
25    user_id: &systemprompt_identifiers::UserId,
26    session_id: &systemprompt_identifiers::SessionId,
27    trace_id: &systemprompt_identifiers::TraceId,
28    agent_name: &str,
29) -> Result<String, RepositoryError> {
30    let metadata_json = task
31        .metadata
32        .as_ref()
33        .map(|m| {
34            serde_json::to_value(m).unwrap_or_else(|e| {
35                tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
36                serde_json::json!({})
37            })
38        })
39        .unwrap_or_else(|| serde_json::json!({}));
40
41    let status = task_state_to_db_string(task.status.state.clone());
42    let task_id_str = task.id.as_str();
43    let context_id_str = task.context_id.as_str();
44    let user_id_str = user_id.as_ref();
45    let session_id_str = session_id.as_ref();
46    let trace_id_str = trace_id.as_ref();
47
48    sqlx::query!(
49        r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
50        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
51        task_id_str,
52        context_id_str,
53        status,
54        task.status.timestamp,
55        user_id_str,
56        session_id_str,
57        trace_id_str,
58        metadata_json,
59        agent_name
60    )
61    .execute(pool.as_ref())
62    .await
63    .map_err(|e| RepositoryError::database(e))?;
64
65    Ok(task.id.to_string())
66}
67
68pub async fn track_agent_in_context(
69    pool: &Arc<PgPool>,
70    context_id: &systemprompt_identifiers::ContextId,
71    agent_name: &str,
72) -> Result<(), RepositoryError> {
73    let context_id_str = context_id.as_str();
74    sqlx::query!(
75        r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
76        ON CONFLICT (context_id, agent_name) DO NOTHING"#,
77        context_id_str,
78        agent_name
79    )
80    .execute(pool.as_ref())
81    .await
82    .map_err(|e| RepositoryError::database(e))?;
83
84    Ok(())
85}
86
87pub async fn update_task_state(
88    pool: &Arc<PgPool>,
89    task_id: &systemprompt_identifiers::TaskId,
90    state: TaskState,
91    timestamp: &chrono::DateTime<chrono::Utc>,
92) -> Result<(), RepositoryError> {
93    let status = task_state_to_db_string(state);
94    let task_id_str = task_id.as_str();
95
96    if state == TaskState::Completed {
97        sqlx::query!(
98            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
99            completed_at = CURRENT_TIMESTAMP,
100            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
101            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
102            WHERE task_id = $3"#,
103            status,
104            timestamp,
105            task_id_str
106        )
107        .execute(pool.as_ref())
108        .await
109        .map_err(|e| RepositoryError::database(e))?;
110    } else if state == TaskState::Working {
111        sqlx::query!(
112            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
113            started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
114            WHERE task_id = $3"#,
115            status,
116            timestamp,
117            task_id_str
118        )
119        .execute(pool.as_ref())
120        .await
121        .map_err(|e| RepositoryError::database(e))?;
122    } else {
123        sqlx::query!(
124            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
125            status,
126            timestamp,
127            task_id_str
128        )
129        .execute(pool.as_ref())
130        .await
131        .map_err(|e| RepositoryError::database(e))?;
132    }
133
134    Ok(())
135}
136
137pub async fn update_task_failed_with_error(
138    pool: &Arc<PgPool>,
139    task_id: &systemprompt_identifiers::TaskId,
140    error_message: &str,
141    timestamp: &chrono::DateTime<chrono::Utc>,
142) -> Result<(), RepositoryError> {
143    let task_id_str = task_id.as_str();
144
145    sqlx::query!(
146        r#"UPDATE agent_tasks SET
147            status = 'failed',
148            status_timestamp = $1,
149            error_message = $2,
150            updated_at = CURRENT_TIMESTAMP,
151            completed_at = CURRENT_TIMESTAMP,
152            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
153            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
154        WHERE task_id = $3"#,
155        timestamp,
156        error_message,
157        task_id_str
158    )
159    .execute(pool.as_ref())
160    .await
161    .map_err(|e| RepositoryError::database(e))?;
162
163    Ok(())
164}