Skip to main content

rustyclaw_core/threads/
subtask.rs

1//! Subtask abstraction — async spawn/join integrated with ThreadManager.
2//!
3//! This module provides the unified mechanism for running concurrent work:
4//! - Subagents spawned by the main agent (joinable, return a result)
5//! - Background tasks (long-running, may not return)
6//! - One-shot tasks (quick async work that returns a result)
7//!
8//! All subtasks:
9//! - Run as tokio tasks
10//! - Create a thread entry in ThreadManager on spawn
11//! - Update thread status on completion/failure
12//! - Support cancellation via CancellationToken
13//! - Support agent-settable descriptions (shown in sidebar)
14
15use super::{SharedThreadManager, ThreadId, ThreadStatus};
16use std::future::Future;
17use tokio::sync::oneshot;
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, warn};
20
21/// Result of a completed subtask.
22#[derive(Debug, Clone)]
23pub enum SubtaskResult {
24    /// Task completed successfully with an optional string result.
25    Ok(Option<String>),
26    /// Task failed with an error message.
27    Err(String),
28    /// Task was cancelled.
29    Cancelled,
30}
31
32/// Handle to a running subtask. Allows joining, cancelling, and updating status.
33///
34/// The type parameter `T` is the return type of the async function.
35/// For subagents that return string results, use `SubtaskHandle<String>`.
36pub struct SubtaskHandle<T: Send + 'static> {
37    /// The thread ID in the ThreadManager.
38    pub thread_id: ThreadId,
39
40    /// Cancellation token — cancel the subtask by calling `.cancel()`.
41    cancel_token: CancellationToken,
42
43    /// Oneshot receiver for the result.
44    result_rx: Option<oneshot::Receiver<Result<T, String>>>,
45
46    /// Shared thread manager for status updates.
47    thread_mgr: SharedThreadManager,
48
49    /// The underlying tokio JoinHandle (for abort).
50    _join_handle: Option<tokio::task::JoinHandle<()>>,
51}
52
53impl<T: Send + 'static> SubtaskHandle<T> {
54    /// Wait for the subtask to complete and return its result.
55    ///
56    /// This consumes the handle. After joining, the thread status is updated
57    /// to Completed or Failed.
58    pub async fn join(mut self) -> Result<T, String> {
59        let rx = self
60            .result_rx
61            .take()
62            .ok_or_else(|| "SubtaskHandle already joined".to_string())?;
63
64        match rx.await {
65            Ok(Ok(value)) => Ok(value),
66            Ok(Err(e)) => Err(e),
67            Err(_) => {
68                // Sender dropped — task panicked or was aborted
69                Err("Subtask channel closed unexpectedly".to_string())
70            }
71        }
72    }
73
74    /// Cancel the subtask.
75    ///
76    /// This signals the cancellation token. The subtask's async function
77    /// should check `token.is_cancelled()` or use `token.cancelled().await`.
78    pub fn cancel(&self) {
79        self.cancel_token.cancel();
80    }
81
82    /// Check if the subtask has been cancelled.
83    pub fn is_cancelled(&self) -> bool {
84        self.cancel_token.is_cancelled()
85    }
86
87    /// Update the description shown in the sidebar for this subtask.
88    pub async fn set_description(&self, description: impl Into<String>) {
89        let mut mgr = self.thread_mgr.write().await;
90        mgr.set_description(self.thread_id, description);
91    }
92
93    /// Update the status of the subtask's thread.
94    pub async fn set_status(&self, status: ThreadStatus) {
95        let mut mgr = self.thread_mgr.write().await;
96        mgr.set_status(self.thread_id, status);
97    }
98
99    /// Get the cancellation token (for passing to subtask internals).
100    pub fn cancel_token(&self) -> CancellationToken {
101        self.cancel_token.clone()
102    }
103}
104
105impl<T: Send + 'static> Drop for SubtaskHandle<T> {
106    fn drop(&mut self) {
107        // If the handle is dropped without joining, cancel the subtask.
108        // We only signal via CancellationToken (cooperative) so the spawned
109        // task has a chance to update ThreadManager status before exiting.
110        // Do NOT call handle.abort() — that preempts the status-update code.
111        if self.result_rx.is_some() {
112            self.cancel_token.cancel();
113        }
114    }
115}
116
117/// Options for spawning a subtask.
118#[derive(Debug, Clone)]
119pub struct SpawnOptions {
120    /// Label shown in the sidebar.
121    pub label: String,
122    /// Initial description of what the subtask is doing.
123    pub description: Option<String>,
124    /// Parent thread that spawned this subtask (if any).
125    pub parent_id: Option<ThreadId>,
126}
127
128impl SpawnOptions {
129    /// Create spawn options with just a label.
130    pub fn new(label: impl Into<String>) -> Self {
131        Self {
132            label: label.into(),
133            description: None,
134            parent_id: None,
135        }
136    }
137
138    /// Set the initial description.
139    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
140        self.description = Some(desc.into());
141        self
142    }
143
144    /// Set the parent thread.
145    pub fn with_parent(mut self, parent_id: ThreadId) -> Self {
146        self.parent_id = Some(parent_id);
147        self
148    }
149}
150
151/// Spawn a subagent as an async subtask.
152///
153/// The `task_fn` receives a `CancellationToken` and a `SharedThreadManager`.
154/// It should check the token periodically and return a result.
155///
156/// Returns a `SubtaskHandle` that can be joined to get the result.
157///
158/// # Example
159/// ```ignore
160/// let handle = spawn_subagent(
161///     thread_mgr.clone(),
162///     SpawnOptions::new("Research Task")
163///         .with_description("Searching for information"),
164///     |token, mgr| async move {
165///         // Do async work, checking token.is_cancelled()
166///         Ok("result data".to_string())
167///     },
168/// ).await;
169///
170/// // Later, join to get the result
171/// let result = handle.join().await;
172/// ```
173pub async fn spawn_subagent<F, Fut, T>(
174    thread_mgr: SharedThreadManager,
175    options: SpawnOptions,
176    task_fn: F,
177) -> SubtaskHandle<T>
178where
179    F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
180    Fut: Future<Output = Result<T, String>> + Send + 'static,
181    T: Send + 'static,
182{
183    let cancel_token = CancellationToken::new();
184    let (result_tx, result_rx) = oneshot::channel();
185
186    // Create thread entry in ThreadManager
187    let thread_id = {
188        let mut mgr = thread_mgr.write().await;
189        let id = mgr.create_subagent(
190            &options.label,
191            "subtask",
192            options.description.as_deref().unwrap_or(&options.label),
193            options.parent_id,
194        );
195        if let Some(desc) = &options.description {
196            mgr.set_description(id, desc);
197        }
198        id
199    };
200
201    debug!(thread_id = %thread_id, label = %options.label, "Spawning subagent subtask");
202
203    // Spawn the tokio task
204    let token = cancel_token.clone();
205    let mgr = thread_mgr.clone();
206    let tid = thread_id;
207
208    let join_handle = tokio::spawn(async move {
209        let result = tokio::select! {
210            _ = token.cancelled() => {
211                Err("Cancelled".to_string())
212            }
213            res = task_fn(token.clone(), mgr.clone()) => {
214                res
215            }
216        };
217
218        // Update thread status based on result
219        {
220            let mut mgr_guard = mgr.write().await;
221            match &result {
222                Ok(_) => {
223                    mgr_guard.complete(tid, Some("Completed".to_string()), None);
224                    debug!(thread_id = %tid, "Subagent subtask completed");
225                }
226                Err(e) if e == "Cancelled" => {
227                    mgr_guard.set_status(tid, ThreadStatus::Cancelled);
228                    debug!(thread_id = %tid, "Subagent subtask cancelled");
229                }
230                Err(e) => {
231                    mgr_guard.fail(tid, e);
232                    warn!(thread_id = %tid, error = %e, "Subagent subtask failed");
233                }
234            }
235        }
236
237        // Send result through oneshot channel
238        let _ = result_tx.send(result);
239    });
240
241    SubtaskHandle {
242        thread_id,
243        cancel_token,
244        result_rx: Some(result_rx),
245        thread_mgr,
246        _join_handle: Some(join_handle),
247    }
248}
249
250/// Spawn a one-shot background task.
251///
252/// Similar to `spawn_subagent` but creates a Task thread instead of SubAgent.
253/// Best for quick async work that returns a result.
254pub async fn spawn_task<F, Fut, T>(
255    thread_mgr: SharedThreadManager,
256    options: SpawnOptions,
257    task_fn: F,
258) -> SubtaskHandle<T>
259where
260    F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
261    Fut: Future<Output = Result<T, String>> + Send + 'static,
262    T: Send + 'static,
263{
264    let cancel_token = CancellationToken::new();
265    let (result_tx, result_rx) = oneshot::channel();
266
267    // Create thread entry in ThreadManager
268    let thread_id = {
269        let mut mgr = thread_mgr.write().await;
270        let id = mgr.create_task(
271            &options.label,
272            options.description.as_deref().unwrap_or(&options.label),
273            options.parent_id,
274        );
275        if let Some(desc) = &options.description {
276            mgr.set_description(id, desc);
277        }
278        id
279    };
280
281    debug!(thread_id = %thread_id, label = %options.label, "Spawning one-shot task");
282
283    let token = cancel_token.clone();
284    let mgr = thread_mgr.clone();
285    let tid = thread_id;
286
287    let join_handle = tokio::spawn(async move {
288        let result = tokio::select! {
289            _ = token.cancelled() => {
290                Err("Cancelled".to_string())
291            }
292            res = task_fn(token.clone(), mgr.clone()) => {
293                res
294            }
295        };
296
297        // Update thread status
298        {
299            let mut mgr_guard = mgr.write().await;
300            match &result {
301                Ok(_) => {
302                    mgr_guard.complete(tid, Some("Completed".to_string()), None);
303                    debug!(thread_id = %tid, "Task completed");
304                }
305                Err(e) if e == "Cancelled" => {
306                    mgr_guard.set_status(tid, ThreadStatus::Cancelled);
307                    debug!(thread_id = %tid, "Task cancelled");
308                }
309                Err(e) => {
310                    mgr_guard.fail(tid, e);
311                    warn!(thread_id = %tid, error = %e, "Task failed");
312                }
313            }
314        }
315
316        let _ = result_tx.send(result);
317    });
318
319    SubtaskHandle {
320        thread_id,
321        cancel_token,
322        result_rx: Some(result_rx),
323        thread_mgr,
324        _join_handle: Some(join_handle),
325    }
326}
327
328/// Spawn a long-running background thread.
329///
330/// Unlike subagents and tasks, background threads don't have a natural
331/// return value. They run until cancelled.
332pub async fn spawn_background<F, Fut>(
333    thread_mgr: SharedThreadManager,
334    options: SpawnOptions,
335    task_fn: F,
336) -> SubtaskHandle<()>
337where
338    F: FnOnce(CancellationToken, SharedThreadManager) -> Fut + Send + 'static,
339    Fut: Future<Output = Result<(), String>> + Send + 'static,
340{
341    let cancel_token = CancellationToken::new();
342    let (result_tx, result_rx) = oneshot::channel();
343
344    let thread_id = {
345        let mut mgr = thread_mgr.write().await;
346        let id = mgr.create_background(
347            &options.label,
348            options.description.as_deref().unwrap_or(&options.label),
349            options.parent_id,
350        );
351        if let Some(desc) = &options.description {
352            mgr.set_description(id, desc);
353        }
354        id
355    };
356
357    debug!(thread_id = %thread_id, label = %options.label, "Spawning background thread");
358
359    let token = cancel_token.clone();
360    let mgr = thread_mgr.clone();
361    let tid = thread_id;
362
363    let join_handle = tokio::spawn(async move {
364        let result = tokio::select! {
365            _ = token.cancelled() => {
366                Err("Cancelled".to_string())
367            }
368            res = task_fn(token.clone(), mgr.clone()) => {
369                res
370            }
371        };
372
373        {
374            let mut mgr_guard = mgr.write().await;
375            match &result {
376                Ok(()) => {
377                    mgr_guard.complete(tid, Some("Finished".to_string()), None);
378                    debug!(thread_id = %tid, "Background thread finished");
379                }
380                Err(e) if e == "Cancelled" => {
381                    mgr_guard.set_status(tid, ThreadStatus::Cancelled);
382                    debug!(thread_id = %tid, "Background thread cancelled");
383                }
384                Err(e) => {
385                    mgr_guard.fail(tid, e);
386                    warn!(thread_id = %tid, error = %e, "Background thread failed");
387                }
388            }
389        }
390
391        let _ = result_tx.send(result);
392    });
393
394    SubtaskHandle {
395        thread_id,
396        cancel_token,
397        result_rx: Some(result_rx),
398        thread_mgr,
399        _join_handle: Some(join_handle),
400    }
401}
402
403/// Registry of active subtask handles for a session.
404///
405/// This is used by the gateway to track all running subtasks and
406/// allow the agent to list, join, or cancel them.
407pub struct SubtaskRegistry {
408    handles: std::collections::HashMap<ThreadId, RegistryEntry>,
409}
410
411/// An entry in the subtask registry.
412///
413/// Cancellation is cooperative: we signal via `CancellationToken` and the
414/// subtask's async function is expected to check `token.is_cancelled()` or
415/// `token.cancelled().await`.
416struct RegistryEntry {
417    cancel_token: CancellationToken,
418    label: String,
419}
420
421impl SubtaskRegistry {
422    /// Create a new empty registry.
423    pub fn new() -> Self {
424        Self {
425            handles: std::collections::HashMap::new(),
426        }
427    }
428
429    /// Register a subtask. Call this after spawning.
430    pub fn register<T: Send + 'static>(
431        &mut self,
432        handle: &SubtaskHandle<T>,
433        label: impl Into<String>,
434    ) {
435        self.handles.insert(
436            handle.thread_id,
437            RegistryEntry {
438                cancel_token: handle.cancel_token.clone(),
439                label: label.into(),
440            },
441        );
442    }
443
444    /// Cancel a subtask by thread ID (cooperative via `CancellationToken`).
445    pub fn cancel(&mut self, thread_id: &ThreadId) -> bool {
446        if let Some(entry) = self.handles.remove(thread_id) {
447            entry.cancel_token.cancel();
448            true
449        } else {
450            false
451        }
452    }
453
454    /// Cancel all subtasks (cooperative via `CancellationToken`).
455    pub fn cancel_all(&mut self) {
456        for (_, entry) in self.handles.drain() {
457            entry.cancel_token.cancel();
458        }
459    }
460
461    /// List active subtask thread IDs with labels.
462    pub fn list(&self) -> Vec<(ThreadId, String)> {
463        self.handles
464            .iter()
465            .map(|(id, entry)| (*id, entry.label.clone()))
466            .collect()
467    }
468
469    /// Remove a completed subtask from the registry.
470    pub fn remove(&mut self, thread_id: &ThreadId) {
471        self.handles.remove(thread_id);
472    }
473
474    /// Get the number of active subtasks.
475    pub fn count(&self) -> usize {
476        self.handles.len()
477    }
478}
479
480impl Default for SubtaskRegistry {
481    fn default() -> Self {
482        Self::new()
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use std::sync::Arc;
490    use tokio::sync::RwLock;
491
492    fn make_thread_mgr() -> SharedThreadManager {
493        Arc::new(RwLock::new(super::super::ThreadManager::new()))
494    }
495
496    #[tokio::test]
497    async fn test_spawn_and_join_subagent() {
498        let mgr = make_thread_mgr();
499
500        let handle = spawn_subagent(
501            mgr.clone(),
502            SpawnOptions::new("Test Subagent").with_description("Doing work"),
503            |_token, _mgr| async move { Ok("result!".to_string()) },
504        )
505        .await;
506
507        let thread_id = handle.thread_id;
508        let result = handle.join().await;
509        assert!(result.is_ok());
510        assert_eq!(result.unwrap(), "result!");
511
512        // Thread should be marked as completed
513        let mgr_guard = mgr.read().await;
514        let thread = mgr_guard.get(thread_id).unwrap();
515        assert!(thread.status.is_terminal());
516    }
517
518    #[tokio::test]
519    async fn test_spawn_and_cancel() {
520        let mgr = make_thread_mgr();
521
522        let handle: SubtaskHandle<String> = spawn_subagent(
523            mgr.clone(),
524            SpawnOptions::new("Cancellable"),
525            |token, _mgr| async move {
526                // Wait until cancelled
527                token.cancelled().await;
528                Err("Cancelled".to_string())
529            },
530        )
531        .await;
532
533        let thread_id = handle.thread_id;
534
535        // Cancel it
536        handle.cancel();
537
538        let result = handle.join().await;
539        assert!(result.is_err());
540
541        // Give tokio a moment to process the status update
542        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
543
544        let mgr_guard = mgr.read().await;
545        let thread = mgr_guard.get(thread_id).unwrap();
546        assert!(thread.status.is_terminal());
547    }
548
549    #[tokio::test]
550    async fn test_spawn_task() {
551        let mgr = make_thread_mgr();
552
553        let handle = spawn_task(
554            mgr.clone(),
555            SpawnOptions::new("Quick Task"),
556            |_token, _mgr| async move { Ok(42i64) },
557        )
558        .await;
559
560        let result = handle.join().await;
561        assert!(result.is_ok());
562        assert_eq!(result.unwrap(), 42);
563    }
564
565    #[tokio::test]
566    async fn test_spawn_with_description_update() {
567        let mgr = make_thread_mgr();
568
569        let handle = spawn_subagent(
570            mgr.clone(),
571            SpawnOptions::new("Descriptive Task"),
572            |_token, _mgr| async move {
573                // Simulate work
574                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
575                Ok("done".to_string())
576            },
577        )
578        .await;
579
580        // Update description while running
581        handle.set_description("Phase 2: Processing").await;
582
583        let thread_id = handle.thread_id;
584        {
585            let mgr_guard = mgr.read().await;
586            let thread = mgr_guard.get(thread_id).unwrap();
587            assert_eq!(
588                thread.description.as_deref(),
589                Some("Phase 2: Processing")
590            );
591        }
592
593        let _ = handle.join().await;
594    }
595
596    #[tokio::test]
597    async fn test_subtask_failure() {
598        let mgr = make_thread_mgr();
599
600        let handle = spawn_subagent(
601            mgr.clone(),
602            SpawnOptions::new("Failing Task"),
603            |_token, _mgr| async move {
604                Err::<String, _>("something went wrong".to_string())
605            },
606        )
607        .await;
608
609        let thread_id = handle.thread_id;
610        let result = handle.join().await;
611        assert!(result.is_err());
612        assert_eq!(result.unwrap_err(), "something went wrong");
613
614        // Thread should be marked as failed
615        let mgr_guard = mgr.read().await;
616        let thread = mgr_guard.get(thread_id).unwrap();
617        assert!(matches!(thread.status, ThreadStatus::Failed { .. }));
618    }
619
620    #[test]
621    fn test_subtask_registry() {
622        let registry = SubtaskRegistry::new();
623        assert_eq!(registry.count(), 0);
624        assert!(registry.list().is_empty());
625    }
626}