Skip to main content

track_core/
dispatch_repository.rs

1use std::path::PathBuf;
2
3use sqlx::Row;
4
5use crate::database::DatabaseContext;
6use crate::errors::{ErrorCode, TrackError};
7use crate::path_component::validate_single_normal_path_component;
8use crate::time_utils::{format_iso_8601_millis, now_utc, parse_iso_8601_millis};
9use crate::types::{DispatchStatus, RemoteAgentPreferredTool, Task, TaskDispatchRecord};
10
11#[derive(Debug, Clone)]
12pub struct DispatchRepository {
13    database: DatabaseContext,
14}
15
16impl DispatchRepository {
17    pub fn new(database_path: Option<PathBuf>) -> Result<Self, TrackError> {
18        let database = DatabaseContext::new(database_path)?;
19        database.initialize()?;
20
21        Ok(Self { database })
22    }
23
24    pub fn create_dispatch(
25        &self,
26        task: &Task,
27        remote_host: &str,
28        preferred_tool: RemoteAgentPreferredTool,
29    ) -> Result<TaskDispatchRecord, TrackError> {
30        let timestamp = now_utc();
31        let record = TaskDispatchRecord {
32            dispatch_id: format!("dispatch-{}", timestamp.unix_timestamp_nanos()),
33            task_id: task.id.clone(),
34            preferred_tool,
35            project: task.project.clone(),
36            status: DispatchStatus::Preparing,
37            created_at: timestamp,
38            updated_at: timestamp,
39            finished_at: None,
40            remote_host: remote_host.to_owned(),
41            branch_name: None,
42            worktree_path: None,
43            pull_request_url: None,
44            follow_up_request: None,
45            summary: None,
46            notes: None,
47            error_message: None,
48            review_request_head_oid: None,
49            review_request_user: None,
50        };
51
52        self.save_dispatch(&record)?;
53        Ok(record)
54    }
55
56    pub fn save_dispatch(&self, record: &TaskDispatchRecord) -> Result<(), TrackError> {
57        let record = record.clone();
58        validate_single_normal_path_component(
59            &record.dispatch_id,
60            "Dispatch id",
61            ErrorCode::InvalidPathComponent,
62        )?;
63        validate_single_normal_path_component(
64            &record.task_id,
65            "Task id",
66            ErrorCode::InvalidPathComponent,
67        )?;
68
69        self.database.run(move |connection| {
70            Box::pin(async move {
71                sqlx::query(
72                    r#"
73                    INSERT INTO task_dispatches (
74                        dispatch_id, task_id, preferred_tool, project, status, created_at, updated_at,
75                        finished_at, remote_host, branch_name, worktree_path, pull_request_url,
76                        follow_up_request, summary, notes, error_message, review_request_head_oid,
77                        review_request_user
78                    )
79                    VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18)
80                    ON CONFLICT(dispatch_id) DO UPDATE SET
81                        task_id = excluded.task_id,
82                        preferred_tool = excluded.preferred_tool,
83                        project = excluded.project,
84                        status = excluded.status,
85                        created_at = excluded.created_at,
86                        updated_at = excluded.updated_at,
87                        finished_at = excluded.finished_at,
88                        remote_host = excluded.remote_host,
89                        branch_name = excluded.branch_name,
90                        worktree_path = excluded.worktree_path,
91                        pull_request_url = excluded.pull_request_url,
92                        follow_up_request = excluded.follow_up_request,
93                        summary = excluded.summary,
94                        notes = excluded.notes,
95                        error_message = excluded.error_message,
96                        review_request_head_oid = excluded.review_request_head_oid,
97                        review_request_user = excluded.review_request_user
98                    "#,
99                )
100                .bind(&record.dispatch_id)
101                .bind(&record.task_id)
102                .bind(record.preferred_tool.as_str())
103                .bind(&record.project)
104                .bind(record.status.as_str())
105                .bind(format_iso_8601_millis(record.created_at))
106                .bind(format_iso_8601_millis(record.updated_at))
107                .bind(record.finished_at.map(format_iso_8601_millis))
108                .bind(&record.remote_host)
109                .bind(record.branch_name.as_deref())
110                .bind(record.worktree_path.as_deref())
111                .bind(record.pull_request_url.as_deref())
112                .bind(record.follow_up_request.as_deref())
113                .bind(record.summary.as_deref())
114                .bind(record.notes.as_deref())
115                .bind(record.error_message.as_deref())
116                .bind(record.review_request_head_oid.as_deref())
117                .bind(record.review_request_user.as_deref())
118                .execute(&mut *connection)
119                .await
120                .map_err(|error| {
121                    TrackError::new(
122                        ErrorCode::DispatchWriteFailed,
123                        format!(
124                            "Could not save the dispatch record for task {}: {error}",
125                            record.task_id
126                        ),
127                    )
128                })?;
129
130                Ok(())
131            })
132        })
133    }
134
135    pub fn latest_dispatch_for_task(
136        &self,
137        task_id: &str,
138    ) -> Result<Option<TaskDispatchRecord>, TrackError> {
139        Ok(self.dispatches_for_task(task_id)?.into_iter().next())
140    }
141
142    pub fn dispatches_for_task(
143        &self,
144        task_id: &str,
145    ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
146        let task_id = validate_single_normal_path_component(
147            task_id,
148            "Task id",
149            ErrorCode::InvalidPathComponent,
150        )?;
151
152        self.database.run(move |connection| {
153            Box::pin(async move {
154                let rows = sqlx::query(
155                    r#"
156                    SELECT *
157                    FROM task_dispatches
158                    WHERE task_id = ?1
159                    ORDER BY created_at DESC
160                    "#,
161                )
162                .bind(&task_id)
163                .fetch_all(&mut *connection)
164                .await
165                .map_err(|error| {
166                    TrackError::new(
167                        ErrorCode::DispatchWriteFailed,
168                        format!("Could not load dispatch history for task {task_id}: {error}"),
169                    )
170                })?;
171
172                rows.into_iter().map(task_dispatch_from_row).collect()
173            })
174        })
175    }
176
177    pub fn latest_dispatches_for_tasks(
178        &self,
179        task_ids: &[String],
180    ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
181        let mut records = Vec::new();
182        for task_id in task_ids {
183            if let Some(record) = self.latest_dispatch_for_task(task_id)? {
184                records.push(record);
185            }
186        }
187
188        Ok(records)
189    }
190
191    pub fn list_dispatches(
192        &self,
193        limit: Option<usize>,
194    ) -> Result<Vec<TaskDispatchRecord>, TrackError> {
195        let limit = limit.map(|value| value as i64);
196        self.database.run(move |connection| {
197            Box::pin(async move {
198                let rows = if let Some(limit) = limit {
199                    sqlx::query(
200                        r#"
201                        SELECT *
202                        FROM task_dispatches
203                        ORDER BY created_at DESC
204                        LIMIT ?1
205                        "#,
206                    )
207                    .bind(limit)
208                    .fetch_all(&mut *connection)
209                    .await
210                } else {
211                    sqlx::query(
212                        r#"
213                        SELECT *
214                        FROM task_dispatches
215                        ORDER BY created_at DESC
216                        "#,
217                    )
218                    .fetch_all(&mut *connection)
219                    .await
220                }
221                .map_err(|error| {
222                    TrackError::new(
223                        ErrorCode::DispatchWriteFailed,
224                        format!("Could not list dispatch records: {error}"),
225                    )
226                })?;
227
228                rows.into_iter().map(task_dispatch_from_row).collect()
229            })
230        })
231    }
232
233    pub fn task_ids_with_history(&self) -> Result<Vec<String>, TrackError> {
234        self.database.run(move |connection| {
235            Box::pin(async move {
236                let rows = sqlx::query(
237                    r#"
238                    SELECT DISTINCT task_id
239                    FROM task_dispatches
240                    ORDER BY task_id ASC
241                    "#,
242                )
243                .fetch_all(&mut *connection)
244                .await
245                .map_err(|error| {
246                    TrackError::new(
247                        ErrorCode::DispatchWriteFailed,
248                        format!("Could not load task ids with dispatch history: {error}"),
249                    )
250                })?;
251
252                Ok(rows
253                    .into_iter()
254                    .map(|row| row.get::<String, _>("task_id"))
255                    .collect())
256            })
257        })
258    }
259
260    pub fn get_dispatch(
261        &self,
262        task_id: &str,
263        dispatch_id: &str,
264    ) -> Result<Option<TaskDispatchRecord>, TrackError> {
265        let task_id = validate_single_normal_path_component(
266            task_id,
267            "Task id",
268            ErrorCode::InvalidPathComponent,
269        )?;
270        let dispatch_id = validate_single_normal_path_component(
271            dispatch_id,
272            "Dispatch id",
273            ErrorCode::InvalidPathComponent,
274        )?;
275
276        self.database.run(move |connection| {
277            Box::pin(async move {
278                let row = sqlx::query(
279                    r#"
280                    SELECT *
281                    FROM task_dispatches
282                    WHERE task_id = ?1 AND dispatch_id = ?2
283                    "#,
284                )
285                .bind(&task_id)
286                .bind(&dispatch_id)
287                .fetch_optional(&mut *connection)
288                .await
289                .map_err(|error| {
290                    TrackError::new(
291                        ErrorCode::DispatchWriteFailed,
292                        format!(
293                            "Could not load the dispatch record {dispatch_id} for task {task_id}: {error}"
294                        ),
295                    )
296                })?;
297
298                row.map(task_dispatch_from_row).transpose()
299            })
300        })
301    }
302
303    pub fn delete_dispatch_history_for_task(&self, task_id: &str) -> Result<(), TrackError> {
304        let task_id = validate_single_normal_path_component(
305            task_id,
306            "Task id",
307            ErrorCode::InvalidPathComponent,
308        )?;
309
310        self.database.run(move |connection| {
311            Box::pin(async move {
312                sqlx::query("DELETE FROM task_dispatches WHERE task_id = ?1")
313                    .bind(&task_id)
314                    .execute(&mut *connection)
315                    .await
316                    .map_err(|error| {
317                        TrackError::new(
318                            ErrorCode::DispatchWriteFailed,
319                            format!(
320                                "Could not remove the dispatch history for task {task_id}: {error}"
321                            ),
322                        )
323                    })?;
324
325                Ok(())
326            })
327        })
328    }
329}
330
331fn task_dispatch_from_row(row: sqlx::sqlite::SqliteRow) -> Result<TaskDispatchRecord, TrackError> {
332    let dispatch_id = row.get::<String, _>("dispatch_id");
333    let created_at =
334        parse_iso_8601_millis(&row.get::<String, _>("created_at")).map_err(|error| {
335            TrackError::new(
336                ErrorCode::DispatchWriteFailed,
337                format!("Dispatch {dispatch_id} has an invalid created_at timestamp: {error}"),
338            )
339        })?;
340    let updated_at =
341        parse_iso_8601_millis(&row.get::<String, _>("updated_at")).map_err(|error| {
342            TrackError::new(
343                ErrorCode::DispatchWriteFailed,
344                format!("Dispatch {dispatch_id} has an invalid updated_at timestamp: {error}"),
345            )
346        })?;
347    let finished_at = row
348        .get::<Option<String>, _>("finished_at")
349        .map(|value| parse_iso_8601_millis(&value))
350        .transpose()
351        .map_err(|error| {
352            TrackError::new(
353                ErrorCode::DispatchWriteFailed,
354                format!("Dispatch {dispatch_id} has an invalid finished_at timestamp: {error}"),
355            )
356        })?;
357
358    Ok(TaskDispatchRecord {
359        dispatch_id,
360        task_id: row.get::<String, _>("task_id"),
361        preferred_tool: parse_preferred_tool(
362            row.try_get::<String, _>("preferred_tool")
363                .unwrap_or_else(|_| "codex".to_owned())
364                .as_str(),
365        )?,
366        project: row.get::<String, _>("project"),
367        status: parse_dispatch_status(row.get::<String, _>("status").as_str())?,
368        created_at,
369        updated_at,
370        finished_at,
371        remote_host: row.get::<String, _>("remote_host"),
372        branch_name: row.get::<Option<String>, _>("branch_name"),
373        worktree_path: row.get::<Option<String>, _>("worktree_path"),
374        pull_request_url: row.get::<Option<String>, _>("pull_request_url"),
375        follow_up_request: row.get::<Option<String>, _>("follow_up_request"),
376        summary: row.get::<Option<String>, _>("summary"),
377        notes: row.get::<Option<String>, _>("notes"),
378        error_message: row.get::<Option<String>, _>("error_message"),
379        review_request_head_oid: row.get::<Option<String>, _>("review_request_head_oid"),
380        review_request_user: row.get::<Option<String>, _>("review_request_user"),
381    })
382}
383
384fn parse_dispatch_status(value: &str) -> Result<DispatchStatus, TrackError> {
385    match value {
386        "preparing" => Ok(DispatchStatus::Preparing),
387        "running" => Ok(DispatchStatus::Running),
388        "succeeded" => Ok(DispatchStatus::Succeeded),
389        "canceled" => Ok(DispatchStatus::Canceled),
390        "failed" => Ok(DispatchStatus::Failed),
391        "blocked" => Ok(DispatchStatus::Blocked),
392        _ => Err(TrackError::new(
393            ErrorCode::DispatchWriteFailed,
394            format!("Dispatch status `{value}` is not valid."),
395        )),
396    }
397}
398
399fn parse_preferred_tool(value: &str) -> Result<RemoteAgentPreferredTool, TrackError> {
400    RemoteAgentPreferredTool::from_str(value).ok_or_else(|| {
401        TrackError::new(
402            ErrorCode::DispatchWriteFailed,
403            format!("Remote agent preferred tool `{value}` is not valid."),
404        )
405    })
406}