Skip to main content

walrus_daemon/hook/task/
registry.rs

1//! Task registry — concurrency control, dispatch, and lifecycle.
2
3use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7    collections::BTreeMap,
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12    time::Duration,
13};
14use tokio::sync::{Mutex, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{client::ClientMessage, server::ServerMessage};
17
18/// In-memory task registry with concurrency control.
19pub struct TaskRegistry {
20    tasks: BTreeMap<u64, Task>,
21    next_id: AtomicU64,
22    /// Maximum number of concurrently InProgress tasks.
23    pub max_concurrent: usize,
24    /// Maximum number of tasks returned by `list()`.
25    pub viewable_window: usize,
26    /// Per-task execution timeout.
27    pub task_timeout: Duration,
28    /// Event channel for dispatching task execution.
29    pub event_tx: DaemonEventSender,
30}
31
32impl TaskRegistry {
33    /// Create a new registry with the given config and event sender.
34    pub fn new(
35        max_concurrent: usize,
36        viewable_window: usize,
37        task_timeout: Duration,
38        event_tx: DaemonEventSender,
39    ) -> Self {
40        Self {
41            tasks: BTreeMap::new(),
42            next_id: AtomicU64::new(1),
43            max_concurrent,
44            viewable_window,
45            task_timeout,
46            event_tx,
47        }
48    }
49
50    /// Create a new task and insert it into the registry.
51    pub fn create(
52        &mut self,
53        agent: CompactString,
54        description: String,
55        created_by: CompactString,
56        parent_id: Option<u64>,
57        status: TaskStatus,
58        spawned: bool,
59    ) -> u64 {
60        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
61        let (status_tx, _) = watch::channel(status);
62        let task = Task {
63            id,
64            parent_id,
65            session_id: None,
66            agent,
67            status,
68            created_by,
69            description,
70            result: None,
71            error: None,
72            blocked_on: None,
73            prompt_tokens: 0,
74            completion_tokens: 0,
75            created_at: Instant::now(),
76            abort_handle: None,
77            spawned,
78            status_tx,
79        };
80        self.tasks.insert(id, task);
81        id
82    }
83
84    /// Get a reference to a task by ID.
85    pub fn get(&self, id: u64) -> Option<&Task> {
86        self.tasks.get(&id)
87    }
88
89    /// Get a mutable reference to a task by ID.
90    pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
91        self.tasks.get_mut(&id)
92    }
93
94    /// Update task status and notify watchers.
95    pub fn set_status(&mut self, id: u64, status: TaskStatus) {
96        if let Some(task) = self.tasks.get_mut(&id) {
97            task.status = status;
98            let _ = task.status_tx.send(status);
99        }
100    }
101
102    /// Remove a task from the registry.
103    pub fn remove(&mut self, id: u64) -> Option<Task> {
104        self.tasks.remove(&id)
105    }
106
107    /// List tasks, most recent first, up to `viewable_window` entries.
108    ///
109    /// Optionally filters by agent, status, or parent_id.
110    pub fn list(
111        &self,
112        agent: Option<&str>,
113        status: Option<TaskStatus>,
114        parent_id: Option<Option<u64>>,
115    ) -> Vec<&Task> {
116        self.tasks
117            .values()
118            .rev()
119            .filter(|t| agent.is_none_or(|a| t.agent == a))
120            .filter(|t| status.is_none_or(|s| t.status == s))
121            .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
122            .take(self.viewable_window)
123            .collect()
124    }
125
126    /// Count of currently InProgress tasks (not Blocked).
127    pub fn active_count(&self) -> usize {
128        self.tasks
129            .values()
130            .filter(|t| t.status == TaskStatus::InProgress)
131            .count()
132    }
133
134    /// Submit a task for execution.
135    ///
136    /// If under the concurrency limit, dispatches immediately and spawns a
137    /// watcher. Otherwise, queues the task. Returns `(task_id, status)`.
138    pub fn submit(
139        &mut self,
140        agent: CompactString,
141        message: String,
142        created_by: CompactString,
143        parent_id: Option<u64>,
144        registry: Arc<Mutex<TaskRegistry>>,
145    ) -> (u64, TaskStatus) {
146        let under_limit = self.active_count() < self.max_concurrent;
147        let initial_status = if under_limit {
148            TaskStatus::InProgress
149        } else {
150            TaskStatus::Queued
151        };
152
153        let task_id = self.create(
154            agent.clone(),
155            message.clone(),
156            created_by,
157            parent_id,
158            initial_status,
159            true,
160        );
161
162        if under_limit {
163            self.dispatch_task(task_id, agent, message, registry);
164        }
165
166        (task_id, initial_status)
167    }
168
169    /// Dispatch a task: send the message via event channel and spawn a watcher.
170    fn dispatch_task(
171        &mut self,
172        task_id: u64,
173        agent: CompactString,
174        message: String,
175        registry: Arc<Mutex<TaskRegistry>>,
176    ) {
177        let (reply_tx, reply_rx) = mpsc::unbounded_channel();
178        let msg = ClientMessage::Send {
179            agent,
180            content: message,
181            session: None,
182            sender: None,
183        };
184        let _ = self.event_tx.send(DaemonEvent::Message {
185            msg,
186            reply: reply_tx,
187        });
188
189        let event_tx = self.event_tx.clone();
190        let timeout = self.task_timeout;
191        let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
192        if let Some(task) = self.tasks.get_mut(&task_id) {
193            task.abort_handle = Some(handle.abort_handle());
194        }
195    }
196
197    /// Mark a task as Finished or Failed, then promote the next queued task.
198    pub fn complete(
199        &mut self,
200        task_id: u64,
201        result: Option<String>,
202        error: Option<String>,
203        registry: Arc<Mutex<TaskRegistry>>,
204    ) {
205        if let Some(task) = self.tasks.get_mut(&task_id) {
206            if error.is_some() {
207                task.status = TaskStatus::Failed;
208                task.error = error;
209                let _ = task.status_tx.send(TaskStatus::Failed);
210            } else {
211                task.status = TaskStatus::Finished;
212                task.result = result;
213                let _ = task.status_tx.send(TaskStatus::Finished);
214            }
215        }
216        self.promote_next(registry);
217    }
218
219    /// Promote the next queued task to InProgress if a slot is available.
220    pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
221        if self.active_count() >= self.max_concurrent {
222            return;
223        }
224        // Find the oldest queued task.
225        let next = self
226            .tasks
227            .values()
228            .find(|t| t.status == TaskStatus::Queued)
229            .map(|t| (t.id, t.agent.clone(), t.description.clone()));
230
231        if let Some((id, agent, message)) = next {
232            self.set_status(id, TaskStatus::InProgress);
233            self.dispatch_task(id, agent, message, registry);
234        }
235    }
236
237    /// Block a task, setting status to Blocked and storing the inbox item.
238    ///
239    /// Returns a receiver that the tool call can await for the user's response.
240    pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
241        let task = self.tasks.get_mut(&task_id)?;
242        let (tx, rx) = oneshot::channel();
243        task.blocked_on = Some(InboxItem {
244            question,
245            reply: tx,
246        });
247        task.status = TaskStatus::Blocked;
248        let _ = task.status_tx.send(TaskStatus::Blocked);
249        Some(rx)
250    }
251
252    /// Approve a blocked task, sending the response and resuming execution.
253    pub fn approve(&mut self, task_id: u64, response: String) -> bool {
254        let Some(task) = self.tasks.get_mut(&task_id) else {
255            return false;
256        };
257        if task.status != TaskStatus::Blocked {
258            return false;
259        }
260        if let Some(inbox) = task.blocked_on.take() {
261            let _ = inbox.reply.send(response);
262        }
263        task.status = TaskStatus::InProgress;
264        let _ = task.status_tx.send(TaskStatus::InProgress);
265        true
266    }
267
268    /// Subscribe to a task's status changes (for await_tasks).
269    pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
270        self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
271    }
272
273    /// Get all child tasks of a given parent.
274    pub fn children(&self, parent_id: u64) -> Vec<&Task> {
275        self.tasks
276            .values()
277            .filter(|t| t.parent_id == Some(parent_id))
278            .collect()
279    }
280
281    /// Find a task by its session ID. Returns the task ID.
282    pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
283        self.tasks
284            .values()
285            .find(|t| t.session_id == Some(session_id))
286            .map(|t| t.id)
287    }
288
289    /// Add token usage to a task.
290    pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
291        if let Some(task) = self.tasks.get_mut(&task_id) {
292            task.prompt_tokens += prompt;
293            task.completion_tokens += completion;
294        }
295    }
296
297    /// Collect queued `create_task` entries grouped by agent.
298    ///
299    /// Returns `(agent, [(task_id, description)])` pairs, capped at
300    /// `max_concurrent` tasks per agent to avoid context overflow.
301    pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
302        let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
303        for task in self.tasks.values() {
304            if task.status == TaskStatus::Queued && !task.spawned {
305                let entry = groups.entry(task.agent.clone()).or_default();
306                if entry.len() < self.max_concurrent {
307                    entry.push((task.id, task.description.clone()));
308                }
309            }
310        }
311        groups
312    }
313
314    /// Collect queued `create_task` entries for a single agent, capped at
315    /// `max_concurrent`.
316    pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
317        let mut entries = Vec::new();
318        for task in self.tasks.values() {
319            if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
320                entries.push((task.id, task.description.clone()));
321                if entries.len() >= self.max_concurrent {
322                    break;
323                }
324            }
325        }
326        entries
327    }
328}
329
330/// Watcher task: awaits reply messages with timeout, closes session, completes task.
331async fn task_watcher(
332    task_id: u64,
333    mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
334    registry: Arc<Mutex<TaskRegistry>>,
335    event_tx: DaemonEventSender,
336    timeout: Duration,
337) {
338    let mut result_content: Option<String> = None;
339    let mut error_msg: Option<String> = None;
340    let mut session_id: Option<u64> = None;
341
342    let collect = async {
343        while let Some(msg) = reply_rx.recv().await {
344            match msg {
345                ServerMessage::Response(resp) => {
346                    session_id = Some(resp.session);
347                    result_content = Some(resp.content);
348                }
349                ServerMessage::Error { message, .. } => {
350                    error_msg = Some(message);
351                }
352                _ => {}
353            }
354        }
355    };
356
357    if tokio::time::timeout(timeout, collect).await.is_err() {
358        error_msg = Some("task timed out".into());
359    }
360
361    // Close the session to prevent accumulation.
362    if let Some(sid) = session_id {
363        let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
364        let _ = event_tx.send(DaemonEvent::Message {
365            msg: ClientMessage::Kill { session: sid },
366            reply: reply_tx,
367        });
368    }
369
370    // Complete the task, auto-close sub-task sessions, and promote next queued.
371    let reg = registry.clone();
372    let mut locked = registry.lock().await;
373    // Collect finished sub-task session IDs for auto-close.
374    let child_sessions: Vec<u64> = locked
375        .children(task_id)
376        .iter()
377        .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
378        .filter_map(|t| t.session_id)
379        .collect();
380    locked.complete(task_id, result_content, error_msg, reg);
381    drop(locked);
382
383    // Auto-close finished sub-task sessions outside the lock.
384    for sid in child_sessions {
385        let (reply_tx, _) = mpsc::unbounded_channel();
386        let _ = event_tx.send(DaemonEvent::Message {
387            msg: ClientMessage::Kill { session: sid },
388            reply: reply_tx,
389        });
390    }
391}