1use std::path::PathBuf;
2
3use sqlx::Row;
4
5use crate::database::DatabaseContext;
6use crate::errors::{ErrorCode, TrackError};
7use crate::path_component::validate_single_normal_path_component;
8use crate::time_utils::{format_iso_8601_millis, now_utc, parse_iso_8601_millis};
9use crate::types::{DispatchStatus, RemoteAgentPreferredTool, Task, TaskDispatchRecord};
10
11#[derive(Debug, Clone)]
12pub struct DispatchRepository {
13 database: DatabaseContext,
14}
15
16impl DispatchRepository {
17 pub fn new(database_path: Option<PathBuf>) -> Result<Self, TrackError> {
18 let database = DatabaseContext::new(database_path)?;
19 database.initialize()?;
20
21 Ok(Self { database })
22 }
23
24 pub fn create_dispatch(
25 &self,
26 task: &Task,
27 remote_host: &str,
28 preferred_tool: RemoteAgentPreferredTool,
29 ) -> Result<TaskDispatchRecord, TrackError> {
30 let timestamp = now_utc();
31 let record = TaskDispatchRecord {
32 dispatch_id: format!("dispatch-{}", timestamp.unix_timestamp_nanos()),
33 task_id: task.id.clone(),
34 preferred_tool,
35 project: task.project.clone(),
36 status: DispatchStatus::Preparing,
37 created_at: timestamp,
38 updated_at: timestamp,
39 finished_at: None,
40 remote_host: remote_host.to_owned(),
41 branch_name: None,
42 worktree_path: None,
43 pull_request_url: None,
44 follow_up_request: None,
45 summary: None,
46 notes: None,
47 error_message: None,
48 review_request_head_oid: None,
49 review_request_user: None,
50 };
51
52 self.save_dispatch(&record)?;
53 Ok(record)
54 }
55
56 pub fn save_dispatch(&self, record: &TaskDispatchRecord) -> Result<(), TrackError> {
57 let record = record.clone();
58 validate_single_normal_path_component(
59 &record.dispatch_id,
60 "Dispatch id",
61 ErrorCode::InvalidPathComponent,
62 )?;
63 validate_single_normal_path_component(
64 &record.task_id,
65 "Task id",
66 ErrorCode::InvalidPathComponent,
67 )?;
68
69 self.database.run(move |connection| {
70 Box::pin(async move {
71 sqlx::query(
72 r#"
73 INSERT INTO task_dispatches (
74 dispatch_id, task_id, preferred_tool, project, status, created_at, updated_at,
75 finished_at, remote_host, branch_name, worktree_path, pull_request_url,
76 follow_up_request, summary, notes, error_message, review_request_head_oid,
77 review_request_user
78 )
79 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18)
80 ON CONFLICT(dispatch_id) DO UPDATE SET
81 task_id = excluded.task_id,
82 preferred_tool = excluded.preferred_tool,
83 project = excluded.project,
84 status = excluded.status,
85 created_at = excluded.created_at,
86 updated_at = excluded.updated_at,
87 finished_at = excluded.finished_at,
88 remote_host = excluded.remote_host,
89 branch_name = excluded.branch_name,
90 worktree_path = excluded.worktree_path,
91 pull_request_url = excluded.pull_request_url,
92 follow_up_request = excluded.follow_up_request,
93 summary = excluded.summary,
94 notes = excluded.notes,
95 error_message = excluded.error_message,
96 review_request_head_oid = excluded.review_request_head_oid,
97 review_request_user = excluded.review_request_user
98 "#,
99 )
100 .bind(&record.dispatch_id)
101 .bind(&record.task_id)
102 .bind(record.preferred_tool.as_str())
103 .bind(&record.project)
104 .bind(record.status.as_str())
105 .bind(format_iso_8601_millis(record.created_at))
106 .bind(format_iso_8601_millis(record.updated_at))
107 .bind(record.finished_at.map(format_iso_8601_millis))
108 .bind(&record.remote_host)
109 .bind(record.branch_name.as_deref())
110 .bind(record.worktree_path.as_deref())
111 .bind(record.pull_request_url.as_deref())
112 .bind(record.follow_up_request.as_deref())
113 .bind(record.summary.as_deref())
114 .bind(record.notes.as_deref())
115 .bind(record.error_message.as_deref())
116 .bind(record.review_request_head_oid.as_deref())
117 .bind(record.review_request_user.as_deref())
118 .execute(&mut *connection)
119 .await
120 .map_err(|error| {
121 TrackError::new(
122 ErrorCode::DispatchWriteFailed,
123 format!(
124 "Could not save the dispatch record for task {}: {error}",
125 record.task_id
126 ),
127 )
128 })?;
129
130 Ok(())
131 })
132 })
133 }
134
135 pub fn latest_dispatch_for_task(
136 &self,
137 task_id: &str,
138 ) -> Result<Option<TaskDispatchRecord>, TrackError> {
139 Ok(self.dispatches_for_task(task_id)?.into_iter().next())
140 }
141
142 pub fn dispatches_for_task(
143 &self,
144 task_id: &str,
145 ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
146 let task_id = validate_single_normal_path_component(
147 task_id,
148 "Task id",
149 ErrorCode::InvalidPathComponent,
150 )?;
151
152 self.database.run(move |connection| {
153 Box::pin(async move {
154 let rows = sqlx::query(
155 r#"
156 SELECT *
157 FROM task_dispatches
158 WHERE task_id = ?1
159 ORDER BY created_at DESC
160 "#,
161 )
162 .bind(&task_id)
163 .fetch_all(&mut *connection)
164 .await
165 .map_err(|error| {
166 TrackError::new(
167 ErrorCode::DispatchWriteFailed,
168 format!("Could not load dispatch history for task {task_id}: {error}"),
169 )
170 })?;
171
172 rows.into_iter().map(task_dispatch_from_row).collect()
173 })
174 })
175 }
176
177 pub fn latest_dispatches_for_tasks(
178 &self,
179 task_ids: &[String],
180 ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
181 let mut records = Vec::new();
182 for task_id in task_ids {
183 if let Some(record) = self.latest_dispatch_for_task(task_id)? {
184 records.push(record);
185 }
186 }
187
188 Ok(records)
189 }
190
191 pub fn list_dispatches(
192 &self,
193 limit: Option<usize>,
194 ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
195 let limit = limit.map(|value| value as i64);
196 self.database.run(move |connection| {
197 Box::pin(async move {
198 let rows = if let Some(limit) = limit {
199 sqlx::query(
200 r#"
201 SELECT *
202 FROM task_dispatches
203 ORDER BY created_at DESC
204 LIMIT ?1
205 "#,
206 )
207 .bind(limit)
208 .fetch_all(&mut *connection)
209 .await
210 } else {
211 sqlx::query(
212 r#"
213 SELECT *
214 FROM task_dispatches
215 ORDER BY created_at DESC
216 "#,
217 )
218 .fetch_all(&mut *connection)
219 .await
220 }
221 .map_err(|error| {
222 TrackError::new(
223 ErrorCode::DispatchWriteFailed,
224 format!("Could not list dispatch records: {error}"),
225 )
226 })?;
227
228 rows.into_iter().map(task_dispatch_from_row).collect()
229 })
230 })
231 }
232
233 pub fn task_ids_with_history(&self) -> Result<Vec<String>, TrackError> {
234 self.database.run(move |connection| {
235 Box::pin(async move {
236 let rows = sqlx::query(
237 r#"
238 SELECT DISTINCT task_id
239 FROM task_dispatches
240 ORDER BY task_id ASC
241 "#,
242 )
243 .fetch_all(&mut *connection)
244 .await
245 .map_err(|error| {
246 TrackError::new(
247 ErrorCode::DispatchWriteFailed,
248 format!("Could not load task ids with dispatch history: {error}"),
249 )
250 })?;
251
252 Ok(rows
253 .into_iter()
254 .map(|row| row.get::<String, _>("task_id"))
255 .collect())
256 })
257 })
258 }
259
260 pub fn get_dispatch(
261 &self,
262 task_id: &str,
263 dispatch_id: &str,
264 ) -> Result<Option<TaskDispatchRecord>, TrackError> {
265 let task_id = validate_single_normal_path_component(
266 task_id,
267 "Task id",
268 ErrorCode::InvalidPathComponent,
269 )?;
270 let dispatch_id = validate_single_normal_path_component(
271 dispatch_id,
272 "Dispatch id",
273 ErrorCode::InvalidPathComponent,
274 )?;
275
276 self.database.run(move |connection| {
277 Box::pin(async move {
278 let row = sqlx::query(
279 r#"
280 SELECT *
281 FROM task_dispatches
282 WHERE task_id = ?1 AND dispatch_id = ?2
283 "#,
284 )
285 .bind(&task_id)
286 .bind(&dispatch_id)
287 .fetch_optional(&mut *connection)
288 .await
289 .map_err(|error| {
290 TrackError::new(
291 ErrorCode::DispatchWriteFailed,
292 format!(
293 "Could not load the dispatch record {dispatch_id} for task {task_id}: {error}"
294 ),
295 )
296 })?;
297
298 row.map(task_dispatch_from_row).transpose()
299 })
300 })
301 }
302
303 pub fn delete_dispatch_history_for_task(&self, task_id: &str) -> Result<(), TrackError> {
304 let task_id = validate_single_normal_path_component(
305 task_id,
306 "Task id",
307 ErrorCode::InvalidPathComponent,
308 )?;
309
310 self.database.run(move |connection| {
311 Box::pin(async move {
312 sqlx::query("DELETE FROM task_dispatches WHERE task_id = ?1")
313 .bind(&task_id)
314 .execute(&mut *connection)
315 .await
316 .map_err(|error| {
317 TrackError::new(
318 ErrorCode::DispatchWriteFailed,
319 format!(
320 "Could not remove the dispatch history for task {task_id}: {error}"
321 ),
322 )
323 })?;
324
325 Ok(())
326 })
327 })
328 }
329}
330
331fn task_dispatch_from_row(row: sqlx::sqlite::SqliteRow) -> Result<TaskDispatchRecord, TrackError> {
332 let dispatch_id = row.get::<String, _>("dispatch_id");
333 let created_at =
334 parse_iso_8601_millis(&row.get::<String, _>("created_at")).map_err(|error| {
335 TrackError::new(
336 ErrorCode::DispatchWriteFailed,
337 format!("Dispatch {dispatch_id} has an invalid created_at timestamp: {error}"),
338 )
339 })?;
340 let updated_at =
341 parse_iso_8601_millis(&row.get::<String, _>("updated_at")).map_err(|error| {
342 TrackError::new(
343 ErrorCode::DispatchWriteFailed,
344 format!("Dispatch {dispatch_id} has an invalid updated_at timestamp: {error}"),
345 )
346 })?;
347 let finished_at = row
348 .get::<Option<String>, _>("finished_at")
349 .map(|value| parse_iso_8601_millis(&value))
350 .transpose()
351 .map_err(|error| {
352 TrackError::new(
353 ErrorCode::DispatchWriteFailed,
354 format!("Dispatch {dispatch_id} has an invalid finished_at timestamp: {error}"),
355 )
356 })?;
357
358 Ok(TaskDispatchRecord {
359 dispatch_id,
360 task_id: row.get::<String, _>("task_id"),
361 preferred_tool: parse_preferred_tool(
362 row.try_get::<String, _>("preferred_tool")
363 .unwrap_or_else(|_| "codex".to_owned())
364 .as_str(),
365 )?,
366 project: row.get::<String, _>("project"),
367 status: parse_dispatch_status(row.get::<String, _>("status").as_str())?,
368 created_at,
369 updated_at,
370 finished_at,
371 remote_host: row.get::<String, _>("remote_host"),
372 branch_name: row.get::<Option<String>, _>("branch_name"),
373 worktree_path: row.get::<Option<String>, _>("worktree_path"),
374 pull_request_url: row.get::<Option<String>, _>("pull_request_url"),
375 follow_up_request: row.get::<Option<String>, _>("follow_up_request"),
376 summary: row.get::<Option<String>, _>("summary"),
377 notes: row.get::<Option<String>, _>("notes"),
378 error_message: row.get::<Option<String>, _>("error_message"),
379 review_request_head_oid: row.get::<Option<String>, _>("review_request_head_oid"),
380 review_request_user: row.get::<Option<String>, _>("review_request_user"),
381 })
382}
383
384fn parse_dispatch_status(value: &str) -> Result<DispatchStatus, TrackError> {
385 match value {
386 "preparing" => Ok(DispatchStatus::Preparing),
387 "running" => Ok(DispatchStatus::Running),
388 "succeeded" => Ok(DispatchStatus::Succeeded),
389 "canceled" => Ok(DispatchStatus::Canceled),
390 "failed" => Ok(DispatchStatus::Failed),
391 "blocked" => Ok(DispatchStatus::Blocked),
392 _ => Err(TrackError::new(
393 ErrorCode::DispatchWriteFailed,
394 format!("Dispatch status `{value}` is not valid."),
395 )),
396 }
397}
398
399fn parse_preferred_tool(value: &str) -> Result<RemoteAgentPreferredTool, TrackError> {
400 RemoteAgentPreferredTool::from_str(value).ok_or_else(|| {
401 TrackError::new(
402 ErrorCode::DispatchWriteFailed,
403 format!("Remote agent preferred tool `{value}` is not valid."),
404 )
405 })
406}