Skip to main content

systemprompt_agent/repository/task/
mutations.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8    match state {
9        TaskState::Pending | TaskState::Submitted => "submitted",
10        TaskState::Working => "working",
11        TaskState::InputRequired => "input-required",
12        TaskState::Completed => "completed",
13        TaskState::Canceled => "canceled",
14        TaskState::Failed => "failed",
15        TaskState::Rejected => "rejected",
16        TaskState::AuthRequired => "auth-required",
17        TaskState::Unknown => "unknown",
18    }
19}
20
21#[allow(missing_debug_implementations)]
22pub struct CreateTaskParams<'a> {
23    pub pool: &'a Arc<PgPool>,
24    pub task: &'a Task,
25    pub user_id: &'a systemprompt_identifiers::UserId,
26    pub session_id: &'a systemprompt_identifiers::SessionId,
27    pub trace_id: &'a systemprompt_identifiers::TraceId,
28    pub agent_name: &'a str,
29}
30
31pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
32    let CreateTaskParams {
33        pool,
34        task,
35        user_id,
36        session_id,
37        trace_id,
38        agent_name,
39    } = params;
40    let metadata_json = task.metadata.as_ref().map_or_else(
41        || serde_json::json!({}),
42        |m| {
43            serde_json::to_value(m).unwrap_or_else(|e| {
44                tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
45                serde_json::json!({})
46            })
47        },
48    );
49
50    let status = task_state_to_db_string(task.status.state);
51    let task_id_str = task.id.as_str();
52    let context_id_str = task.context_id.as_str();
53    let user_id_str = user_id.as_ref();
54    let session_id_str = session_id.as_ref();
55    let trace_id_str = trace_id.as_ref();
56
57    sqlx::query!(
58        r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
59        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
60        task_id_str,
61        context_id_str,
62        status,
63        task.status.timestamp,
64        user_id_str,
65        session_id_str,
66        trace_id_str,
67        metadata_json,
68        agent_name
69    )
70    .execute(pool.as_ref())
71    .await
72    .map_err(RepositoryError::database)?;
73
74    Ok(task.id.to_string())
75}
76
77pub async fn track_agent_in_context(
78    pool: &Arc<PgPool>,
79    context_id: &systemprompt_identifiers::ContextId,
80    agent_name: &str,
81) -> Result<(), RepositoryError> {
82    let context_id_str = context_id.as_str();
83    sqlx::query!(
84        r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
85        ON CONFLICT (context_id, agent_name) DO NOTHING"#,
86        context_id_str,
87        agent_name
88    )
89    .execute(pool.as_ref())
90    .await
91    .map_err(RepositoryError::database)?;
92
93    Ok(())
94}
95
96pub async fn update_task_state(
97    pool: &Arc<PgPool>,
98    task_id: &systemprompt_identifiers::TaskId,
99    state: TaskState,
100    timestamp: &chrono::DateTime<chrono::Utc>,
101) -> Result<(), RepositoryError> {
102    let status = task_state_to_db_string(state);
103    let task_id_str = task_id.as_str();
104
105    if state == TaskState::Completed {
106        sqlx::query!(
107            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
108            completed_at = CURRENT_TIMESTAMP,
109            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
110            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
111            WHERE task_id = $3"#,
112            status,
113            timestamp,
114            task_id_str
115        )
116        .execute(pool.as_ref())
117        .await
118        .map_err(RepositoryError::database)?;
119    } else if state == TaskState::Working {
120        sqlx::query!(
121            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
122            started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
123            WHERE task_id = $3"#,
124            status,
125            timestamp,
126            task_id_str
127        )
128        .execute(pool.as_ref())
129        .await
130        .map_err(RepositoryError::database)?;
131    } else {
132        sqlx::query!(
133            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
134            status,
135            timestamp,
136            task_id_str
137        )
138        .execute(pool.as_ref())
139        .await
140        .map_err(RepositoryError::database)?;
141    }
142
143    Ok(())
144}
145
146pub async fn update_task_failed_with_error(
147    pool: &Arc<PgPool>,
148    task_id: &systemprompt_identifiers::TaskId,
149    error_message: &str,
150    timestamp: &chrono::DateTime<chrono::Utc>,
151) -> Result<(), RepositoryError> {
152    let task_id_str = task_id.as_str();
153
154    sqlx::query!(
155        r#"UPDATE agent_tasks SET
156            status = 'failed',
157            status_timestamp = $1,
158            error_message = $2,
159            updated_at = CURRENT_TIMESTAMP,
160            completed_at = CURRENT_TIMESTAMP,
161            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
162            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
163        WHERE task_id = $3"#,
164        timestamp,
165        error_message,
166        task_id_str
167    )
168    .execute(pool.as_ref())
169    .await
170    .map_err(RepositoryError::database)?;
171
172    Ok(())
173}