Skip to main content

turul_mcp_server/task/
tokio_executor.rs

1//! Tokio-based task executor — default in-process execution using tokio::spawn.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::{RwLock, watch};
8use tracing::debug;
9
10use turul_mcp_protocol::TaskStatus;
11use turul_mcp_task_storage::{TaskOutcome, TaskStorageError};
12
13use crate::cancellation::CancellationHandle;
14use crate::task::executor::{BoxedTaskWork, TaskExecutor, TaskHandle};
15
16struct TokioTaskEntry {
17    cancellation: CancellationHandle,
18    status_tx: watch::Sender<TaskStatus>,
19}
20
21/// In-process task executor using Tokio runtime.
22pub struct TokioTaskExecutor {
23    entries: Arc<RwLock<HashMap<String, TokioTaskEntry>>>,
24}
25
26impl TokioTaskExecutor {
27    pub fn new() -> Self {
28        Self {
29            entries: Arc::new(RwLock::new(HashMap::new())),
30        }
31    }
32}
33
34impl Default for TokioTaskExecutor {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40struct TokioTaskHandle {
41    cancellation: CancellationHandle,
42}
43
44impl TaskHandle for TokioTaskHandle {
45    fn cancel(&self) {
46        self.cancellation.cancel();
47    }
48    fn is_cancelled(&self) -> bool {
49        self.cancellation.is_cancelled()
50    }
51}
52
53#[async_trait]
54impl TaskExecutor for TokioTaskExecutor {
55    async fn start_task(
56        &self,
57        task_id: &str,
58        work: BoxedTaskWork,
59    ) -> Result<Box<dyn TaskHandle>, TaskStorageError> {
60        let cancellation = CancellationHandle::new();
61        let (status_tx, _) = watch::channel(TaskStatus::Working);
62
63        let entry = TokioTaskEntry {
64            cancellation: cancellation.clone(),
65            status_tx: status_tx.clone(),
66        };
67        self.entries
68            .write()
69            .await
70            .insert(task_id.to_string(), entry);
71
72        let cancel_clone = cancellation.clone();
73        let task_id_owned = task_id.to_string();
74        let entries = Arc::clone(&self.entries);
75
76        tokio::spawn(async move {
77            let outcome = tokio::select! {
78                result = (work)() => result,
79                _ = cancel_clone.cancelled() => {
80                    TaskOutcome::Error {
81                        code: -32800,
82                        message: "Task cancelled".to_string(),
83                        data: None,
84                    }
85                }
86            };
87
88            let terminal_status = match &outcome {
89                TaskOutcome::Success(_) => TaskStatus::Completed,
90                TaskOutcome::Error { .. } => TaskStatus::Failed,
91            };
92            if let Some(entry) = entries.read().await.get(&task_id_owned) {
93                let _ = entry.status_tx.send(terminal_status);
94            }
95            // Small delay to let watchers receive the notification before cleanup
96            tokio::task::yield_now().await;
97            entries.write().await.remove(&task_id_owned);
98
99            debug!(task_id = %task_id_owned, status = ?terminal_status, "Task execution completed");
100        });
101
102        Ok(Box::new(TokioTaskHandle { cancellation }))
103    }
104
105    async fn cancel_task(&self, task_id: &str) -> Result<(), TaskStorageError> {
106        if let Some(entry) = self.entries.read().await.get(task_id) {
107            entry.cancellation.cancel();
108            Ok(())
109        } else {
110            Err(TaskStorageError::TaskNotFound(task_id.to_string()))
111        }
112    }
113
114    async fn await_terminal(&self, task_id: &str) -> Option<TaskStatus> {
115        let mut rx = {
116            let entries = self.entries.read().await;
117            entries.get(task_id)?.status_tx.subscribe()
118        };
119        loop {
120            if rx.changed().await.is_err() {
121                // Sender dropped — task entry was cleaned up, meaning it completed
122                return None;
123            }
124            let status = *rx.borrow();
125            if turul_mcp_task_storage::is_terminal(status) {
126                return Some(status);
127            }
128        }
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[tokio::test]
137    async fn test_start_and_complete_task() {
138        let executor = TokioTaskExecutor::new();
139        let handle = executor
140            .start_task(
141                "task-1",
142                Box::new(|| {
143                    Box::pin(async { TaskOutcome::Success(serde_json::json!({"result": 42})) })
144                }),
145            )
146            .await
147            .unwrap();
148
149        // Task should complete
150        let status = executor.await_terminal("task-1").await;
151        assert!(matches!(status, Some(TaskStatus::Completed)));
152        assert!(!handle.is_cancelled());
153    }
154
155    #[tokio::test]
156    async fn test_cancel_task() {
157        let executor = TokioTaskExecutor::new();
158        let handle = executor
159            .start_task(
160                "task-2",
161                Box::new(|| {
162                    Box::pin(async {
163                        // Simulate long-running work
164                        tokio::time::sleep(std::time::Duration::from_secs(60)).await;
165                        TaskOutcome::Success(serde_json::json!({}))
166                    })
167                }),
168            )
169            .await
170            .unwrap();
171
172        // Cancel it
173        executor.cancel_task("task-2").await.unwrap();
174        assert!(handle.is_cancelled());
175    }
176
177    #[tokio::test]
178    async fn test_cancel_nonexistent_task() {
179        let executor = TokioTaskExecutor::new();
180        let result = executor.cancel_task("nonexistent").await;
181        assert!(result.is_err());
182    }
183
184    #[tokio::test]
185    async fn test_await_terminal_nonexistent() {
186        let executor = TokioTaskExecutor::new();
187        let result = executor.await_terminal("nonexistent").await;
188        assert!(result.is_none());
189    }
190}