Skip to main content

systemprompt_api/routes/agent/
tasks.rs

1use axum::extract::{Path, Query, State};
2use axum::http::StatusCode;
3use axum::response::IntoResponse;
4use axum::{Extension, Json};
5use serde::Deserialize;
6use systemprompt_identifiers::{ContextId, TaskId, UserId};
7use systemprompt_models::api::ApiError;
8
9use systemprompt_agent::models::a2a::TaskState;
10use systemprompt_agent::repository::task::TaskRepository;
11use systemprompt_models::RequestContext;
12use systemprompt_runtime::AppContext;
13
14#[derive(Debug, Deserialize)]
15pub struct TaskFilterParams {
16    status: Option<String>,
17    limit: Option<u32>,
18}
19
20pub async fn list_tasks_by_context(
21    Extension(_req_ctx): Extension<RequestContext>,
22    State(app_context): State<AppContext>,
23    Path(context_id): Path<String>,
24) -> Result<impl IntoResponse, ApiError> {
25    tracing::debug!(context_id = %context_id, "Listing tasks");
26
27    let task_repo = TaskRepository::new(app_context.db_pool())
28        .map_err(|e| ApiError::internal_error(format!("Database error: {e}")))?;
29
30    let context_id_typed = ContextId::new(&context_id);
31    let tasks = task_repo
32        .list_tasks_by_context(&context_id_typed)
33        .await
34        .map_err(|e| {
35            tracing::error!(error = %e, "Failed to list tasks");
36            ApiError::internal_error("Failed to retrieve tasks")
37        })?;
38
39    tracing::debug!(context_id = %context_id, count = %tasks.len(), "Tasks listed");
40    Ok((StatusCode::OK, Json(tasks)))
41}
42
43pub async fn get_task(
44    Extension(_req_ctx): Extension<RequestContext>,
45    State(app_context): State<AppContext>,
46    Path(task_id): Path<String>,
47) -> Result<impl IntoResponse, ApiError> {
48    tracing::debug!(task_id = %task_id, "Retrieving task");
49
50    let task_repo = TaskRepository::new(app_context.db_pool())
51        .map_err(|e| ApiError::internal_error(format!("Database error: {e}")))?;
52
53    let task_id_typed = TaskId::new(&task_id);
54    match task_repo.get_task(&task_id_typed).await {
55        Ok(Some(task)) => {
56            tracing::debug!("Task retrieved successfully");
57            Ok((StatusCode::OK, Json(task)).into_response())
58        },
59        Ok(None) => {
60            tracing::debug!("Task not found");
61            Err(ApiError::not_found(format!("Task '{}' not found", task_id)))
62        },
63        Err(e) => {
64            tracing::error!(error = %e, "Failed to retrieve task");
65            Err(ApiError::internal_error("Failed to retrieve task"))
66        },
67    }
68}
69
70pub async fn list_tasks_by_user(
71    Extension(req_ctx): Extension<RequestContext>,
72    State(app_context): State<AppContext>,
73    Query(params): Query<TaskFilterParams>,
74) -> Result<impl IntoResponse, ApiError> {
75    let user_id = req_ctx.auth.user_id.as_str();
76
77    tracing::debug!(user_id = %user_id, "Listing tasks");
78
79    let task_repo = TaskRepository::new(app_context.db_pool())
80        .map_err(|e| ApiError::internal_error(format!("Database error: {e}")))?;
81
82    let task_state = params.status.as_ref().and_then(|s| match s.as_str() {
83        "submitted" => Some(TaskState::Submitted),
84        "working" => Some(TaskState::Working),
85        "input-required" => Some(TaskState::InputRequired),
86        "completed" => Some(TaskState::Completed),
87        "canceled" | "cancelled" => Some(TaskState::Canceled),
88        "failed" => Some(TaskState::Failed),
89        "rejected" => Some(TaskState::Rejected),
90        "auth-required" => Some(TaskState::AuthRequired),
91        _ => None,
92    });
93
94    let user_id_typed = UserId::new(user_id);
95    let mut tasks = task_repo
96        .get_tasks_by_user_id(&user_id_typed, params.limit.map(|l| l as i32), None)
97        .await
98        .map_err(|e| {
99            tracing::error!(error = %e, "Failed to list tasks");
100            ApiError::internal_error("Failed to retrieve tasks")
101        })?;
102
103    if let Some(state) = task_state {
104        tasks.retain(|t| t.status.state == state);
105    }
106
107    tracing::debug!(user_id = %user_id, count = %tasks.len(), "Tasks listed");
108    Ok((StatusCode::OK, Json(tasks)))
109}
110
111pub async fn get_messages_by_task(
112    Extension(_req_ctx): Extension<RequestContext>,
113    State(app_context): State<AppContext>,
114    Path(task_id): Path<String>,
115) -> Result<impl IntoResponse, ApiError> {
116    tracing::debug!(task_id = %task_id, "Retrieving messages");
117
118    let task_repo = TaskRepository::new(app_context.db_pool())
119        .map_err(|e| ApiError::internal_error(format!("Database error: {e}")))?;
120
121    let task_id_typed = TaskId::new(&task_id);
122    let messages = task_repo
123        .get_messages_by_task(&task_id_typed)
124        .await
125        .map_err(|e| {
126            tracing::error!(error = %e, "Failed to retrieve messages");
127            ApiError::internal_error("Failed to retrieve messages")
128        })?;
129
130    tracing::debug!(task_id = %task_id, count = %messages.len(), "Messages retrieved");
131    Ok((StatusCode::OK, Json(messages)))
132}
133
134pub async fn delete_task(
135    Extension(_req_ctx): Extension<RequestContext>,
136    State(app_context): State<AppContext>,
137    Path(task_id): Path<String>,
138) -> Result<impl IntoResponse, ApiError> {
139    tracing::debug!(task_id = %task_id, "Deleting task");
140
141    let task_repo = TaskRepository::new(app_context.db_pool())
142        .map_err(|e| ApiError::internal_error(format!("Database error: {e}")))?;
143
144    let task_id_typed = TaskId::new(&task_id);
145    task_repo.delete_task(&task_id_typed).await.map_err(|e| {
146        tracing::error!(error = %e, "Failed to delete task");
147        ApiError::internal_error("Failed to delete task")
148    })?;
149
150    tracing::debug!(task_id = %task_id, "Task deleted");
151    Ok(StatusCode::NO_CONTENT)
152}