1use super::{SharedThreadManager, ThreadId, ThreadStatus};
16use std::future::Future;
17use tokio::sync::oneshot;
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, warn};
20
21#[derive(Debug, Clone)]
23pub enum SubtaskResult {
24 Ok(Option<String>),
26 Err(String),
28 Cancelled,
30}
31
32pub struct SubtaskHandle<T: Send + 'static> {
37 pub thread_id: ThreadId,
39
40 cancel_token: CancellationToken,
42
43 result_rx: Option<oneshot::Receiver<Result<T, String>>>,
45
46 thread_mgr: SharedThreadManager,
48
49 _join_handle: Option<tokio::task::JoinHandle<()>>,
51}
52
53impl<T: Send + 'static> SubtaskHandle<T> {
54 pub async fn join(mut self) -> Result<T, String> {
59 let rx = self
60 .result_rx
61 .take()
62 .ok_or_else(|| "SubtaskHandle already joined".to_string())?;
63
64 match rx.await {
65 Ok(Ok(value)) => Ok(value),
66 Ok(Err(e)) => Err(e),
67 Err(_) => {
68 Err("Subtask channel closed unexpectedly".to_string())
70 }
71 }
72 }
73
74 pub fn cancel(&self) {
79 self.cancel_token.cancel();
80 }
81
82 pub fn is_cancelled(&self) -> bool {
84 self.cancel_token.is_cancelled()
85 }
86
87 pub async fn set_description(&self, description: impl Into<String>) {
89 let mut mgr = self.thread_mgr.write().await;
90 mgr.set_description(self.thread_id, description);
91 }
92
93 pub async fn set_status(&self, status: ThreadStatus) {
95 let mut mgr = self.thread_mgr.write().await;
96 mgr.set_status(self.thread_id, status);
97 }
98
99 pub fn cancel_token(&self) -> CancellationToken {
101 self.cancel_token.clone()
102 }
103}
104
105impl<T: Send + 'static> Drop for SubtaskHandle<T> {
106 fn drop(&mut self) {
107 if self.result_rx.is_some() {
112 self.cancel_token.cancel();
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct SpawnOptions {
120 pub label: String,
122 pub description: Option<String>,
124 pub parent_id: Option<ThreadId>,
126}
127
128impl SpawnOptions {
129 pub fn new(label: impl Into<String>) -> Self {
131 Self {
132 label: label.into(),
133 description: None,
134 parent_id: None,
135 }
136 }
137
138 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
140 self.description = Some(desc.into());
141 self
142 }
143
144 pub fn with_parent(mut self, parent_id: ThreadId) -> Self {
146 self.parent_id = Some(parent_id);
147 self
148 }
149}
150
151pub async fn spawn_subagent<F, Fut, T>(
174 thread_mgr: SharedThreadManager,
175 options: SpawnOptions,
176 task_fn: F,
177) -> SubtaskHandle<T>
178where
179 F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
180 Fut: Future<Output = Result<T, String>> + Send + 'static,
181 T: Send + 'static,
182{
183 let cancel_token = CancellationToken::new();
184 let (result_tx, result_rx) = oneshot::channel();
185
186 let thread_id = {
188 let mut mgr = thread_mgr.write().await;
189 let id = mgr.create_subagent(
190 &options.label,
191 "subtask",
192 options.description.as_deref().unwrap_or(&options.label),
193 options.parent_id,
194 );
195 if let Some(desc) = &options.description {
196 mgr.set_description(id, desc);
197 }
198 id
199 };
200
201 debug!(thread_id = %thread_id, label = %options.label, "Spawning subagent subtask");
202
203 let token = cancel_token.clone();
205 let mgr = thread_mgr.clone();
206 let tid = thread_id;
207
208 let join_handle = tokio::spawn(async move {
209 let result = tokio::select! {
210 _ = token.cancelled() => {
211 Err("Cancelled".to_string())
212 }
213 res = task_fn(token.clone(), mgr.clone()) => {
214 res
215 }
216 };
217
218 {
220 let mut mgr_guard = mgr.write().await;
221 match &result {
222 Ok(_) => {
223 mgr_guard.complete(tid, Some("Completed".to_string()), None);
224 debug!(thread_id = %tid, "Subagent subtask completed");
225 }
226 Err(e) if e == "Cancelled" => {
227 mgr_guard.set_status(tid, ThreadStatus::Cancelled);
228 debug!(thread_id = %tid, "Subagent subtask cancelled");
229 }
230 Err(e) => {
231 mgr_guard.fail(tid, e);
232 warn!(thread_id = %tid, error = %e, "Subagent subtask failed");
233 }
234 }
235 }
236
237 let _ = result_tx.send(result);
239 });
240
241 SubtaskHandle {
242 thread_id,
243 cancel_token,
244 result_rx: Some(result_rx),
245 thread_mgr,
246 _join_handle: Some(join_handle),
247 }
248}
249
250pub async fn spawn_task<F, Fut, T>(
255 thread_mgr: SharedThreadManager,
256 options: SpawnOptions,
257 task_fn: F,
258) -> SubtaskHandle<T>
259where
260 F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
261 Fut: Future<Output = Result<T, String>> + Send + 'static,
262 T: Send + 'static,
263{
264 let cancel_token = CancellationToken::new();
265 let (result_tx, result_rx) = oneshot::channel();
266
267 let thread_id = {
269 let mut mgr = thread_mgr.write().await;
270 let id = mgr.create_task(
271 &options.label,
272 options.description.as_deref().unwrap_or(&options.label),
273 options.parent_id,
274 );
275 if let Some(desc) = &options.description {
276 mgr.set_description(id, desc);
277 }
278 id
279 };
280
281 debug!(thread_id = %thread_id, label = %options.label, "Spawning one-shot task");
282
283 let token = cancel_token.clone();
284 let mgr = thread_mgr.clone();
285 let tid = thread_id;
286
287 let join_handle = tokio::spawn(async move {
288 let result = tokio::select! {
289 _ = token.cancelled() => {
290 Err("Cancelled".to_string())
291 }
292 res = task_fn(token.clone(), mgr.clone()) => {
293 res
294 }
295 };
296
297 {
299 let mut mgr_guard = mgr.write().await;
300 match &result {
301 Ok(_) => {
302 mgr_guard.complete(tid, Some("Completed".to_string()), None);
303 debug!(thread_id = %tid, "Task completed");
304 }
305 Err(e) if e == "Cancelled" => {
306 mgr_guard.set_status(tid, ThreadStatus::Cancelled);
307 debug!(thread_id = %tid, "Task cancelled");
308 }
309 Err(e) => {
310 mgr_guard.fail(tid, e);
311 warn!(thread_id = %tid, error = %e, "Task failed");
312 }
313 }
314 }
315
316 let _ = result_tx.send(result);
317 });
318
319 SubtaskHandle {
320 thread_id,
321 cancel_token,
322 result_rx: Some(result_rx),
323 thread_mgr,
324 _join_handle: Some(join_handle),
325 }
326}
327
328pub async fn spawn_background<F, Fut>(
333 thread_mgr: SharedThreadManager,
334 options: SpawnOptions,
335 task_fn: F,
336) -> SubtaskHandle<()>
337where
338 F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
339 Fut: Future<Output = Result<(), String>> + Send + 'static,
340{
341 let cancel_token = CancellationToken::new();
342 let (result_tx, result_rx) = oneshot::channel();
343
344 let thread_id = {
345 let mut mgr = thread_mgr.write().await;
346 let id = mgr.create_background(
347 &options.label,
348 options.description.as_deref().unwrap_or(&options.label),
349 options.parent_id,
350 );
351 if let Some(desc) = &options.description {
352 mgr.set_description(id, desc);
353 }
354 id
355 };
356
357 debug!(thread_id = %thread_id, label = %options.label, "Spawning background thread");
358
359 let token = cancel_token.clone();
360 let mgr = thread_mgr.clone();
361 let tid = thread_id;
362
363 let join_handle = tokio::spawn(async move {
364 let result = tokio::select! {
365 _ = token.cancelled() => {
366 Err("Cancelled".to_string())
367 }
368 res = task_fn(token.clone(), mgr.clone()) => {
369 res
370 }
371 };
372
373 {
374 let mut mgr_guard = mgr.write().await;
375 match &result {
376 Ok(()) => {
377 mgr_guard.complete(tid, Some("Finished".to_string()), None);
378 debug!(thread_id = %tid, "Background thread finished");
379 }
380 Err(e) if e == "Cancelled" => {
381 mgr_guard.set_status(tid, ThreadStatus::Cancelled);
382 debug!(thread_id = %tid, "Background thread cancelled");
383 }
384 Err(e) => {
385 mgr_guard.fail(tid, e);
386 warn!(thread_id = %tid, error = %e, "Background thread failed");
387 }
388 }
389 }
390
391 let _ = result_tx.send(result);
392 });
393
394 SubtaskHandle {
395 thread_id,
396 cancel_token,
397 result_rx: Some(result_rx),
398 thread_mgr,
399 _join_handle: Some(join_handle),
400 }
401}
402
403pub struct SubtaskRegistry {
408 handles: std::collections::HashMap<ThreadId, RegistryEntry>,
409}
410
411struct RegistryEntry {
417 cancel_token: CancellationToken,
418 label: String,
419}
420
421impl SubtaskRegistry {
422 pub fn new() -> Self {
424 Self {
425 handles: std::collections::HashMap::new(),
426 }
427 }
428
429 pub fn register<T: Send + 'static>(
431 &mut self,
432 handle: &SubtaskHandle<T>,
433 label: impl Into<String>,
434 ) {
435 self.handles.insert(
436 handle.thread_id,
437 RegistryEntry {
438 cancel_token: handle.cancel_token.clone(),
439 label: label.into(),
440 },
441 );
442 }
443
444 pub fn cancel(&mut self, thread_id: &ThreadId) -> bool {
446 if let Some(entry) = self.handles.remove(thread_id) {
447 entry.cancel_token.cancel();
448 true
449 } else {
450 false
451 }
452 }
453
454 pub fn cancel_all(&mut self) {
456 for (_, entry) in self.handles.drain() {
457 entry.cancel_token.cancel();
458 }
459 }
460
461 pub fn list(&self) -> Vec<(ThreadId, String)> {
463 self.handles
464 .iter()
465 .map(|(id, entry)| (*id, entry.label.clone()))
466 .collect()
467 }
468
469 pub fn remove(&mut self, thread_id: &ThreadId) {
471 self.handles.remove(thread_id);
472 }
473
474 pub fn count(&self) -> usize {
476 self.handles.len()
477 }
478}
479
480impl Default for SubtaskRegistry {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use std::sync::Arc;
490 use tokio::sync::RwLock;
491
492 fn make_thread_mgr() -> SharedThreadManager {
493 Arc::new(RwLock::new(super::super::ThreadManager::new()))
494 }
495
496 #[tokio::test]
497 async fn test_spawn_and_join_subagent() {
498 let mgr = make_thread_mgr();
499
500 let handle = spawn_subagent(
501 mgr.clone(),
502 SpawnOptions::new("Test Subagent").with_description("Doing work"),
503 |_token, _mgr| async move { Ok("result!".to_string()) },
504 )
505 .await;
506
507 let thread_id = handle.thread_id;
508 let result = handle.join().await;
509 assert!(result.is_ok());
510 assert_eq!(result.unwrap(), "result!");
511
512 let mgr_guard = mgr.read().await;
514 let thread = mgr_guard.get(thread_id).unwrap();
515 assert!(thread.status.is_terminal());
516 }
517
518 #[tokio::test]
519 async fn test_spawn_and_cancel() {
520 let mgr = make_thread_mgr();
521
522 let handle: SubtaskHandle<String> = spawn_subagent(
523 mgr.clone(),
524 SpawnOptions::new("Cancellable"),
525 |token, _mgr| async move {
526 token.cancelled().await;
528 Err("Cancelled".to_string())
529 },
530 )
531 .await;
532
533 let thread_id = handle.thread_id;
534
535 handle.cancel();
537
538 let result = handle.join().await;
539 assert!(result.is_err());
540
541 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
543
544 let mgr_guard = mgr.read().await;
545 let thread = mgr_guard.get(thread_id).unwrap();
546 assert!(thread.status.is_terminal());
547 }
548
549 #[tokio::test]
550 async fn test_spawn_task() {
551 let mgr = make_thread_mgr();
552
553 let handle = spawn_task(
554 mgr.clone(),
555 SpawnOptions::new("Quick Task"),
556 |_token, _mgr| async move { Ok(42i64) },
557 )
558 .await;
559
560 let result = handle.join().await;
561 assert!(result.is_ok());
562 assert_eq!(result.unwrap(), 42);
563 }
564
565 #[tokio::test]
566 async fn test_spawn_with_description_update() {
567 let mgr = make_thread_mgr();
568
569 let handle = spawn_subagent(
570 mgr.clone(),
571 SpawnOptions::new("Descriptive Task"),
572 |_token, _mgr| async move {
573 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
575 Ok("done".to_string())
576 },
577 )
578 .await;
579
580 handle.set_description("Phase 2: Processing").await;
582
583 let thread_id = handle.thread_id;
584 {
585 let mgr_guard = mgr.read().await;
586 let thread = mgr_guard.get(thread_id).unwrap();
587 assert_eq!(
588 thread.description.as_deref(),
589 Some("Phase 2: Processing")
590 );
591 }
592
593 let _ = handle.join().await;
594 }
595
596 #[tokio::test]
597 async fn test_subtask_failure() {
598 let mgr = make_thread_mgr();
599
600 let handle = spawn_subagent(
601 mgr.clone(),
602 SpawnOptions::new("Failing Task"),
603 |_token, _mgr| async move {
604 Err::<String, _>("something went wrong".to_string())
605 },
606 )
607 .await;
608
609 let thread_id = handle.thread_id;
610 let result = handle.join().await;
611 assert!(result.is_err());
612 assert_eq!(result.unwrap_err(), "something went wrong");
613
614 let mgr_guard = mgr.read().await;
616 let thread = mgr_guard.get(thread_id).unwrap();
617 assert!(matches!(thread.status, ThreadStatus::Failed { .. }));
618 }
619
620 #[test]
621 fn test_subtask_registry() {
622 let registry = SubtaskRegistry::new();
623 assert_eq!(registry.count(), 0);
624 assert!(registry.list().is_empty());
625 }
626}