1use super::{now_ms, Database};
4use crate::types::{CleanupSummary, DisconnectSummary, Worker};
5use anyhow::{anyhow, Result};
6use rusqlite::{params, Connection};
7
8pub const MAX_WORKER_ID_LEN: usize = 36;
10
11const MAX_PETNAME_ATTEMPTS: u32 = 100;
13
14fn generate_unique_petname(conn: &Connection) -> String {
17 let base = petname::petname(2, "-").unwrap_or_else(|| "worker".to_string());
18
19 let exists: bool = conn
21 .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&base], |_| Ok(true))
22 .unwrap_or(false);
23
24 if !exists {
25 return base;
26 }
27
28 for i in 2..=MAX_PETNAME_ATTEMPTS {
30 let candidate = format!("{}-{}", base, i);
31 let exists: bool = conn
32 .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&candidate], |_| Ok(true))
33 .unwrap_or(false);
34 if !exists {
35 return candidate;
36 }
37 }
38
39 petname::petname(3, "-").unwrap_or_else(|| format!("worker-{}", now_ms()))
41}
42
43fn get_worker_internal(conn: &Connection, worker_id: &str) -> Result<Option<Worker>> {
45 let mut stmt = conn.prepare(
46 "SELECT id, tags, max_claims, registered_at, last_heartbeat
47 FROM workers WHERE id = ?1",
48 )?;
49
50 let result = stmt.query_row(params![worker_id], |row| {
51 let id: String = row.get(0)?;
52 let tags_json: String = row.get(1)?;
53 let max_claims: i32 = row.get(2)?;
54 let registered_at: i64 = row.get(3)?;
55 let last_heartbeat: i64 = row.get(4)?;
56
57 Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
58 });
59
60 match result {
61 Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
62 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
63 Ok(Some(Worker {
64 id,
65 tags,
66 max_claims,
67 registered_at,
68 last_heartbeat,
69 }))
70 }
71 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
72 Err(e) => Err(e.into()),
73 }
74}
75
76impl Database {
77 pub fn register_worker(
84 &self,
85 worker_id: Option<String>,
86 tags: Vec<String>,
87 force: bool,
88 ) -> Result<Worker> {
89 let provided_id = match worker_id {
91 Some(id) => {
92 if id.len() > MAX_WORKER_ID_LEN {
93 return Err(anyhow!(
94 "Worker ID must be at most {} characters, got {}",
95 MAX_WORKER_ID_LEN,
96 id.len()
97 ));
98 }
99 if id.is_empty() {
100 return Err(anyhow!("Worker ID cannot be empty"));
101 }
102 Some(id)
103 }
104 None => None,
105 };
106 let now = now_ms();
107 let max_claims = i32::MAX; let tags_json = serde_json::to_string(&tags)?;
109
110 self.with_conn(|conn| {
111 let id = match provided_id {
113 Some(id) => id,
114 None => generate_unique_petname(conn),
115 };
116
117 let exists: bool = conn
119 .query_row("SELECT 1 FROM workers WHERE id = ?1", params![&id], |_| Ok(true))
120 .unwrap_or(false);
121
122 let current_max_sequence: i64 = conn
126 .query_row("SELECT COALESCE(MAX(id), 0) FROM claim_sequence", [], |row| row.get(0))
127 .unwrap_or(0);
128 let initial_sequence = current_max_sequence + 1;
129
130 if exists {
131 if force {
132 conn.execute(
134 "UPDATE workers SET tags = ?1, max_claims = ?2, last_heartbeat = ?3, last_claim_sequence = ?4 WHERE id = ?5",
135 params![tags_json, max_claims, now, initial_sequence, &id],
136 )?;
137 } else {
138 return Err(anyhow!("Worker ID '{}' already registered. Use force=true to reconnect.", id));
139 }
140 } else {
141 conn.execute(
142 "INSERT INTO workers (id, tags, max_claims, registered_at, last_heartbeat, last_claim_sequence)
143 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
144 params![&id, tags_json, max_claims, now, now, initial_sequence],
145 )?;
146 }
147
148 Ok(Worker {
149 id,
150 tags,
151 max_claims,
152 registered_at: now,
153 last_heartbeat: now,
154 })
155 })
156 }
157
158 pub fn get_worker(&self, worker_id: &str) -> Result<Option<Worker>> {
160 self.with_conn(|conn| {
161 let mut stmt = conn.prepare(
162 "SELECT id, tags, max_claims, registered_at, last_heartbeat
163 FROM workers WHERE id = ?1",
164 )?;
165
166 let result = stmt.query_row(params![worker_id], |row| {
167 let id: String = row.get(0)?;
168 let tags_json: String = row.get(1)?;
169 let max_claims: i32 = row.get(2)?;
170 let registered_at: i64 = row.get(3)?;
171 let last_heartbeat: i64 = row.get(4)?;
172
173 Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
174 });
175
176 match result {
177 Ok((id, tags_json, max_claims, registered_at, last_heartbeat)) => {
178 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
179 Ok(Some(Worker {
180 id,
181 tags,
182 max_claims,
183 registered_at,
184 last_heartbeat,
185 }))
186 }
187 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
188 Err(e) => Err(e.into()),
189 }
190 })
191 }
192
193 pub fn require_worker(&self, worker_id: &str) -> Result<Worker> {
195 self.get_worker(worker_id)?
196 .ok_or_else(|| anyhow::anyhow!("Worker {} not found", worker_id))
197 }
198
199 pub fn update_worker(
201 &self,
202 worker_id: &str,
203 tags: Option<Vec<String>>,
204 max_claims: Option<i32>,
205 ) -> Result<Worker> {
206 self.with_conn(|conn| {
207 let worker = get_worker_internal(conn, worker_id)?
208 .ok_or_else(|| anyhow!("Worker not found"))?;
209
210 let new_tags = tags.unwrap_or(worker.tags.clone());
211 let new_max_claims = max_claims.unwrap_or(worker.max_claims);
212 let tags_json = serde_json::to_string(&new_tags)?;
213
214 conn.execute(
215 "UPDATE workers SET tags = ?1, max_claims = ?2 WHERE id = ?3",
216 params![tags_json, new_max_claims, worker_id],
217 )?;
218
219 Ok(Worker {
220 id: worker_id.to_string(),
221 tags: new_tags,
222 max_claims: new_max_claims,
223 registered_at: worker.registered_at,
224 last_heartbeat: worker.last_heartbeat,
225 })
226 })
227 }
228
229 pub fn heartbeat(&self, worker_id: &str) -> Result<i32> {
231 let now = now_ms();
232
233 self.with_conn(|conn| {
234 let updated = conn.execute(
235 "UPDATE workers SET last_heartbeat = ?1 WHERE id = ?2",
236 params![now, worker_id],
237 )?;
238
239 if updated == 0 {
240 return Err(anyhow!("Worker not found"));
241 }
242
243 let count: i32 = conn.query_row(
245 "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
246 params![worker_id],
247 |row| row.get(0),
248 )?;
249
250 Ok(count)
251 })
252 }
253
254 pub fn unregister_worker(&self, worker_id: &str, final_status: &str) -> Result<DisconnectSummary> {
257 self.with_conn_mut(|conn| {
258 let tx = conn.transaction()?;
259
260 let tasks_released = tx.execute(
262 "UPDATE tasks SET worker_id = NULL, claimed_at = NULL, status = ?2
263 WHERE worker_id = ?1",
264 params![worker_id, final_status],
265 )? as i32;
266
267 let files_released = tx.execute(
269 "DELETE FROM file_locks WHERE worker_id = ?1",
270 params![worker_id],
271 )? as i32;
272
273 tx.execute(
275 "DELETE FROM workers WHERE id = ?1",
276 params![worker_id],
277 )?;
278
279 tx.commit()?;
280 Ok(DisconnectSummary {
281 tasks_released,
282 files_released,
283 final_status: final_status.to_string(),
284 })
285 })
286 }
287
288 pub fn list_workers(&self) -> Result<Vec<Worker>> {
290 self.with_conn(|conn| {
291 let mut stmt = conn.prepare(
292 "SELECT id, tags, max_claims, registered_at, last_heartbeat
293 FROM workers ORDER BY registered_at DESC",
294 )?;
295
296 let workers = stmt.query_map([], |row| {
297 let id: String = row.get(0)?;
298 let tags_json: String = row.get(1)?;
299 let max_claims: i32 = row.get(2)?;
300 let registered_at: i64 = row.get(3)?;
301 let last_heartbeat: i64 = row.get(4)?;
302
303 Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
304 })?
305 .filter_map(|r| r.ok())
306 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat)| {
307 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
308 Worker {
309 id,
310 tags,
311 max_claims,
312 registered_at,
313 last_heartbeat,
314 }
315 })
316 .collect();
317
318 Ok(workers)
319 })
320 }
321
322 pub fn list_workers_info(&self) -> Result<Vec<crate::types::WorkerInfo>> {
324 self.with_conn(|conn| {
325 let mut stmt = conn.prepare(
326 "SELECT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
327 (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
328 (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
329 FROM workers w ORDER BY w.registered_at DESC",
330 )?;
331
332 let workers = stmt.query_map([], |row| {
333 let id: String = row.get(0)?;
334 let tags_json: String = row.get(1)?;
335 let max_claims: i32 = row.get(2)?;
336 let registered_at: i64 = row.get(3)?;
337 let last_heartbeat: i64 = row.get(4)?;
338 let claim_count: i32 = row.get(5)?;
339 let current_thought: Option<String> = row.get(6)?;
340
341 Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
342 })?
343 .filter_map(|r| r.ok())
344 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
345 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
346 crate::types::WorkerInfo {
347 id,
348 tags,
349 max_claims,
350 claim_count,
351 current_thought,
352 registered_at,
353 last_heartbeat,
354 }
355 })
356 .collect();
357
358 Ok(workers)
359 })
360 }
361
362 pub fn list_workers_filtered(
369 &self,
370 tags: Option<&Vec<String>>,
371 file: Option<&str>,
372 task_id: Option<&str>,
373 depth: i32,
374 ) -> Result<Vec<crate::types::WorkerInfo>> {
375 self.with_conn(|conn| {
376 let mut sql = String::from(
378 "SELECT DISTINCT w.id, w.tags, w.max_claims, w.registered_at, w.last_heartbeat,
379 (SELECT COUNT(*) FROM tasks WHERE worker_id = w.id AND status = 'in_progress') as claim_count,
380 (SELECT current_thought FROM tasks WHERE worker_id = w.id AND status = 'in_progress' AND current_thought IS NOT NULL LIMIT 1) as current_thought
381 FROM workers w WHERE 1=1",
382 );
383 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
384
385 if let Some(f) = file {
387 sql.push_str(" AND w.id IN (SELECT worker_id FROM file_locks WHERE file_path = ?)");
388 params_vec.push(Box::new(f.to_string()));
389 }
390
391 if let Some(tid) = task_id {
393 let related_task_ids = Self::get_related_task_ids_internal(conn, tid, depth)?;
395 if !related_task_ids.is_empty() {
396 let placeholders: Vec<String> = related_task_ids.iter().map(|_| "?".to_string()).collect();
397 sql.push_str(&format!(
398 " AND w.id IN (SELECT DISTINCT worker_id FROM tasks WHERE id IN ({}) AND worker_id IS NOT NULL)",
399 placeholders.join(", ")
400 ));
401 for task in related_task_ids {
402 params_vec.push(Box::new(task));
403 }
404 } else {
405 return Ok(Vec::new());
407 }
408 }
409
410 sql.push_str(" ORDER BY w.registered_at DESC");
411
412 let params_refs: Vec<&dyn rusqlite::ToSql> =
413 params_vec.iter().map(|b| b.as_ref()).collect();
414
415 let mut stmt = conn.prepare(&sql)?;
416 let workers: Vec<crate::types::WorkerInfo> = stmt
417 .query_map(params_refs.as_slice(), |row| {
418 let id: String = row.get(0)?;
419 let tags_json: String = row.get(1)?;
420 let max_claims: i32 = row.get(2)?;
421 let registered_at: i64 = row.get(3)?;
422 let last_heartbeat: i64 = row.get(4)?;
423 let claim_count: i32 = row.get(5)?;
424 let current_thought: Option<String> = row.get(6)?;
425
426 Ok((id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought))
427 })?
428 .filter_map(|r| r.ok())
429 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat, claim_count, current_thought)| {
430 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
431 crate::types::WorkerInfo {
432 id,
433 tags,
434 max_claims,
435 claim_count,
436 current_thought,
437 registered_at,
438 last_heartbeat,
439 }
440 })
441 .collect();
442
443 let workers = if let Some(required_tags) = tags {
445 workers
446 .into_iter()
447 .filter(|w| required_tags.iter().all(|t| w.tags.contains(t)))
448 .collect()
449 } else {
450 workers
451 };
452
453 Ok(workers)
454 })
455 }
456
457 fn get_related_task_ids_internal(conn: &Connection, task_id: &str, depth: i32) -> Result<Vec<String>> {
460 use std::collections::HashSet;
461
462 let mut result = HashSet::new();
463 result.insert(task_id.to_string());
464
465 if depth == 0 {
466 return Ok(result.into_iter().collect());
467 }
468
469 let abs_depth = depth.abs();
470 let mut current_level: HashSet<String> = [task_id.to_string()].into_iter().collect();
471
472 for _ in 0..abs_depth {
473 if current_level.is_empty() {
474 break;
475 }
476
477 let mut next_level = HashSet::new();
478
479 for tid in ¤t_level {
480 let related: Vec<String> = if depth > 0 {
481 let mut stmt = conn.prepare(
483 "SELECT to_task_id FROM dependencies WHERE from_task_id = ?1"
484 )?;
485 stmt.query_map(params![tid], |row| row.get(0))?
486 .filter_map(|r| r.ok())
487 .collect()
488 } else {
489 let mut stmt = conn.prepare(
491 "SELECT from_task_id FROM dependencies WHERE to_task_id = ?1"
492 )?;
493 stmt.query_map(params![tid], |row| row.get(0))?
494 .filter_map(|r| r.ok())
495 .collect()
496 };
497
498 for related_id in related {
499 if !result.contains(&related_id) {
500 next_level.insert(related_id.clone());
501 result.insert(related_id);
502 }
503 }
504 }
505
506 current_level = next_level;
507 }
508
509 Ok(result.into_iter().collect())
510 }
511
512 pub fn get_stale_workers(&self, timeout_seconds: i64) -> Result<Vec<Worker>> {
514 let cutoff = now_ms() - (timeout_seconds * 1000);
515
516 self.with_conn(|conn| {
517 let mut stmt = conn.prepare(
518 "SELECT id, tags, max_claims, registered_at, last_heartbeat
519 FROM workers WHERE last_heartbeat < ?1",
520 )?;
521
522 let workers = stmt.query_map(params![cutoff], |row| {
523 let id: String = row.get(0)?;
524 let tags_json: String = row.get(1)?;
525 let max_claims: i32 = row.get(2)?;
526 let registered_at: i64 = row.get(3)?;
527 let last_heartbeat: i64 = row.get(4)?;
528
529 Ok((id, tags_json, max_claims, registered_at, last_heartbeat))
530 })?
531 .filter_map(|r| r.ok())
532 .map(|(id, tags_json, max_claims, registered_at, last_heartbeat)| {
533 let tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
534 Worker {
535 id,
536 tags,
537 max_claims,
538 registered_at,
539 last_heartbeat,
540 }
541 })
542 .collect();
543
544 Ok(workers)
545 })
546 }
547
548 pub fn cleanup_stale_workers(&self, timeout_seconds: i64, final_status: &str) -> Result<CleanupSummary> {
551 let stale_workers = self.get_stale_workers(timeout_seconds)?;
552
553 let mut total_tasks_released = 0;
554 let mut total_files_released = 0;
555 let mut evicted_worker_ids = Vec::new();
556
557 for worker in &stale_workers {
558 let _ = self.release_worker_locks(&worker.id);
560
561 if let Ok(summary) = self.unregister_worker(&worker.id, final_status) {
563 total_tasks_released += summary.tasks_released;
564 total_files_released += summary.files_released;
565 evicted_worker_ids.push(worker.id.clone());
566 }
567 }
568
569 Ok(CleanupSummary {
570 workers_evicted: evicted_worker_ids.len() as i32,
571 tasks_released: total_tasks_released,
572 files_released: total_files_released,
573 final_status: final_status.to_string(),
574 evicted_worker_ids,
575 })
576 }
577
578 pub fn get_claim_count(&self, worker_id: &str) -> Result<i32> {
580 self.with_conn(|conn| {
581 let count: i32 = conn.query_row(
582 "SELECT COUNT(*) FROM tasks WHERE worker_id = ?1 AND status = 'in_progress'",
583 params![worker_id],
584 |row| row.get(0),
585 )?;
586 Ok(count)
587 })
588 }
589}