Skip to main content

task_graph_mcp/db/
agents.rs

1//! Worker CRUD operations.
2
3use super::{Database, now_ms};
4use crate::types::{CleanupSummary, DisconnectSummary, Worker};
5use anyhow::{Result, anyhow};
6use rusqlite::{Connection, params};
7
8/// Maximum length for worker IDs.
9pub const MAX_WORKER_ID_LEN: usize = 36;
10
11/// Maximum attempts to generate a unique petname before falling back.
12const MAX_PETNAME_ATTEMPTS: u32 = 100;
13
14/// Generate a unique petname-based worker ID.
15/// Tries base petname first, then appends numbers (e.g., "happy-turtle-2").
16fn generate_unique_petname(conn: &Connection) -> String {
17    let base = petname::petname(2, "-").unwrap_or_else(|| "worker".to_string());
18
19    // Check if base name is available
20    let exists: bool = conn
21        .query_row(
22            "SELECT 1 FROM workers WHERE id = ?1",
23            params![&base],
24            |_| Ok(true),
25        )
26        .unwrap_or(false);
27
28    if !exists {
29        return base;
30    }
31
32    // Try appending numbers: happy-turtle-2, happy-turtle-3, etc.
33    for i in 2..=MAX_PETNAME_ATTEMPTS {
34        let candidate = format!("{}-{}", base, i);
35        let exists: bool = conn
36            .query_row(
37                "SELECT 1 FROM workers WHERE id = ?1",
38                params![&candidate],
39                |_| Ok(true),
40            )
41            .unwrap_or(false);
42        if !exists {
43            return candidate;
44        }
45    }
46
47    // Fallback: generate a completely new petname with 3 words for uniqueness
48    petname::petname(3, "-").unwrap_or_else(|| format!("worker-{}", now_ms()))
49}
50
51/// Internal helper to get a worker using an existing connection (avoids deadlock).
52fn get_worker_internal(conn: &Connection, worker_id: &str) -> Result<Option<Worker>> {
53    let mut stmt = conn.prepare(
54        "SELECT id, tags, max_claims, registered_at, last_heartbeat
55         FROM workers WHERE id = ?1",
56    )?;
57
58    let result = stmt.query_row(params![worker_id], |row| {
59        let id: String = row.get(0)?;
60        let tags_json: String = row.get(1)?;
61        let max_claims: i32 = row.get(2)?;
62        let registered_at: i64 = row.get(3)?;
63        let last_heartbeat: i64 = row.get(4)?;
64
65        Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
66    });
67
68    match result {
69        Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
70            let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
71            Ok(Some(Worker {
72                id,
73                tags,
74                max_claims,
75                registered_at,
76                last_heartbeat,
77            }))
78        }
79        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
80        Err(e) => Err(e.into()),
81    }
82}
83
84impl Database {
85    /// Register a new worker.
86    ///
87    /// If `worker_id` is provided, it must be at most 36 characters.
88    /// If not provided, a human-readable petname will be generated (e.g., "happy-turtle").
89    /// If `force` is true and the worker already exists, it will be re-registered
90    /// (useful for stuck worker recovery).
91    pub fn register_worker(
92        &self,
93        worker_id: Option<String>,
94        tags: Vec<String>,
95        force: bool,
96    ) -> Result<Worker> {
97        // Validate user-provided ID upfront (before acquiring connection)
98        let provided_id = match worker_id {
99            Some(id) => {
100                if id.len() > MAX_WORKER_ID_LEN {
101                    return Err(anyhow!(
102                        "Worker ID must be at most {} characters, got {}",
103                        MAX_WORKER_ID_LEN,
104                        id.len()
105                    ));
106                }
107                if id.is_empty() {
108                    return Err(anyhow!("Worker ID cannot be empty"));
109                }
110                Some(id)
111            }
112            None => None,
113        };
114        let now = now_ms();
115        let max_claims = i32::MAX; // Effectively unlimited until overclaiming becomes a problem
116        let tags_json = serde_json::to_string(&tags)?;
117
118        self.with_conn(|conn| {
119            // Generate ID inside connection to avoid race conditions
120            let id = match provided_id {
121                Some(id) => id,
122                None => generate_unique_petname(conn),
123            };
124
125            // Check if worker ID already exists
126            let exists: bool = conn
127                .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&id], |_| Ok(true))
128                .unwrap_or(false);
129
130            // Get current max claim sequence + 1 to initialize poll position.
131            // This ensures first poll returns empty (no events since registration).
132            // The +1 is needed because we now query with `id >= last_seq`.
133            let current_max_sequence: i64 = conn
134                .query_row("SELECT COALESCE(MAX(id), 0) FROM claim_sequence", [], |row| row.get(0))
135                .unwrap_or(0);
136            let initial_sequence = current_max_sequence + 1;
137
138            if exists {
139                if force {
140                    // Force reconnection: update existing worker and reset poll position
141                    conn.execute(
142                        "UPDATE workers SET tags = ?1, max_claims = ?2, last_heartbeat = ?3, last_claim_sequence = ?4 WHERE id = ?5",
143                        params![tags_json, max_claims, now, initial_sequence, &id],
144                    )?;
145                } else {
146                    return Err(anyhow!("Worker ID '{}' already registered. Use force=true to reconnect.", id));
147                }
148            } else {
149                conn.execute(
150                    "INSERT INTO workers (id, tags, max_claims, registered_at, last_heartbeat, last_claim_sequence)
151                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
152                    params![&id, tags_json, max_claims, now, now, initial_sequence],
153                )?;
154            }
155
156            Ok(Worker {
157                id,
158                tags,
159                max_claims,
160                registered_at: now,
161                last_heartbeat: now,
162            })
163        })
164    }
165
166    /// Get a worker by ID.
167    pub fn get_worker(&self, worker_id: &str) -> Result<Option<Worker>> {
168        self.with_conn(|conn| {
169            let mut stmt = conn.prepare(
170                "SELECT id, tags, max_claims, registered_at, last_heartbeat
171                 FROM workers WHERE id = ?1",
172            )?;
173
174            let result = stmt.query_row(params![worker_id], |row| {
175                let id: String = row.get(0)?;
176                let tags_json: String = row.get(1)?;
177                let max_claims: i32 = row.get(2)?;
178                let registered_at: i64 = row.get(3)?;
179                let last_heartbeat: i64 = row.get(4)?;
180
181                Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
182            });
183
184            match result {
185                Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
186                    let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
187                    Ok(Some(Worker {
188                        id,
189                        tags,
190                        max_claims,
191                        registered_at,
192                        last_heartbeat,
193                    }))
194                }
195                Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
196                Err(e) => Err(e.into()),
197            }
198        })
199    }
200
201    /// Check if a worker exists. Returns error if not found.
202    pub fn require_worker(&self, worker_id: &str) -> Result<Worker> {
203        self.get_worker(worker_id)?
204            .ok_or_else(|| anyhow::anyhow!("Worker {} not found", worker_id))
205    }
206
207    /// Update a worker.
208    pub fn update_worker(
209        &self,
210        worker_id: &str,
211        tags: Option<Vec<String>>,
212        max_claims: Option<i32>,
213    ) -> Result<Worker> {
214        self.with_conn(|conn| {
215            let worker =
216                get_worker_internal(conn, worker_id)?.ok_or_else(|| anyhow!("Worker not found"))?;
217
218            let new_tags = tags.unwrap_or(worker.tags.clone());
219            let new_max_claims = max_claims.unwrap_or(worker.max_claims);
220            let tags_json = serde_json::to_string(&new_tags)?;
221
222            conn.execute(
223                "UPDATE workers SET tags = ?1, max_claims = ?2 WHERE id = ?3",
224                params![tags_json, new_max_claims, worker_id],
225            )?;
226
227            Ok(Worker {
228                id: worker_id.to_string(),
229                tags: new_tags,
230                max_claims: new_max_claims,
231                registered_at: worker.registered_at,
232                last_heartbeat: worker.last_heartbeat,
233            })
234        })
235    }
236
237    /// Update worker heartbeat.
238    pub fn heartbeat(&self, worker_id: &str) -> Result<i32> {
239        let now = now_ms();
240
241        self.with_conn(|conn| {
242            let updated = conn.execute(
243                "UPDATE workers SET last_heartbeat = ?1 WHERE id = ?2",
244                params![now, worker_id],
245            )?;
246
247            if updated == 0 {
248                return Err(anyhow!("Worker not found"));
249            }
250
251            // Return current claim count
252            let count: i32 = conn.query_row(
253                "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
254                params![worker_id],
255                |row| row.get(0),
256            )?;
257
258            Ok(count)
259        })
260    }
261
262    /// Unregister a worker (releases all claims).
263    /// Returns a summary of released tasks and files.
264    pub fn unregister_worker(
265        &self,
266        worker_id: &str,
267        final_status: &str,
268    ) -> Result<DisconnectSummary> {
269        self.with_conn_mut(|conn| {
270            let tx = conn.transaction()?;
271
272            // Release all task claims, setting them to final_status
273            let tasks_released = tx.execute(
274                "UPDATE tasks SET worker_id = NULL, claimed_at = NULL, status = ?2
275                 WHERE worker_id = ?1",
276                params![worker_id, final_status],
277            )? as i32;
278
279            // Remove all file locks
280            let files_released = tx.execute(
281                "DELETE FROM file_locks WHERE worker_id = ?1",
282                params![worker_id],
283            )? as i32;
284
285            // Remove worker
286            tx.execute("DELETE FROM workers WHERE id = ?1", params![worker_id])?;
287
288            tx.commit()?;
289            Ok(DisconnectSummary {
290                tasks_released,
291                files_released,
292                final_status: final_status.to_string(),
293            })
294        })
295    }
296
297    /// List all workers.
298    pub fn list_workers(&self) -> Result<Vec<Worker>> {
299        self.with_conn(|conn| {
300            let mut stmt = conn.prepare(
301                "SELECT id, tags, max_claims, registered_at, last_heartbeat
302                 FROM workers ORDER BY registered_at DESC",
303            )?;
304
305            let workers = stmt
306                .query_map([], |row| {
307                    let id: String = row.get(0)?;
308                    let tags_json: String = row.get(1)?;
309                    let max_claims: i32 = row.get(2)?;
310                    let registered_at: i64 = row.get(3)?;
311                    let last_heartbeat: i64 = row.get(4)?;
312
313                    Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
314                })?
315                .filter_map(|r| r.ok())
316                .map(
317                    |(id, tags_json, max_claims, registered_at, last_heartbeat)| {
318                        let tags: Vec<String> =
319                            serde_json::from_str(&tags_json).unwrap_or_default();
320                        Worker {
321                            id,
322                            tags,
323                            max_claims,
324                            registered_at,
325                            last_heartbeat,
326                        }
327                    },
328                )
329                .collect();
330
331            Ok(workers)
332        })
333    }
334
335    /// List all workers with extended info (claim count, current thought).
336    pub fn list_workers_info(&self) -> Result<Vec<crate::types::WorkerInfo>> {
337        self.with_conn(|conn| {
338            let mut stmt = conn.prepare(
339                "SELECT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
340                        (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
341                        (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
342                 FROM workers w ORDER BY w.registered_at DESC",
343            )?;
344
345            let workers = stmt.query_map([], |row| {
346                let id: String = row.get(0)?;
347                let tags_json: String = row.get(1)?;
348                let max_claims: i32 = row.get(2)?;
349                let registered_at: i64 = row.get(3)?;
350                let last_heartbeat: i64 = row.get(4)?;
351                let claim_count: i32 = row.get(5)?;
352                let current_thought: Option<String> = row.get(6)?;
353
354                Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
355            })?
356            .filter_map(|r| r.ok())
357            .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
358                let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
359                crate::types::WorkerInfo {
360                    id,
361                    tags,
362                    max_claims,
363                    claim_count,
364                    current_thought,
365                    registered_at,
366                    last_heartbeat,
367                }
368            })
369            .collect();
370
371            Ok(workers)
372        })
373    }
374
375    /// List workers with optional filters by tags, file claimed, or related task.
376    ///
377    /// - `tags`: Workers must have ALL of these tags
378    /// - `file`: Workers that have claimed this file
379    /// - `task_id`: Workers working on tasks related to this task
380    /// - `depth`: Task relationship depth (-3 to 3). Negative: ancestors, positive: descendants
381    pub fn list_workers_filtered(
382        &self,
383        tags: Option<&Vec<String>>,
384        file: Option<&str>,
385        task_id: Option<&str>,
386        depth: i32,
387    ) -> Result<Vec<crate::types::WorkerInfo>> {
388        self.with_conn(|conn| {
389            // Start with base query
390            let mut sql = String::from(
391                "SELECT DISTINCT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
392                        (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
393                        (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
394                 FROM workers w WHERE 1=1",
395            );
396            let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
397
398            // Filter by file claim
399            if let Some(f) = file {
400                sql.push_str(" AND w.id IN (SELECT worker_id FROM file_locks WHERE file_path = ?)");
401                params_vec.push(Box::new(f.to_string()));
402            }
403
404            // Filter by related task (with depth traversal)
405            if let Some(tid) = task_id {
406                // Get all related task IDs at the given depth
407                let related_task_ids = Self::get_related_task_ids_internal(conn, tid, depth)?;
408                if !related_task_ids.is_empty() {
409                    let placeholders: Vec<String> = related_task_ids.iter().map(|_| "?".to_string()).collect();
410                    sql.push_str(&format!(
411                        " AND w.id IN (SELECT DISTINCT worker_id FROM tasks WHERE id IN ({}) AND worker_id IS NOT NULL)",
412                        placeholders.join(", ")
413                    ));
414                    for task in related_task_ids {
415                        params_vec.push(Box::new(task));
416                    }
417                } else {
418                    // No related tasks found, return empty result
419                    return Ok(Vec::new());
420                }
421            }
422
423            sql.push_str(" ORDER BY w.registered_at DESC");
424
425            let params_refs: Vec<&dyn rusqlite::ToSql> =
426                params_vec.iter().map(|b| b.as_ref()).collect();
427
428            let mut stmt = conn.prepare(&sql)?;
429            let workers: Vec<crate::types::WorkerInfo> = stmt
430                .query_map(params_refs.as_slice(), |row| {
431                    let id: String = row.get(0)?;
432                    let tags_json: String = row.get(1)?;
433                    let max_claims: i32 = row.get(2)?;
434                    let registered_at: i64 = row.get(3)?;
435                    let last_heartbeat: i64 = row.get(4)?;
436                    let claim_count: i32 = row.get(5)?;
437                    let current_thought: Option<String> = row.get(6)?;
438
439                    Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
440                })?
441                .filter_map(|r| r.ok())
442                .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
443                    let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
444                    crate::types::WorkerInfo {
445                        id,
446                        tags,
447                        max_claims,
448                        claim_count,
449                        current_thought,
450                        registered_at,
451                        last_heartbeat,
452                    }
453                })
454                .collect();
455
456            // Post-filter by tags (need to check ALL tags match)
457            let workers = if let Some(required_tags) = tags {
458                workers
459                    .into_iter()
460                    .filter(|w| required_tags.iter().all(|t| w.tags.contains(t)))
461                    .collect()
462            } else {
463                workers
464            };
465
466            Ok(workers)
467        })
468    }
469
470    /// Internal helper to get related task IDs at a given depth.
471    /// Negative depth: ancestors (parents/blockers), positive depth: descendants (children/blocked).
472    fn get_related_task_ids_internal(
473        conn: &Connection,
474        task_id: &str,
475        depth: i32,
476    ) -> Result<Vec<String>> {
477        use std::collections::HashSet;
478
479        let mut result = HashSet::new();
480        result.insert(task_id.to_string());
481
482        if depth == 0 {
483            return Ok(result.into_iter().collect());
484        }
485
486        let abs_depth = depth.abs();
487        let mut current_level: HashSet<String> = [task_id.to_string()].into_iter().collect();
488
489        for _ in 0..abs_depth {
490            if current_level.is_empty() {
491                break;
492            }
493
494            let mut next_level = HashSet::new();
495
496            for tid in &current_level {
497                let related: Vec<String> = if depth > 0 {
498                    // Descendants: tasks where this task is the from_task_id (children, blocked tasks)
499                    let mut stmt = conn
500                        .prepare("SELECT to_task_id FROM dependencies WHERE from_task_id = ?1")?;
501                    stmt.query_map(params![tid], |row| row.get(0))?
502                        .filter_map(|r| r.ok())
503                        .collect()
504                } else {
505                    // Ancestors: tasks where this task is the to_task_id (parents, blockers)
506                    let mut stmt = conn
507                        .prepare("SELECT from_task_id FROM dependencies WHERE to_task_id = ?1")?;
508                    stmt.query_map(params![tid], |row| row.get(0))?
509                        .filter_map(|r| r.ok())
510                        .collect()
511                };
512
513                for related_id in related {
514                    if !result.contains(&related_id) {
515                        next_level.insert(related_id.clone());
516                        result.insert(related_id);
517                    }
518                }
519            }
520
521            current_level = next_level;
522        }
523
524        Ok(result.into_iter().collect())
525    }
526
527    /// Get workers with stale heartbeats.
528    pub fn get_stale_workers(&self, timeout_seconds: i64) -> Result<Vec<Worker>> {
529        let cutoff = now_ms() - (timeout_seconds * 1000);
530
531        self.with_conn(|conn| {
532            let mut stmt = conn.prepare(
533                "SELECT id, tags, max_claims, registered_at, last_heartbeat
534                 FROM workers WHERE last_heartbeat < ?1",
535            )?;
536
537            let workers = stmt
538                .query_map(params![cutoff], |row| {
539                    let id: String = row.get(0)?;
540                    let tags_json: String = row.get(1)?;
541                    let max_claims: i32 = row.get(2)?;
542                    let registered_at: i64 = row.get(3)?;
543                    let last_heartbeat: i64 = row.get(4)?;
544
545                    Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
546                })?
547                .filter_map(|r| r.ok())
548                .map(
549                    |(id, tags_json, max_claims, registered_at, last_heartbeat)| {
550                        let tags: Vec<String> =
551                            serde_json::from_str(&tags_json).unwrap_or_default();
552                        Worker {
553                            id,
554                            tags,
555                            max_claims,
556                            registered_at,
557                            last_heartbeat,
558                        }
559                    },
560                )
561                .collect();
562
563            Ok(workers)
564        })
565    }
566
567    /// Cleanup stale workers by evicting them and releasing their claims.
568    /// Returns a summary of the cleanup operation.
569    pub fn cleanup_stale_workers(
570        &self,
571        timeout_seconds: i64,
572        final_status: &str,
573    ) -> Result<CleanupSummary> {
574        let stale_workers = self.get_stale_workers(timeout_seconds)?;
575
576        let mut total_tasks_released = 0;
577        let mut total_files_released = 0;
578        let mut evicted_worker_ids = Vec::new();
579
580        for worker in &stale_workers {
581            // Release file locks first
582            let _ = self.release_worker_locks(&worker.id);
583
584            // Unregister the worker
585            if let Ok(summary) = self.unregister_worker(&worker.id, final_status) {
586                total_tasks_released += summary.tasks_released;
587                total_files_released += summary.files_released;
588                evicted_worker_ids.push(worker.id.clone());
589            }
590        }
591
592        Ok(CleanupSummary {
593            workers_evicted: evicted_worker_ids.len() as i32,
594            tasks_released: total_tasks_released,
595            files_released: total_files_released,
596            final_status: final_status.to_string(),
597            evicted_worker_ids,
598        })
599    }
600
601    /// Get claim count for a worker.
602    pub fn get_claim_count(&self, worker_id: &str) -> Result<i32> {
603        self.with_conn(|conn| {
604            let count: i32 = conn.query_row(
605                "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
606                params![worker_id],
607                |row| row.get(0),
608            )?;
609            Ok(count)
610        })
611    }
612}