Skip to main content

tower_mcp/
async_task.rs

1//! Async task management for long-running MCP operations
2//!
3//! This module provides task lifecycle management for operations that may take
4//! longer than a typical request/response cycle. Tasks can be created via
5//! task-augmented `tools/call` requests, tracked, polled for status, and cancelled.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use tower_mcp::async_task::{TaskStore, Task};
11//! use tower_mcp::protocol::TaskStatus;
12//!
13//! // Create a task store
14//! let store = TaskStore::new();
15//!
16//! // Create a task
17//! let task = store.create_task("my-tool", serde_json::json!({"key": "value"}), None);
18//!
19//! // Get task status
20//! let info = store.get_task(&task.id);
21//!
22//! // Mark task as complete
23//! store.complete_task(&task.id, Ok(result));
24//! ```
25
26use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskObject, TaskStatus};
32
33/// Default time-to-live for completed tasks (5 minutes, in milliseconds)
34const DEFAULT_TTL_MS: u64 = 300_000;
35
36/// Default poll interval suggestion (2 seconds, in milliseconds)
37const DEFAULT_POLL_INTERVAL_MS: u64 = 2_000;
38
39/// Internal task representation with full state
40#[derive(Debug)]
41pub struct Task {
42    /// Unique task identifier
43    pub id: String,
44    /// Name of the tool being executed
45    pub tool_name: String,
46    /// Arguments passed to the tool
47    pub arguments: serde_json::Value,
48    /// Current task status
49    pub status: TaskStatus,
50    /// When the task was created
51    pub created_at: Instant,
52    /// ISO 8601 timestamp string
53    pub created_at_str: String,
54    /// ISO 8601 timestamp of last state change
55    pub last_updated_at_str: String,
56    /// Time-to-live in milliseconds (for cleanup after completion)
57    pub ttl: u64,
58    /// Suggested polling interval in milliseconds
59    pub poll_interval: u64,
60    /// Human-readable status message
61    pub status_message: Option<String>,
62    /// The result of the tool call (when completed)
63    pub result: Option<CallToolResult>,
64    /// Error message (when failed)
65    pub error: Option<String>,
66    /// Cancellation token for aborting the task
67    pub cancellation_token: CancellationToken,
68    /// When the task reached terminal status (for TTL tracking)
69    pub completed_at: Option<Instant>,
70    /// Notified when task reaches a terminal state
71    pub completion_notify: Arc<tokio::sync::Notify>,
72}
73
74impl Task {
75    /// Create a new task
76    fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
77        let cancelled = Arc::new(AtomicBool::new(false));
78        let now_str = chrono_now_iso8601();
79        Self {
80            id,
81            tool_name,
82            arguments,
83            status: TaskStatus::Working,
84            created_at: Instant::now(),
85            created_at_str: now_str.clone(),
86            last_updated_at_str: now_str,
87            ttl: ttl.unwrap_or(DEFAULT_TTL_MS),
88            poll_interval: DEFAULT_POLL_INTERVAL_MS,
89            status_message: Some("Task started".to_string()),
90            result: None,
91            error: None,
92            cancellation_token: CancellationToken { cancelled },
93            completed_at: None,
94            completion_notify: Arc::new(tokio::sync::Notify::new()),
95        }
96    }
97
98    /// Convert to TaskObject for API responses
99    pub fn to_task_object(&self) -> TaskObject {
100        TaskObject {
101            task_id: self.id.clone(),
102            status: self.status,
103            status_message: self.status_message.clone(),
104            created_at: self.created_at_str.clone(),
105            last_updated_at: self.last_updated_at_str.clone(),
106            ttl: Some(self.ttl),
107            poll_interval: Some(self.poll_interval),
108            meta: None,
109        }
110    }
111
112    /// Check if this task should be cleaned up (TTL expired)
113    pub fn is_expired(&self) -> bool {
114        if let Some(completed_at) = self.completed_at {
115            completed_at.elapsed() > Duration::from_millis(self.ttl)
116        } else {
117            false
118        }
119    }
120
121    /// Check if the task has been cancelled
122    pub fn is_cancelled(&self) -> bool {
123        self.cancellation_token.is_cancelled()
124    }
125}
126
127/// A shareable cancellation token for task management
128#[derive(Debug, Clone)]
129pub struct CancellationToken {
130    cancelled: Arc<AtomicBool>,
131}
132
133impl CancellationToken {
134    /// Check if cancellation has been requested
135    pub fn is_cancelled(&self) -> bool {
136        self.cancelled.load(Ordering::Relaxed)
137    }
138
139    /// Request cancellation
140    pub fn cancel(&self) {
141        self.cancelled.store(true, Ordering::Relaxed);
142    }
143}
144
145/// Thread-safe task storage
146#[derive(Debug, Clone)]
147pub struct TaskStore {
148    tasks: Arc<RwLock<HashMap<String, Task>>>,
149    next_id: Arc<AtomicU64>,
150}
151
152impl Default for TaskStore {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl TaskStore {
159    /// Create a new task store
160    pub fn new() -> Self {
161        Self {
162            tasks: Arc::new(RwLock::new(HashMap::new())),
163            next_id: Arc::new(AtomicU64::new(1)),
164        }
165    }
166
167    /// Generate a unique task ID
168    fn generate_id(&self) -> String {
169        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
170        format!("task-{}", id)
171    }
172
173    /// Create and store a new task
174    ///
175    /// Returns the task ID and a cancellation token for the spawned work.
176    pub fn create_task(
177        &self,
178        tool_name: &str,
179        arguments: serde_json::Value,
180        ttl: Option<u64>,
181    ) -> (String, CancellationToken) {
182        let id = self.generate_id();
183        let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
184        let token = task.cancellation_token.clone();
185
186        if let Ok(mut tasks) = self.tasks.write() {
187            tasks.insert(id.clone(), task);
188        }
189
190        (id, token)
191    }
192
193    /// Get task object by ID
194    pub fn get_task(&self, task_id: &str) -> Option<TaskObject> {
195        if let Ok(tasks) = self.tasks.read() {
196            tasks.get(task_id).map(|t| t.to_task_object())
197        } else {
198            None
199        }
200    }
201
202    /// Get full task result by ID (task object, result, error)
203    pub fn get_task_result(
204        &self,
205        task_id: &str,
206    ) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
207        if let Ok(tasks) = self.tasks.read() {
208            tasks
209                .get(task_id)
210                .map(|t| (t.to_task_object(), t.result.clone(), t.error.clone()))
211        } else {
212            None
213        }
214    }
215
216    /// Wait for a task to reach a terminal state, then return its result.
217    ///
218    /// If the task is already terminal, returns immediately. Otherwise blocks
219    /// until the task completes, fails, or is cancelled.
220    pub async fn wait_for_completion(
221        &self,
222        task_id: &str,
223    ) -> Option<(TaskObject, Option<CallToolResult>, Option<String>)> {
224        // First check if already terminal and get the notify handle
225        let notify = {
226            let tasks = self.tasks.read().ok()?;
227            let task = tasks.get(task_id)?;
228            if task.status.is_terminal() {
229                return Some((
230                    task.to_task_object(),
231                    task.result.clone(),
232                    task.error.clone(),
233                ));
234            }
235            task.completion_notify.clone()
236        };
237
238        // Wait for completion notification
239        notify.notified().await;
240
241        // Read the result
242        self.get_task_result(task_id)
243    }
244
245    /// List all tasks, optionally filtered by status
246    pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskObject> {
247        if let Ok(tasks) = self.tasks.read() {
248            tasks
249                .values()
250                .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
251                .map(|t| t.to_task_object())
252                .collect()
253        } else {
254            vec![]
255        }
256    }
257
258    /// Mark a task as requiring input
259    pub fn require_input(&self, task_id: &str, message: &str) -> bool {
260        let Ok(mut tasks) = self.tasks.write() else {
261            return false;
262        };
263        let Some(task) = tasks.get_mut(task_id) else {
264            return false;
265        };
266        if task.status.is_terminal() {
267            return false;
268        }
269        task.status = TaskStatus::InputRequired;
270        task.status_message = Some(message.to_string());
271        task.last_updated_at_str = chrono_now_iso8601();
272        true
273    }
274
275    /// Mark a task as completed with a result
276    pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
277        let Ok(mut tasks) = self.tasks.write() else {
278            return false;
279        };
280        let Some(task) = tasks.get_mut(task_id) else {
281            return false;
282        };
283        if task.status.is_terminal() {
284            return false;
285        }
286        task.status = TaskStatus::Completed;
287        task.status_message = Some("Task completed".to_string());
288        task.result = Some(result);
289        task.completed_at = Some(Instant::now());
290        task.last_updated_at_str = chrono_now_iso8601();
291        task.completion_notify.notify_waiters();
292        true
293    }
294
295    /// Mark a task as failed with an error
296    pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
297        let Ok(mut tasks) = self.tasks.write() else {
298            return false;
299        };
300        let Some(task) = tasks.get_mut(task_id) else {
301            return false;
302        };
303        if task.status.is_terminal() {
304            return false;
305        }
306        task.status = TaskStatus::Failed;
307        task.status_message = Some(format!("Task failed: {}", error));
308        task.error = Some(error.to_string());
309        task.completed_at = Some(Instant::now());
310        task.last_updated_at_str = chrono_now_iso8601();
311        task.completion_notify.notify_waiters();
312        true
313    }
314
315    /// Cancel a task
316    pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskObject> {
317        let mut tasks = self.tasks.write().ok()?;
318        let task = tasks.get_mut(task_id)?;
319
320        // Signal cancellation
321        task.cancellation_token.cancel();
322
323        // If not already terminal, mark as cancelled
324        if !task.status.is_terminal() {
325            task.status = TaskStatus::Cancelled;
326            task.status_message = Some(
327                reason
328                    .map(|r| format!("Cancelled: {}", r))
329                    .unwrap_or_else(|| "Task cancelled".to_string()),
330            );
331            task.completed_at = Some(Instant::now());
332            task.last_updated_at_str = chrono_now_iso8601();
333            task.completion_notify.notify_waiters();
334        }
335        Some(task.to_task_object())
336    }
337
338    /// Remove expired tasks (call periodically for cleanup)
339    pub fn cleanup_expired(&self) -> usize {
340        if let Ok(mut tasks) = self.tasks.write() {
341            let before = tasks.len();
342            tasks.retain(|_, t| !t.is_expired());
343            before - tasks.len()
344        } else {
345            0
346        }
347    }
348
349    /// Get the number of tasks in the store
350    #[cfg(test)]
351    pub fn len(&self) -> usize {
352        if let Ok(tasks) = self.tasks.read() {
353            tasks.len()
354        } else {
355            0
356        }
357    }
358
359    /// Check if the store is empty
360    #[cfg(test)]
361    pub fn is_empty(&self) -> bool {
362        self.len() == 0
363    }
364}
365
366/// Generate ISO 8601 timestamp for current time
367fn chrono_now_iso8601() -> String {
368    use std::time::SystemTime;
369
370    let now = SystemTime::now();
371    let duration = now
372        .duration_since(SystemTime::UNIX_EPOCH)
373        .unwrap_or_default();
374
375    let secs = duration.as_secs();
376    let millis = duration.subsec_millis();
377
378    // Simple ISO 8601 format (UTC)
379    // Calculate date/time components
380    let days = secs / 86400;
381    let remaining = secs % 86400;
382    let hours = remaining / 3600;
383    let remaining = remaining % 3600;
384    let minutes = remaining / 60;
385    let seconds = remaining % 60;
386
387    // Calculate year/month/day from days since epoch (1970-01-01)
388    // This is a simplified calculation that handles leap years
389    let mut year = 1970i32;
390    let mut remaining_days = days as i32;
391
392    loop {
393        let days_in_year = if is_leap_year(year) { 366 } else { 365 };
394        if remaining_days < days_in_year {
395            break;
396        }
397        remaining_days -= days_in_year;
398        year += 1;
399    }
400
401    let days_in_months: [i32; 12] = if is_leap_year(year) {
402        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
403    } else {
404        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
405    };
406
407    let mut month = 1;
408    for days_in_month in days_in_months.iter() {
409        if remaining_days < *days_in_month {
410            break;
411        }
412        remaining_days -= days_in_month;
413        month += 1;
414    }
415
416    let day = remaining_days + 1;
417
418    format!(
419        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
420        year, month, day, hours, minutes, seconds, millis
421    )
422}
423
424fn is_leap_year(year: i32) -> bool {
425    (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_create_task() {
434        let store = TaskStore::new();
435        let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
436
437        assert!(id.starts_with("task-"));
438        assert!(!token.is_cancelled());
439
440        let info = store.get_task(&id).expect("task should exist");
441        assert_eq!(info.task_id, id);
442        assert_eq!(info.status, TaskStatus::Working);
443    }
444
445    #[test]
446    fn test_task_lifecycle() {
447        let store = TaskStore::new();
448        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
449
450        // Complete task
451        assert!(store.complete_task(&id, CallToolResult::text("Done")));
452
453        let info = store.get_task(&id).unwrap();
454        assert_eq!(info.status, TaskStatus::Completed);
455    }
456
457    #[test]
458    fn test_task_cancellation() {
459        let store = TaskStore::new();
460        let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
461
462        assert!(!token.is_cancelled());
463
464        let task_obj = store.cancel_task(&id, Some("User requested"));
465        assert!(task_obj.is_some());
466        assert_eq!(task_obj.unwrap().status, TaskStatus::Cancelled);
467        assert!(token.is_cancelled());
468
469        let info = store.get_task(&id).unwrap();
470        assert_eq!(info.status, TaskStatus::Cancelled);
471    }
472
473    #[test]
474    fn test_task_failure() {
475        let store = TaskStore::new();
476        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
477
478        assert!(store.fail_task(&id, "Something went wrong"));
479
480        let info = store.get_task(&id).unwrap();
481        assert_eq!(info.status, TaskStatus::Failed);
482        assert!(info.status_message.as_ref().unwrap().contains("failed"));
483    }
484
485    #[test]
486    fn test_list_tasks() {
487        let store = TaskStore::new();
488        store.create_task("tool1", serde_json::json!({}), None);
489        store.create_task("tool2", serde_json::json!({}), None);
490        let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
491
492        // Complete one task
493        store.complete_task(&id3, CallToolResult::text("Done"));
494
495        // List all tasks
496        let all = store.list_tasks(None);
497        assert_eq!(all.len(), 3);
498
499        // List only working tasks
500        let working = store.list_tasks(Some(TaskStatus::Working));
501        assert_eq!(working.len(), 2);
502
503        // List only completed tasks
504        let completed = store.list_tasks(Some(TaskStatus::Completed));
505        assert_eq!(completed.len(), 1);
506    }
507
508    #[test]
509    fn test_terminal_state_immutable() {
510        let store = TaskStore::new();
511        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
512
513        // Complete the task
514        store.complete_task(&id, CallToolResult::text("Done"));
515
516        // Try to fail - should fail
517        assert!(!store.fail_task(&id, "Error"));
518
519        // Status should still be completed
520        let info = store.get_task(&id).unwrap();
521        assert_eq!(info.status, TaskStatus::Completed);
522    }
523
524    #[test]
525    fn test_task_ids_unique() {
526        let store = TaskStore::new();
527        let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
528        let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
529        let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
530
531        assert_ne!(id1, id2);
532        assert_ne!(id2, id3);
533        assert_ne!(id1, id3);
534    }
535
536    #[test]
537    fn test_get_task_result() {
538        let store = TaskStore::new();
539        let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
540
541        // Complete with result
542        let result = CallToolResult::text("The result");
543        store.complete_task(&id, result);
544
545        let (task_obj, result, error) = store.get_task_result(&id).unwrap();
546        assert_eq!(task_obj.status, TaskStatus::Completed);
547        assert!(result.is_some());
548        assert!(error.is_none());
549    }
550
551    #[test]
552    fn test_iso8601_timestamp() {
553        let ts = chrono_now_iso8601();
554        // Basic format check
555        assert!(ts.ends_with('Z'));
556        assert!(ts.contains('T'));
557        assert_eq!(ts.len(), 24); // YYYY-MM-DDTHH:MM:SS.mmmZ
558    }
559
560    #[test]
561    fn test_task_status_display() {
562        assert_eq!(TaskStatus::Working.to_string(), "working");
563        assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
564        assert_eq!(TaskStatus::Completed.to_string(), "completed");
565        assert_eq!(TaskStatus::Failed.to_string(), "failed");
566        assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
567    }
568
569    #[test]
570    fn test_task_status_is_terminal() {
571        assert!(!TaskStatus::Working.is_terminal());
572        assert!(!TaskStatus::InputRequired.is_terminal());
573        assert!(TaskStatus::Completed.is_terminal());
574        assert!(TaskStatus::Failed.is_terminal());
575        assert!(TaskStatus::Cancelled.is_terminal());
576    }
577}