Skip to main content

systemprompt_agent/repository/task/
mutations.rs

1use sqlx::PgPool;
2use std::sync::Arc;
3use systemprompt_traits::RepositoryError;
4
5use crate::models::a2a::{Task, TaskState};
6
7pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
8    match state {
9        TaskState::Pending => "TASK_STATE_PENDING",
10        TaskState::Submitted => "TASK_STATE_SUBMITTED",
11        TaskState::Working => "TASK_STATE_WORKING",
12        TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
13        TaskState::Completed => "TASK_STATE_COMPLETED",
14        TaskState::Canceled => "TASK_STATE_CANCELED",
15        TaskState::Failed => "TASK_STATE_FAILED",
16        TaskState::Rejected => "TASK_STATE_REJECTED",
17        TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
18        TaskState::Unknown => "TASK_STATE_UNKNOWN",
19    }
20}
21
22#[allow(missing_debug_implementations)]
23pub struct CreateTaskParams<'a> {
24    pub pool: &'a Arc<PgPool>,
25    pub task: &'a Task,
26    pub user_id: &'a systemprompt_identifiers::UserId,
27    pub session_id: &'a systemprompt_identifiers::SessionId,
28    pub trace_id: &'a systemprompt_identifiers::TraceId,
29    pub agent_name: &'a str,
30}
31
32pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
33    let CreateTaskParams {
34        pool,
35        task,
36        user_id,
37        session_id,
38        trace_id,
39        agent_name,
40    } = params;
41    let metadata_json = task.metadata.as_ref().map_or_else(
42        || serde_json::json!({}),
43        |m| {
44            serde_json::to_value(m).unwrap_or_else(|e| {
45                tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
46                serde_json::json!({})
47            })
48        },
49    );
50
51    let status = task_state_to_db_string(task.status.state);
52    let task_id_str = task.id.as_str();
53    let context_id_str = task.context_id.as_str();
54    let user_id_str = user_id.as_ref();
55    let session_id_str = session_id.as_ref();
56    let trace_id_str = trace_id.as_ref();
57
58    sqlx::query!(
59        r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
60        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
61        task_id_str,
62        context_id_str,
63        status,
64        task.status.timestamp,
65        user_id_str,
66        session_id_str,
67        trace_id_str,
68        metadata_json,
69        agent_name
70    )
71    .execute(pool.as_ref())
72    .await
73    .map_err(RepositoryError::database)?;
74
75    Ok(task.id.to_string())
76}
77
78pub async fn track_agent_in_context(
79    pool: &Arc<PgPool>,
80    context_id: &systemprompt_identifiers::ContextId,
81    agent_name: &str,
82) -> Result<(), RepositoryError> {
83    let context_id_str = context_id.as_str();
84    sqlx::query!(
85        r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
86        ON CONFLICT (context_id, agent_name) DO NOTHING"#,
87        context_id_str,
88        agent_name
89    )
90    .execute(pool.as_ref())
91    .await
92    .map_err(RepositoryError::database)?;
93
94    Ok(())
95}
96
97pub async fn update_task_state(
98    pool: &Arc<PgPool>,
99    task_id: &systemprompt_identifiers::TaskId,
100    state: TaskState,
101    timestamp: &chrono::DateTime<chrono::Utc>,
102) -> Result<(), RepositoryError> {
103    let status = task_state_to_db_string(state);
104    let task_id_str = task_id.as_str();
105
106    if state == TaskState::Completed {
107        sqlx::query!(
108            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
109            completed_at = CURRENT_TIMESTAMP,
110            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
111            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
112            WHERE task_id = $3"#,
113            status,
114            timestamp,
115            task_id_str
116        )
117        .execute(pool.as_ref())
118        .await
119        .map_err(RepositoryError::database)?;
120    } else if state == TaskState::Working {
121        sqlx::query!(
122            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
123            started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
124            WHERE task_id = $3"#,
125            status,
126            timestamp,
127            task_id_str
128        )
129        .execute(pool.as_ref())
130        .await
131        .map_err(RepositoryError::database)?;
132    } else {
133        sqlx::query!(
134            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
135            status,
136            timestamp,
137            task_id_str
138        )
139        .execute(pool.as_ref())
140        .await
141        .map_err(RepositoryError::database)?;
142    }
143
144    Ok(())
145}
146
147pub async fn update_task_failed_with_error(
148    pool: &Arc<PgPool>,
149    task_id: &systemprompt_identifiers::TaskId,
150    error_message: &str,
151    timestamp: &chrono::DateTime<chrono::Utc>,
152) -> Result<(), RepositoryError> {
153    let task_id_str = task_id.as_str();
154
155    sqlx::query!(
156        r#"UPDATE agent_tasks SET
157            status = 'TASK_STATE_FAILED',
158            status_timestamp = $1,
159            error_message = $2,
160            updated_at = CURRENT_TIMESTAMP,
161            completed_at = CURRENT_TIMESTAMP,
162            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
163            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
164        WHERE task_id = $3"#,
165        timestamp,
166        error_message,
167        task_id_str
168    )
169    .execute(pool.as_ref())
170    .await
171    .map_err(RepositoryError::database)?;
172
173    Ok(())
174}