Skip to main content

systemprompt_agent/repository/task/
queries.rs

1use crate::models::TaskRow;
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
6use systemprompt_traits::RepositoryError;
7
8use super::constructor::TaskConstructor;
9use crate::models::a2a::Task;
10
11pub async fn get_task(
12    pool: &Arc<PgPool>,
13    db_pool: &DbPool,
14    task_id: &TaskId,
15) -> Result<Option<Task>, RepositoryError> {
16    let task_id_str = task_id.as_str();
17    let row = sqlx::query_as!(
18        TaskRow,
19        r#"SELECT
20            task_id as "task_id!: TaskId",
21            context_id as "context_id!: ContextId",
22            status as "status!",
23            status_timestamp,
24            user_id as "user_id?: UserId",
25            session_id as "session_id?: SessionId",
26            trace_id as "trace_id?: TraceId",
27            agent_name as "agent_name?: AgentName",
28            started_at,
29            completed_at,
30            execution_time_ms,
31            error_message,
32            metadata,
33            created_at as "created_at!",
34            updated_at as "updated_at!"
35        FROM agent_tasks WHERE task_id = $1"#,
36        task_id_str
37    )
38    .fetch_optional(pool.as_ref())
39    .await
40    .map_err(|e| RepositoryError::database(e))?;
41
42    let Some(_row) = row else {
43        return Ok(None);
44    };
45
46    let constructor = TaskConstructor::new(db_pool.clone());
47    let task = constructor.construct_task_from_task_id(task_id).await?;
48
49    Ok(Some(task))
50}
51
52pub async fn list_tasks_by_context(
53    pool: &Arc<PgPool>,
54    db_pool: &DbPool,
55    context_id: &ContextId,
56) -> Result<Vec<Task>, RepositoryError> {
57    let context_id_str = context_id.as_str();
58    let rows = sqlx::query_as!(
59        TaskRow,
60        r#"SELECT
61            task_id as "task_id!: TaskId",
62            context_id as "context_id!: ContextId",
63            status as "status!",
64            status_timestamp,
65            user_id as "user_id?: UserId",
66            session_id as "session_id?: SessionId",
67            trace_id as "trace_id?: TraceId",
68            agent_name as "agent_name?: AgentName",
69            started_at,
70            completed_at,
71            execution_time_ms,
72            error_message,
73            metadata,
74            created_at as "created_at!",
75            updated_at as "updated_at!"
76        FROM agent_tasks WHERE context_id = $1 ORDER BY created_at ASC"#,
77        context_id_str
78    )
79    .fetch_all(pool.as_ref())
80    .await
81    .map_err(|e| RepositoryError::database(e))?;
82
83    let constructor = TaskConstructor::new(db_pool.clone());
84    let mut tasks = Vec::new();
85
86    for row in rows {
87        tasks.push(
88            constructor
89                .construct_task_from_task_id(&row.task_id)
90                .await?,
91        );
92    }
93
94    Ok(tasks)
95}
96
97pub async fn get_tasks_by_user_id(
98    pool: &Arc<PgPool>,
99    db_pool: &DbPool,
100    user_id: &UserId,
101    limit: Option<i32>,
102    offset: Option<i32>,
103) -> Result<Vec<Task>, RepositoryError> {
104    let lim = limit.map(i64::from).unwrap_or(1000);
105    let off = offset.map(i64::from).unwrap_or(0);
106    let user_id_str = user_id.as_str();
107
108    let rows = sqlx::query_as!(
109        TaskRow,
110        r#"SELECT
111            task_id as "task_id!: TaskId",
112            context_id as "context_id!: ContextId",
113            status as "status!",
114            status_timestamp,
115            user_id as "user_id?: UserId",
116            session_id as "session_id?: SessionId",
117            trace_id as "trace_id?: TraceId",
118            agent_name as "agent_name?: AgentName",
119            started_at,
120            completed_at,
121            execution_time_ms,
122            error_message,
123            metadata,
124            created_at as "created_at!",
125            updated_at as "updated_at!"
126        FROM agent_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
127        user_id_str,
128        lim,
129        off
130    )
131    .fetch_all(pool.as_ref())
132    .await
133    .map_err(|e| RepositoryError::database(e))?;
134
135    let constructor = TaskConstructor::new(db_pool.clone());
136    let mut tasks = Vec::new();
137
138    for row in &rows {
139        tasks.push(
140            constructor
141                .construct_task_from_task_id(&row.task_id)
142                .await?,
143        );
144    }
145
146    Ok(tasks)
147}
148
149#[derive(Debug, Clone)]
150pub struct TaskContextInfo {
151    pub context_id: String,
152    pub user_id: String,
153}
154
155impl TaskContextInfo {
156    pub fn context_id(&self) -> ContextId {
157        ContextId::new(&self.context_id)
158    }
159
160    pub fn user_id(&self) -> UserId {
161        UserId::new(&self.user_id)
162    }
163}
164
165pub async fn get_task_context_info(
166    pool: &Arc<PgPool>,
167    task_id: &TaskId,
168) -> Result<Option<TaskContextInfo>, RepositoryError> {
169    let task_id_str = task_id.as_str();
170    let row = sqlx::query!(
171        r#"SELECT
172            context_id as "context_id!",
173            user_id
174        FROM agent_tasks WHERE task_id = $1"#,
175        task_id_str
176    )
177    .fetch_optional(pool.as_ref())
178    .await
179    .map_err(|e| RepositoryError::database(e))?;
180
181    Ok(row.map(|r| TaskContextInfo {
182        context_id: r.context_id,
183        user_id: r.user_id.unwrap_or_else(String::new),
184    }))
185}