Skip to main content

tower_mcp/
async_task.rs

1//! Async task management for long-running MCP operations
2//!
3//! This module provides task lifecycle management for operations that may take
4//! longer than a typical request/response cycle. Tasks can be enqueued, tracked,
5//! polled for status, and cancelled.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use tower_mcp::async_task::{TaskStore, Task};
11//! use tower_mcp::protocol::TaskStatus;
12//!
13//! // Create a task store
14//! let store = TaskStore::new();
15//!
16//! // Enqueue a task
17//! let task = store.create_task("my-tool", serde_json::json!({"key": "value"}), None);
18//!
19//! // Get task status
20//! let info = store.get_task(&task.id);
21//!
22//! // Mark task as complete
23//! store.complete_task(&task.id, Ok(result));
24//! ```
25
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskInfo, TaskStatus};
32
33/// Default time-to-live for completed tasks (5 minutes)
34const DEFAULT_TTL_SECS: u64 = 300;
35
36/// Default poll interval suggestion (2 seconds)
37const DEFAULT_POLL_INTERVAL_SECS: u64 = 2;
38
39/// Internal task representation with full state
40#[derive(Debug)]
41pub struct Task {
42    /// Unique task identifier
43    pub id: String,
44    /// Name of the tool being executed
45    pub tool_name: String,
46    /// Arguments passed to the tool
47    pub arguments: serde_json::Value,
48    /// Current task status
49    pub status: TaskStatus,
50    /// When the task was created
51    pub created_at: Instant,
52    /// ISO 8601 timestamp string
53    pub created_at_str: String,
54    /// Time-to-live in seconds (for cleanup after completion)
55    pub ttl: u64,
56    /// Suggested polling interval in seconds
57    pub poll_interval: u64,
58    /// Current progress (0.0 - 100.0)
59    pub progress: Option<f64>,
60    /// Human-readable status message
61    pub message: Option<String>,
62    /// The result of the tool call (when completed)
63    pub result: Option<CallToolResult>,
64    /// Error message (when failed)
65    pub error: Option<String>,
66    /// Cancellation token for aborting the task
67    pub cancellation_token: CancellationToken,
68    /// When the task reached terminal status (for TTL tracking)
69    pub completed_at: Option<Instant>,
70}
71
72impl Task {
73    /// Create a new task
74    fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
75        let cancelled = Arc::new(AtomicBool::new(false));
76        Self {
77            id,
78            tool_name,
79            arguments,
80            status: TaskStatus::Working,
81            created_at: Instant::now(),
82            created_at_str: chrono_now_iso8601(),
83            ttl: ttl.unwrap_or(DEFAULT_TTL_SECS),
84            poll_interval: DEFAULT_POLL_INTERVAL_SECS,
85            progress: None,
86            message: Some("Task started".to_string()),
87            result: None,
88            error: None,
89            cancellation_token: CancellationToken { cancelled },
90            completed_at: None,
91        }
92    }
93
94    /// Convert to TaskInfo for API responses
95    pub fn to_info(&self) -> TaskInfo {
96        TaskInfo {
97            task_id: self.id.clone(),
98            status: self.status,
99            created_at: self.created_at_str.clone(),
100            ttl: Some(self.ttl),
101            poll_interval: Some(self.poll_interval),
102            progress: self.progress,
103            message: self.message.clone(),
104        }
105    }
106
107    /// Check if this task should be cleaned up (TTL expired)
108    pub fn is_expired(&self) -> bool {
109        if let Some(completed_at) = self.completed_at {
110            completed_at.elapsed() > Duration::from_secs(self.ttl)
111        } else {
112            false
113        }
114    }
115
116    /// Check if the task has been cancelled
117    pub fn is_cancelled(&self) -> bool {
118        self.cancellation_token.is_cancelled()
119    }
120}
121
122/// A shareable cancellation token for task management
123#[derive(Debug, Clone)]
124pub struct CancellationToken {
125    cancelled: Arc<AtomicBool>,
126}
127
128impl CancellationToken {
129    /// Check if cancellation has been requested
130    pub fn is_cancelled(&self) -> bool {
131        self.cancelled.load(Ordering::Relaxed)
132    }
133
134    /// Request cancellation
135    pub fn cancel(&self) {
136        self.cancelled.store(true, Ordering::Relaxed);
137    }
138}
139
140/// Thread-safe task storage
141#[derive(Debug, Clone)]
142pub struct TaskStore {
143    tasks: Arc<RwLock<HashMap<String, Task>>>,
144    next_id: Arc<AtomicU64>,
145}
146
147impl Default for TaskStore {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl TaskStore {
154    /// Create a new task store
155    pub fn new() -> Self {
156        Self {
157            tasks: Arc::new(RwLock::new(HashMap::new())),
158            next_id: Arc::new(AtomicU64::new(1)),
159        }
160    }
161
162    /// Generate a unique task ID
163    fn generate_id(&self) -> String {
164        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
165        format!("task-{}", id)
166    }
167
168    /// Create and store a new task
169    ///
170    /// Returns the task ID and a cancellation token for the spawned work.
171    pub fn create_task(
172        &self,
173        tool_name: &str,
174        arguments: serde_json::Value,
175        ttl: Option<u64>,
176    ) -> (String, CancellationToken) {
177        let id = self.generate_id();
178        let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
179        let token = task.cancellation_token.clone();
180
181        if let Ok(mut tasks) = self.tasks.write() {
182            tasks.insert(id.clone(), task);
183        }
184
185        (id, token)
186    }
187
188    /// Get task info by ID
189    pub fn get_task(&self, task_id: &str) -> Option<TaskInfo> {
190        if let Ok(tasks) = self.tasks.read() {
191            tasks.get(task_id).map(|t| t.to_info())
192        } else {
193            None
194        }
195    }
196
197    /// Get full task data by ID (for internal use)
198    pub fn get_task_full(
199        &self,
200        task_id: &str,
201    ) -> Option<(TaskStatus, Option<CallToolResult>, Option<String>)> {
202        if let Ok(tasks) = self.tasks.read() {
203            tasks
204                .get(task_id)
205                .map(|t| (t.status, t.result.clone(), t.error.clone()))
206        } else {
207            None
208        }
209    }
210
211    /// List all tasks, optionally filtered by status
212    pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
213        if let Ok(tasks) = self.tasks.read() {
214            tasks
215                .values()
216                .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
217                .map(|t| t.to_info())
218                .collect()
219        } else {
220            vec![]
221        }
222    }
223
224    /// Update task progress
225    pub fn update_progress(&self, task_id: &str, progress: f64, message: Option<String>) -> bool {
226        if let Ok(mut tasks) = self.tasks.write()
227            && let Some(task) = tasks.get_mut(task_id)
228            && !task.status.is_terminal()
229        {
230            task.progress = Some(progress);
231            if let Some(msg) = message {
232                task.message = Some(msg);
233            }
234            return true;
235        }
236        false
237    }
238
239    /// Mark a task as requiring input
240    pub fn require_input(&self, task_id: &str, message: &str) -> bool {
241        if let Ok(mut tasks) = self.tasks.write()
242            && let Some(task) = tasks.get_mut(task_id)
243            && !task.status.is_terminal()
244        {
245            task.status = TaskStatus::InputRequired;
246            task.message = Some(message.to_string());
247            return true;
248        }
249        false
250    }
251
252    /// Mark a task as completed with a result
253    pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
254        if let Ok(mut tasks) = self.tasks.write()
255            && let Some(task) = tasks.get_mut(task_id)
256            && !task.status.is_terminal()
257        {
258            task.status = TaskStatus::Completed;
259            task.progress = Some(100.0);
260            task.message = Some("Task completed".to_string());
261            task.result = Some(result);
262            task.completed_at = Some(Instant::now());
263            return true;
264        }
265        false
266    }
267
268    /// Mark a task as failed with an error
269    pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
270        if let Ok(mut tasks) = self.tasks.write()
271            && let Some(task) = tasks.get_mut(task_id)
272            && !task.status.is_terminal()
273        {
274            task.status = TaskStatus::Failed;
275            task.message = Some(format!("Task failed: {}", error));
276            task.error = Some(error.to_string());
277            task.completed_at = Some(Instant::now());
278            return true;
279        }
280        false
281    }
282
283    /// Cancel a task
284    pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskStatus> {
285        if let Ok(mut tasks) = self.tasks.write()
286            && let Some(task) = tasks.get_mut(task_id)
287        {
288            // Signal cancellation
289            task.cancellation_token.cancel();
290
291            // If not already terminal, mark as cancelled
292            if !task.status.is_terminal() {
293                task.status = TaskStatus::Cancelled;
294                task.message = Some(
295                    reason
296                        .map(|r| format!("Cancelled: {}", r))
297                        .unwrap_or_else(|| "Task cancelled".to_string()),
298                );
299                task.completed_at = Some(Instant::now());
300            }
301            return Some(task.status);
302        }
303        None
304    }
305
306    /// Remove expired tasks (call periodically for cleanup)
307    pub fn cleanup_expired(&self) -> usize {
308        if let Ok(mut tasks) = self.tasks.write() {
309            let before = tasks.len();
310            tasks.retain(|_, t| !t.is_expired());
311            before - tasks.len()
312        } else {
313            0
314        }
315    }
316
317    /// Get the number of tasks in the store
318    #[cfg(test)]
319    pub fn len(&self) -> usize {
320        if let Ok(tasks) = self.tasks.read() {
321            tasks.len()
322        } else {
323            0
324        }
325    }
326
327    /// Check if the store is empty
328    #[cfg(test)]
329    pub fn is_empty(&self) -> bool {
330        self.len() == 0
331    }
332}
333
334/// Generate ISO 8601 timestamp for current time
335fn chrono_now_iso8601() -> String {
336    use std::time::SystemTime;
337
338    let now = SystemTime::now();
339    let duration = now
340        .duration_since(SystemTime::UNIX_EPOCH)
341        .unwrap_or_default();
342
343    let secs = duration.as_secs();
344    let millis = duration.subsec_millis();
345
346    // Simple ISO 8601 format (UTC)
347    // Calculate date/time components
348    let days = secs / 86400;
349    let remaining = secs % 86400;
350    let hours = remaining / 3600;
351    let remaining = remaining % 3600;
352    let minutes = remaining / 60;
353    let seconds = remaining % 60;
354
355    // Calculate year/month/day from days since epoch (1970-01-01)
356    // This is a simplified calculation that handles leap years
357    let mut year = 1970i32;
358    let mut remaining_days = days as i32;
359
360    loop {
361        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
362        if remaining_days < days_in_year {
363            break;
364        }
365        remaining_days -= days_in_year;
366        year += 1;
367    }
368
369    let days_in_months: [i32; 12] = if is_leap_year(year) {
370        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
371    } else {
372        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
373    };
374
375    let mut month = 1;
376    for days_in_month in days_in_months.iter() {
377        if remaining_days < *days_in_month {
378            break;
379        }
380        remaining_days -= days_in_month;
381        month += 1;
382    }
383
384    let day = remaining_days + 1;
385
386    format!(
387        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
388        year, month, day, hours, minutes, seconds, millis
389    )
390}
391
392fn is_leap_year(year: i32) -> bool {
393    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_create_task() {
402        let store = TaskStore::new();
403        let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
404
405        assert!(id.starts_with("task-"));
406        assert!(!token.is_cancelled());
407
408        let info = store.get_task(&id).expect("task should exist");
409        assert_eq!(info.task_id, id);
410        assert_eq!(info.status, TaskStatus::Working);
411    }
412
413    #[test]
414    fn test_task_lifecycle() {
415        let store = TaskStore::new();
416        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
417
418        // Update progress
419        assert!(store.update_progress(&id, 50.0, Some("Halfway".to_string())));
420
421        let info = store.get_task(&id).unwrap();
422        assert_eq!(info.progress, Some(50.0));
423        assert_eq!(info.message.as_deref(), Some("Halfway"));
424
425        // Complete task
426        assert!(store.complete_task(&id, CallToolResult::text("Done")));
427
428        let info = store.get_task(&id).unwrap();
429        assert_eq!(info.status, TaskStatus::Completed);
430        assert_eq!(info.progress, Some(100.0));
431    }
432
433    #[test]
434    fn test_task_cancellation() {
435        let store = TaskStore::new();
436        let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
437
438        assert!(!token.is_cancelled());
439
440        let status = store.cancel_task(&id, Some("User requested"));
441        assert_eq!(status, Some(TaskStatus::Cancelled));
442        assert!(token.is_cancelled());
443
444        let info = store.get_task(&id).unwrap();
445        assert_eq!(info.status, TaskStatus::Cancelled);
446    }
447
448    #[test]
449    fn test_task_failure() {
450        let store = TaskStore::new();
451        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
452
453        assert!(store.fail_task(&id, "Something went wrong"));
454
455        let info = store.get_task(&id).unwrap();
456        assert_eq!(info.status, TaskStatus::Failed);
457        assert!(info.message.as_ref().unwrap().contains("failed"));
458    }
459
460    #[test]
461    fn test_list_tasks() {
462        let store = TaskStore::new();
463        store.create_task("tool1", serde_json::json!({}), None);
464        store.create_task("tool2", serde_json::json!({}), None);
465        let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
466
467        // Complete one task
468        store.complete_task(&id3, CallToolResult::text("Done"));
469
470        // List all tasks
471        let all = store.list_tasks(None);
472        assert_eq!(all.len(), 3);
473
474        // List only working tasks
475        let working = store.list_tasks(Some(TaskStatus::Working));
476        assert_eq!(working.len(), 2);
477
478        // List only completed tasks
479        let completed = store.list_tasks(Some(TaskStatus::Completed));
480        assert_eq!(completed.len(), 1);
481    }
482
483    #[test]
484    fn test_terminal_state_immutable() {
485        let store = TaskStore::new();
486        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
487
488        // Complete the task
489        store.complete_task(&id, CallToolResult::text("Done"));
490
491        // Try to update - should fail
492        assert!(!store.update_progress(&id, 50.0, None));
493        assert!(!store.fail_task(&id, "Error"));
494
495        // Status should still be completed
496        let info = store.get_task(&id).unwrap();
497        assert_eq!(info.status, TaskStatus::Completed);
498    }
499
500    #[test]
501    fn test_task_ids_unique() {
502        let store = TaskStore::new();
503        let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
504        let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
505        let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
506
507        assert_ne!(id1, id2);
508        assert_ne!(id2, id3);
509        assert_ne!(id1, id3);
510    }
511
512    #[test]
513    fn test_get_task_full() {
514        let store = TaskStore::new();
515        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
516
517        // Complete with result
518        let result = CallToolResult::text("The result");
519        store.complete_task(&id, result);
520
521        let (status, result, error) = store.get_task_full(&id).unwrap();
522        assert_eq!(status, TaskStatus::Completed);
523        assert!(result.is_some());
524        assert!(error.is_none());
525    }
526
527    #[test]
528    fn test_iso8601_timestamp() {
529        let ts = chrono_now_iso8601();
530        // Basic format check
531        assert!(ts.ends_with('Z'));
532        assert!(ts.contains('T'));
533        assert_eq!(ts.len(), 24); // YYYY-MM-DDTHH:MM:SS.mmmZ
534    }
535
536    #[test]
537    fn test_task_status_display() {
538        assert_eq!(TaskStatus::Working.to_string(), "working");
539        assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
540        assert_eq!(TaskStatus::Completed.to_string(), "completed");
541        assert_eq!(TaskStatus::Failed.to_string(), "failed");
542        assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
543    }
544
545    #[test]
546    fn test_task_status_is_terminal() {
547        assert!(!TaskStatus::Working.is_terminal());
548        assert!(!TaskStatus::InputRequired.is_terminal());
549        assert!(TaskStatus::Completed.is_terminal());
550        assert!(TaskStatus::Failed.is_terminal());
551        assert!(TaskStatus::Cancelled.is_terminal());
552    }
553}