Skip to main content

task_graph_mcp/db/
agents.rs

1//! Worker CRUD operations.
2
3use super::{now_ms, Database};
4use crate::types::{CleanupSummary, DisconnectSummary, Worker};
5use anyhow::{anyhow, Result};
6use rusqlite::{params, Connection};
7
8/// Maximum length for worker IDs.
9pub const MAX_WORKER_ID_LEN: usize = 36;
10
11/// Maximum attempts to generate a unique petname before falling back.
12const MAX_PETNAME_ATTEMPTS: u32 = 100;
13
14/// Generate a unique petname-based worker ID.
15/// Tries base petname first, then appends numbers (e.g., "happy-turtle-2").
16fn generate_unique_petname(conn: &Connection) -> String {
17    let base = petname::petname(2, "-").unwrap_or_else(|| "worker".to_string());
18    
19    // Check if base name is available
20    let exists: bool = conn
21        .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&base], |_| Ok(true))
22        .unwrap_or(false);
23    
24    if !exists {
25        return base;
26    }
27    
28    // Try appending numbers: happy-turtle-2, happy-turtle-3, etc.
29    for i in 2..=MAX_PETNAME_ATTEMPTS {
30        let candidate = format!("{}-{}", base, i);
31        let exists: bool = conn
32            .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&candidate], |_| Ok(true))
33            .unwrap_or(false);
34        if !exists {
35            return candidate;
36        }
37    }
38    
39    // Fallback: generate a completely new petname with 3 words for uniqueness
40    petname::petname(3, "-").unwrap_or_else(|| format!("worker-{}", now_ms()))
41}
42
43/// Internal helper to get a worker using an existing connection (avoids deadlock).
44fn get_worker_internal(conn: &Connection, worker_id: &str) -> Result<Option<Worker>> {
45    let mut stmt = conn.prepare(
46        "SELECT id, tags, max_claims, registered_at, last_heartbeat
47         FROM workers WHERE id = ?1",
48    )?;
49
50    let result = stmt.query_row(params![worker_id], |row| {
51        let id: String = row.get(0)?;
52        let tags_json: String = row.get(1)?;
53        let max_claims: i32 = row.get(2)?;
54        let registered_at: i64 = row.get(3)?;
55        let last_heartbeat: i64 = row.get(4)?;
56
57        Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
58    });
59
60    match result {
61        Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
62            let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
63            Ok(Some(Worker {
64                id,
65                tags,
66                max_claims,
67                registered_at,
68                last_heartbeat,
69            }))
70        }
71        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
72        Err(e) => Err(e.into()),
73    }
74}
75
76impl Database {
77    /// Register a new worker.
78    ///
79    /// If `worker_id` is provided, it must be at most 36 characters.
80    /// If not provided, a human-readable petname will be generated (e.g., "happy-turtle").
81    /// If `force` is true and the worker already exists, it will be re-registered
82    /// (useful for stuck worker recovery).
83    pub fn register_worker(
84        &self,
85        worker_id: Option<String>,
86        tags: Vec<String>,
87        force: bool,
88    ) -> Result<Worker> {
89        // Validate user-provided ID upfront (before acquiring connection)
90        let provided_id = match worker_id {
91            Some(id) => {
92                if id.len() > MAX_WORKER_ID_LEN {
93                    return Err(anyhow!(
94                        "Worker ID must be at most {} characters, got {}",
95                        MAX_WORKER_ID_LEN,
96                        id.len()
97                    ));
98                }
99                if id.is_empty() {
100                    return Err(anyhow!("Worker ID cannot be empty"));
101                }
102                Some(id)
103            }
104            None => None,
105        };
106        let now = now_ms();
107        let max_claims = i32::MAX; // Effectively unlimited until overclaiming becomes a problem
108        let tags_json = serde_json::to_string(&tags)?;
109
110        self.with_conn(|conn| {
111            // Generate ID inside connection to avoid race conditions
112            let id = match provided_id {
113                Some(id) => id,
114                None => generate_unique_petname(conn),
115            };
116
117            // Check if worker ID already exists
118            let exists: bool = conn
119                .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&id], |_| Ok(true))
120                .unwrap_or(false);
121
122            // Get current max claim sequence + 1 to initialize poll position.
123            // This ensures first poll returns empty (no events since registration).
124            // The +1 is needed because we now query with `id >= last_seq`.
125            let current_max_sequence: i64 = conn
126                .query_row("SELECT COALESCE(MAX(id), 0) FROM claim_sequence", [], |row| row.get(0))
127                .unwrap_or(0);
128            let initial_sequence = current_max_sequence + 1;
129
130            if exists {
131                if force {
132                    // Force reconnection: update existing worker and reset poll position
133                    conn.execute(
134                        "UPDATE workers SET tags = ?1, max_claims = ?2, last_heartbeat = ?3, last_claim_sequence = ?4 WHERE id = ?5",
135                        params![tags_json, max_claims, now, initial_sequence, &id],
136                    )?;
137                } else {
138                    return Err(anyhow!("Worker ID '{}' already registered. Use force=true to reconnect.", id));
139                }
140            } else {
141                conn.execute(
142                    "INSERT INTO workers (id, tags, max_claims, registered_at, last_heartbeat, last_claim_sequence)
143                     VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
144                    params![&id, tags_json, max_claims, now, now, initial_sequence],
145                )?;
146            }
147
148            Ok(Worker {
149                id,
150                tags,
151                max_claims,
152                registered_at: now,
153                last_heartbeat: now,
154            })
155        })
156    }
157
158    /// Get a worker by ID.
159    pub fn get_worker(&self, worker_id: &str) -> Result<Option<Worker>> {
160        self.with_conn(|conn| {
161            let mut stmt = conn.prepare(
162                "SELECT id, tags, max_claims, registered_at, last_heartbeat
163                 FROM workers WHERE id = ?1",
164            )?;
165
166            let result = stmt.query_row(params![worker_id], |row| {
167                let id: String = row.get(0)?;
168                let tags_json: String = row.get(1)?;
169                let max_claims: i32 = row.get(2)?;
170                let registered_at: i64 = row.get(3)?;
171                let last_heartbeat: i64 = row.get(4)?;
172
173                Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
174            });
175
176            match result {
177                Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
178                    let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
179                    Ok(Some(Worker {
180                        id,
181                        tags,
182                        max_claims,
183                        registered_at,
184                        last_heartbeat,
185                    }))
186                }
187                Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
188                Err(e) => Err(e.into()),
189            }
190        })
191    }
192
193    /// Check if a worker exists. Returns error if not found.
194    pub fn require_worker(&self, worker_id: &str) -> Result<Worker> {
195        self.get_worker(worker_id)?
196            .ok_or_else(|| anyhow::anyhow!("Worker {} not found", worker_id))
197    }
198
199    /// Update a worker.
200    pub fn update_worker(
201        &self,
202        worker_id: &str,
203        tags: Option<Vec<String>>,
204        max_claims: Option<i32>,
205    ) -> Result<Worker> {
206        self.with_conn(|conn| {
207            let worker = get_worker_internal(conn, worker_id)?
208                .ok_or_else(|| anyhow!("Worker not found"))?;
209
210            let new_tags = tags.unwrap_or(worker.tags.clone());
211            let new_max_claims = max_claims.unwrap_or(worker.max_claims);
212            let tags_json = serde_json::to_string(&new_tags)?;
213
214            conn.execute(
215                "UPDATE workers SET tags = ?1, max_claims = ?2 WHERE id = ?3",
216                params![tags_json, new_max_claims, worker_id],
217            )?;
218
219            Ok(Worker {
220                id: worker_id.to_string(),
221                tags: new_tags,
222                max_claims: new_max_claims,
223                registered_at: worker.registered_at,
224                last_heartbeat: worker.last_heartbeat,
225            })
226        })
227    }
228
229    /// Update worker heartbeat.
230    pub fn heartbeat(&self, worker_id: &str) -> Result<i32> {
231        let now = now_ms();
232
233        self.with_conn(|conn| {
234            let updated = conn.execute(
235                "UPDATE workers SET last_heartbeat = ?1 WHERE id = ?2",
236                params![now, worker_id],
237            )?;
238
239            if updated == 0 {
240                return Err(anyhow!("Worker not found"));
241            }
242
243            // Return current claim count
244            let count: i32 = conn.query_row(
245                "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
246                params![worker_id],
247                |row| row.get(0),
248            )?;
249
250            Ok(count)
251        })
252    }
253
254    /// Unregister a worker (releases all claims).
255    /// Returns a summary of released tasks and files.
256    pub fn unregister_worker(&self, worker_id: &str, final_status: &str) -> Result<DisconnectSummary> {
257        self.with_conn_mut(|conn| {
258            let tx = conn.transaction()?;
259
260            // Release all task claims, setting them to final_status
261            let tasks_released = tx.execute(
262                "UPDATE tasks SET worker_id = NULL, claimed_at = NULL, status = ?2
263                 WHERE worker_id = ?1",
264                params![worker_id, final_status],
265            )? as i32;
266
267            // Remove all file locks
268            let files_released = tx.execute(
269                "DELETE FROM file_locks WHERE worker_id = ?1",
270                params![worker_id],
271            )? as i32;
272
273            // Remove worker
274            tx.execute(
275                "DELETE FROM workers WHERE id = ?1",
276                params![worker_id],
277            )?;
278
279            tx.commit()?;
280            Ok(DisconnectSummary {
281                tasks_released,
282                files_released,
283                final_status: final_status.to_string(),
284            })
285        })
286    }
287
288    /// List all workers.
289    pub fn list_workers(&self) -> Result<Vec<Worker>> {
290        self.with_conn(|conn| {
291            let mut stmt = conn.prepare(
292                "SELECT id, tags, max_claims, registered_at, last_heartbeat
293                 FROM workers ORDER BY registered_at DESC",
294            )?;
295
296            let workers = stmt.query_map([], |row| {
297                let id: String = row.get(0)?;
298                let tags_json: String = row.get(1)?;
299                let max_claims: i32 = row.get(2)?;
300                let registered_at: i64 = row.get(3)?;
301                let last_heartbeat: i64 = row.get(4)?;
302
303                Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
304            })?
305            .filter_map(|r| r.ok())
306            .map(|(id, tags_json, max_claims, registered_at, last_heartbeat)| {
307                let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
308                Worker {
309                    id,
310                    tags,
311                    max_claims,
312                    registered_at,
313                    last_heartbeat,
314                }
315            })
316            .collect();
317
318            Ok(workers)
319        })
320    }
321
322    /// List all workers with extended info (claim count, current thought).
323    pub fn list_workers_info(&self) -> Result<Vec<crate::types::WorkerInfo>> {
324        self.with_conn(|conn| {
325            let mut stmt = conn.prepare(
326                "SELECT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
327                        (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
328                        (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
329                 FROM workers w ORDER BY w.registered_at DESC",
330            )?;
331
332            let workers = stmt.query_map([], |row| {
333                let id: String = row.get(0)?;
334                let tags_json: String = row.get(1)?;
335                let max_claims: i32 = row.get(2)?;
336                let registered_at: i64 = row.get(3)?;
337                let last_heartbeat: i64 = row.get(4)?;
338                let claim_count: i32 = row.get(5)?;
339                let current_thought: Option<String> = row.get(6)?;
340
341                Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
342            })?
343            .filter_map(|r| r.ok())
344            .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
345                let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
346                crate::types::WorkerInfo {
347                    id,
348                    tags,
349                    max_claims,
350                    claim_count,
351                    current_thought,
352                    registered_at,
353                    last_heartbeat,
354                }
355            })
356            .collect();
357
358            Ok(workers)
359        })
360    }
361
362    /// List workers with optional filters by tags, file claimed, or related task.
363    ///
364    /// - `tags`: Workers must have ALL of these tags
365    /// - `file`: Workers that have claimed this file
366    /// - `task_id`: Workers working on tasks related to this task
367    /// - `depth`: Task relationship depth (-3 to 3). Negative: ancestors, positive: descendants
368    pub fn list_workers_filtered(
369        &self,
370        tags: Option<&Vec<String>>,
371        file: Option<&str>,
372        task_id: Option<&str>,
373        depth: i32,
374    ) -> Result<Vec<crate::types::WorkerInfo>> {
375        self.with_conn(|conn| {
376            // Start with base query
377            let mut sql = String::from(
378                "SELECT DISTINCT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
379                        (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
380                        (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
381                 FROM workers w WHERE 1=1",
382            );
383            let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
384
385            // Filter by file claim
386            if let Some(f) = file {
387                sql.push_str(" AND w.id IN (SELECT worker_id FROM file_locks WHERE file_path = ?)");
388                params_vec.push(Box::new(f.to_string()));
389            }
390
391            // Filter by related task (with depth traversal)
392            if let Some(tid) = task_id {
393                // Get all related task IDs at the given depth
394                let related_task_ids = Self::get_related_task_ids_internal(conn, tid, depth)?;
395                if !related_task_ids.is_empty() {
396                    let placeholders: Vec<String> = related_task_ids.iter().map(|_| "?".to_string()).collect();
397                    sql.push_str(&format!(
398                        " AND w.id IN (SELECT DISTINCT worker_id FROM tasks WHERE id IN ({}) AND worker_id IS NOT NULL)",
399                        placeholders.join(", ")
400                    ));
401                    for task in related_task_ids {
402                        params_vec.push(Box::new(task));
403                    }
404                } else {
405                    // No related tasks found, return empty result
406                    return Ok(Vec::new());
407                }
408            }
409
410            sql.push_str(" ORDER BY w.registered_at DESC");
411
412            let params_refs: Vec<&dyn rusqlite::ToSql> =
413                params_vec.iter().map(|b| b.as_ref()).collect();
414
415            let mut stmt = conn.prepare(&sql)?;
416            let workers: Vec<crate::types::WorkerInfo> = stmt
417                .query_map(params_refs.as_slice(), |row| {
418                    let id: String = row.get(0)?;
419                    let tags_json: String = row.get(1)?;
420                    let max_claims: i32 = row.get(2)?;
421                    let registered_at: i64 = row.get(3)?;
422                    let last_heartbeat: i64 = row.get(4)?;
423                    let claim_count: i32 = row.get(5)?;
424                    let current_thought: Option<String> = row.get(6)?;
425
426                    Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
427                })?
428                .filter_map(|r| r.ok())
429                .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
430                    let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
431                    crate::types::WorkerInfo {
432                        id,
433                        tags,
434                        max_claims,
435                        claim_count,
436                        current_thought,
437                        registered_at,
438                        last_heartbeat,
439                    }
440                })
441                .collect();
442
443            // Post-filter by tags (need to check ALL tags match)
444            let workers = if let Some(required_tags) = tags {
445                workers
446                    .into_iter()
447                    .filter(|w| required_tags.iter().all(|t| w.tags.contains(t)))
448                    .collect()
449            } else {
450                workers
451            };
452
453            Ok(workers)
454        })
455    }
456
457    /// Internal helper to get related task IDs at a given depth.
458    /// Negative depth: ancestors (parents/blockers), positive depth: descendants (children/blocked).
459    fn get_related_task_ids_internal(conn: &Connection, task_id: &str, depth: i32) -> Result<Vec<String>> {
460        use std::collections::HashSet;
461
462        let mut result = HashSet::new();
463        result.insert(task_id.to_string());
464
465        if depth == 0 {
466            return Ok(result.into_iter().collect());
467        }
468
469        let abs_depth = depth.abs();
470        let mut current_level: HashSet<String> = [task_id.to_string()].into_iter().collect();
471
472        for _ in 0..abs_depth {
473            if current_level.is_empty() {
474                break;
475            }
476
477            let mut next_level = HashSet::new();
478
479            for tid in &current_level {
480                let related: Vec<String> = if depth > 0 {
481                    // Descendants: tasks where this task is the from_task_id (children, blocked tasks)
482                    let mut stmt = conn.prepare(
483                        "SELECT to_task_id FROM dependencies WHERE from_task_id = ?1"
484                    )?;
485                    stmt.query_map(params![tid], |row| row.get(0))?
486                        .filter_map(|r| r.ok())
487                        .collect()
488                } else {
489                    // Ancestors: tasks where this task is the to_task_id (parents, blockers)
490                    let mut stmt = conn.prepare(
491                        "SELECT from_task_id FROM dependencies WHERE to_task_id = ?1"
492                    )?;
493                    stmt.query_map(params![tid], |row| row.get(0))?
494                        .filter_map(|r| r.ok())
495                        .collect()
496                };
497
498                for related_id in related {
499                    if !result.contains(&related_id) {
500                        next_level.insert(related_id.clone());
501                        result.insert(related_id);
502                    }
503                }
504            }
505
506            current_level = next_level;
507        }
508
509        Ok(result.into_iter().collect())
510    }
511
512    /// Get workers with stale heartbeats.
513    pub fn get_stale_workers(&self, timeout_seconds: i64) -> Result<Vec<Worker>> {
514        let cutoff = now_ms() - (timeout_seconds * 1000);
515
516        self.with_conn(|conn| {
517            let mut stmt = conn.prepare(
518                "SELECT id, tags, max_claims, registered_at, last_heartbeat
519                 FROM workers WHERE last_heartbeat < ?1",
520            )?;
521
522            let workers = stmt.query_map(params![cutoff], |row| {
523                let id: String = row.get(0)?;
524                let tags_json: String = row.get(1)?;
525                let max_claims: i32 = row.get(2)?;
526                let registered_at: i64 = row.get(3)?;
527                let last_heartbeat: i64 = row.get(4)?;
528
529                Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
530            })?
531            .filter_map(|r| r.ok())
532            .map(|(id, tags_json, max_claims, registered_at, last_heartbeat)| {
533                let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
534                Worker {
535                    id,
536                    tags,
537                    max_claims,
538                    registered_at,
539                    last_heartbeat,
540                }
541            })
542            .collect();
543
544            Ok(workers)
545        })
546    }
547
548    /// Cleanup stale workers by evicting them and releasing their claims.
549    /// Returns a summary of the cleanup operation.
550    pub fn cleanup_stale_workers(&self, timeout_seconds: i64, final_status: &str) -> Result<CleanupSummary> {
551        let stale_workers = self.get_stale_workers(timeout_seconds)?;
552        
553        let mut total_tasks_released = 0;
554        let mut total_files_released = 0;
555        let mut evicted_worker_ids = Vec::new();
556        
557        for worker in &stale_workers {
558            // Release file locks first
559            let _ = self.release_worker_locks(&worker.id);
560            
561            // Unregister the worker
562            if let Ok(summary) = self.unregister_worker(&worker.id, final_status) {
563                total_tasks_released += summary.tasks_released;
564                total_files_released += summary.files_released;
565                evicted_worker_ids.push(worker.id.clone());
566            }
567        }
568        
569        Ok(CleanupSummary {
570            workers_evicted: evicted_worker_ids.len() as i32,
571            tasks_released: total_tasks_released,
572            files_released: total_files_released,
573            final_status: final_status.to_string(),
574            evicted_worker_ids,
575        })
576    }
577
578    /// Get claim count for a worker.
579    pub fn get_claim_count(&self, worker_id: &str) -> Result<i32> {
580        self.with_conn(|conn| {
581            let count: i32 = conn.query_row(
582                "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
583                params![worker_id],
584                |row| row.get(0),
585            )?;
586            Ok(count)
587        })
588    }
589}