1use super::{Database, now_ms};
4use crate::config::IdsConfig;
5use crate::types::{CleanupSummary, DisconnectSummary, Worker};
6use anyhow::{Result, anyhow};
7use petname::{Generator, Petnames};
8use rusqlite::{Connection, params};
9
10pub const MAX_WORKER_ID_LEN: usize = 64;
12
13fn generate_agent_id(ids_config: &IdsConfig) -> String {
16 let words = ids_config.agent_id_words;
17 let case = ids_config.id_case;
18
19 let base = Petnames::medium()
21 .generate_one(words, "-")
22 .unwrap_or_else(|| format!("worker-{}", now_ms()));
23
24 case.convert(&base)
26}
27
28fn get_worker_internal(conn: &Connection, worker_id: &str) -> Result<Option<Worker>> {
30 let mut stmt = conn.prepare(
31 "SELECT id, tags, max_claims, registered_at, last_heartbeat, last_status, last_phase, workflow
32 FROM workers WHERE id = ?1",
33 )?;
34
35 let result = stmt.query_row(params![worker_id], |row| {
36 let id: String = row.get(0)?;
37 let tags_json: String = row.get(1)?;
38 let max_claims: i32 = row.get(2)?;
39 let registered_at: i64 = row.get(3)?;
40 let last_heartbeat: i64 = row.get(4)?;
41 let last_status: Option<String> = row.get(5)?;
42 let last_phase: Option<String> = row.get(6)?;
43 let workflow: Option<String> = row.get(7)?;
44
45 Ok((
46 id,
47 tags_json,
48 max_claims,
49 registered_at,
50 last_heartbeat,
51 last_status,
52 last_phase,
53 workflow,
54 ))
55 });
56
57 match result {
58 Ok((
59 id,
60 tags_json,
61 max_claims,
62 registered_at,
63 last_heartbeat,
64 last_status,
65 last_phase,
66 workflow,
67 )) => {
68 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
69 Ok(Some(Worker {
70 id,
71 tags,
72 max_claims,
73 registered_at,
74 last_heartbeat,
75 last_status,
76 last_phase,
77 workflow,
78 }))
79 }
80 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
81 Err(e) => Err(e.into()),
82 }
83}
84
85impl Database {
86 pub fn register_worker(
94 &self,
95 worker_id: Option<String>,
96 tags: Vec<String>,
97 force: bool,
98 ids_config: &IdsConfig,
99 workflow: Option<String>,
100 ) -> Result<Worker> {
101 let provided_id = match worker_id {
103 Some(id) => {
104 if id.len() > MAX_WORKER_ID_LEN {
105 return Err(anyhow!(
106 "Worker ID must be at most {} characters, got {}",
107 MAX_WORKER_ID_LEN,
108 id.len()
109 ));
110 }
111 if id.is_empty() {
112 return Err(anyhow!("Worker ID cannot be empty"));
113 }
114 Some(id)
115 }
116 None => None,
117 };
118 let now = now_ms();
119 let max_claims = i32::MAX; let tags_json = serde_json::to_string(&tags)?;
121
122 self.with_conn(|conn| {
123 let id = match provided_id {
125 Some(id) => id,
126 None => generate_agent_id(ids_config),
127 };
128
129 let exists: bool = conn
131 .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&id], |_| Ok(true))
132 .unwrap_or(false);
133
134 let current_max_sequence: i64 = conn
138 .query_row("SELECT COALESCE(MAX(id), 0) FROM claim_sequence", [], |row| row.get(0))
139 .unwrap_or(0);
140 let initial_sequence = current_max_sequence + 1;
141
142 if exists {
143 if force {
144 conn.execute(
146 "UPDATE workers SET tags = ?1, max_claims = ?2, last_heartbeat = ?3, last_claim_sequence = ?4, workflow = ?5 WHERE id = ?6",
147 params![tags_json, max_claims, now, initial_sequence, &workflow, &id],
148 )?;
149 } else {
150 return Err(anyhow!("Worker ID '{}' already registered. Use force=true to reconnect.", id));
151 }
152 } else {
153 conn.execute(
154 "INSERT INTO workers (id, tags, max_claims, registered_at, last_heartbeat, last_claim_sequence, workflow)
155 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
156 params![&id, tags_json, max_claims, now, now, initial_sequence, &workflow],
157 )?;
158 }
159
160 Ok(Worker {
161 id,
162 tags,
163 max_claims,
164 registered_at: now,
165 last_heartbeat: now,
166 last_status: None,
167 last_phase: None,
168 workflow,
169 })
170 })
171 }
172
173 pub fn get_worker(&self, worker_id: &str) -> Result<Option<Worker>> {
175 self.with_conn(|conn| {
176 let mut stmt = conn.prepare(
177 "SELECT id, tags, max_claims, registered_at, last_heartbeat, last_status, last_phase, workflow
178 FROM workers WHERE id = ?1",
179 )?;
180
181 let result = stmt.query_row(params![worker_id], |row| {
182 let id: String = row.get(0)?;
183 let tags_json: String = row.get(1)?;
184 let max_claims: i32 = row.get(2)?;
185 let registered_at: i64 = row.get(3)?;
186 let last_heartbeat: i64 = row.get(4)?;
187 let last_status: Option<String> = row.get(5)?;
188 let last_phase: Option<String> = row.get(6)?;
189 let workflow: Option<String> = row.get(7)?;
190
191 Ok((
192 id,
193 tags_json,
194 max_claims,
195 registered_at,
196 last_heartbeat,
197 last_status,
198 last_phase,
199 workflow,
200 ))
201 });
202
203 match result {
204 Ok((
205 id,
206 tags_json,
207 max_claims,
208 registered_at,
209 last_heartbeat,
210 last_status,
211 last_phase,
212 workflow,
213 )) => {
214 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
215 Ok(Some(Worker {
216 id,
217 tags,
218 max_claims,
219 registered_at,
220 last_heartbeat,
221 last_status,
222 last_phase,
223 workflow,
224 }))
225 }
226 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
227 Err(e) => Err(e.into()),
228 }
229 })
230 }
231
232 pub fn require_worker(&self, worker_id: &str) -> Result<Worker> {
234 self.get_worker(worker_id)?
235 .ok_or_else(|| anyhow::anyhow!("Worker {} not found", worker_id))
236 }
237
238 pub fn update_worker(
240 &self,
241 worker_id: &str,
242 tags: Option<Vec<String>>,
243 max_claims: Option<i32>,
244 ) -> Result<Worker> {
245 self.with_conn(|conn| {
246 let worker =
247 get_worker_internal(conn, worker_id)?.ok_or_else(|| anyhow!("Worker not found"))?;
248
249 let new_tags = tags.unwrap_or(worker.tags.clone());
250 let new_max_claims = max_claims.unwrap_or(worker.max_claims);
251 let tags_json = serde_json::to_string(&new_tags)?;
252
253 conn.execute(
254 "UPDATE workers SET tags = ?1, max_claims = ?2 WHERE id = ?3",
255 params![tags_json, new_max_claims, worker_id],
256 )?;
257
258 Ok(Worker {
259 id: worker_id.to_string(),
260 tags: new_tags,
261 max_claims: new_max_claims,
262 registered_at: worker.registered_at,
263 last_heartbeat: worker.last_heartbeat,
264 last_status: worker.last_status,
265 last_phase: worker.last_phase,
266 workflow: worker.workflow,
267 })
268 })
269 }
270
271 pub fn update_worker_state(
274 &self,
275 worker_id: &str,
276 new_status: Option<&str>,
277 new_phase: Option<&str>,
278 ) -> Result<(Option<String>, Option<String>)> {
279 self.with_conn(|conn| {
280 let (old_status, old_phase): (Option<String>, Option<String>) = conn
282 .query_row(
283 "SELECT last_status, last_phase FROM workers WHERE id = ?1",
284 params![worker_id],
285 |row| Ok((row.get(0)?, row.get(1)?)),
286 )
287 .map_err(|e| match e {
288 rusqlite::Error::QueryReturnedNoRows => anyhow!("Worker not found"),
289 e => e.into(),
290 })?;
291
292 conn.execute(
294 "UPDATE workers SET last_status = ?1, last_phase = ?2 WHERE id = ?3",
295 params![new_status, new_phase, worker_id],
296 )?;
297
298 Ok((old_status, old_phase))
299 })
300 }
301
302 pub fn heartbeat(&self, worker_id: &str) -> Result<i32> {
304 let now = now_ms();
305
306 self.with_conn(|conn| {
307 let updated = conn.execute(
308 "UPDATE workers SET last_heartbeat = ?1 WHERE id = ?2",
309 params![now, worker_id],
310 )?;
311
312 if updated == 0 {
313 return Err(anyhow!("Worker not found"));
314 }
315
316 let count: i32 = conn.query_row(
318 "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'working'",
319 params![worker_id],
320 |row| row.get(0),
321 )?;
322
323 Ok(count)
324 })
325 }
326
327 pub fn unregister_worker(
330 &self,
331 worker_id: &str,
332 final_status: &str,
333 ) -> Result<DisconnectSummary> {
334 self.with_conn_mut(|conn| {
335 let tx = conn.transaction()?;
336
337 let tasks_released = tx.execute(
339 "UPDATE tasks SET worker_id = NULL, claimed_at = NULL, status = ?2
340 WHERE worker_id = ?1",
341 params![worker_id, final_status],
342 )? as i32;
343
344 let files_released = tx.execute(
346 "DELETE FROM file_locks WHERE worker_id = ?1",
347 params![worker_id],
348 )? as i32;
349
350 tx.execute("DELETE FROM workers WHERE id = ?1", params![worker_id])?;
352
353 tx.commit()?;
354 Ok(DisconnectSummary {
355 tasks_released,
356 files_released,
357 final_status: final_status.to_string(),
358 })
359 })
360 }
361
362 pub fn list_workers(&self) -> Result<Vec<Worker>> {
364 self.with_conn(|conn| {
365 let mut stmt = conn.prepare(
366 "SELECT id, tags, max_claims, registered_at, last_heartbeat, last_status, last_phase, workflow
367 FROM workers ORDER BY registered_at DESC",
368 )?;
369
370 let workers = stmt
371 .query_map([], |row| {
372 let id: String = row.get(0)?;
373 let tags_json: String = row.get(1)?;
374 let max_claims: i32 = row.get(2)?;
375 let registered_at: i64 = row.get(3)?;
376 let last_heartbeat: i64 = row.get(4)?;
377 let last_status: Option<String> = row.get(5)?;
378 let last_phase: Option<String> = row.get(6)?;
379 let workflow: Option<String> = row.get(7)?;
380
381 Ok((
382 id,
383 tags_json,
384 max_claims,
385 registered_at,
386 last_heartbeat,
387 last_status,
388 last_phase,
389 workflow,
390 ))
391 })?
392 .filter_map(|r| r.ok())
393 .map(
394 |(
395 id,
396 tags_json,
397 max_claims,
398 registered_at,
399 last_heartbeat,
400 last_status,
401 last_phase,
402 workflow,
403 )| {
404 let tags: Vec<String> =
405 serde_json::from_str(&tags_json).unwrap_or_default();
406 Worker {
407 id,
408 tags,
409 max_claims,
410 registered_at,
411 last_heartbeat,
412 last_status,
413 last_phase,
414 workflow,
415 }
416 },
417 )
418 .collect();
419
420 Ok(workers)
421 })
422 }
423
424 pub fn list_workers_info(&self) -> Result<Vec<crate::types::WorkerInfo>> {
426 self.with_conn(|conn| {
427 let mut stmt = conn.prepare(
428 "SELECT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
429 (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'working') as claim_count,
430 (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'working' AND current_thought IS NOT NULL LIMIT 1) as current_thought,
431 w.last_status, w.last_phase, w.workflow
432 FROM workers w ORDER BY w.registered_at DESC",
433 )?;
434
435 let workers = stmt.query_map([], |row| {
436 let id: String = row.get(0)?;
437 let tags_json: String = row.get(1)?;
438 let max_claims: i32 = row.get(2)?;
439 let registered_at: i64 = row.get(3)?;
440 let last_heartbeat: i64 = row.get(4)?;
441 let claim_count: i32 = row.get(5)?;
442 let current_thought: Option<String> = row.get(6)?;
443 let last_status: Option<String> = row.get(7)?;
444 let last_phase: Option<String> = row.get(8)?;
445 let workflow: Option<String> = row.get(9)?;
446
447 Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought, last_status, last_phase, workflow))
448 })?
449 .filter_map(|r| r.ok())
450 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought, last_status, last_phase, workflow)| {
451 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
452 crate::types::WorkerInfo {
453 id,
454 tags,
455 max_claims,
456 claim_count,
457 current_thought,
458 registered_at,
459 last_heartbeat,
460 last_status,
461 last_phase,
462 workflow,
463 }
464 })
465 .collect();
466
467 Ok(workers)
468 })
469 }
470
471 pub fn list_workers_filtered(
478 &self,
479 tags: Option<&Vec<String>>,
480 file: Option<&str>,
481 task_id: Option<&str>,
482 depth: i32,
483 ) -> Result<Vec<crate::types::WorkerInfo>> {
484 self.with_conn(|conn| {
485 let mut sql = String::from(
487 "SELECT DISTINCT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
488 (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'working') as claim_count,
489 (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'working' AND current_thought IS NOT NULL LIMIT 1) as current_thought,
490 w.last_status, w.last_phase, w.workflow
491 FROM workers w WHERE 1=1",
492 );
493 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
494
495 if let Some(f) = file {
497 sql.push_str(" AND w.id IN (SELECT worker_id FROM file_locks WHERE file_path = ?)");
498 params_vec.push(Box::new(f.to_string()));
499 }
500
501 if let Some(tid) = task_id {
503 let related_task_ids = Self::get_related_task_ids_internal(conn, tid, depth)?;
505 if !related_task_ids.is_empty() {
506 let placeholders: Vec<String> = related_task_ids.iter().map(|_| "?".to_string()).collect();
507 sql.push_str(&format!(
508 " AND w.id IN (SELECT DISTINCT worker_id FROM tasks WHERE id IN ({}) AND worker_id IS NOT NULL)",
509 placeholders.join(", ")
510 ));
511 for task in related_task_ids {
512 params_vec.push(Box::new(task));
513 }
514 } else {
515 return Ok(Vec::new());
517 }
518 }
519
520 sql.push_str(" ORDER BY w.registered_at DESC");
521
522 let params_refs: Vec<&dyn rusqlite::ToSql> =
523 params_vec.iter().map(|b| b.as_ref()).collect();
524
525 let mut stmt = conn.prepare(&sql)?;
526 let workers: Vec<crate::types::WorkerInfo> = stmt
527 .query_map(params_refs.as_slice(), |row| {
528 let id: String = row.get(0)?;
529 let tags_json: String = row.get(1)?;
530 let max_claims: i32 = row.get(2)?;
531 let registered_at: i64 = row.get(3)?;
532 let last_heartbeat: i64 = row.get(4)?;
533 let claim_count: i32 = row.get(5)?;
534 let current_thought: Option<String> = row.get(6)?;
535 let last_status: Option<String> = row.get(7)?;
536 let last_phase: Option<String> = row.get(8)?;
537 let workflow: Option<String> = row.get(9)?;
538
539 Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought, last_status, last_phase, workflow))
540 })?
541 .filter_map(|r| r.ok())
542 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought, last_status, last_phase, workflow)| {
543 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
544 crate::types::WorkerInfo {
545 id,
546 tags,
547 max_claims,
548 claim_count,
549 current_thought,
550 registered_at,
551 last_heartbeat,
552 last_status,
553 last_phase,
554 workflow,
555 }
556 })
557 .collect();
558
559 let workers = if let Some(required_tags) = tags {
561 workers
562 .into_iter()
563 .filter(|w| required_tags.iter().all(|t| w.tags.contains(t)))
564 .collect()
565 } else {
566 workers
567 };
568
569 Ok(workers)
570 })
571 }
572
573 fn get_related_task_ids_internal(
576 conn: &Connection,
577 task_id: &str,
578 depth: i32,
579 ) -> Result<Vec<String>> {
580 use std::collections::HashSet;
581
582 let mut result = HashSet::new();
583 result.insert(task_id.to_string());
584
585 if depth == 0 {
586 return Ok(result.into_iter().collect());
587 }
588
589 let abs_depth = depth.abs();
590 let mut current_level: HashSet<String> = [task_id.to_string()].into_iter().collect();
591
592 for _ in 0..abs_depth {
593 if current_level.is_empty() {
594 break;
595 }
596
597 let mut next_level = HashSet::new();
598
599 for tid in ¤t_level {
600 let related: Vec<String> = if depth > 0 {
601 let mut stmt = conn
603 .prepare("SELECT to_task_id FROM dependencies WHERE from_task_id = ?1")?;
604 stmt.query_map(params![tid], |row| row.get(0))?
605 .filter_map(|r| r.ok())
606 .collect()
607 } else {
608 let mut stmt = conn
610 .prepare("SELECT from_task_id FROM dependencies WHERE to_task_id = ?1")?;
611 stmt.query_map(params![tid], |row| row.get(0))?
612 .filter_map(|r| r.ok())
613 .collect()
614 };
615
616 for related_id in related {
617 if !result.contains(&related_id) {
618 next_level.insert(related_id.clone());
619 result.insert(related_id);
620 }
621 }
622 }
623
624 current_level = next_level;
625 }
626
627 Ok(result.into_iter().collect())
628 }
629
630 pub fn get_stale_workers(&self, timeout_seconds: i64) -> Result<Vec<Worker>> {
632 let cutoff = now_ms() - (timeout_seconds * 1000);
633
634 self.with_conn(|conn| {
635 let mut stmt = conn.prepare(
636 "SELECT id, tags, max_claims, registered_at, last_heartbeat, last_status, last_phase, workflow
637 FROM workers WHERE last_heartbeat < ?1",
638 )?;
639
640 let workers = stmt
641 .query_map(params![cutoff], |row| {
642 let id: String = row.get(0)?;
643 let tags_json: String = row.get(1)?;
644 let max_claims: i32 = row.get(2)?;
645 let registered_at: i64 = row.get(3)?;
646 let last_heartbeat: i64 = row.get(4)?;
647 let last_status: Option<String> = row.get(5)?;
648 let last_phase: Option<String> = row.get(6)?;
649 let workflow: Option<String> = row.get(7)?;
650
651 Ok((
652 id,
653 tags_json,
654 max_claims,
655 registered_at,
656 last_heartbeat,
657 last_status,
658 last_phase,
659 workflow,
660 ))
661 })?
662 .filter_map(|r| r.ok())
663 .map(
664 |(
665 id,
666 tags_json,
667 max_claims,
668 registered_at,
669 last_heartbeat,
670 last_status,
671 last_phase,
672 workflow,
673 )| {
674 let tags: Vec<String> =
675 serde_json::from_str(&tags_json).unwrap_or_default();
676 Worker {
677 id,
678 tags,
679 max_claims,
680 registered_at,
681 last_heartbeat,
682 last_status,
683 last_phase,
684 workflow,
685 }
686 },
687 )
688 .collect();
689
690 Ok(workers)
691 })
692 }
693
694 pub fn cleanup_stale_workers(
697 &self,
698 timeout_seconds: i64,
699 final_status: &str,
700 ) -> Result<CleanupSummary> {
701 let stale_workers = self.get_stale_workers(timeout_seconds)?;
702
703 let mut total_tasks_released = 0;
704 let mut total_files_released = 0;
705 let mut evicted_worker_ids = Vec::new();
706
707 for worker in &stale_workers {
708 let _ = self.release_worker_locks(&worker.id);
710
711 if let Ok(summary) = self.unregister_worker(&worker.id, final_status) {
713 total_tasks_released += summary.tasks_released;
714 total_files_released += summary.files_released;
715 evicted_worker_ids.push(worker.id.clone());
716 }
717 }
718
719 Ok(CleanupSummary {
720 workers_evicted: evicted_worker_ids.len() as i32,
721 tasks_released: total_tasks_released,
722 files_released: total_files_released,
723 final_status: final_status.to_string(),
724 evicted_worker_ids,
725 })
726 }
727
728 pub fn get_claim_count(&self, worker_id: &str) -> Result<i32> {
730 self.with_conn(|conn| {
731 let count: i32 = conn.query_row(
732 "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'working'",
733 params![worker_id],
734 |row| row.get(0),
735 )?;
736 Ok(count)
737 })
738 }
739}