Skip to main content

systemprompt_api/routes/agent/
tasks.rs

1use axum::extract::{Path, Query, State};
2use axum::http::StatusCode;
3use axum::response::IntoResponse;
4use axum::{Extension, Json};
5use serde::Deserialize;
6use systemprompt_identifiers::{ContextId, TaskId, UserId};
7
8use systemprompt_agent::models::a2a::TaskState;
9use systemprompt_agent::repository::context::ContextRepository;
10use systemprompt_agent::repository::task::TaskRepository;
11use systemprompt_models::RequestContext;
12use systemprompt_runtime::AppContext;
13
14use crate::error::ApiHttpError;
15
16#[derive(Debug, Deserialize)]
17pub struct TaskFilterParams {
18    pub status: Option<String>,
19    pub limit: Option<u32>,
20}
21
22pub async fn list_tasks_by_context(
23    Extension(req_ctx): Extension<RequestContext>,
24    State(app_context): State<AppContext>,
25    Path(context_id): Path<String>,
26) -> Result<impl IntoResponse, ApiHttpError> {
27    tracing::debug!(context_id = %context_id, "Listing tasks");
28
29    let context_id_typed = ContextId::new(&context_id);
30
31    let context_repo = ContextRepository::new(app_context.db_pool())?;
32    context_repo
33        .validate_context_ownership(&context_id_typed, req_ctx.user_id())
34        .await?;
35
36    let task_repo = TaskRepository::new(app_context.db_pool())?;
37    let tasks = task_repo.list_tasks_by_context(&context_id_typed).await?;
38
39    tracing::debug!(context_id = %context_id, count = %tasks.len(), "Tasks listed");
40    Ok((StatusCode::OK, Json(tasks)))
41}
42
43pub async fn get_task(
44    Extension(req_ctx): Extension<RequestContext>,
45    State(app_context): State<AppContext>,
46    Path(task_id): Path<String>,
47) -> Result<impl IntoResponse, ApiHttpError> {
48    tracing::debug!(task_id = %task_id, "Retrieving task");
49
50    let task_repo = TaskRepository::new(app_context.db_pool())?;
51
52    let task_id_typed = TaskId::new(&task_id);
53    task_repo
54        .validate_task_ownership(&task_id_typed, req_ctx.user_id())
55        .await?;
56
57    let task = task_repo
58        .get_task(&task_id_typed)
59        .await?
60        .ok_or_else(|| ApiHttpError::not_found(format!("Task '{task_id}' not found")))?;
61
62    tracing::debug!("Task retrieved successfully");
63    Ok((StatusCode::OK, Json(task)))
64}
65
66pub async fn list_tasks_by_user(
67    Extension(req_ctx): Extension<RequestContext>,
68    State(app_context): State<AppContext>,
69    Query(params): Query<TaskFilterParams>,
70) -> Result<impl IntoResponse, ApiHttpError> {
71    let user_id = req_ctx.auth.actor.user_id.as_str();
72
73    tracing::debug!(user_id = %user_id, "Listing tasks");
74
75    let task_repo = TaskRepository::new(app_context.db_pool())?;
76
77    let task_state = params.status.as_ref().and_then(|s| match s.as_str() {
78        "submitted" => Some(TaskState::Submitted),
79        "working" => Some(TaskState::Working),
80        "input-required" => Some(TaskState::InputRequired),
81        "completed" => Some(TaskState::Completed),
82        "canceled" | "cancelled" => Some(TaskState::Canceled),
83        "failed" => Some(TaskState::Failed),
84        "rejected" => Some(TaskState::Rejected),
85        "auth-required" => Some(TaskState::AuthRequired),
86        _ => None,
87    });
88
89    let user_id_typed = UserId::new(user_id);
90    let mut tasks = task_repo
91        .get_tasks_by_user_id(&user_id_typed, params.limit.map(|l| l as i32), None)
92        .await?;
93
94    if let Some(state) = task_state {
95        tasks.retain(|t| t.status.state == state);
96    }
97
98    tracing::debug!(user_id = %user_id, count = %tasks.len(), "Tasks listed");
99    Ok((StatusCode::OK, Json(tasks)))
100}
101
102pub async fn get_messages_by_task(
103    Extension(req_ctx): Extension<RequestContext>,
104    State(app_context): State<AppContext>,
105    Path(task_id): Path<String>,
106) -> Result<impl IntoResponse, ApiHttpError> {
107    tracing::debug!(task_id = %task_id, "Retrieving messages");
108
109    let task_repo = TaskRepository::new(app_context.db_pool())?;
110
111    let task_id_typed = TaskId::new(&task_id);
112    task_repo
113        .validate_task_ownership(&task_id_typed, req_ctx.user_id())
114        .await?;
115
116    let messages = task_repo.get_messages_by_task(&task_id_typed).await?;
117
118    tracing::debug!(task_id = %task_id, count = %messages.len(), "Messages retrieved");
119    Ok((StatusCode::OK, Json(messages)))
120}
121
122pub async fn delete_task(
123    Extension(req_ctx): Extension<RequestContext>,
124    State(app_context): State<AppContext>,
125    Path(task_id): Path<String>,
126) -> Result<impl IntoResponse, ApiHttpError> {
127    tracing::debug!(task_id = %task_id, "Deleting task");
128
129    let task_repo = TaskRepository::new(app_context.db_pool())?;
130
131    let task_id_typed = TaskId::new(&task_id);
132    task_repo
133        .validate_task_ownership(&task_id_typed, req_ctx.user_id())
134        .await?;
135
136    task_repo.delete_task(&task_id_typed).await?;
137
138    tracing::debug!(task_id = %task_id, "Task deleted");
139    Ok(StatusCode::NO_CONTENT)
140}