1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{RwLock, broadcast, mpsc};
6use tracing::{debug, info, instrument, warn};
7
8use super::model::{Task, TaskId, TaskKind, TaskProgress, TaskStatus};
9
10pub struct TaskHandle {
12 pub id: TaskId,
14
15 control_tx: mpsc::Sender<TaskControl>,
17
18 output_tx: broadcast::Sender<TaskOutput>,
20}
21
22impl Clone for TaskHandle {
23 fn clone(&self) -> Self {
24 Self {
25 id: self.id,
26 control_tx: self.control_tx.clone(),
27 output_tx: self.output_tx.clone(),
28 }
29 }
30}
31
32impl TaskHandle {
33 pub fn subscribe(&self) -> broadcast::Receiver<TaskOutput> {
35 self.output_tx.subscribe()
36 }
37
38 pub async fn send_control(
40 &self,
41 cmd: TaskControl,
42 ) -> Result<(), mpsc::error::SendError<TaskControl>> {
43 self.control_tx.send(cmd).await
44 }
45
46 pub async fn pause(&self) -> Result<(), mpsc::error::SendError<TaskControl>> {
48 self.send_control(TaskControl::Pause).await
49 }
50
51 pub async fn resume(&self) -> Result<(), mpsc::error::SendError<TaskControl>> {
53 self.send_control(TaskControl::Resume).await
54 }
55
56 pub async fn cancel(&self) -> Result<(), mpsc::error::SendError<TaskControl>> {
58 self.send_control(TaskControl::Cancel).await
59 }
60
61 pub async fn foreground(&self) -> Result<(), mpsc::error::SendError<TaskControl>> {
63 self.send_control(TaskControl::Foreground).await
64 }
65
66 pub async fn background(&self) -> Result<(), mpsc::error::SendError<TaskControl>> {
68 self.send_control(TaskControl::Background).await
69 }
70
71 pub async fn send_input(
73 &self,
74 input: String,
75 ) -> Result<(), mpsc::error::SendError<TaskControl>> {
76 self.send_control(TaskControl::Input(input)).await
77 }
78}
79
80#[derive(Debug, Clone)]
82pub enum TaskControl {
83 Pause,
85 Resume,
87 Cancel,
89 Foreground,
91 Background,
93 Input(String),
95 Progress(TaskProgress),
97}
98
99#[derive(Debug, Clone)]
101pub enum TaskOutput {
102 Stdout(String),
104 Stderr(String),
106 Progress(TaskProgress),
108 StatusChange(TaskStatus),
110 Completed {
112 summary: Option<String>,
113 output: Option<String>,
114 },
115 Failed { error: String, retryable: bool },
117}
118
119#[derive(Debug, Clone)]
121pub enum TaskEvent {
122 Created(Task),
124 StatusChanged {
126 id: TaskId,
127 old: TaskStatus,
128 new: TaskStatus,
129 },
130 DescriptionChanged { id: TaskId, description: String },
132 Output { id: TaskId, output: TaskOutput },
134 Foregrounded(TaskId),
136 Backgrounded(TaskId),
138 Completed(Task),
140 Failed(Task),
142 Cancelled(TaskId),
144}
145
146pub struct TaskManager {
148 tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
150
151 controls: Arc<RwLock<HashMap<TaskId, mpsc::Sender<TaskControl>>>>,
153
154 outputs: Arc<RwLock<HashMap<TaskId, broadcast::Sender<TaskOutput>>>>,
156
157 events_tx: broadcast::Sender<TaskEvent>,
159
160 foreground_tasks: Arc<RwLock<HashMap<String, TaskId>>>,
162}
163
164impl TaskManager {
165 pub fn new() -> Self {
167 let (events_tx, _) = broadcast::channel(256);
168
169 Self {
170 tasks: Arc::new(RwLock::new(HashMap::new())),
171 controls: Arc::new(RwLock::new(HashMap::new())),
172 outputs: Arc::new(RwLock::new(HashMap::new())),
173 events_tx,
174 foreground_tasks: Arc::new(RwLock::new(HashMap::new())),
175 }
176 }
177
178 pub fn subscribe(&self) -> broadcast::Receiver<TaskEvent> {
180 self.events_tx.subscribe()
181 }
182
183 #[instrument(skip(self), fields(kind = ?kind))]
185 pub async fn create(&self, kind: TaskKind, session_key: Option<String>) -> TaskHandle {
186 let task = Task::new(kind).with_session(session_key.clone().unwrap_or_default());
187 let id = task.id;
188
189 let (control_tx, _control_rx) = mpsc::channel(32);
191
192 let (output_tx, _output_rx) = broadcast::channel(256);
194
195 self.tasks.write().await.insert(id, task.clone());
197 self.controls.write().await.insert(id, control_tx.clone());
198 self.outputs.write().await.insert(id, output_tx.clone());
199
200 let _ = self.events_tx.send(TaskEvent::Created(task));
202
203 debug!(task_id = %id, "Task created");
204
205 TaskHandle {
206 id,
207 control_tx,
208 output_tx,
209 }
210 }
211
212 pub async fn get(&self, id: TaskId) -> Option<Task> {
214 self.tasks.read().await.get(&id).cloned()
215 }
216
217 pub async fn all(&self) -> Vec<Task> {
219 self.tasks.read().await.values().cloned().collect()
220 }
221
222 pub async fn for_session(&self, session_key: &str) -> Vec<Task> {
224 self.tasks
225 .read()
226 .await
227 .values()
228 .filter(|t| t.session_key.as_deref() == Some(session_key))
229 .cloned()
230 .collect()
231 }
232
233 pub async fn active(&self) -> Vec<Task> {
235 self.tasks
236 .read()
237 .await
238 .values()
239 .filter(|t| !t.status.is_terminal())
240 .cloned()
241 .collect()
242 }
243
244 pub async fn foreground_task(&self, session_key: &str) -> Option<Task> {
246 let fg_id = self
247 .foreground_tasks
248 .read()
249 .await
250 .get(session_key)
251 .copied()?;
252 self.get(fg_id).await
253 }
254
255 #[instrument(skip(self))]
257 pub async fn set_foreground(&self, id: TaskId) -> Result<(), String> {
258 let mut tasks = self.tasks.write().await;
259
260 let session = {
262 let task = tasks
263 .get(&id)
264 .ok_or_else(|| format!("Task {} not found", id))?;
265 task.session_key
266 .clone()
267 .ok_or_else(|| "Task has no session".to_string())?
268 };
269
270 if let Some(old_fg_id) = self.foreground_tasks.read().await.get(&session).copied() {
272 if old_fg_id != id {
273 if let Some(old_task) = tasks.get_mut(&old_fg_id) {
274 old_task.background();
275 let _ = self.events_tx.send(TaskEvent::Backgrounded(old_fg_id));
276 }
277 }
278 }
279
280 if let Some(task) = tasks.get_mut(&id) {
282 task.foreground();
283 }
284 self.foreground_tasks.write().await.insert(session, id);
285 let _ = self.events_tx.send(TaskEvent::Foregrounded(id));
286
287 info!(task_id = %id, "Task foregrounded");
288 Ok(())
289 }
290
291 #[instrument(skip(self))]
293 pub async fn set_background(&self, id: TaskId) -> Result<(), String> {
294 let mut tasks = self.tasks.write().await;
295 let task = tasks
296 .get_mut(&id)
297 .ok_or_else(|| format!("Task {} not found", id))?;
298
299 task.background();
300
301 if let Some(ref session) = task.session_key {
303 let mut fg = self.foreground_tasks.write().await;
304 if fg.get(session) == Some(&id) {
305 fg.remove(session);
306 }
307 }
308
309 let _ = self.events_tx.send(TaskEvent::Backgrounded(id));
310 info!(task_id = %id, "Task backgrounded");
311 Ok(())
312 }
313
314 #[instrument(skip(self))]
316 pub async fn update_status(&self, id: TaskId, new_status: TaskStatus) {
317 let mut tasks = self.tasks.write().await;
318 if let Some(task) = tasks.get_mut(&id) {
319 let old_status = task.status.clone();
320 task.status = new_status.clone();
321
322 if new_status.is_terminal() {
323 task.finished_at = Some(std::time::SystemTime::now());
324 }
325
326 let _ = self.events_tx.send(TaskEvent::StatusChanged {
327 id,
328 old: old_status,
329 new: new_status,
330 });
331 }
332 }
333
334 pub async fn start(&self, id: TaskId) {
336 let mut tasks = self.tasks.write().await;
337 if let Some(task) = tasks.get_mut(&id) {
338 task.start();
339
340 if let Some(ref session) = task.session_key {
342 let mut fg = self.foreground_tasks.write().await;
343 if !fg.contains_key(session) {
344 fg.insert(session.clone(), id);
345 let _ = self.events_tx.send(TaskEvent::Foregrounded(id));
346 }
347 }
348 }
349 }
350
351 pub async fn complete(&self, id: TaskId, summary: Option<String>) {
353 let task = {
354 let mut tasks = self.tasks.write().await;
355 if let Some(task) = tasks.get_mut(&id) {
356 task.complete(summary);
357
358 if let Some(ref session) = task.session_key {
360 self.foreground_tasks.write().await.remove(session);
361 }
362
363 Some(task.clone())
364 } else {
365 None
366 }
367 };
368
369 if let Some(t) = task {
370 let _ = self.events_tx.send(TaskEvent::Completed(t));
371 info!(task_id = %id, "Task completed");
372 }
373 }
374
375 pub async fn fail(&self, id: TaskId, error: String, retryable: bool) {
377 let task = {
378 let mut tasks = self.tasks.write().await;
379 if let Some(task) = tasks.get_mut(&id) {
380 task.fail(&error, retryable);
381
382 if let Some(ref session) = task.session_key {
384 self.foreground_tasks.write().await.remove(session);
385 }
386
387 Some(task.clone())
388 } else {
389 None
390 }
391 };
392
393 if let Some(t) = task {
394 let _ = self.events_tx.send(TaskEvent::Failed(t));
395 warn!(task_id = %id, error = %error, "Task failed");
396 }
397 }
398
399 pub async fn cancel(&self, id: TaskId) -> Result<(), String> {
401 if let Some(control_tx) = self.controls.read().await.get(&id) {
403 let _ = control_tx.send(TaskControl::Cancel).await;
404 }
405
406 let mut tasks = self.tasks.write().await;
407 if let Some(task) = tasks.get_mut(&id) {
408 task.cancel();
409
410 if let Some(ref session) = task.session_key {
412 self.foreground_tasks.write().await.remove(session);
413 }
414
415 let _ = self.events_tx.send(TaskEvent::Cancelled(id));
416 info!(task_id = %id, "Task cancelled");
417 Ok(())
418 } else {
419 Err(format!("Task {} not found", id))
420 }
421 }
422
423 pub async fn set_description(&self, id: TaskId, description: &str) -> Result<(), String> {
425 let mut tasks = self.tasks.write().await;
426 if let Some(task) = tasks.get_mut(&id) {
427 task.description = Some(description.to_string());
428 let _ = self.events_tx.send(TaskEvent::DescriptionChanged {
429 id,
430 description: description.to_string(),
431 });
432 info!(task_id = %id, description, "Task description updated");
433 Ok(())
434 } else {
435 Err(format!("Task {} not found", id))
436 }
437 }
438
439 pub async fn send_output(&self, id: TaskId, output: TaskOutput) {
441 if let Some(output_tx) = self.outputs.read().await.get(&id) {
442 let _ = output_tx.send(output.clone());
443 }
444
445 let _ = self.events_tx.send(TaskEvent::Output { id, output });
446 }
447
448 pub async fn cleanup_old(&self, max_age: std::time::Duration) {
450 let now = std::time::SystemTime::now();
451 let mut tasks = self.tasks.write().await;
452
453 let to_remove: Vec<TaskId> = tasks
454 .iter()
455 .filter(|(_, t)| {
456 if !t.status.is_terminal() {
457 return false;
458 }
459 if let Some(finished) = t.finished_at {
460 now.duration_since(finished).unwrap_or_default() > max_age
461 } else {
462 false
463 }
464 })
465 .map(|(id, _)| *id)
466 .collect();
467
468 for id in &to_remove {
469 tasks.remove(id);
470 self.controls.write().await.remove(id);
471 self.outputs.write().await.remove(id);
472 }
473
474 if !to_remove.is_empty() {
475 debug!(count = to_remove.len(), "Cleaned up old tasks");
476 }
477 }
478
479 pub async fn stats(&self) -> TaskStats {
481 let tasks = self.tasks.read().await;
482
483 let mut stats = TaskStats::default();
484 for task in tasks.values() {
485 stats.total += 1;
486 match &task.status {
487 TaskStatus::Pending => stats.pending += 1,
488 TaskStatus::Running { .. } => stats.running += 1,
489 TaskStatus::Background { .. } => stats.background += 1,
490 TaskStatus::Paused { .. } => stats.paused += 1,
491 TaskStatus::Completed { .. } => stats.completed += 1,
492 TaskStatus::Failed { .. } => stats.failed += 1,
493 TaskStatus::Cancelled => stats.cancelled += 1,
494 TaskStatus::WaitingForInput { .. } => stats.waiting_input += 1,
495 }
496 }
497
498 stats
499 }
500}
501
502impl Default for TaskManager {
503 fn default() -> Self {
504 Self::new()
505 }
506}
507
508#[derive(Debug, Clone, Default)]
510pub struct TaskStats {
511 pub total: usize,
512 pub pending: usize,
513 pub running: usize,
514 pub background: usize,
515 pub paused: usize,
516 pub completed: usize,
517 pub failed: usize,
518 pub cancelled: usize,
519 pub waiting_input: usize,
520}
521
522impl TaskStats {
523 pub fn active(&self) -> usize {
525 self.pending + self.running + self.background + self.paused + self.waiting_input
526 }
527}