systemprompt_agent/repository/task/
queries.rs1use crate::models::TaskRow;
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
6use systemprompt_traits::RepositoryError;
7
8use super::constructor::TaskConstructor;
9use crate::models::a2a::Task;
10
11pub async fn get_task(
12 pool: &Arc<PgPool>,
13 db_pool: &DbPool,
14 task_id: &TaskId,
15) -> Result<Option<Task>, RepositoryError> {
16 let task_id_str = task_id.as_str();
17 let row = sqlx::query_as!(
18 TaskRow,
19 r#"SELECT
20 task_id as "task_id!: TaskId",
21 context_id as "context_id!: ContextId",
22 status as "status!",
23 status_timestamp,
24 user_id as "user_id?: UserId",
25 session_id as "session_id?: SessionId",
26 trace_id as "trace_id?: TraceId",
27 agent_name as "agent_name?: AgentName",
28 started_at,
29 completed_at,
30 execution_time_ms,
31 error_message,
32 metadata,
33 created_at as "created_at!",
34 updated_at as "updated_at!"
35 FROM agent_tasks WHERE task_id = $1"#,
36 task_id_str
37 )
38 .fetch_optional(pool.as_ref())
39 .await
40 .map_err(RepositoryError::database)?;
41
42 let Some(_row) = row else {
43 return Ok(None);
44 };
45
46 let constructor = TaskConstructor::new(db_pool)?;
47 let task = constructor.construct_task_from_task_id(task_id).await?;
48
49 Ok(Some(task))
50}
51
52pub async fn list_tasks_by_context(
53 pool: &Arc<PgPool>,
54 db_pool: &DbPool,
55 context_id: &ContextId,
56) -> Result<Vec<Task>, RepositoryError> {
57 let context_id_str = context_id.as_str();
58 let rows = sqlx::query_as!(
59 TaskRow,
60 r#"SELECT
61 task_id as "task_id!: TaskId",
62 context_id as "context_id!: ContextId",
63 status as "status!",
64 status_timestamp,
65 user_id as "user_id?: UserId",
66 session_id as "session_id?: SessionId",
67 trace_id as "trace_id?: TraceId",
68 agent_name as "agent_name?: AgentName",
69 started_at,
70 completed_at,
71 execution_time_ms,
72 error_message,
73 metadata,
74 created_at as "created_at!",
75 updated_at as "updated_at!"
76 FROM agent_tasks WHERE context_id = $1 ORDER BY created_at ASC"#,
77 context_id_str
78 )
79 .fetch_all(pool.as_ref())
80 .await
81 .map_err(RepositoryError::database)?;
82
83 let constructor = TaskConstructor::new(db_pool)?;
84 let task_ids: Vec<TaskId> = rows.iter().map(|r| r.task_id.clone()).collect();
85 let tasks = constructor.construct_tasks_batch(&task_ids).await?;
86
87 Ok(tasks)
88}
89
90pub async fn get_tasks_by_user_id(
91 pool: &Arc<PgPool>,
92 db_pool: &DbPool,
93 user_id: &UserId,
94 limit: Option<i32>,
95 offset: Option<i32>,
96) -> Result<Vec<Task>, RepositoryError> {
97 let lim = limit.map_or(1000, i64::from);
98 let off = offset.map_or(0, i64::from);
99 let user_id_str = user_id.as_str();
100
101 let rows = sqlx::query_as!(
102 TaskRow,
103 r#"SELECT
104 task_id as "task_id!: TaskId",
105 context_id as "context_id!: ContextId",
106 status as "status!",
107 status_timestamp,
108 user_id as "user_id?: UserId",
109 session_id as "session_id?: SessionId",
110 trace_id as "trace_id?: TraceId",
111 agent_name as "agent_name?: AgentName",
112 started_at,
113 completed_at,
114 execution_time_ms,
115 error_message,
116 metadata,
117 created_at as "created_at!",
118 updated_at as "updated_at!"
119 FROM agent_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
120 user_id_str,
121 lim,
122 off
123 )
124 .fetch_all(pool.as_ref())
125 .await
126 .map_err(RepositoryError::database)?;
127
128 let constructor = TaskConstructor::new(db_pool)?;
129 let task_ids: Vec<TaskId> = rows.iter().map(|r| r.task_id.clone()).collect();
130 let tasks = constructor.construct_tasks_batch(&task_ids).await?;
131
132 Ok(tasks)
133}
134
135#[derive(Debug, Clone)]
136pub struct TaskContextInfo {
137 pub context_id: ContextId,
138 pub user_id: UserId,
139}
140
141pub async fn get_task_context_info(
142 pool: &Arc<PgPool>,
143 task_id: &TaskId,
144) -> Result<Option<TaskContextInfo>, RepositoryError> {
145 let task_id_str = task_id.as_str();
146 let row = sqlx::query!(
147 r#"SELECT
148 context_id as "context_id!: ContextId",
149 user_id as "user_id?: UserId"
150 FROM agent_tasks WHERE task_id = $1"#,
151 task_id_str
152 )
153 .fetch_optional(pool.as_ref())
154 .await
155 .map_err(RepositoryError::database)?;
156
157 Ok(row.map(|r| TaskContextInfo {
158 context_id: r.context_id,
159 user_id: r.user_id.unwrap_or_else(|| UserId::new("")),
160 }))
161}