Skip to main content

systemprompt_agent/repository/task/
queries.rs

1use crate::models::TaskRow;
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
6use systemprompt_traits::RepositoryError;
7
8use super::constructor::TaskConstructor;
9use crate::models::a2a::Task;
10
11pub async fn get_task(
12    pool: &Arc<PgPool>,
13    db_pool: &DbPool,
14    task_id: &TaskId,
15) -> Result<Option<Task>, RepositoryError> {
16    let task_id_str = task_id.as_str();
17    let row = sqlx::query_as!(
18        TaskRow,
19        r#"SELECT
20            task_id as "task_id!: TaskId",
21            context_id as "context_id!: ContextId",
22            status as "status!",
23            status_timestamp,
24            user_id as "user_id?: UserId",
25            session_id as "session_id?: SessionId",
26            trace_id as "trace_id?: TraceId",
27            agent_name as "agent_name?: AgentName",
28            started_at,
29            completed_at,
30            execution_time_ms,
31            error_message,
32            metadata,
33            created_at as "created_at!",
34            updated_at as "updated_at!"
35        FROM agent_tasks WHERE task_id = $1"#,
36        task_id_str
37    )
38    .fetch_optional(pool.as_ref())
39    .await
40    .map_err(RepositoryError::database)?;
41
42    let Some(_row) = row else {
43        return Ok(None);
44    };
45
46    let constructor = TaskConstructor::new(db_pool)?;
47    let task = constructor.construct_task_from_task_id(task_id).await?;
48
49    Ok(Some(task))
50}
51
52pub async fn list_tasks_by_context(
53    pool: &Arc<PgPool>,
54    db_pool: &DbPool,
55    context_id: &ContextId,
56) -> Result<Vec<Task>, RepositoryError> {
57    let context_id_str = context_id.as_str();
58    let rows = sqlx::query_as!(
59        TaskRow,
60        r#"SELECT
61            task_id as "task_id!: TaskId",
62            context_id as "context_id!: ContextId",
63            status as "status!",
64            status_timestamp,
65            user_id as "user_id?: UserId",
66            session_id as "session_id?: SessionId",
67            trace_id as "trace_id?: TraceId",
68            agent_name as "agent_name?: AgentName",
69            started_at,
70            completed_at,
71            execution_time_ms,
72            error_message,
73            metadata,
74            created_at as "created_at!",
75            updated_at as "updated_at!"
76        FROM agent_tasks WHERE context_id = $1 ORDER BY created_at ASC"#,
77        context_id_str
78    )
79    .fetch_all(pool.as_ref())
80    .await
81    .map_err(RepositoryError::database)?;
82
83    let constructor = TaskConstructor::new(db_pool)?;
84    let task_ids: Vec<TaskId> = rows.iter().map(|r| r.task_id.clone()).collect();
85    let tasks = constructor.construct_tasks_batch(&task_ids).await?;
86
87    Ok(tasks)
88}
89
90pub async fn get_tasks_by_user_id(
91    pool: &Arc<PgPool>,
92    db_pool: &DbPool,
93    user_id: &UserId,
94    limit: Option<i32>,
95    offset: Option<i32>,
96) -> Result<Vec<Task>, RepositoryError> {
97    let lim = limit.map_or(1000, i64::from);
98    let off = offset.map_or(0, i64::from);
99    let user_id_str = user_id.as_str();
100
101    let rows = sqlx::query_as!(
102        TaskRow,
103        r#"SELECT
104            task_id as "task_id!: TaskId",
105            context_id as "context_id!: ContextId",
106            status as "status!",
107            status_timestamp,
108            user_id as "user_id?: UserId",
109            session_id as "session_id?: SessionId",
110            trace_id as "trace_id?: TraceId",
111            agent_name as "agent_name?: AgentName",
112            started_at,
113            completed_at,
114            execution_time_ms,
115            error_message,
116            metadata,
117            created_at as "created_at!",
118            updated_at as "updated_at!"
119        FROM agent_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
120        user_id_str,
121        lim,
122        off
123    )
124    .fetch_all(pool.as_ref())
125    .await
126    .map_err(RepositoryError::database)?;
127
128    let constructor = TaskConstructor::new(db_pool)?;
129    let task_ids: Vec<TaskId> = rows.iter().map(|r| r.task_id.clone()).collect();
130    let tasks = constructor.construct_tasks_batch(&task_ids).await?;
131
132    Ok(tasks)
133}
134
135#[derive(Debug, Clone)]
136pub struct TaskContextInfo {
137    pub context_id: ContextId,
138    pub user_id: UserId,
139}
140
141pub async fn get_task_context_info(
142    pool: &Arc<PgPool>,
143    task_id: &TaskId,
144) -> Result<Option<TaskContextInfo>, RepositoryError> {
145    let task_id_str = task_id.as_str();
146    let row = sqlx::query!(
147        r#"SELECT
148            context_id as "context_id!: ContextId",
149            user_id as "user_id?: UserId"
150        FROM agent_tasks WHERE task_id = $1"#,
151        task_id_str
152    )
153    .fetch_optional(pool.as_ref())
154    .await
155    .map_err(RepositoryError::database)?;
156
157    Ok(row.map(|r| TaskContextInfo {
158        context_id: r.context_id,
159        user_id: r.user_id.unwrap_or_else(|| UserId::new("")),
160    }))
161}