1use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7 collections::BTreeMap,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12 time::Duration,
13};
14use tokio::sync::{Mutex, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{client::ClientMessage, server::ServerMessage};
17
18pub struct TaskRegistry {
20 tasks: BTreeMap<u64, Task>,
21 next_id: AtomicU64,
22 pub max_concurrent: usize,
24 pub viewable_window: usize,
26 pub task_timeout: Duration,
28 pub event_tx: DaemonEventSender,
30}
31
32impl TaskRegistry {
33 pub fn new(
35 max_concurrent: usize,
36 viewable_window: usize,
37 task_timeout: Duration,
38 event_tx: DaemonEventSender,
39 ) -> Self {
40 Self {
41 tasks: BTreeMap::new(),
42 next_id: AtomicU64::new(1),
43 max_concurrent,
44 viewable_window,
45 task_timeout,
46 event_tx,
47 }
48 }
49
50 pub fn create(
52 &mut self,
53 agent: CompactString,
54 description: String,
55 created_by: CompactString,
56 parent_id: Option<u64>,
57 status: TaskStatus,
58 spawned: bool,
59 ) -> u64 {
60 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
61 let (status_tx, _) = watch::channel(status);
62 let task = Task {
63 id,
64 parent_id,
65 session_id: None,
66 agent,
67 status,
68 created_by,
69 description,
70 result: None,
71 error: None,
72 blocked_on: None,
73 prompt_tokens: 0,
74 completion_tokens: 0,
75 created_at: Instant::now(),
76 abort_handle: None,
77 spawned,
78 status_tx,
79 };
80 self.tasks.insert(id, task);
81 id
82 }
83
84 pub fn get(&self, id: u64) -> Option<&Task> {
86 self.tasks.get(&id)
87 }
88
89 pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
91 self.tasks.get_mut(&id)
92 }
93
94 pub fn set_status(&mut self, id: u64, status: TaskStatus) {
96 if let Some(task) = self.tasks.get_mut(&id) {
97 task.status = status;
98 let _ = task.status_tx.send(status);
99 }
100 }
101
102 pub fn remove(&mut self, id: u64) -> Option<Task> {
104 self.tasks.remove(&id)
105 }
106
107 pub fn list(
111 &self,
112 agent: Option<&str>,
113 status: Option<TaskStatus>,
114 parent_id: Option<Option<u64>>,
115 ) -> Vec<&Task> {
116 self.tasks
117 .values()
118 .rev()
119 .filter(|t| agent.is_none_or(|a| t.agent == a))
120 .filter(|t| status.is_none_or(|s| t.status == s))
121 .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
122 .take(self.viewable_window)
123 .collect()
124 }
125
126 pub fn active_count(&self) -> usize {
128 self.tasks
129 .values()
130 .filter(|t| t.status == TaskStatus::InProgress)
131 .count()
132 }
133
134 pub fn submit(
139 &mut self,
140 agent: CompactString,
141 message: String,
142 created_by: CompactString,
143 parent_id: Option<u64>,
144 registry: Arc<Mutex<TaskRegistry>>,
145 ) -> (u64, TaskStatus) {
146 let under_limit = self.active_count() < self.max_concurrent;
147 let initial_status = if under_limit {
148 TaskStatus::InProgress
149 } else {
150 TaskStatus::Queued
151 };
152
153 let task_id = self.create(
154 agent.clone(),
155 message.clone(),
156 created_by,
157 parent_id,
158 initial_status,
159 true,
160 );
161
162 if under_limit {
163 self.dispatch_task(task_id, agent, message, registry);
164 }
165
166 (task_id, initial_status)
167 }
168
169 fn dispatch_task(
171 &mut self,
172 task_id: u64,
173 agent: CompactString,
174 message: String,
175 registry: Arc<Mutex<TaskRegistry>>,
176 ) {
177 let (reply_tx, reply_rx) = mpsc::unbounded_channel();
178 let msg = ClientMessage::Send {
179 agent,
180 content: message,
181 session: None,
182 sender: None,
183 };
184 let _ = self.event_tx.send(DaemonEvent::Message {
185 msg,
186 reply: reply_tx,
187 });
188
189 let event_tx = self.event_tx.clone();
190 let timeout = self.task_timeout;
191 let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
192 if let Some(task) = self.tasks.get_mut(&task_id) {
193 task.abort_handle = Some(handle.abort_handle());
194 }
195 }
196
197 pub fn complete(
199 &mut self,
200 task_id: u64,
201 result: Option<String>,
202 error: Option<String>,
203 registry: Arc<Mutex<TaskRegistry>>,
204 ) {
205 if let Some(task) = self.tasks.get_mut(&task_id) {
206 if error.is_some() {
207 task.status = TaskStatus::Failed;
208 task.error = error;
209 let _ = task.status_tx.send(TaskStatus::Failed);
210 } else {
211 task.status = TaskStatus::Finished;
212 task.result = result;
213 let _ = task.status_tx.send(TaskStatus::Finished);
214 }
215 }
216 self.promote_next(registry);
217 }
218
219 pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
221 if self.active_count() >= self.max_concurrent {
222 return;
223 }
224 let next = self
226 .tasks
227 .values()
228 .find(|t| t.status == TaskStatus::Queued)
229 .map(|t| (t.id, t.agent.clone(), t.description.clone()));
230
231 if let Some((id, agent, message)) = next {
232 self.set_status(id, TaskStatus::InProgress);
233 self.dispatch_task(id, agent, message, registry);
234 }
235 }
236
237 pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
241 let task = self.tasks.get_mut(&task_id)?;
242 let (tx, rx) = oneshot::channel();
243 task.blocked_on = Some(InboxItem {
244 question,
245 reply: tx,
246 });
247 task.status = TaskStatus::Blocked;
248 let _ = task.status_tx.send(TaskStatus::Blocked);
249 Some(rx)
250 }
251
252 pub fn approve(&mut self, task_id: u64, response: String) -> bool {
254 let Some(task) = self.tasks.get_mut(&task_id) else {
255 return false;
256 };
257 if task.status != TaskStatus::Blocked {
258 return false;
259 }
260 if let Some(inbox) = task.blocked_on.take() {
261 let _ = inbox.reply.send(response);
262 }
263 task.status = TaskStatus::InProgress;
264 let _ = task.status_tx.send(TaskStatus::InProgress);
265 true
266 }
267
268 pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
270 self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
271 }
272
273 pub fn children(&self, parent_id: u64) -> Vec<&Task> {
275 self.tasks
276 .values()
277 .filter(|t| t.parent_id == Some(parent_id))
278 .collect()
279 }
280
281 pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
283 self.tasks
284 .values()
285 .find(|t| t.session_id == Some(session_id))
286 .map(|t| t.id)
287 }
288
289 pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
291 if let Some(task) = self.tasks.get_mut(&task_id) {
292 task.prompt_tokens += prompt;
293 task.completion_tokens += completion;
294 }
295 }
296
297 pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
302 let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
303 for task in self.tasks.values() {
304 if task.status == TaskStatus::Queued && !task.spawned {
305 let entry = groups.entry(task.agent.clone()).or_default();
306 if entry.len() < self.max_concurrent {
307 entry.push((task.id, task.description.clone()));
308 }
309 }
310 }
311 groups
312 }
313
314 pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
317 let mut entries = Vec::new();
318 for task in self.tasks.values() {
319 if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
320 entries.push((task.id, task.description.clone()));
321 if entries.len() >= self.max_concurrent {
322 break;
323 }
324 }
325 }
326 entries
327 }
328}
329
330async fn task_watcher(
332 task_id: u64,
333 mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
334 registry: Arc<Mutex<TaskRegistry>>,
335 event_tx: DaemonEventSender,
336 timeout: Duration,
337) {
338 let mut result_content: Option<String> = None;
339 let mut error_msg: Option<String> = None;
340 let mut session_id: Option<u64> = None;
341
342 let collect = async {
343 while let Some(msg) = reply_rx.recv().await {
344 match msg {
345 ServerMessage::Response(resp) => {
346 session_id = Some(resp.session);
347 result_content = Some(resp.content);
348 }
349 ServerMessage::Error { message, .. } => {
350 error_msg = Some(message);
351 }
352 _ => {}
353 }
354 }
355 };
356
357 if tokio::time::timeout(timeout, collect).await.is_err() {
358 error_msg = Some("task timed out".into());
359 }
360
361 if let Some(sid) = session_id {
363 let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
364 let _ = event_tx.send(DaemonEvent::Message {
365 msg: ClientMessage::Kill { session: sid },
366 reply: reply_tx,
367 });
368 }
369
370 let reg = registry.clone();
372 let mut locked = registry.lock().await;
373 let child_sessions: Vec<u64> = locked
375 .children(task_id)
376 .iter()
377 .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
378 .filter_map(|t| t.session_id)
379 .collect();
380 locked.complete(task_id, result_content, error_msg, reg);
381 drop(locked);
382
383 for sid in child_sessions {
385 let (reply_tx, _) = mpsc::unbounded_channel();
386 let _ = event_tx.send(DaemonEvent::Message {
387 msg: ClientMessage::Kill { session: sid },
388 reply: reply_tx,
389 });
390 }
391}