walrus_daemon/hook/system/task/
registry.rs1use crate::hook::system::task::{InboxItem, Task, TaskStatus};
7use compact_str::CompactString;
8use std::{
9 collections::BTreeMap,
10 sync::atomic::{AtomicU64, Ordering},
11};
12use tokio::sync::{broadcast, oneshot, watch};
13use tokio::time::Instant;
14use wcore::protocol::message::{
15 TaskCompleted, TaskCreated, TaskEvent, TaskInfo, TaskStatusChanged, task_event,
16};
17
18pub struct TaskRegistry {
20 tasks: BTreeMap<u64, Task>,
21 next_id: AtomicU64,
22 pub max_concurrent: usize,
24 pub viewable_window: usize,
26 task_broadcast: broadcast::Sender<TaskEvent>,
28}
29
30impl TaskRegistry {
31 pub fn new(max_concurrent: usize, viewable_window: usize) -> Self {
33 let (task_broadcast, _) = broadcast::channel(64);
34 Self {
35 tasks: BTreeMap::new(),
36 next_id: AtomicU64::new(1),
37 max_concurrent,
38 viewable_window,
39 task_broadcast,
40 }
41 }
42
43 pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
45 self.task_broadcast.subscribe()
46 }
47
48 pub fn task_info(task: &Task) -> TaskInfo {
50 TaskInfo {
51 id: task.id,
52 parent_id: task.parent_id,
53 agent: task.agent.to_string(),
54 status: task.status.to_string(),
55 description: task.description.clone(),
56 result: task.result.clone(),
57 error: task.error.clone(),
58 created_by: task.created_by.to_string(),
59 prompt_tokens: task.prompt_tokens,
60 completion_tokens: task.completion_tokens,
61 alive_secs: task.created_at.elapsed().as_secs(),
62 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
63 }
64 }
65
66 pub fn create(
68 &mut self,
69 agent: CompactString,
70 description: String,
71 created_by: CompactString,
72 parent_id: Option<u64>,
73 status: TaskStatus,
74 ) -> u64 {
75 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
76 let (status_tx, _) = watch::channel(status);
77 let task = Task {
78 id,
79 parent_id,
80 session_id: None,
81 agent,
82 status,
83 created_by,
84 description,
85 result: None,
86 error: None,
87 blocked_on: None,
88 prompt_tokens: 0,
89 completion_tokens: 0,
90 created_at: Instant::now(),
91 abort_handle: None,
92 status_tx,
93 };
94 self.tasks.insert(id, task);
95 if let Some(t) = self.tasks.get(&id) {
96 let _ = self.task_broadcast.send(TaskEvent {
97 event: Some(task_event::Event::Created(TaskCreated {
98 task: Some(Self::task_info(t)),
99 })),
100 });
101 }
102 id
103 }
104
105 pub fn get(&self, id: u64) -> Option<&Task> {
107 self.tasks.get(&id)
108 }
109
110 pub fn get_mut(&mut self, id: u64) -> Option<&mut Task> {
112 self.tasks.get_mut(&id)
113 }
114
115 pub fn set_status(&mut self, id: u64, status: TaskStatus) {
119 if let Some(task) = self.tasks.get_mut(&id) {
120 task.status = status;
121 let _ = task.status_tx.send(status);
122 let _ = self.task_broadcast.send(TaskEvent {
123 event: Some(task_event::Event::StatusChanged(TaskStatusChanged {
124 task_id: id,
125 status: status.to_string(),
126 blocked_on: task.blocked_on.as_ref().map(|i| i.question.clone()),
127 })),
128 });
129 }
130 }
131
132 pub fn remove(&mut self, id: u64) -> Option<Task> {
134 self.tasks.remove(&id)
135 }
136
137 pub fn list(
139 &self,
140 agent: Option<&str>,
141 status: Option<TaskStatus>,
142 parent_id: Option<Option<u64>>,
143 ) -> Vec<&Task> {
144 self.tasks
145 .values()
146 .rev()
147 .filter(|t| agent.is_none_or(|a| t.agent == a))
148 .filter(|t| status.is_none_or(|s| t.status == s))
149 .filter(|t| parent_id.is_none_or(|p| t.parent_id == p))
150 .take(self.viewable_window)
151 .collect()
152 }
153
154 pub fn active_count(&self) -> usize {
156 self.tasks
157 .values()
158 .filter(|t| t.status == TaskStatus::InProgress)
159 .count()
160 }
161
162 pub fn has_slot(&self) -> bool {
164 self.active_count() < self.max_concurrent
165 }
166
167 pub fn complete(&mut self, task_id: u64, result: Option<String>, error: Option<String>) {
169 let status = if error.is_some() {
170 TaskStatus::Failed
171 } else {
172 TaskStatus::Finished
173 };
174 if let Some(task) = self.tasks.get_mut(&task_id) {
175 task.result = result.clone();
176 task.error = error.clone();
177 }
178 self.set_status(task_id, status);
179 let _ = self.task_broadcast.send(TaskEvent {
180 event: Some(task_event::Event::Completed(TaskCompleted {
181 task_id,
182 status: status.to_string(),
183 result,
184 error,
185 })),
186 });
187 }
188
189 pub fn promote_next(&mut self) -> Option<(u64, CompactString, String)> {
191 if !self.has_slot() {
192 return None;
193 }
194 let next = self
195 .tasks
196 .values()
197 .find(|t| t.status == TaskStatus::Queued)
198 .map(|t| (t.id, t.agent.clone(), t.description.clone()));
199 if let Some((id, _, _)) = &next {
200 self.set_status(*id, TaskStatus::InProgress);
201 }
202 next
203 }
204
205 pub fn block(&mut self, task_id: u64, question: String) -> Option<oneshot::Receiver<String>> {
207 let task = self.tasks.get_mut(&task_id)?;
208 let (tx, rx) = oneshot::channel();
209 task.blocked_on = Some(InboxItem {
210 question,
211 reply: tx,
212 });
213 self.set_status(task_id, TaskStatus::Blocked);
214 Some(rx)
215 }
216
217 pub fn approve(&mut self, task_id: u64, response: String) -> bool {
219 let Some(task) = self.tasks.get_mut(&task_id) else {
220 return false;
221 };
222 if task.status != TaskStatus::Blocked {
223 return false;
224 }
225 if let Some(inbox) = task.blocked_on.take() {
226 let _ = inbox.reply.send(response);
227 }
228 self.set_status(task_id, TaskStatus::InProgress);
229 true
230 }
231
232 pub fn kill(&mut self, task_id: u64) -> Option<tokio::task::AbortHandle> {
234 let task = self.tasks.get_mut(&task_id)?;
235 let handle = task.abort_handle.take();
236 if let Some(ref h) = handle {
237 h.abort();
238 }
239 task.error = Some("killed by user".into());
240 self.set_status(task_id, TaskStatus::Failed);
241 handle
242 }
243
244 pub fn subscribe_status(&self, task_id: u64) -> Option<watch::Receiver<TaskStatus>> {
246 self.tasks.get(&task_id).map(|t| t.status_tx.subscribe())
247 }
248
249 pub fn children(&self, parent_id: u64) -> Vec<&Task> {
251 self.tasks
252 .values()
253 .filter(|t| t.parent_id == Some(parent_id))
254 .collect()
255 }
256
257 pub fn find_by_session(&self, session_id: u64) -> Option<u64> {
259 self.tasks
260 .values()
261 .find(|t| t.session_id == Some(session_id))
262 .map(|t| t.id)
263 }
264
265 pub fn add_tokens(&mut self, task_id: u64, prompt: u64, completion: u64) {
267 if let Some(task) = self.tasks.get_mut(&task_id) {
268 task.prompt_tokens += prompt;
269 task.completion_tokens += completion;
270 }
271 }
272}