Skip to main content

walrus_daemon/hook/task/
registry.rs

1//! Task registry — concurrency control, dispatch, and lifecycle.
2
3use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7    collections::BTreeMap,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12    time::Duration,
13};
14use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{
17    TaskEvent,
18    client::ClientMessage,
19    server::{ServerMessage, TaskInfo},
20};
21
22/// In-memory task registry with concurrency control.
23pub struct TaskRegistry {
24    tasks: BTreeMap<u64, Task>,
25    next_id: AtomicU64,
26    /// Maximum number of concurrently InProgress tasks.
27    pub max_concurrent: usize,
28    /// Maximum number of tasks returned by `list()`.
29    pub viewable_window: usize,
30    /// Per-task execution timeout.
31    pub task_timeout: Duration,
32    /// Event channel for dispatching task execution.
33    pub event_tx: DaemonEventSender,
34    /// Broadcast channel for task lifecycle events (subscriptions).
35    task_broadcast: broadcast::Sender<TaskEvent>,
36}
37
38impl TaskRegistry {
39    /// Create a new registry with the given config and event sender.
40    pub fn new(
41        max_concurrent: usize,
42        viewable_window: usize,
43        task_timeout: Duration,
44        event_tx: DaemonEventSender,
45    ) -> Self {
46        let (task_broadcast, _) = broadcast::channel(64);
47        Self {
48            tasks: BTreeMap::new(),
49            next_id: AtomicU64::new(1),
50            max_concurrent,
51            viewable_window,
52            task_timeout,
53            event_tx,
54            task_broadcast,
55        }
56    }
57
58    /// Subscribe to task lifecycle events.
59    pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
60        self.task_broadcast.subscribe()
61    }
62
63    /// Build a `TaskInfo` snapshot from an internal `Task`.
64    fn task_info(task: &Task) -> TaskInfo {
65        TaskInfo {
66            id: task.id,
67            parent_id: task.parent_id,
68            agent: task.agent.clone(),
69            status: task.status.to_string(),
70            description: task.description.clone(),
71            result: task.result.clone(),
72            error: task.error.clone(),
73            created_by: task.created_by.clone(),
74            prompt_tokens: task.prompt_tokens,
75            completion_tokens: task.completion_tokens,
76            alive_secs: task.created_at.elapsed().as_secs(),
77            blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
78        }
79    }
80
81    /// Create a new task and insert it into the registry.
82    pub fn create(
83        &mut self,
84        agent: CompactString,
85        description: String,
86        created_by: CompactString,
87        parent_id: Option<u64>,
88        status: TaskStatus,
89        spawned: bool,
90    ) -> u64 {
91        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
92        let (status_tx, _) = watch::channel(status);
93        let task = Task {
94            id,
95            parent_id,
96            session_id: None,
97            agent,
98            status,
99            created_by,
100            description,
101            result: None,
102            error: None,
103            blocked_on: None,
104            prompt_tokens: 0,
105            completion_tokens: 0,
106            created_at: Instant::now(),
107            abort_handle: None,
108            spawned,
109            status_tx,
110        };
111        self.tasks.insert(id, task);
112        if let Some(t) = self.tasks.get(&id) {
113            let _ = self.task_broadcast.send(TaskEvent::Created {
114                task: Self::task_info(t),
115            });
116        }
117        id
118    }
119
120    /// Get a reference to a task by ID.
121    pub fn get(&self, id: u64) -> Option<&Task> {
122        self.tasks.get(&id)
123    }
124
125    /// Get a mutable reference to a task by ID.
126    pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
127        self.tasks.get_mut(&id)
128    }
129
130    /// Update task status and notify watchers.
131    pub fn set_status(&mut self, id: u64, status: TaskStatus) {
132        if let Some(task) = self.tasks.get_mut(&id) {
133            task.status = status;
134            let _ = task.status_tx.send(status);
135            let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
136                task_id: id,
137                status: status.to_string(),
138                blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
139            });
140        }
141    }
142
143    /// Remove a task from the registry.
144    pub fn remove(&mut self, id: u64) -> Option<Task> {
145        self.tasks.remove(&id)
146    }
147
148    /// List tasks, most recent first, up to `viewable_window` entries.
149    ///
150    /// Optionally filters by agent, status, or parent_id.
151    pub fn list(
152        &self,
153        agent: Option<&str>,
154        status: Option<TaskStatus>,
155        parent_id: Option<Option<u64>>,
156    ) -> Vec<&Task> {
157        self.tasks
158            .values()
159            .rev()
160            .filter(|t| agent.is_none_or(|a| t.agent == a))
161            .filter(|t| status.is_none_or(|s| t.status == s))
162            .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
163            .take(self.viewable_window)
164            .collect()
165    }
166
167    /// Count of currently InProgress tasks (not Blocked).
168    pub fn active_count(&self) -> usize {
169        self.tasks
170            .values()
171            .filter(|t| t.status == TaskStatus::InProgress)
172            .count()
173    }
174
175    /// Submit a task for execution.
176    ///
177    /// If under the concurrency limit, dispatches immediately and spawns a
178    /// watcher. Otherwise, queues the task. Returns `(task_id, status)`.
179    pub fn submit(
180        &mut self,
181        agent: CompactString,
182        message: String,
183        created_by: CompactString,
184        parent_id: Option<u64>,
185        registry: Arc<Mutex<TaskRegistry>>,
186    ) -> (u64, TaskStatus) {
187        let under_limit = self.active_count() < self.max_concurrent;
188        let initial_status = if under_limit {
189            TaskStatus::InProgress
190        } else {
191            TaskStatus::Queued
192        };
193
194        let task_id = self.create(
195            agent.clone(),
196            message.clone(),
197            created_by,
198            parent_id,
199            initial_status,
200            true,
201        );
202
203        if under_limit {
204            self.dispatch_task(task_id, agent, message, registry);
205        }
206
207        (task_id, initial_status)
208    }
209
210    /// Dispatch a task: send the message via event channel and spawn a watcher.
211    fn dispatch_task(
212        &mut self,
213        task_id: u64,
214        agent: CompactString,
215        message: String,
216        registry: Arc<Mutex<TaskRegistry>>,
217    ) {
218        let (reply_tx, reply_rx) = mpsc::unbounded_channel();
219        let msg = ClientMessage::Send {
220            agent,
221            content: message,
222            session: None,
223            sender: None,
224        };
225        let _ = self.event_tx.send(DaemonEvent::Message {
226            msg,
227            reply: reply_tx,
228        });
229
230        let event_tx = self.event_tx.clone();
231        let timeout = self.task_timeout;
232        let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
233        if let Some(task) = self.tasks.get_mut(&task_id) {
234            task.abort_handle = Some(handle.abort_handle());
235        }
236    }
237
238    /// Mark a task as Finished or Failed, then promote the next queued task.
239    pub fn complete(
240        &mut self,
241        task_id: u64,
242        result: Option<String>,
243        error: Option<String>,
244        registry: Arc<Mutex<TaskRegistry>>,
245    ) {
246        if let Some(task) = self.tasks.get_mut(&task_id) {
247            if error.is_some() {
248                task.status = TaskStatus::Failed;
249                task.error = error.clone();
250                let _ = task.status_tx.send(TaskStatus::Failed);
251                let _ = self.task_broadcast.send(TaskEvent::Completed {
252                    task_id,
253                    status: TaskStatus::Failed.to_string(),
254                    result: None,
255                    error,
256                });
257            } else {
258                task.status = TaskStatus::Finished;
259                task.result = result.clone();
260                let _ = task.status_tx.send(TaskStatus::Finished);
261                let _ = self.task_broadcast.send(TaskEvent::Completed {
262                    task_id,
263                    status: TaskStatus::Finished.to_string(),
264                    result,
265                    error: None,
266                });
267            }
268        }
269        self.promote_next(registry);
270    }
271
272    /// Promote the next queued task to InProgress if a slot is available.
273    pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
274        if self.active_count() >= self.max_concurrent {
275            return;
276        }
277        // Find the oldest queued task.
278        let next = self
279            .tasks
280            .values()
281            .find(|t| t.status == TaskStatus::Queued)
282            .map(|t| (t.id, t.agent.clone(), t.description.clone()));
283
284        if let Some((id, agent, message)) = next {
285            self.set_status(id, TaskStatus::InProgress);
286            self.dispatch_task(id, agent, message, registry);
287        }
288    }
289
290    /// Block a task, setting status to Blocked and storing the inbox item.
291    ///
292    /// Returns a receiver that the tool call can await for the user's response.
293    pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
294        let task = self.tasks.get_mut(&task_id)?;
295        let (tx, rx) = oneshot::channel();
296        task.blocked_on = Some(InboxItem {
297            question,
298            reply: tx,
299        });
300        task.status = TaskStatus::Blocked;
301        let _ = task.status_tx.send(TaskStatus::Blocked);
302        let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
303            task_id,
304            status: TaskStatus::Blocked.to_string(),
305            blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
306        });
307        Some(rx)
308    }
309
310    /// Approve a blocked task, sending the response and resuming execution.
311    pub fn approve(&mut self, task_id: u64, response: String) -> bool {
312        let Some(task) = self.tasks.get_mut(&task_id) else {
313            return false;
314        };
315        if task.status != TaskStatus::Blocked {
316            return false;
317        }
318        if let Some(inbox) = task.blocked_on.take() {
319            let _ = inbox.reply.send(response);
320        }
321        task.status = TaskStatus::InProgress;
322        let _ = task.status_tx.send(TaskStatus::InProgress);
323        let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
324            task_id,
325            status: TaskStatus::InProgress.to_string(),
326            blocked_on: None,
327        });
328        true
329    }
330
331    /// Subscribe to a task's status changes (for await_tasks).
332    pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
333        self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
334    }
335
336    /// Get all child tasks of a given parent.
337    pub fn children(&self, parent_id: u64) -> Vec<&Task> {
338        self.tasks
339            .values()
340            .filter(|t| t.parent_id == Some(parent_id))
341            .collect()
342    }
343
344    /// Find a task by its session ID. Returns the task ID.
345    pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
346        self.tasks
347            .values()
348            .find(|t| t.session_id == Some(session_id))
349            .map(|t| t.id)
350    }
351
352    /// Add token usage to a task.
353    pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
354        if let Some(task) = self.tasks.get_mut(&task_id) {
355            task.prompt_tokens += prompt;
356            task.completion_tokens += completion;
357        }
358    }
359
360    /// Collect queued `create_task` entries grouped by agent.
361    ///
362    /// Returns `(agent, [(task_id, description)])` pairs, capped at
363    /// `max_concurrent` tasks per agent to avoid context overflow.
364    pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
365        let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
366        for task in self.tasks.values() {
367            if task.status == TaskStatus::Queued && !task.spawned {
368                let entry = groups.entry(task.agent.clone()).or_default();
369                if entry.len() < self.max_concurrent {
370                    entry.push((task.id, task.description.clone()));
371                }
372            }
373        }
374        groups
375    }
376
377    /// Collect queued `create_task` entries for a single agent, capped at
378    /// `max_concurrent`.
379    pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
380        let mut entries = Vec::new();
381        for task in self.tasks.values() {
382            if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
383                entries.push((task.id, task.description.clone()));
384                if entries.len() >= self.max_concurrent {
385                    break;
386                }
387            }
388        }
389        entries
390    }
391}
392
393/// Watcher task: awaits reply messages with timeout, closes session, completes task.
394async fn task_watcher(
395    task_id: u64,
396    mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
397    registry: Arc<Mutex<TaskRegistry>>,
398    event_tx: DaemonEventSender,
399    timeout: Duration,
400) {
401    let mut result_content: Option<String> = None;
402    let mut error_msg: Option<String> = None;
403    let mut session_id: Option<u64> = None;
404
405    let collect = async {
406        while let Some(msg) = reply_rx.recv().await {
407            match msg {
408                ServerMessage::Response(resp) => {
409                    session_id = Some(resp.session);
410                    result_content = Some(resp.content);
411                }
412                ServerMessage::Error { message, .. } => {
413                    error_msg = Some(message);
414                }
415                _ => {}
416            }
417        }
418    };
419
420    if tokio::time::timeout(timeout, collect).await.is_err() {
421        error_msg = Some("task timed out".into());
422    }
423
424    // Close the session to prevent accumulation.
425    if let Some(sid) = session_id {
426        let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
427        let _ = event_tx.send(DaemonEvent::Message {
428            msg: ClientMessage::Kill { session: sid },
429            reply: reply_tx,
430        });
431    }
432
433    // Complete the task, auto-close sub-task sessions, and promote next queued.
434    let reg = registry.clone();
435    let mut locked = registry.lock().await;
436    // Collect finished sub-task session IDs for auto-close.
437    let child_sessions: Vec<u64> = locked
438        .children(task_id)
439        .iter()
440        .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
441        .filter_map(|t| t.session_id)
442        .collect();
443    locked.complete(task_id, result_content, error_msg, reg);
444    drop(locked);
445
446    // Auto-close finished sub-task sessions outside the lock.
447    for sid in child_sessions {
448        let (reply_tx, _) = mpsc::unbounded_channel();
449        let _ = event_tx.send(DaemonEvent::Message {
450            msg: ClientMessage::Kill { session: sid },
451            reply: reply_tx,
452        });
453    }
454}