1use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7 collections::BTreeMap,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12 time::Duration,
13};
14use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{
17 ClientMessage, KillMsg, SendMsg, ServerMessage, TaskCompleted, TaskCreated, TaskEvent,
18 TaskInfo, TaskStatusChanged, client_message, server_message, task_event,
19};
20
21pub struct TaskRegistry {
23 tasks: BTreeMap<u64, Task>,
24 next_id: AtomicU64,
25 pub max_concurrent: usize,
27 pub viewable_window: usize,
29 pub task_timeout: Duration,
31 pub event_tx: DaemonEventSender,
33 task_broadcast: broadcast::Sender<TaskEvent>,
35}
36
37impl TaskRegistry {
38 pub fn new(
40 max_concurrent: usize,
41 viewable_window: usize,
42 task_timeout: Duration,
43 event_tx: DaemonEventSender,
44 ) -> Self {
45 let (task_broadcast, _) = broadcast::channel(64);
46 Self {
47 tasks: BTreeMap::new(),
48 next_id: AtomicU64::new(1),
49 max_concurrent,
50 viewable_window,
51 task_timeout,
52 event_tx,
53 task_broadcast,
54 }
55 }
56
57 pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
59 self.task_broadcast.subscribe()
60 }
61
62 fn task_info(task: &Task) -> TaskInfo {
64 TaskInfo {
65 id: task.id,
66 parent_id: task.parent_id,
67 agent: task.agent.to_string(),
68 status: task.status.to_string(),
69 description: task.description.clone(),
70 result: task.result.clone(),
71 error: task.error.clone(),
72 created_by: task.created_by.to_string(),
73 prompt_tokens: task.prompt_tokens,
74 completion_tokens: task.completion_tokens,
75 alive_secs: task.created_at.elapsed().as_secs(),
76 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
77 }
78 }
79
80 pub fn create(
82 &mut self,
83 agent: CompactString,
84 description: String,
85 created_by: CompactString,
86 parent_id: Option<u64>,
87 status: TaskStatus,
88 spawned: bool,
89 ) -> u64 {
90 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
91 let (status_tx, _) = watch::channel(status);
92 let task = Task {
93 id,
94 parent_id,
95 session_id: None,
96 agent,
97 status,
98 created_by,
99 description,
100 result: None,
101 error: None,
102 blocked_on: None,
103 prompt_tokens: 0,
104 completion_tokens: 0,
105 created_at: Instant::now(),
106 abort_handle: None,
107 spawned,
108 status_tx,
109 };
110 self.tasks.insert(id, task);
111 if let Some(t) = self.tasks.get(&id) {
112 let _ = self.task_broadcast.send(TaskEvent {
113 event: Some(task_event::Event::Created(TaskCreated {
114 task: Some(Self::task_info(t)),
115 })),
116 });
117 }
118 id
119 }
120
121 pub fn get(&self, id: u64) -> Option<&Task> {
123 self.tasks.get(&id)
124 }
125
126 pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
128 self.tasks.get_mut(&id)
129 }
130
131 pub fn set_status(&mut self, id: u64, status: TaskStatus) {
133 if let Some(task) = self.tasks.get_mut(&id) {
134 task.status = status;
135 let _ = task.status_tx.send(status);
136 let _ = self.task_broadcast.send(TaskEvent {
137 event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
138 task_id: id,
139 status: status.to_string(),
140 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
141 })),
142 });
143 }
144 }
145
146 pub fn remove(&mut self, id: u64) -> Option<Task> {
148 self.tasks.remove(&id)
149 }
150
151 pub fn list(
155 &self,
156 agent: Option<&str>,
157 status: Option<TaskStatus>,
158 parent_id: Option<Option<u64>>,
159 ) -> Vec<&Task> {
160 self.tasks
161 .values()
162 .rev()
163 .filter(|t| agent.is_none_or(|a| t.agent == a))
164 .filter(|t| status.is_none_or(|s| t.status == s))
165 .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
166 .take(self.viewable_window)
167 .collect()
168 }
169
170 pub fn active_count(&self) -> usize {
172 self.tasks
173 .values()
174 .filter(|t| t.status == TaskStatus::InProgress)
175 .count()
176 }
177
178 pub fn submit(
183 &mut self,
184 agent: CompactString,
185 message: String,
186 created_by: CompactString,
187 parent_id: Option<u64>,
188 registry: Arc<Mutex<TaskRegistry>>,
189 ) -> (u64, TaskStatus) {
190 let under_limit = self.active_count() < self.max_concurrent;
191 let initial_status = if under_limit {
192 TaskStatus::InProgress
193 } else {
194 TaskStatus::Queued
195 };
196
197 let task_id = self.create(
198 agent.clone(),
199 message.clone(),
200 created_by,
201 parent_id,
202 initial_status,
203 true,
204 );
205
206 if under_limit {
207 self.dispatch_task(task_id, agent, message, registry);
208 }
209
210 (task_id, initial_status)
211 }
212
213 fn dispatch_task(
215 &mut self,
216 task_id: u64,
217 agent: CompactString,
218 message: String,
219 registry: Arc<Mutex<TaskRegistry>>,
220 ) {
221 let (reply_tx, reply_rx) = mpsc::unbounded_channel();
222 let msg = ClientMessage::from(SendMsg {
223 agent: agent.to_string(),
224 content: message,
225 session: None,
226 sender: None,
227 });
228 let _ = self.event_tx.send(DaemonEvent::Message {
229 msg,
230 reply: reply_tx,
231 });
232
233 let event_tx = self.event_tx.clone();
234 let timeout = self.task_timeout;
235 let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
236 if let Some(task) = self.tasks.get_mut(&task_id) {
237 task.abort_handle = Some(handle.abort_handle());
238 }
239 }
240
241 pub fn complete(
243 &mut self,
244 task_id: u64,
245 result: Option<String>,
246 error: Option<String>,
247 registry: Arc<Mutex<TaskRegistry>>,
248 ) {
249 if let Some(task) = self.tasks.get_mut(&task_id) {
250 if error.is_some() {
251 task.status = TaskStatus::Failed;
252 task.error = error.clone();
253 let _ = task.status_tx.send(TaskStatus::Failed);
254 let _ = self.task_broadcast.send(TaskEvent {
255 event: Some(task_event::Event::Completed(TaskCompleted {
256 task_id,
257 status: TaskStatus::Failed.to_string(),
258 result: None,
259 error,
260 })),
261 });
262 } else {
263 task.status = TaskStatus::Finished;
264 task.result = result.clone();
265 let _ = task.status_tx.send(TaskStatus::Finished);
266 let _ = self.task_broadcast.send(TaskEvent {
267 event: Some(task_event::Event::Completed(TaskCompleted {
268 task_id,
269 status: TaskStatus::Finished.to_string(),
270 result,
271 error: None,
272 })),
273 });
274 }
275 }
276 self.promote_next(registry);
277 }
278
279 pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
281 if self.active_count() >= self.max_concurrent {
282 return;
283 }
284 let next = self
286 .tasks
287 .values()
288 .find(|t| t.status == TaskStatus::Queued)
289 .map(|t| (t.id, t.agent.clone(), t.description.clone()));
290
291 if let Some((id, agent, message)) = next {
292 self.set_status(id, TaskStatus::InProgress);
293 self.dispatch_task(id, agent, message, registry);
294 }
295 }
296
297 pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
301 let task = self.tasks.get_mut(&task_id)?;
302 let (tx, rx) = oneshot::channel();
303 task.blocked_on = Some(InboxItem {
304 question,
305 reply: tx,
306 });
307 task.status = TaskStatus::Blocked;
308 let _ = task.status_tx.send(TaskStatus::Blocked);
309 let _ = self.task_broadcast.send(TaskEvent {
310 event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
311 task_id,
312 status: TaskStatus::Blocked.to_string(),
313 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
314 })),
315 });
316 Some(rx)
317 }
318
319 pub fn approve(&mut self, task_id: u64, response: String) -> bool {
321 let Some(task) = self.tasks.get_mut(&task_id) else {
322 return false;
323 };
324 if task.status != TaskStatus::Blocked {
325 return false;
326 }
327 if let Some(inbox) = task.blocked_on.take() {
328 let _ = inbox.reply.send(response);
329 }
330 task.status = TaskStatus::InProgress;
331 let _ = task.status_tx.send(TaskStatus::InProgress);
332 let _ = self.task_broadcast.send(TaskEvent {
333 event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
334 task_id,
335 status: TaskStatus::InProgress.to_string(),
336 blocked_on: None,
337 })),
338 });
339 true
340 }
341
342 pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
344 self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
345 }
346
347 pub fn children(&self, parent_id: u64) -> Vec<&Task> {
349 self.tasks
350 .values()
351 .filter(|t| t.parent_id == Some(parent_id))
352 .collect()
353 }
354
355 pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
357 self.tasks
358 .values()
359 .find(|t| t.session_id == Some(session_id))
360 .map(|t| t.id)
361 }
362
363 pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
365 if let Some(task) = self.tasks.get_mut(&task_id) {
366 task.prompt_tokens += prompt;
367 task.completion_tokens += completion;
368 }
369 }
370
371 pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
376 let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
377 for task in self.tasks.values() {
378 if task.status == TaskStatus::Queued && !task.spawned {
379 let entry = groups.entry(task.agent.clone()).or_default();
380 if entry.len() < self.max_concurrent {
381 entry.push((task.id, task.description.clone()));
382 }
383 }
384 }
385 groups
386 }
387
388 pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
391 let mut entries = Vec::new();
392 for task in self.tasks.values() {
393 if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
394 entries.push((task.id, task.description.clone()));
395 if entries.len() >= self.max_concurrent {
396 break;
397 }
398 }
399 }
400 entries
401 }
402}
403
404async fn task_watcher(
406 task_id: u64,
407 mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
408 registry: Arc<Mutex<TaskRegistry>>,
409 event_tx: DaemonEventSender,
410 timeout: Duration,
411) {
412 let mut result_content: Option<String> = None;
413 let mut error_msg: Option<String> = None;
414 let mut session_id: Option<u64> = None;
415
416 let collect = async {
417 while let Some(msg) = reply_rx.recv().await {
418 match msg.msg {
419 Some(server_message::Msg::Response(resp)) => {
420 session_id = Some(resp.session);
421 result_content = Some(resp.content);
422 }
423 Some(server_message::Msg::Error(err)) => {
424 error_msg = Some(err.message);
425 }
426 _ => {}
427 }
428 }
429 };
430
431 if tokio::time::timeout(timeout, collect).await.is_err() {
432 error_msg = Some("task timed out".into());
433 }
434
435 if let Some(sid) = session_id {
437 let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
438 let _ = event_tx.send(DaemonEvent::Message {
439 msg: ClientMessage {
440 msg: Some(client_message::Msg::Kill(KillMsg { session: sid })),
441 },
442 reply: reply_tx,
443 });
444 }
445
446 let reg = registry.clone();
448 let mut locked = registry.lock().await;
449 let child_sessions: Vec<u64> = locked
451 .children(task_id)
452 .iter()
453 .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
454 .filter_map(|t| t.session_id)
455 .collect();
456 locked.complete(task_id, result_content, error_msg, reg);
457 drop(locked);
458
459 for sid in child_sessions {
461 let (reply_tx, _) = mpsc::unbounded_channel();
462 let _ = event_tx.send(DaemonEvent::Message {
463 msg: ClientMessage {
464 msg: Some(client_message::Msg::Kill(KillMsg { session: sid })),
465 },
466 reply: reply_tx,
467 });
468 }
469}