Skip to main content

walrus_daemon/hook/system/task/
registry.rs

1//! Task registry — concurrency control and lifecycle.
2//!
3//! Pure data structure: no dispatch or spawning. Callers (hook, event loop)
4//! own task execution; the registry just tracks state and broadcasts events.
5
6use crate::hook::system::task::{InboxItem, Task, TaskStatus};
7use compact_str::CompactString;
8use std::{
9    collections::BTreeMap,
10    sync::atomic::{AtomicU64, Ordering},
11};
12use tokio::sync::{broadcast, oneshot, watch};
13use tokio::time::Instant;
14use wcore::protocol::message::{
15    TaskCompleted, TaskCreated, TaskEvent, TaskInfo, TaskStatusChanged, task_event,
16};
17
18/// In-memory task registry with concurrency control.
19pub struct TaskRegistry {
20    tasks: BTreeMap<u64, Task>,
21    next_id: AtomicU64,
22    /// Maximum number of concurrently InProgress tasks.
23    pub max_concurrent: usize,
24    /// Maximum number of tasks returned by `list()`.
25    pub viewable_window: usize,
26    /// Broadcast channel for task lifecycle events (subscriptions).
27    task_broadcast: broadcast::Sender<TaskEvent>,
28}
29
30impl TaskRegistry {
31    /// Create a new registry with the given config.
32    pub fn new(max_concurrent: usize, viewable_window: usize) -> Self {
33        let (task_broadcast, _) = broadcast::channel(64);
34        Self {
35            tasks: BTreeMap::new(),
36            next_id: AtomicU64::new(1),
37            max_concurrent,
38            viewable_window,
39            task_broadcast,
40        }
41    }
42
43    /// Subscribe to task lifecycle events.
44    pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
45        self.task_broadcast.subscribe()
46    }
47
48    /// Build a `TaskInfo` snapshot from an internal `Task`.
49    pub fn task_info(task: &Task) -> TaskInfo {
50        TaskInfo {
51            id: task.id,
52            parent_id: task.parent_id,
53            agent: task.agent.to_string(),
54            status: task.status.to_string(),
55            description: task.description.clone(),
56            result: task.result.clone(),
57            error: task.error.clone(),
58            created_by: task.created_by.to_string(),
59            prompt_tokens: task.prompt_tokens,
60            completion_tokens: task.completion_tokens,
61            alive_secs: task.created_at.elapsed().as_secs(),
62            blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
63        }
64    }
65
66    /// Create a new task and insert it into the registry. Returns task ID.
67    pub fn create(
68        &mut self,
69        agent: CompactString,
70        description: String,
71        created_by: CompactString,
72        parent_id: Option<u64>,
73        status: TaskStatus,
74    ) -> u64 {
75        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
76        let (status_tx, _) = watch::channel(status);
77        let task = Task {
78            id,
79            parent_id,
80            session_id: None,
81            agent,
82            status,
83            created_by,
84            description,
85            result: None,
86            error: None,
87            blocked_on: None,
88            prompt_tokens: 0,
89            completion_tokens: 0,
90            created_at: Instant::now(),
91            abort_handle: None,
92            status_tx,
93        };
94        self.tasks.insert(id, task);
95        if let Some(t) = self.tasks.get(&id) {
96            let _ = self.task_broadcast.send(TaskEvent {
97                event: Some(task_event::Event::Created(TaskCreated {
98                    task: Some(Self::task_info(t)),
99                })),
100            });
101        }
102        id
103    }
104
105    /// Get a reference to a task by ID.
106    pub fn get(&self, id: u64) -> Option<&Task> {
107        self.tasks.get(&id)
108    }
109
110    /// Get a mutable reference to a task by ID.
111    pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
112        self.tasks.get_mut(&id)
113    }
114
115    /// Update task status and notify all watchers (watch + broadcast).
116    ///
117    /// This is the **single path** for all status transitions.
118    pub fn set_status(&mut self, id: u64, status: TaskStatus) {
119        if let Some(task) = self.tasks.get_mut(&id) {
120            task.status = status;
121            let _ = task.status_tx.send(status);
122            let _ = self.task_broadcast.send(TaskEvent {
123                event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
124                    task_id: id,
125                    status: status.to_string(),
126                    blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
127                })),
128            });
129        }
130    }
131
132    /// Remove a task from the registry.
133    pub fn remove(&mut self, id: u64) -> Option<Task> {
134        self.tasks.remove(&id)
135    }
136
137    /// List tasks, most recent first, up to `viewable_window` entries.
138    pub fn list(
139        &self,
140        agent: Option<&str>,
141        status: Option<TaskStatus>,
142        parent_id: Option<Option<u64>>,
143    ) -> Vec<&Task> {
144        self.tasks
145            .values()
146            .rev()
147            .filter(|t| agent.is_none_or(|a| t.agent == a))
148            .filter(|t| status.is_none_or(|s| t.status == s))
149            .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
150            .take(self.viewable_window)
151            .collect()
152    }
153
154    /// Count of currently InProgress tasks (not Blocked).
155    pub fn active_count(&self) -> usize {
156        self.tasks
157            .values()
158            .filter(|t| t.status == TaskStatus::InProgress)
159            .count()
160    }
161
162    /// Whether a new task can be dispatched immediately.
163    pub fn has_slot(&self) -> bool {
164        self.active_count() < self.max_concurrent
165    }
166
167    /// Mark a task as Finished or Failed and broadcast a Completed event.
168    pub fn complete(&mut self, task_id: u64, result: Option<String>, error: Option<String>) {
169        let status = if error.is_some() {
170            TaskStatus::Failed
171        } else {
172            TaskStatus::Finished
173        };
174        if let Some(task) = self.tasks.get_mut(&task_id) {
175            task.result = result.clone();
176            task.error = error.clone();
177        }
178        self.set_status(task_id, status);
179        let _ = self.task_broadcast.send(TaskEvent {
180            event: Some(task_event::Event::Completed(TaskCompleted {
181                task_id,
182                status: status.to_string(),
183                result,
184                error,
185            })),
186        });
187    }
188
189    /// Find the next queued task and return its dispatch info, or `None`.
190    pub fn promote_next(&mut self) -> Option<(u64, CompactString, String)> {
191        if !self.has_slot() {
192            return None;
193        }
194        let next = self
195            .tasks
196            .values()
197            .find(|t| t.status == TaskStatus::Queued)
198            .map(|t| (t.id, t.agent.clone(), t.description.clone()));
199        if let Some((id, _, _)) = &next {
200            self.set_status(*id, TaskStatus::InProgress);
201        }
202        next
203    }
204
205    /// Block a task for user approval. Returns a receiver for the response.
206    pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
207        let task = self.tasks.get_mut(&task_id)?;
208        let (tx, rx) = oneshot::channel();
209        task.blocked_on = Some(InboxItem {
210            question,
211            reply: tx,
212        });
213        self.set_status(task_id, TaskStatus::Blocked);
214        Some(rx)
215    }
216
217    /// Approve a blocked task, sending the response and resuming execution.
218    pub fn approve(&mut self, task_id: u64, response: String) -> bool {
219        let Some(task) = self.tasks.get_mut(&task_id) else {
220            return false;
221        };
222        if task.status != TaskStatus::Blocked {
223            return false;
224        }
225        if let Some(inbox) = task.blocked_on.take() {
226            let _ = inbox.reply.send(response);
227        }
228        self.set_status(task_id, TaskStatus::InProgress);
229        true
230    }
231
232    /// Kill a running or blocked task. Returns abort handle if it had one.
233    pub fn kill(&mut self, task_id: u64) -> Option<tokio::task::AbortHandle> {
234        let task = self.tasks.get_mut(&task_id)?;
235        let handle = task.abort_handle.take();
236        if let Some(ref h) = handle {
237            h.abort();
238        }
239        task.error = Some("killed by user".into());
240        self.set_status(task_id, TaskStatus::Failed);
241        handle
242    }
243
244    /// Subscribe to a task's status changes (for await_tasks).
245    pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
246        self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
247    }
248
249    /// Get all child tasks of a given parent.
250    pub fn children(&self, parent_id: u64) -> Vec<&Task> {
251        self.tasks
252            .values()
253            .filter(|t| t.parent_id == Some(parent_id))
254            .collect()
255    }
256
257    /// Find a task by its session ID. Returns the task ID.
258    pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
259        self.tasks
260            .values()
261            .find(|t| t.session_id == Some(session_id))
262            .map(|t| t.id)
263    }
264
265    /// Add token usage to a task.
266    pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
267        if let Some(task) = self.tasks.get_mut(&task_id) {
268            task.prompt_tokens += prompt;
269            task.completion_tokens += completion;
270        }
271    }
272}