systemprompt_agent/repository/task/
queries.rs1use crate::models::TaskRow;
2use sqlx::PgPool;
3use std::sync::Arc;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
6use systemprompt_traits::RepositoryError;
7
8use super::constructor::TaskConstructor;
9use crate::models::a2a::Task;
10
11pub async fn get_task(
12 pool: &Arc<PgPool>,
13 db_pool: &DbPool,
14 task_id: &TaskId,
15) -> Result<Option<Task>, RepositoryError> {
16 let task_id_str = task_id.as_str();
17 let row = sqlx::query_as!(
18 TaskRow,
19 r#"SELECT
20 task_id as "task_id!: TaskId",
21 context_id as "context_id!: ContextId",
22 status as "status!",
23 status_timestamp,
24 user_id as "user_id?: UserId",
25 session_id as "session_id?: SessionId",
26 trace_id as "trace_id?: TraceId",
27 agent_name as "agent_name?: AgentName",
28 started_at,
29 completed_at,
30 execution_time_ms,
31 error_message,
32 metadata,
33 created_at as "created_at!",
34 updated_at as "updated_at!"
35 FROM agent_tasks WHERE task_id = $1"#,
36 task_id_str
37 )
38 .fetch_optional(pool.as_ref())
39 .await
40 .map_err(|e| RepositoryError::database(e))?;
41
42 let Some(_row) = row else {
43 return Ok(None);
44 };
45
46 let constructor = TaskConstructor::new(db_pool.clone());
47 let task = constructor.construct_task_from_task_id(task_id).await?;
48
49 Ok(Some(task))
50}
51
52pub async fn list_tasks_by_context(
53 pool: &Arc<PgPool>,
54 db_pool: &DbPool,
55 context_id: &ContextId,
56) -> Result<Vec<Task>, RepositoryError> {
57 let context_id_str = context_id.as_str();
58 let rows = sqlx::query_as!(
59 TaskRow,
60 r#"SELECT
61 task_id as "task_id!: TaskId",
62 context_id as "context_id!: ContextId",
63 status as "status!",
64 status_timestamp,
65 user_id as "user_id?: UserId",
66 session_id as "session_id?: SessionId",
67 trace_id as "trace_id?: TraceId",
68 agent_name as "agent_name?: AgentName",
69 started_at,
70 completed_at,
71 execution_time_ms,
72 error_message,
73 metadata,
74 created_at as "created_at!",
75 updated_at as "updated_at!"
76 FROM agent_tasks WHERE context_id = $1 ORDER BY created_at ASC"#,
77 context_id_str
78 )
79 .fetch_all(pool.as_ref())
80 .await
81 .map_err(|e| RepositoryError::database(e))?;
82
83 let constructor = TaskConstructor::new(db_pool.clone());
84 let mut tasks = Vec::new();
85
86 for row in rows {
87 tasks.push(
88 constructor
89 .construct_task_from_task_id(&row.task_id)
90 .await?,
91 );
92 }
93
94 Ok(tasks)
95}
96
97pub async fn get_tasks_by_user_id(
98 pool: &Arc<PgPool>,
99 db_pool: &DbPool,
100 user_id: &UserId,
101 limit: Option<i32>,
102 offset: Option<i32>,
103) -> Result<Vec<Task>, RepositoryError> {
104 let lim = limit.map(i64::from).unwrap_or(1000);
105 let off = offset.map(i64::from).unwrap_or(0);
106 let user_id_str = user_id.as_str();
107
108 let rows = sqlx::query_as!(
109 TaskRow,
110 r#"SELECT
111 task_id as "task_id!: TaskId",
112 context_id as "context_id!: ContextId",
113 status as "status!",
114 status_timestamp,
115 user_id as "user_id?: UserId",
116 session_id as "session_id?: SessionId",
117 trace_id as "trace_id?: TraceId",
118 agent_name as "agent_name?: AgentName",
119 started_at,
120 completed_at,
121 execution_time_ms,
122 error_message,
123 metadata,
124 created_at as "created_at!",
125 updated_at as "updated_at!"
126 FROM agent_tasks WHERE user_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"#,
127 user_id_str,
128 lim,
129 off
130 )
131 .fetch_all(pool.as_ref())
132 .await
133 .map_err(|e| RepositoryError::database(e))?;
134
135 let constructor = TaskConstructor::new(db_pool.clone());
136 let mut tasks = Vec::new();
137
138 for row in &rows {
139 tasks.push(
140 constructor
141 .construct_task_from_task_id(&row.task_id)
142 .await?,
143 );
144 }
145
146 Ok(tasks)
147}
148
149#[derive(Debug, Clone)]
150pub struct TaskContextInfo {
151 pub context_id: String,
152 pub user_id: String,
153}
154
155impl TaskContextInfo {
156 pub fn context_id(&self) -> ContextId {
157 ContextId::new(&self.context_id)
158 }
159
160 pub fn user_id(&self) -> UserId {
161 UserId::new(&self.user_id)
162 }
163}
164
165pub async fn get_task_context_info(
166 pool: &Arc<PgPool>,
167 task_id: &TaskId,
168) -> Result<Option<TaskContextInfo>, RepositoryError> {
169 let task_id_str = task_id.as_str();
170 let row = sqlx::query!(
171 r#"SELECT
172 context_id as "context_id!",
173 user_id
174 FROM agent_tasks WHERE task_id = $1"#,
175 task_id_str
176 )
177 .fetch_optional(pool.as_ref())
178 .await
179 .map_err(|e| RepositoryError::database(e))?;
180
181 Ok(row.map(|r| TaskContextInfo {
182 context_id: r.context_id,
183 user_id: r.user_id.unwrap_or_else(String::new),
184 }))
185}