1use crate::daemon::event::{DaemonEvent, DaemonEventSender};
4use crate::hook::task::{InboxItem, Task, TaskStatus};
5use compact_str::CompactString;
6use std::{
7 collections::BTreeMap,
8 sync::{
9 Arc,
10 atomic::{AtomicU64, Ordering},
11 },
12 time::Duration,
13};
14use tokio::sync::{Mutex, broadcast, mpsc, oneshot, watch};
15use tokio::time::Instant;
16use wcore::protocol::message::{
17 TaskEvent,
18 client::ClientMessage,
19 server::{ServerMessage, TaskInfo},
20};
21
22pub struct TaskRegistry {
24 tasks: BTreeMap<u64, Task>,
25 next_id: AtomicU64,
26 pub max_concurrent: usize,
28 pub viewable_window: usize,
30 pub task_timeout: Duration,
32 pub event_tx: DaemonEventSender,
34 task_broadcast: broadcast::Sender<TaskEvent>,
36}
37
38impl TaskRegistry {
39 pub fn new(
41 max_concurrent: usize,
42 viewable_window: usize,
43 task_timeout: Duration,
44 event_tx: DaemonEventSender,
45 ) -> Self {
46 let (task_broadcast, _) = broadcast::channel(64);
47 Self {
48 tasks: BTreeMap::new(),
49 next_id: AtomicU64::new(1),
50 max_concurrent,
51 viewable_window,
52 task_timeout,
53 event_tx,
54 task_broadcast,
55 }
56 }
57
58 pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
60 self.task_broadcast.subscribe()
61 }
62
63 fn task_info(task: &Task) -> TaskInfo {
65 TaskInfo {
66 id: task.id,
67 parent_id: task.parent_id,
68 agent: task.agent.clone(),
69 status: task.status.to_string(),
70 description: task.description.clone(),
71 result: task.result.clone(),
72 error: task.error.clone(),
73 created_by: task.created_by.clone(),
74 prompt_tokens: task.prompt_tokens,
75 completion_tokens: task.completion_tokens,
76 alive_secs: task.created_at.elapsed().as_secs(),
77 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
78 }
79 }
80
81 pub fn create(
83 &mut self,
84 agent: CompactString,
85 description: String,
86 created_by: CompactString,
87 parent_id: Option<u64>,
88 status: TaskStatus,
89 spawned: bool,
90 ) -> u64 {
91 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
92 let (status_tx, _) = watch::channel(status);
93 let task = Task {
94 id,
95 parent_id,
96 session_id: None,
97 agent,
98 status,
99 created_by,
100 description,
101 result: None,
102 error: None,
103 blocked_on: None,
104 prompt_tokens: 0,
105 completion_tokens: 0,
106 created_at: Instant::now(),
107 abort_handle: None,
108 spawned,
109 status_tx,
110 };
111 self.tasks.insert(id, task);
112 if let Some(t) = self.tasks.get(&id) {
113 let _ = self.task_broadcast.send(TaskEvent::Created {
114 task: Self::task_info(t),
115 });
116 }
117 id
118 }
119
120 pub fn get(&self, id: u64) -> Option<&Task> {
122 self.tasks.get(&id)
123 }
124
125 pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
127 self.tasks.get_mut(&id)
128 }
129
130 pub fn set_status(&mut self, id: u64, status: TaskStatus) {
132 if let Some(task) = self.tasks.get_mut(&id) {
133 task.status = status;
134 let _ = task.status_tx.send(status);
135 let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
136 task_id: id,
137 status: status.to_string(),
138 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
139 });
140 }
141 }
142
143 pub fn remove(&mut self, id: u64) -> Option<Task> {
145 self.tasks.remove(&id)
146 }
147
148 pub fn list(
152 &self,
153 agent: Option<&str>,
154 status: Option<TaskStatus>,
155 parent_id: Option<Option<u64>>,
156 ) -> Vec<&Task> {
157 self.tasks
158 .values()
159 .rev()
160 .filter(|t| agent.is_none_or(|a| t.agent == a))
161 .filter(|t| status.is_none_or(|s| t.status == s))
162 .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
163 .take(self.viewable_window)
164 .collect()
165 }
166
167 pub fn active_count(&self) -> usize {
169 self.tasks
170 .values()
171 .filter(|t| t.status == TaskStatus::InProgress)
172 .count()
173 }
174
175 pub fn submit(
180 &mut self,
181 agent: CompactString,
182 message: String,
183 created_by: CompactString,
184 parent_id: Option<u64>,
185 registry: Arc<Mutex<TaskRegistry>>,
186 ) -> (u64, TaskStatus) {
187 let under_limit = self.active_count() < self.max_concurrent;
188 let initial_status = if under_limit {
189 TaskStatus::InProgress
190 } else {
191 TaskStatus::Queued
192 };
193
194 let task_id = self.create(
195 agent.clone(),
196 message.clone(),
197 created_by,
198 parent_id,
199 initial_status,
200 true,
201 );
202
203 if under_limit {
204 self.dispatch_task(task_id, agent, message, registry);
205 }
206
207 (task_id, initial_status)
208 }
209
210 fn dispatch_task(
212 &mut self,
213 task_id: u64,
214 agent: CompactString,
215 message: String,
216 registry: Arc<Mutex<TaskRegistry>>,
217 ) {
218 let (reply_tx, reply_rx) = mpsc::unbounded_channel();
219 let msg = ClientMessage::Send {
220 agent,
221 content: message,
222 session: None,
223 sender: None,
224 };
225 let _ = self.event_tx.send(DaemonEvent::Message {
226 msg,
227 reply: reply_tx,
228 });
229
230 let event_tx = self.event_tx.clone();
231 let timeout = self.task_timeout;
232 let handle = tokio::spawn(task_watcher(task_id, reply_rx, registry, event_tx, timeout));
233 if let Some(task) = self.tasks.get_mut(&task_id) {
234 task.abort_handle = Some(handle.abort_handle());
235 }
236 }
237
238 pub fn complete(
240 &mut self,
241 task_id: u64,
242 result: Option<String>,
243 error: Option<String>,
244 registry: Arc<Mutex<TaskRegistry>>,
245 ) {
246 if let Some(task) = self.tasks.get_mut(&task_id) {
247 if error.is_some() {
248 task.status = TaskStatus::Failed;
249 task.error = error.clone();
250 let _ = task.status_tx.send(TaskStatus::Failed);
251 let _ = self.task_broadcast.send(TaskEvent::Completed {
252 task_id,
253 status: TaskStatus::Failed.to_string(),
254 result: None,
255 error,
256 });
257 } else {
258 task.status = TaskStatus::Finished;
259 task.result = result.clone();
260 let _ = task.status_tx.send(TaskStatus::Finished);
261 let _ = self.task_broadcast.send(TaskEvent::Completed {
262 task_id,
263 status: TaskStatus::Finished.to_string(),
264 result,
265 error: None,
266 });
267 }
268 }
269 self.promote_next(registry);
270 }
271
272 pub fn promote_next(&mut self, registry: Arc<Mutex<TaskRegistry>>) {
274 if self.active_count() >= self.max_concurrent {
275 return;
276 }
277 let next = self
279 .tasks
280 .values()
281 .find(|t| t.status == TaskStatus::Queued)
282 .map(|t| (t.id, t.agent.clone(), t.description.clone()));
283
284 if let Some((id, agent, message)) = next {
285 self.set_status(id, TaskStatus::InProgress);
286 self.dispatch_task(id, agent, message, registry);
287 }
288 }
289
290 pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
294 let task = self.tasks.get_mut(&task_id)?;
295 let (tx, rx) = oneshot::channel();
296 task.blocked_on = Some(InboxItem {
297 question,
298 reply: tx,
299 });
300 task.status = TaskStatus::Blocked;
301 let _ = task.status_tx.send(TaskStatus::Blocked);
302 let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
303 task_id,
304 status: TaskStatus::Blocked.to_string(),
305 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
306 });
307 Some(rx)
308 }
309
310 pub fn approve(&mut self, task_id: u64, response: String) -> bool {
312 let Some(task) = self.tasks.get_mut(&task_id) else {
313 return false;
314 };
315 if task.status != TaskStatus::Blocked {
316 return false;
317 }
318 if let Some(inbox) = task.blocked_on.take() {
319 let _ = inbox.reply.send(response);
320 }
321 task.status = TaskStatus::InProgress;
322 let _ = task.status_tx.send(TaskStatus::InProgress);
323 let _ = self.task_broadcast.send(TaskEvent::StatusChanged {
324 task_id,
325 status: TaskStatus::InProgress.to_string(),
326 blocked_on: None,
327 });
328 true
329 }
330
331 pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
333 self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
334 }
335
336 pub fn children(&self, parent_id: u64) -> Vec<&Task> {
338 self.tasks
339 .values()
340 .filter(|t| t.parent_id == Some(parent_id))
341 .collect()
342 }
343
344 pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
346 self.tasks
347 .values()
348 .find(|t| t.session_id == Some(session_id))
349 .map(|t| t.id)
350 }
351
352 pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
354 if let Some(task) = self.tasks.get_mut(&task_id) {
355 task.prompt_tokens += prompt;
356 task.completion_tokens += completion;
357 }
358 }
359
360 pub fn queued_create_tasks(&self) -> BTreeMap<CompactString, Vec<(u64, String)>> {
365 let mut groups: BTreeMap<CompactString, Vec<(u64, String)>> = BTreeMap::new();
366 for task in self.tasks.values() {
367 if task.status == TaskStatus::Queued && !task.spawned {
368 let entry = groups.entry(task.agent.clone()).or_default();
369 if entry.len() < self.max_concurrent {
370 entry.push((task.id, task.description.clone()));
371 }
372 }
373 }
374 groups
375 }
376
377 pub fn queued_create_tasks_for(&self, agent: &str) -> Vec<(u64, String)> {
380 let mut entries = Vec::new();
381 for task in self.tasks.values() {
382 if task.status == TaskStatus::Queued && !task.spawned && task.agent == agent {
383 entries.push((task.id, task.description.clone()));
384 if entries.len() >= self.max_concurrent {
385 break;
386 }
387 }
388 }
389 entries
390 }
391}
392
393async fn task_watcher(
395 task_id: u64,
396 mut reply_rx: mpsc::UnboundedReceiver<ServerMessage>,
397 registry: Arc<Mutex<TaskRegistry>>,
398 event_tx: DaemonEventSender,
399 timeout: Duration,
400) {
401 let mut result_content: Option<String> = None;
402 let mut error_msg: Option<String> = None;
403 let mut session_id: Option<u64> = None;
404
405 let collect = async {
406 while let Some(msg) = reply_rx.recv().await {
407 match msg {
408 ServerMessage::Response(resp) => {
409 session_id = Some(resp.session);
410 result_content = Some(resp.content);
411 }
412 ServerMessage::Error { message, .. } => {
413 error_msg = Some(message);
414 }
415 _ => {}
416 }
417 }
418 };
419
420 if tokio::time::timeout(timeout, collect).await.is_err() {
421 error_msg = Some("task timed out".into());
422 }
423
424 if let Some(sid) = session_id {
426 let (reply_tx, _reply_rx) = mpsc::unbounded_channel();
427 let _ = event_tx.send(DaemonEvent::Message {
428 msg: ClientMessage::Kill { session: sid },
429 reply: reply_tx,
430 });
431 }
432
433 let reg = registry.clone();
435 let mut locked = registry.lock().await;
436 let child_sessions: Vec<u64> = locked
438 .children(task_id)
439 .iter()
440 .filter(|t| t.status == TaskStatus::Finished || t.status == TaskStatus::Failed)
441 .filter_map(|t| t.session_id)
442 .collect();
443 locked.complete(task_id, result_content, error_msg, reg);
444 drop(locked);
445
446 for sid in child_sessions {
448 let (reply_tx, _) = mpsc::unbounded_channel();
449 let _ = event_tx.send(DaemonEvent::Message {
450 msg: ClientMessage::Kill { session: sid },
451 reply: reply_tx,
452 });
453 }
454}