1use std::sync::Arc;
7
8use tracing::{debug, info};
9
10use turul_mcp_protocol::TaskStatus;
11use turul_mcp_task_storage::{
12 InMemoryTaskStorage, TaskOutcome, TaskRecord, TaskStorage, TaskStorageError,
13};
14
15use crate::task::executor::TaskExecutor;
16use crate::task::tokio_executor::TokioTaskExecutor;
17
18pub struct TaskRuntime {
27 storage: Arc<dyn TaskStorage>,
29 executor: Arc<dyn TaskExecutor>,
31 recovery_timeout_ms: u64,
33}
34
35impl TaskRuntime {
36 pub fn new(storage: Arc<dyn TaskStorage>, executor: Arc<dyn TaskExecutor>) -> Self {
38 Self {
39 storage,
40 executor,
41 recovery_timeout_ms: 300_000, }
43 }
44
45 pub fn with_default_executor(storage: Arc<dyn TaskStorage>) -> Self {
47 Self::new(storage, Arc::new(TokioTaskExecutor::new()))
48 }
49
50 pub fn with_recovery_timeout(mut self, timeout_ms: u64) -> Self {
52 self.recovery_timeout_ms = timeout_ms;
53 self
54 }
55
56 pub fn in_memory() -> Self {
58 Self::with_default_executor(Arc::new(InMemoryTaskStorage::new()))
59 }
60
61 pub fn storage(&self) -> &dyn TaskStorage {
63 self.storage.as_ref()
64 }
65
66 pub fn storage_arc(&self) -> Arc<dyn TaskStorage> {
68 Arc::clone(&self.storage)
69 }
70
71 pub fn executor(&self) -> &dyn TaskExecutor {
73 self.executor.as_ref()
74 }
75
76 pub async fn register_task(&self, task: TaskRecord) -> Result<TaskRecord, TaskStorageError> {
83 let task_id = task.task_id.clone();
84
85 let created = self.storage.create_task(task).await?;
87
88 debug!(task_id = %task_id, "Registered task in storage");
89
90 Ok(created)
91 }
92
93 pub async fn update_status(
95 &self,
96 task_id: &str,
97 new_status: TaskStatus,
98 status_message: Option<String>,
99 ) -> Result<TaskRecord, TaskStorageError> {
100 let updated = self
101 .storage
102 .update_task_status(task_id, new_status, status_message)
103 .await?;
104
105 Ok(updated)
106 }
107
108 pub async fn complete_task(
110 &self,
111 task_id: &str,
112 outcome: TaskOutcome,
113 status: TaskStatus,
114 status_message: Option<String>,
115 ) -> Result<(), TaskStorageError> {
116 self.storage.store_task_result(task_id, outcome).await?;
118
119 self.update_status(task_id, status, status_message).await?;
121
122 Ok(())
123 }
124
125 pub async fn cancel_task(&self, task_id: &str) -> Result<TaskRecord, TaskStorageError> {
127 if let Err(e) = self.executor.cancel_task(task_id).await {
129 debug!(task_id = %task_id, error = %e, "Executor cancel returned error (task may have already completed)");
130 }
131
132 self.update_status(
134 task_id,
135 TaskStatus::Cancelled,
136 Some("Cancelled by client".to_string()),
137 )
138 .await
139 }
140
141 pub async fn await_terminal(&self, task_id: &str) -> Option<TaskStatus> {
145 self.executor.await_terminal(task_id).await
146 }
147
148 pub async fn get_task(&self, task_id: &str) -> Result<Option<TaskRecord>, TaskStorageError> {
152 self.storage.get_task(task_id).await
153 }
154
155 pub async fn get_task_result(
157 &self,
158 task_id: &str,
159 ) -> Result<Option<TaskOutcome>, TaskStorageError> {
160 self.storage.get_task_result(task_id).await
161 }
162
163 pub async fn list_tasks(
165 &self,
166 cursor: Option<&str>,
167 limit: Option<u32>,
168 ) -> Result<turul_mcp_task_storage::TaskListPage, TaskStorageError> {
169 self.storage.list_tasks(cursor, limit).await
170 }
171
172 pub async fn list_tasks_for_session(
174 &self,
175 session_id: &str,
176 cursor: Option<&str>,
177 limit: Option<u32>,
178 ) -> Result<turul_mcp_task_storage::TaskListPage, TaskStorageError> {
179 self.storage
180 .list_tasks_for_session(session_id, cursor, limit)
181 .await
182 }
183
184 pub async fn recover_stuck_tasks(&self) -> Result<Vec<String>, TaskStorageError> {
188 let recovered = self
189 .storage
190 .recover_stuck_tasks(self.recovery_timeout_ms)
191 .await?;
192
193 if !recovered.is_empty() {
194 info!(
195 count = recovered.len(),
196 timeout_ms = self.recovery_timeout_ms,
197 "Recovered stuck tasks on startup"
198 );
199 }
200
201 Ok(recovered)
202 }
203
204 pub async fn maintenance(&self) -> Result<(), TaskStorageError> {
206 self.storage.maintenance().await
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use turul_mcp_task_storage::{InMemoryTaskStorage, TaskOutcome, TaskRecord};
214
215 fn create_working_task() -> TaskRecord {
216 TaskRecord {
217 task_id: InMemoryTaskStorage::generate_task_id(),
218 session_id: Some("session-1".to_string()),
219 status: TaskStatus::Working,
220 status_message: Some("Processing".to_string()),
221 created_at: chrono::Utc::now().to_rfc3339(),
222 last_updated_at: chrono::Utc::now().to_rfc3339(),
223 ttl: Some(60_000),
224 poll_interval: Some(5_000),
225 original_method: "tools/call".to_string(),
226 original_params: None,
227 result: None,
228 meta: None,
229 }
230 }
231
232 #[tokio::test]
233 async fn test_register_and_get_task() {
234 let runtime = TaskRuntime::in_memory();
235 let task = create_working_task();
236 let task_id = task.task_id.clone();
237
238 let created = runtime.register_task(task).await.unwrap();
239 assert_eq!(created.task_id, task_id);
240 assert_eq!(created.status, TaskStatus::Working);
241
242 let fetched = runtime.get_task(&task_id).await.unwrap().unwrap();
243 assert_eq!(fetched.task_id, task_id);
244 }
245
246 #[tokio::test]
247 async fn test_update_status() {
248 let runtime = TaskRuntime::in_memory();
249 let task = create_working_task();
250 let task_id = task.task_id.clone();
251
252 runtime.register_task(task).await.unwrap();
253
254 let updated = runtime
255 .update_status(&task_id, TaskStatus::Completed, Some("Done".to_string()))
256 .await
257 .unwrap();
258 assert_eq!(updated.status, TaskStatus::Completed);
259 }
260
261 #[tokio::test]
262 async fn test_complete_task() {
263 let runtime = TaskRuntime::in_memory();
264 let task = create_working_task();
265 let task_id = task.task_id.clone();
266
267 runtime.register_task(task).await.unwrap();
268
269 let outcome = TaskOutcome::Success(serde_json::json!({"answer": 42}));
270 runtime
271 .complete_task(&task_id, outcome, TaskStatus::Completed, None)
272 .await
273 .unwrap();
274
275 let result = runtime.get_task_result(&task_id).await.unwrap().unwrap();
276 match result {
277 TaskOutcome::Success(v) => assert_eq!(v["answer"], 42),
278 _ => panic!("Expected Success outcome"),
279 }
280 }
281
282 #[tokio::test]
283 async fn test_cancel_task() {
284 let runtime = TaskRuntime::in_memory();
285 let task = create_working_task();
286 let task_id = task.task_id.clone();
287
288 runtime.register_task(task).await.unwrap();
289
290 let cancelled = runtime.cancel_task(&task_id).await.unwrap();
291 assert_eq!(cancelled.status, TaskStatus::Cancelled);
292 }
293
294 #[tokio::test]
295 async fn test_list_tasks() {
296 let runtime = TaskRuntime::in_memory();
297
298 let task1 = create_working_task();
299 let task2 = create_working_task();
300
301 runtime.register_task(task1).await.unwrap();
302 runtime.register_task(task2).await.unwrap();
303
304 let page = runtime.list_tasks(None, None).await.unwrap();
305 assert_eq!(page.tasks.len(), 2);
306 }
307}