Skip to main content

tower_mcp/
async_task.rs

1//! Async task management for long-running MCP operations
2//!
3//! This module provides task lifecycle management for operations that may take
4//! longer than a typical request/response cycle. Tasks can be enqueued, tracked,
5//! polled for status, and cancelled.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use tower_mcp::async_task::{TaskStore, Task};
11//! use tower_mcp::protocol::TaskStatus;
12//!
13//! // Create a task store
14//! let store = TaskStore::new();
15//!
16//! // Enqueue a task
17//! let task = store.create_task("my-tool", serde_json::json!({"key": "value"}), None);
18//!
19//! // Get task status
20//! let info = store.get_task(&task.id);
21//!
22//! // Mark task as complete
23//! store.complete_task(&task.id, Ok(result));
24//! ```
25
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskInfo, TaskStatus};
32
33/// Default time-to-live for completed tasks (5 minutes)
34const DEFAULT_TTL_SECS: u64 = 300;
35
36/// Default poll interval suggestion (2 seconds)
37const DEFAULT_POLL_INTERVAL_SECS: u64 = 2;
38
39/// Internal task representation with full state
40#[derive(Debug)]
41pub struct Task {
42    /// Unique task identifier
43    pub id: String,
44    /// Name of the tool being executed
45    pub tool_name: String,
46    /// Arguments passed to the tool
47    pub arguments: serde_json::Value,
48    /// Current task status
49    pub status: TaskStatus,
50    /// When the task was created
51    pub created_at: Instant,
52    /// ISO 8601 timestamp string
53    pub created_at_str: String,
54    /// Time-to-live in seconds (for cleanup after completion)
55    pub ttl: u64,
56    /// Suggested polling interval in seconds
57    pub poll_interval: u64,
58    /// Current progress (0.0 - 100.0)
59    pub progress: Option<f64>,
60    /// Human-readable status message
61    pub message: Option<String>,
62    /// The result of the tool call (when completed)
63    pub result: Option<CallToolResult>,
64    /// Error message (when failed)
65    pub error: Option<String>,
66    /// Cancellation token for aborting the task
67    pub cancellation_token: CancellationToken,
68    /// When the task reached terminal status (for TTL tracking)
69    pub completed_at: Option<Instant>,
70}
71
72impl Task {
73    /// Create a new task
74    fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
75        let cancelled = Arc::new(AtomicBool::new(false));
76        Self {
77            id,
78            tool_name,
79            arguments,
80            status: TaskStatus::Working,
81            created_at: Instant::now(),
82            created_at_str: chrono_now_iso8601(),
83            ttl: ttl.unwrap_or(DEFAULT_TTL_SECS),
84            poll_interval: DEFAULT_POLL_INTERVAL_SECS,
85            progress: None,
86            message: Some("Task started".to_string()),
87            result: None,
88            error: None,
89            cancellation_token: CancellationToken { cancelled },
90            completed_at: None,
91        }
92    }
93
94    /// Convert to TaskInfo for API responses
95    pub fn to_info(&self) -> TaskInfo {
96        TaskInfo {
97            task_id: self.id.clone(),
98            status: self.status,
99            created_at: self.created_at_str.clone(),
100            ttl: Some(self.ttl),
101            poll_interval: Some(self.poll_interval),
102            progress: self.progress,
103            message: self.message.clone(),
104        }
105    }
106
107    /// Check if this task should be cleaned up (TTL expired)
108    pub fn is_expired(&self) -> bool {
109        if let Some(completed_at) = self.completed_at {
110            completed_at.elapsed() > Duration::from_secs(self.ttl)
111        } else {
112            false
113        }
114    }
115
116    /// Check if the task has been cancelled
117    pub fn is_cancelled(&self) -> bool {
118        self.cancellation_token.is_cancelled()
119    }
120}
121
122/// A shareable cancellation token for task management
123#[derive(Debug, Clone)]
124pub struct CancellationToken {
125    cancelled: Arc<AtomicBool>,
126}
127
128impl CancellationToken {
129    /// Check if cancellation has been requested
130    pub fn is_cancelled(&self) -> bool {
131        self.cancelled.load(Ordering::Relaxed)
132    }
133
134    /// Request cancellation
135    pub fn cancel(&self) {
136        self.cancelled.store(true, Ordering::Relaxed);
137    }
138}
139
140/// Thread-safe task storage
141#[derive(Debug, Clone)]
142pub struct TaskStore {
143    tasks: Arc<RwLock<HashMap<String, Task>>>,
144    next_id: Arc<AtomicU64>,
145}
146
147impl Default for TaskStore {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl TaskStore {
154    /// Create a new task store
155    pub fn new() -> Self {
156        Self {
157            tasks: Arc::new(RwLock::new(HashMap::new())),
158            next_id: Arc::new(AtomicU64::new(1)),
159        }
160    }
161
162    /// Generate a unique task ID
163    fn generate_id(&self) -> String {
164        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
165        format!("task-{}", id)
166    }
167
168    /// Create and store a new task
169    ///
170    /// Returns the task ID and a cancellation token for the spawned work.
171    pub fn create_task(
172        &self,
173        tool_name: &str,
174        arguments: serde_json::Value,
175        ttl: Option<u64>,
176    ) -> (String, CancellationToken) {
177        let id = self.generate_id();
178        let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
179        let token = task.cancellation_token.clone();
180
181        if let Ok(mut tasks) = self.tasks.write() {
182            tasks.insert(id.clone(), task);
183        }
184
185        (id, token)
186    }
187
188    /// Get task info by ID
189    pub fn get_task(&self, task_id: &str) -> Option<TaskInfo> {
190        if let Ok(tasks) = self.tasks.read() {
191            tasks.get(task_id).map(|t| t.to_info())
192        } else {
193            None
194        }
195    }
196
197    /// Get full task data by ID (for internal use)
198    pub fn get_task_full(
199        &self,
200        task_id: &str,
201    ) -> Option<(TaskStatus, Option<CallToolResult>, Option<String>)> {
202        if let Ok(tasks) = self.tasks.read() {
203            tasks
204                .get(task_id)
205                .map(|t| (t.status, t.result.clone(), t.error.clone()))
206        } else {
207            None
208        }
209    }
210
211    /// List all tasks, optionally filtered by status
212    pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
213        if let Ok(tasks) = self.tasks.read() {
214            tasks
215                .values()
216                .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
217                .map(|t| t.to_info())
218                .collect()
219        } else {
220            vec![]
221        }
222    }
223
224    /// Update task progress
225    pub fn update_progress(&self, task_id: &str, progress: f64, message: Option<String>) -> bool {
226        let Ok(mut tasks) = self.tasks.write() else {
227            return false;
228        };
229        let Some(task) = tasks.get_mut(task_id) else {
230            return false;
231        };
232        if task.status.is_terminal() {
233            return false;
234        }
235        task.progress = Some(progress);
236        if let Some(msg) = message {
237            task.message = Some(msg);
238        }
239        true
240    }
241
242    /// Mark a task as requiring input
243    pub fn require_input(&self, task_id: &str, message: &str) -> bool {
244        let Ok(mut tasks) = self.tasks.write() else {
245            return false;
246        };
247        let Some(task) = tasks.get_mut(task_id) else {
248            return false;
249        };
250        if task.status.is_terminal() {
251            return false;
252        }
253        task.status = TaskStatus::InputRequired;
254        task.message = Some(message.to_string());
255        true
256    }
257
258    /// Mark a task as completed with a result
259    pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
260        let Ok(mut tasks) = self.tasks.write() else {
261            return false;
262        };
263        let Some(task) = tasks.get_mut(task_id) else {
264            return false;
265        };
266        if task.status.is_terminal() {
267            return false;
268        }
269        task.status = TaskStatus::Completed;
270        task.progress = Some(100.0);
271        task.message = Some("Task completed".to_string());
272        task.result = Some(result);
273        task.completed_at = Some(Instant::now());
274        true
275    }
276
277    /// Mark a task as failed with an error
278    pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
279        let Ok(mut tasks) = self.tasks.write() else {
280            return false;
281        };
282        let Some(task) = tasks.get_mut(task_id) else {
283            return false;
284        };
285        if task.status.is_terminal() {
286            return false;
287        }
288        task.status = TaskStatus::Failed;
289        task.message = Some(format!("Task failed: {}", error));
290        task.error = Some(error.to_string());
291        task.completed_at = Some(Instant::now());
292        true
293    }
294
295    /// Cancel a task
296    pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskStatus> {
297        let mut tasks = self.tasks.write().ok()?;
298        let task = tasks.get_mut(task_id)?;
299
300        // Signal cancellation
301        task.cancellation_token.cancel();
302
303        // If not already terminal, mark as cancelled
304        if !task.status.is_terminal() {
305            task.status = TaskStatus::Cancelled;
306            task.message = Some(
307                reason
308                    .map(|r| format!("Cancelled: {}", r))
309                    .unwrap_or_else(|| "Task cancelled".to_string()),
310            );
311            task.completed_at = Some(Instant::now());
312        }
313        Some(task.status)
314    }
315
316    /// Remove expired tasks (call periodically for cleanup)
317    pub fn cleanup_expired(&self) -> usize {
318        if let Ok(mut tasks) = self.tasks.write() {
319            let before = tasks.len();
320            tasks.retain(|_, t| !t.is_expired());
321            before - tasks.len()
322        } else {
323            0
324        }
325    }
326
327    /// Get the number of tasks in the store
328    #[cfg(test)]
329    pub fn len(&self) -> usize {
330        if let Ok(tasks) = self.tasks.read() {
331            tasks.len()
332        } else {
333            0
334        }
335    }
336
337    /// Check if the store is empty
338    #[cfg(test)]
339    pub fn is_empty(&self) -> bool {
340        self.len() == 0
341    }
342}
343
344/// Generate ISO 8601 timestamp for current time
345fn chrono_now_iso8601() -> String {
346    use std::time::SystemTime;
347
348    let now = SystemTime::now();
349    let duration = now
350        .duration_since(SystemTime::UNIX_EPOCH)
351        .unwrap_or_default();
352
353    let secs = duration.as_secs();
354    let millis = duration.subsec_millis();
355
356    // Simple ISO 8601 format (UTC)
357    // Calculate date/time components
358    let days = secs / 86400;
359    let remaining = secs % 86400;
360    let hours = remaining / 3600;
361    let remaining = remaining % 3600;
362    let minutes = remaining / 60;
363    let seconds = remaining % 60;
364
365    // Calculate year/month/day from days since epoch (1970-01-01)
366    // This is a simplified calculation that handles leap years
367    let mut year = 1970i32;
368    let mut remaining_days = days as i32;
369
370    loop {
371        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
372        if remaining_days < days_in_year {
373            break;
374        }
375        remaining_days -= days_in_year;
376        year += 1;
377    }
378
379    let days_in_months: [i32; 12] = if is_leap_year(year) {
380        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
381    } else {
382        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
383    };
384
385    let mut month = 1;
386    for days_in_month in days_in_months.iter() {
387        if remaining_days < *days_in_month {
388            break;
389        }
390        remaining_days -= days_in_month;
391        month += 1;
392    }
393
394    let day = remaining_days + 1;
395
396    format!(
397        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
398        year, month, day, hours, minutes, seconds, millis
399    )
400}
401
402fn is_leap_year(year: i32) -> bool {
403    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_create_task() {
412        let store = TaskStore::new();
413        let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
414
415        assert!(id.starts_with("task-"));
416        assert!(!token.is_cancelled());
417
418        let info = store.get_task(&id).expect("task should exist");
419        assert_eq!(info.task_id, id);
420        assert_eq!(info.status, TaskStatus::Working);
421    }
422
423    #[test]
424    fn test_task_lifecycle() {
425        let store = TaskStore::new();
426        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
427
428        // Update progress
429        assert!(store.update_progress(&id, 50.0, Some("Halfway".to_string())));
430
431        let info = store.get_task(&id).unwrap();
432        assert_eq!(info.progress, Some(50.0));
433        assert_eq!(info.message.as_deref(), Some("Halfway"));
434
435        // Complete task
436        assert!(store.complete_task(&id, CallToolResult::text("Done")));
437
438        let info = store.get_task(&id).unwrap();
439        assert_eq!(info.status, TaskStatus::Completed);
440        assert_eq!(info.progress, Some(100.0));
441    }
442
443    #[test]
444    fn test_task_cancellation() {
445        let store = TaskStore::new();
446        let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
447
448        assert!(!token.is_cancelled());
449
450        let status = store.cancel_task(&id, Some("User requested"));
451        assert_eq!(status, Some(TaskStatus::Cancelled));
452        assert!(token.is_cancelled());
453
454        let info = store.get_task(&id).unwrap();
455        assert_eq!(info.status, TaskStatus::Cancelled);
456    }
457
458    #[test]
459    fn test_task_failure() {
460        let store = TaskStore::new();
461        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
462
463        assert!(store.fail_task(&id, "Something went wrong"));
464
465        let info = store.get_task(&id).unwrap();
466        assert_eq!(info.status, TaskStatus::Failed);
467        assert!(info.message.as_ref().unwrap().contains("failed"));
468    }
469
470    #[test]
471    fn test_list_tasks() {
472        let store = TaskStore::new();
473        store.create_task("tool1", serde_json::json!({}), None);
474        store.create_task("tool2", serde_json::json!({}), None);
475        let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
476
477        // Complete one task
478        store.complete_task(&id3, CallToolResult::text("Done"));
479
480        // List all tasks
481        let all = store.list_tasks(None);
482        assert_eq!(all.len(), 3);
483
484        // List only working tasks
485        let working = store.list_tasks(Some(TaskStatus::Working));
486        assert_eq!(working.len(), 2);
487
488        // List only completed tasks
489        let completed = store.list_tasks(Some(TaskStatus::Completed));
490        assert_eq!(completed.len(), 1);
491    }
492
493    #[test]
494    fn test_terminal_state_immutable() {
495        let store = TaskStore::new();
496        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
497
498        // Complete the task
499        store.complete_task(&id, CallToolResult::text("Done"));
500
501        // Try to update - should fail
502        assert!(!store.update_progress(&id, 50.0, None));
503        assert!(!store.fail_task(&id, "Error"));
504
505        // Status should still be completed
506        let info = store.get_task(&id).unwrap();
507        assert_eq!(info.status, TaskStatus::Completed);
508    }
509
510    #[test]
511    fn test_task_ids_unique() {
512        let store = TaskStore::new();
513        let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
514        let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
515        let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
516
517        assert_ne!(id1, id2);
518        assert_ne!(id2, id3);
519        assert_ne!(id1, id3);
520    }
521
522    #[test]
523    fn test_get_task_full() {
524        let store = TaskStore::new();
525        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
526
527        // Complete with result
528        let result = CallToolResult::text("The result");
529        store.complete_task(&id, result);
530
531        let (status, result, error) = store.get_task_full(&id).unwrap();
532        assert_eq!(status, TaskStatus::Completed);
533        assert!(result.is_some());
534        assert!(error.is_none());
535    }
536
537    #[test]
538    fn test_iso8601_timestamp() {
539        let ts = chrono_now_iso8601();
540        // Basic format check
541        assert!(ts.ends_with('Z'));
542        assert!(ts.contains('T'));
543        assert_eq!(ts.len(), 24); // YYYY-MM-DDTHH:MM:SS.mmmZ
544    }
545
546    #[test]
547    fn test_task_status_display() {
548        assert_eq!(TaskStatus::Working.to_string(), "working");
549        assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
550        assert_eq!(TaskStatus::Completed.to_string(), "completed");
551        assert_eq!(TaskStatus::Failed.to_string(), "failed");
552        assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
553    }
554
555    #[test]
556    fn test_task_status_is_terminal() {
557        assert!(!TaskStatus::Working.is_terminal());
558        assert!(!TaskStatus::InputRequired.is_terminal());
559        assert!(TaskStatus::Completed.is_terminal());
560        assert!(TaskStatus::Failed.is_terminal());
561        assert!(TaskStatus::Cancelled.is_terminal());
562    }
563}