Skip to main content

walrus_daemon/hook/task/
registry.rs

1//! Task registry — concurrency control, dispatch, and lifecycle.
2
3use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7    collections::BTreeMap,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12    time::Duration,
13};
14use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{
17    ClientMessage, KillMsg, SendMsg, ServerMessage, TaskCompleted, TaskCreated, TaskEvent,
18    TaskInfo, TaskStatusChanged, client_message, server_message, task_event,
19};
20
21/// In-memory task registry with concurrency control.
22pub struct TaskRegistry {
23    tasks: BTreeMap<u64, Task>,
24    next_id: AtomicU64,
25    /// Maximum number of concurrently InProgress tasks.
26    pub max_concurrent: usize,
27    /// Maximum number of tasks returned by `list()`.
28    pub viewable_window: usize,
29    /// Per-task execution timeout.
30    pub task_timeout: Duration,
31    /// Event channel for dispatching task execution.
32    pub event_tx: DaemonEventSender,
33    /// Broadcast channel for task lifecycle events (subscriptions).
34    task_broadcast: broadcast::Sender<TaskEvent>,
35}
36
37impl TaskRegistry {
38    /// Create a new registry with the given config and event sender.
39    pub fn new(
40        max_concurrent: usize,
41        viewable_window: usize,
42        task_timeout: Duration,
43        event_tx: DaemonEventSender,
44    ) -> Self {
45        let (task_broadcast, _) = broadcast::channel(64);
46        Self {
47            tasks: BTreeMap::new(),
48            next_id: AtomicU64::new(1),
49            max_concurrent,
50            viewable_window,
51            task_timeout,
52            event_tx,
53            task_broadcast,
54        }
55    }
56
57    /// Subscribe to task lifecycle events.
58    pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
59        self.task_broadcast.subscribe()
60    }
61
62    /// Build a `TaskInfo` snapshot from an internal `Task`.
63    fn task_info(task: &Task) -> TaskInfo {
64        TaskInfo {
65            id: task.id,
66            parent_id: task.parent_id,
67            agent: task.agent.to_string(),
68            status: task.status.to_string(),
69            description: task.description.clone(),
70            result: task.result.clone(),
71            error: task.error.clone(),
72            created_by: task.created_by.to_string(),
73            prompt_tokens: task.prompt_tokens,
74            completion_tokens: task.completion_tokens,
75            alive_secs: task.created_at.elapsed().as_secs(),
76            blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
77        }
78    }
79
80    /// Create a new task and insert it into the registry.
81    pub fn create(
82        &mut self,
83        agent: CompactString,
84        description: String,
85        created_by: CompactString,
86        parent_id: Option<u64>,
87        status: TaskStatus,
88        spawned: bool,
89    ) -> u64 {
90        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
91        let (status_tx, _) = watch::channel(status);
92        let task = Task {
93            id,
94            parent_id,
95            session_id: None,
96            agent,
97            status,
98            created_by,
99            description,
100            result: None,
101            error: None,
102            blocked_on: None,
103            prompt_tokens: 0,
104            completion_tokens: 0,
105            created_at: Instant::now(),
106            abort_handle: None,
107            spawned,
108            status_tx,
109        };
110        self.tasks.insert(id, task);
111        if let Some(t) = self.tasks.get(&id) {
112            let _ = self.task_broadcast.send(TaskEvent {
113                event: Some(task_event::Event::Created(TaskCreated {
114                    task: Some(Self::task_info(t)),
115                })),
116            });
117        }
118        id
119    }
120
121    /// Get a reference to a task by ID.
122    pub fn get(&self, id: u64) -> Option<&Task> {
123        self.tasks.get(&id)
124    }
125
126    /// Get a mutable reference to a task by ID.
127    pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
128        self.tasks.get_mut(&id)
129    }
130
131    /// Update task status and notify watchers.
132    pub fn set_status(&mut self, id: u64, status: TaskStatus) {
133        if let Some(task) = self.tasks.get_mut(&id) {
134            task.status = status;
135            let _ = task.status_tx.send(status);
136            let _ = self.task_broadcast.send(TaskEvent {
137                event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
138                    task_id: id,
139                    status: status.to_string(),
140                    blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
141                })),
142            });
143        }
144    }
145
146    /// Remove a task from the registry.
147    pub fn remove(&mut self, id: u64) -> Option<Task> {
148        self.tasks.remove(&id)
149    }
150
151    /// List tasks, most recent first, up to `viewable_window` entries.
152    ///
153    /// Optionally filters by agent, status, or parent_id.
154    pub fn list(
155        &self,
156        agent: Option<&str>,
157        status: Option<TaskStatus>,
158        parent_id: Option<Option<u64>>,
159    ) -> Vec<&Task> {
160        self.tasks
161            .values()
162            .rev()
163            .filter(|t| agent.is_none_or(|a| t.agent == a))
164            .filter(|t| status.is_none_or(|s| t.status == s))
165            .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
166            .take(self.viewable_window)
167            .collect()
168    }
169
170    /// Count of currently InProgress tasks (not Blocked).
171    pub fn active_count(&self) -> usize {
172        self.tasks
173            .values()
174            .filter(|t| t.status == TaskStatus::InProgress)
175            .count()
176    }
177
178    /// Submit a task for execution.
179    ///
180    /// If under the concurrency limit, dispatches immediately and spawns a
181    /// watcher. Otherwise, queues the task. Returns `(task_id, status)`.
182    pub fn submit(
183        &mut self,
184        agent: CompactString,
185        message: String,
186        created_by: CompactString,
187        parent_id: Option<u64>,
188        registry: Arc<Mutex<TaskRegistry>>,
189    ) -> (u64, TaskStatus) {
190        let under_limit = self.active_count() < self.max_concurrent;
191        let initial_status = if under_limit {
192            TaskStatus::InProgress
193        } else {
194            TaskStatus::Queued
195        };
196
197        let task_id = self.create(
198            agent.clone(),
199            message.clone(),
200            created_by,
201            parent_id,
202            initial_status,
203            true,
204        );
205
206        if under_limit {
207            self.dispatch_task(task_id, agent, message, registry);
208        }
209
210        (task_id, initial_status)
211    }
212
213    /// Dispatch a task: send the message via event channel and spawn a watcher.
214    fn dispatch_task(
215        &mut self,
216        task_id: u64,
217        agent: CompactString,
218        message: String,
219        registry: Arc<Mutex<TaskRegistry>>,
220    ) {
221        let (reply_tx, reply_rx) = mpsc::unbounded_channel();
222        let msg = ClientMessage::from(SendMsg {
223            agent: agent.to_string(),
224            content: message,
225            session: None,
226            sender: None,
227        });
228        let _ = self.event_tx.send(DaemonEvent::Message {
229            msg,
230            reply: reply_tx,
231        });
232
233        let event_tx = self.event_tx.clone();
234        let timeout = self.task_timeout;
235        let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
236        if let Some(task) = self.tasks.get_mut(&task_id) {
237            task.abort_handle = Some(handle.abort_handle());
238        }
239    }
240
241    /// Mark a task as Finished or Failed, then promote the next queued task.
242    pub fn complete(
243        &mut self,
244        task_id: u64,
245        result: Option<String>,
246        error: Option<String>,
247        registry: Arc<Mutex<TaskRegistry>>,
248    ) {
249        if let Some(task) = self.tasks.get_mut(&task_id) {
250            if error.is_some() {
251                task.status = TaskStatus::Failed;
252                task.error = error.clone();
253                let _ = task.status_tx.send(TaskStatus::Failed);
254                let _ = self.task_broadcast.send(TaskEvent {
255                    event: Some(task_event::Event::Completed(TaskCompleted {
256                        task_id,
257                        status: TaskStatus::Failed.to_string(),
258                        result: None,
259                        error,
260                    })),
261                });
262            } else {
263                task.status = TaskStatus::Finished;
264                task.result = result.clone();
265                let _ = task.status_tx.send(TaskStatus::Finished);
266                let _ = self.task_broadcast.send(TaskEvent {
267                    event: Some(task_event::Event::Completed(TaskCompleted {
268                        task_id,
269                        status: TaskStatus::Finished.to_string(),
270                        result,
271                        error: None,
272                    })),
273                });
274            }
275        }
276        self.promote_next(registry);
277    }
278
279    /// Promote the next queued task to InProgress if a slot is available.
280    pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
281        if self.active_count() >= self.max_concurrent {
282            return;
283        }
284        // Find the oldest queued task.
285        let next = self
286            .tasks
287            .values()
288            .find(|t| t.status == TaskStatus::Queued)
289            .map(|t| (t.id, t.agent.clone(), t.description.clone()));
290
291        if let Some((id, agent, message)) = next {
292            self.set_status(id, TaskStatus::InProgress);
293            self.dispatch_task(id, agent, message, registry);
294        }
295    }
296
297    /// Block a task, setting status to Blocked and storing the inbox item.
298    ///
299    /// Returns a receiver that the tool call can await for the user's response.
300    pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
301        let task = self.tasks.get_mut(&task_id)?;
302        let (tx, rx) = oneshot::channel();
303        task.blocked_on = Some(InboxItem {
304            question,
305            reply: tx,
306        });
307        task.status = TaskStatus::Blocked;
308        let _ = task.status_tx.send(TaskStatus::Blocked);
309        let _ = self.task_broadcast.send(TaskEvent {
310            event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
311                task_id,
312                status: TaskStatus::Blocked.to_string(),
313                blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
314            })),
315        });
316        Some(rx)
317    }
318
319    /// Approve a blocked task, sending the response and resuming execution.
320    pub fn approve(&mut self, task_id: u64, response: String) -> bool {
321        let Some(task) = self.tasks.get_mut(&task_id) else {
322            return false;
323        };
324        if task.status != TaskStatus::Blocked {
325            return false;
326        }
327        if let Some(inbox) = task.blocked_on.take() {
328            let _ = inbox.reply.send(response);
329        }
330        task.status = TaskStatus::InProgress;
331        let _ = task.status_tx.send(TaskStatus::InProgress);
332        let _ = self.task_broadcast.send(TaskEvent {
333            event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
334                task_id,
335                status: TaskStatus::InProgress.to_string(),
336                blocked_on: None,
337            })),
338        });
339        true
340    }
341
342    /// Subscribe to a task's status changes (for await_tasks).
343    pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
344        self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
345    }
346
347    /// Get all child tasks of a given parent.
348    pub fn children(&self, parent_id: u64) -> Vec<&Task> {
349        self.tasks
350            .values()
351            .filter(|t| t.parent_id == Some(parent_id))
352            .collect()
353    }
354
355    /// Find a task by its session ID. Returns the task ID.
356    pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
357        self.tasks
358            .values()
359            .find(|t| t.session_id == Some(session_id))
360            .map(|t| t.id)
361    }
362
363    /// Add token usage to a task.
364    pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
365        if let Some(task) = self.tasks.get_mut(&task_id) {
366            task.prompt_tokens += prompt;
367            task.completion_tokens += completion;
368        }
369    }
370
371    /// Collect queued `create_task` entries grouped by agent.
372    ///
373    /// Returns `(agent, [(task_id, description)])` pairs, capped at
374    /// `max_concurrent` tasks per agent to avoid context overflow.
375    pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
376        let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
377        for task in self.tasks.values() {
378            if task.status == TaskStatus::Queued && !task.spawned {
379                let entry = groups.entry(task.agent.clone()).or_default();
380                if entry.len() < self.max_concurrent {
381                    entry.push((task.id, task.description.clone()));
382                }
383            }
384        }
385        groups
386    }
387
388    /// Collect queued `create_task` entries for a single agent, capped at
389    /// `max_concurrent`.
390    pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
391        let mut entries = Vec::new();
392        for task in self.tasks.values() {
393            if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
394                entries.push((task.id, task.description.clone()));
395                if entries.len() >= self.max_concurrent {
396                    break;
397                }
398            }
399        }
400        entries
401    }
402}
403
404/// Watcher task: awaits reply messages with timeout, closes session, completes task.
405async fn task_watcher(
406    task_id: u64,
407    mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
408    registry: Arc<Mutex<TaskRegistry>>,
409    event_tx: DaemonEventSender,
410    timeout: Duration,
411) {
412    let mut result_content: Option<String> = None;
413    let mut error_msg: Option<String> = None;
414    let mut session_id: Option<u64> = None;
415
416    let collect = async {
417        while let Some(msg) = reply_rx.recv().await {
418            match msg.msg {
419                Some(server_message::Msg::Response(resp)) => {
420                    session_id = Some(resp.session);
421                    result_content = Some(resp.content);
422                }
423                Some(server_message::Msg::Error(err)) => {
424                    error_msg = Some(err.message);
425                }
426                _ => {}
427            }
428        }
429    };
430
431    if tokio::time::timeout(timeout, collect).await.is_err() {
432        error_msg = Some("task timed out".into());
433    }
434
435    // Close the session to prevent accumulation.
436    if let Some(sid) = session_id {
437        let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
438        let _ = event_tx.send(DaemonEvent::Message {
439            msg: ClientMessage {
440                msg: Some(client_message::Msg::Kill(KillMsg { session: sid })),
441            },
442            reply: reply_tx,
443        });
444    }
445
446    // Complete the task, auto-close sub-task sessions, and promote next queued.
447    let reg = registry.clone();
448    let mut locked = registry.lock().await;
449    // Collect finished sub-task session IDs for auto-close.
450    let child_sessions: Vec<u64> = locked
451        .children(task_id)
452        .iter()
453        .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
454        .filter_map(|t| t.session_id)
455        .collect();
456    locked.complete(task_id, result_content, error_msg, reg);
457    drop(locked);
458
459    // Auto-close finished sub-task sessions outside the lock.
460    for sid in child_sessions {
461        let (reply_tx, _) = mpsc::unbounded_channel();
462        let _ = event_tx.send(DaemonEvent::Message {
463            msg: ClientMessage {
464                msg: Some(client_message::Msg::Kill(KillMsg { session: sid })),
465            },
466            reply: reply_tx,
467        });
468    }
469}