turul_mcp_server/task/
tokio_executor.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use tokio::sync::{RwLock, watch};
8use tracing::debug;
9
10use turul_mcp_protocol::TaskStatus;
11use turul_mcp_task_storage::{TaskOutcome, TaskStorageError};
12
13use crate::cancellation::CancellationHandle;
14use crate::task::executor::{BoxedTaskWork, TaskExecutor, TaskHandle};
15
16struct TokioTaskEntry {
17 cancellation: CancellationHandle,
18 status_tx: watch::Sender<TaskStatus>,
19}
20
21pub struct TokioTaskExecutor {
23 entries: Arc<RwLock<HashMap<String, TokioTaskEntry>>>,
24}
25
26impl TokioTaskExecutor {
27 pub fn new() -> Self {
28 Self {
29 entries: Arc::new(RwLock::new(HashMap::new())),
30 }
31 }
32}
33
34impl Default for TokioTaskExecutor {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40struct TokioTaskHandle {
41 cancellation: CancellationHandle,
42}
43
44impl TaskHandle for TokioTaskHandle {
45 fn cancel(&self) {
46 self.cancellation.cancel();
47 }
48 fn is_cancelled(&self) -> bool {
49 self.cancellation.is_cancelled()
50 }
51}
52
53#[async_trait]
54impl TaskExecutor for TokioTaskExecutor {
55 async fn start_task(
56 &self,
57 task_id: &str,
58 work: BoxedTaskWork,
59 ) -> Result<Box<dyn TaskHandle>, TaskStorageError> {
60 let cancellation = CancellationHandle::new();
61 let (status_tx, _) = watch::channel(TaskStatus::Working);
62
63 let entry = TokioTaskEntry {
64 cancellation: cancellation.clone(),
65 status_tx: status_tx.clone(),
66 };
67 self.entries
68 .write()
69 .await
70 .insert(task_id.to_string(), entry);
71
72 let cancel_clone = cancellation.clone();
73 let task_id_owned = task_id.to_string();
74 let entries = Arc::clone(&self.entries);
75
76 tokio::spawn(async move {
77 let outcome = tokio::select! {
78 result = (work)() => result,
79 _ = cancel_clone.cancelled() => {
80 TaskOutcome::Error {
81 code: -32800,
82 message: "Task cancelled".to_string(),
83 data: None,
84 }
85 }
86 };
87
88 let terminal_status = match &outcome {
89 TaskOutcome::Success(_) => TaskStatus::Completed,
90 TaskOutcome::Error { .. } => TaskStatus::Failed,
91 };
92 if let Some(entry) = entries.read().await.get(&task_id_owned) {
93 let _ = entry.status_tx.send(terminal_status);
94 }
95 tokio::task::yield_now().await;
97 entries.write().await.remove(&task_id_owned);
98
99 debug!(task_id = %task_id_owned, status = ?terminal_status, "Task execution completed");
100 });
101
102 Ok(Box::new(TokioTaskHandle { cancellation }))
103 }
104
105 async fn cancel_task(&self, task_id: &str) -> Result<(), TaskStorageError> {
106 if let Some(entry) = self.entries.read().await.get(task_id) {
107 entry.cancellation.cancel();
108 Ok(())
109 } else {
110 Err(TaskStorageError::TaskNotFound(task_id.to_string()))
111 }
112 }
113
114 async fn await_terminal(&self, task_id: &str) -> Option<TaskStatus> {
115 let mut rx = {
116 let entries = self.entries.read().await;
117 entries.get(task_id)?.status_tx.subscribe()
118 };
119 loop {
120 if rx.changed().await.is_err() {
121 return None;
123 }
124 let status = *rx.borrow();
125 if turul_mcp_task_storage::is_terminal(status) {
126 return Some(status);
127 }
128 }
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[tokio::test]
137 async fn test_start_and_complete_task() {
138 let executor = TokioTaskExecutor::new();
139 let handle = executor
140 .start_task(
141 "task-1",
142 Box::new(|| {
143 Box::pin(async { TaskOutcome::Success(serde_json::json!({"result": 42})) })
144 }),
145 )
146 .await
147 .unwrap();
148
149 let status = executor.await_terminal("task-1").await;
151 assert!(matches!(status, Some(TaskStatus::Completed)));
152 assert!(!handle.is_cancelled());
153 }
154
155 #[tokio::test]
156 async fn test_cancel_task() {
157 let executor = TokioTaskExecutor::new();
158 let handle = executor
159 .start_task(
160 "task-2",
161 Box::new(|| {
162 Box::pin(async {
163 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
165 TaskOutcome::Success(serde_json::json!({}))
166 })
167 }),
168 )
169 .await
170 .unwrap();
171
172 executor.cancel_task("task-2").await.unwrap();
174 assert!(handle.is_cancelled());
175 }
176
177 #[tokio::test]
178 async fn test_cancel_nonexistent_task() {
179 let executor = TokioTaskExecutor::new();
180 let result = executor.cancel_task("nonexistent").await;
181 assert!(result.is_err());
182 }
183
184 #[tokio::test]
185 async fn test_await_terminal_nonexistent() {
186 let executor = TokioTaskExecutor::new();
187 let result = executor.await_terminal("nonexistent").await;
188 assert!(result.is_none());
189 }
190}