Skip to main content

turul_mcp_server/task/
runtime.rs

1//! Task Runtime — bridges task storage with runtime execution state.
2//!
3//! `TaskRuntime` combines durable task storage (which persists across restarts)
4//! with a pluggable `TaskExecutor` that manages how task work is actually executed.
5
6use std::sync::Arc;
7
8use tracing::{debug, info};
9
10use turul_mcp_protocol::TaskStatus;
11use turul_mcp_task_storage::{
12    InMemoryTaskStorage, TaskOutcome, TaskRecord, TaskStorage, TaskStorageError,
13};
14
15use crate::task::executor::TaskExecutor;
16use crate::task::tokio_executor::TokioTaskExecutor;
17
18/// Bridges task storage with runtime execution state.
19///
20/// Owns both:
21/// - A `TaskStorage` backend (durable, serializable)
22/// - A `TaskExecutor` for running task work and managing cancellation
23///
24/// Lives in `turul-mcp-server` (not in `turul-mcp-task-storage`) because it combines
25/// backend-agnostic storage with executor-specific runtime primitives.
26pub struct TaskRuntime {
27    /// The durable storage backend
28    storage: Arc<dyn TaskStorage>,
29    /// The task executor for running work
30    executor: Arc<dyn TaskExecutor>,
31    /// Recovery timeout for stuck tasks (milliseconds)
32    recovery_timeout_ms: u64,
33}
34
35impl TaskRuntime {
36    /// Create a new task runtime with the given storage backend and executor.
37    pub fn new(storage: Arc<dyn TaskStorage>, executor: Arc<dyn TaskExecutor>) -> Self {
38        Self {
39            storage,
40            executor,
41            recovery_timeout_ms: 300_000, // 5 minutes default
42        }
43    }
44
45    /// Create a new task runtime with the given storage and the default `TokioTaskExecutor`.
46    pub fn with_default_executor(storage: Arc<dyn TaskStorage>) -> Self {
47        Self::new(storage, Arc::new(TokioTaskExecutor::new()))
48    }
49
50    /// Create with custom recovery timeout.
51    pub fn with_recovery_timeout(mut self, timeout_ms: u64) -> Self {
52        self.recovery_timeout_ms = timeout_ms;
53        self
54    }
55
56    /// Create a new task runtime with in-memory storage and the default `TokioTaskExecutor`.
57    pub fn in_memory() -> Self {
58        Self::with_default_executor(Arc::new(InMemoryTaskStorage::new()))
59    }
60
61    /// Get a reference to the underlying storage.
62    pub fn storage(&self) -> &dyn TaskStorage {
63        self.storage.as_ref()
64    }
65
66    /// Get a shared reference to the storage Arc.
67    pub fn storage_arc(&self) -> Arc<dyn TaskStorage> {
68        Arc::clone(&self.storage)
69    }
70
71    /// Get a reference to the executor.
72    pub fn executor(&self) -> &dyn TaskExecutor {
73        self.executor.as_ref()
74    }
75
76    // === Task Lifecycle ===
77
78    /// Register a new task in storage. Returns the created record.
79    ///
80    /// Does NOT start execution — call `executor().start_task()` separately
81    /// when the work is ready to run.
82    pub async fn register_task(&self, task: TaskRecord) -> Result<TaskRecord, TaskStorageError> {
83        let task_id = task.task_id.clone();
84
85        // Persist in storage
86        let created = self.storage.create_task(task).await?;
87
88        debug!(task_id = %task_id, "Registered task in storage");
89
90        Ok(created)
91    }
92
93    /// Update a task's status in storage.
94    pub async fn update_status(
95        &self,
96        task_id: &str,
97        new_status: TaskStatus,
98        status_message: Option<String>,
99    ) -> Result<TaskRecord, TaskStorageError> {
100        let updated = self
101            .storage
102            .update_task_status(task_id, new_status, status_message)
103            .await?;
104
105        Ok(updated)
106    }
107
108    /// Store a task's result and update status atomically.
109    pub async fn complete_task(
110        &self,
111        task_id: &str,
112        outcome: TaskOutcome,
113        status: TaskStatus,
114        status_message: Option<String>,
115    ) -> Result<(), TaskStorageError> {
116        // Store result first
117        self.storage.store_task_result(task_id, outcome).await?;
118
119        // Then update status
120        self.update_status(task_id, status, status_message).await?;
121
122        Ok(())
123    }
124
125    /// Cancel a task: delegate to executor AND update storage status.
126    pub async fn cancel_task(&self, task_id: &str) -> Result<TaskRecord, TaskStorageError> {
127        // Try to cancel via executor (ignore error if task not in executor — may have already completed)
128        if let Err(e) = self.executor.cancel_task(task_id).await {
129            debug!(task_id = %task_id, error = %e, "Executor cancel returned error (task may have already completed)");
130        }
131
132        // Update storage status
133        self.update_status(
134            task_id,
135            TaskStatus::Cancelled,
136            Some("Cancelled by client".to_string()),
137        )
138        .await
139    }
140
141    /// Wait until a task reaches terminal status via the executor.
142    ///
143    /// Returns `None` if the task is not tracked by the executor (already completed or not in-flight).
144    pub async fn await_terminal(&self, task_id: &str) -> Option<TaskStatus> {
145        self.executor.await_terminal(task_id).await
146    }
147
148    // === Delegation to storage ===
149
150    /// Get a task by ID from storage.
151    pub async fn get_task(&self, task_id: &str) -> Result<Option<TaskRecord>, TaskStorageError> {
152        self.storage.get_task(task_id).await
153    }
154
155    /// Get a task's stored result.
156    pub async fn get_task_result(
157        &self,
158        task_id: &str,
159    ) -> Result<Option<TaskOutcome>, TaskStorageError> {
160        self.storage.get_task_result(task_id).await
161    }
162
163    /// List tasks with pagination.
164    pub async fn list_tasks(
165        &self,
166        cursor: Option<&str>,
167        limit: Option<u32>,
168    ) -> Result<turul_mcp_task_storage::TaskListPage, TaskStorageError> {
169        self.storage.list_tasks(cursor, limit).await
170    }
171
172    /// List tasks for a specific session.
173    pub async fn list_tasks_for_session(
174        &self,
175        session_id: &str,
176        cursor: Option<&str>,
177        limit: Option<u32>,
178    ) -> Result<turul_mcp_task_storage::TaskListPage, TaskStorageError> {
179        self.storage
180            .list_tasks_for_session(session_id, cursor, limit)
181            .await
182    }
183
184    // === Recovery ===
185
186    /// Recover stuck tasks on startup. Called during server initialization.
187    pub async fn recover_stuck_tasks(&self) -> Result<Vec<String>, TaskStorageError> {
188        let recovered = self
189            .storage
190            .recover_stuck_tasks(self.recovery_timeout_ms)
191            .await?;
192
193        if !recovered.is_empty() {
194            info!(
195                count = recovered.len(),
196                timeout_ms = self.recovery_timeout_ms,
197                "Recovered stuck tasks on startup"
198            );
199        }
200
201        Ok(recovered)
202    }
203
204    /// Run periodic maintenance (TTL expiry, cleanup).
205    pub async fn maintenance(&self) -> Result<(), TaskStorageError> {
206        self.storage.maintenance().await
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use turul_mcp_task_storage::{InMemoryTaskStorage, TaskOutcome, TaskRecord};
214
215    fn create_working_task() -> TaskRecord {
216        TaskRecord {
217            task_id: InMemoryTaskStorage::generate_task_id(),
218            session_id: Some("session-1".to_string()),
219            status: TaskStatus::Working,
220            status_message: Some("Processing".to_string()),
221            created_at: chrono::Utc::now().to_rfc3339(),
222            last_updated_at: chrono::Utc::now().to_rfc3339(),
223            ttl: Some(60_000),
224            poll_interval: Some(5_000),
225            original_method: "tools/call".to_string(),
226            original_params: None,
227            result: None,
228            meta: None,
229        }
230    }
231
232    #[tokio::test]
233    async fn test_register_and_get_task() {
234        let runtime = TaskRuntime::in_memory();
235        let task = create_working_task();
236        let task_id = task.task_id.clone();
237
238        let created = runtime.register_task(task).await.unwrap();
239        assert_eq!(created.task_id, task_id);
240        assert_eq!(created.status, TaskStatus::Working);
241
242        let fetched = runtime.get_task(&task_id).await.unwrap().unwrap();
243        assert_eq!(fetched.task_id, task_id);
244    }
245
246    #[tokio::test]
247    async fn test_update_status() {
248        let runtime = TaskRuntime::in_memory();
249        let task = create_working_task();
250        let task_id = task.task_id.clone();
251
252        runtime.register_task(task).await.unwrap();
253
254        let updated = runtime
255            .update_status(&task_id, TaskStatus::Completed, Some("Done".to_string()))
256            .await
257            .unwrap();
258        assert_eq!(updated.status, TaskStatus::Completed);
259    }
260
261    #[tokio::test]
262    async fn test_complete_task() {
263        let runtime = TaskRuntime::in_memory();
264        let task = create_working_task();
265        let task_id = task.task_id.clone();
266
267        runtime.register_task(task).await.unwrap();
268
269        let outcome = TaskOutcome::Success(serde_json::json!({"answer": 42}));
270        runtime
271            .complete_task(&task_id, outcome, TaskStatus::Completed, None)
272            .await
273            .unwrap();
274
275        let result = runtime.get_task_result(&task_id).await.unwrap().unwrap();
276        match result {
277            TaskOutcome::Success(v) => assert_eq!(v["answer"], 42),
278            _ => panic!("Expected Success outcome"),
279        }
280    }
281
282    #[tokio::test]
283    async fn test_cancel_task() {
284        let runtime = TaskRuntime::in_memory();
285        let task = create_working_task();
286        let task_id = task.task_id.clone();
287
288        runtime.register_task(task).await.unwrap();
289
290        let cancelled = runtime.cancel_task(&task_id).await.unwrap();
291        assert_eq!(cancelled.status, TaskStatus::Cancelled);
292    }
293
294    #[tokio::test]
295    async fn test_list_tasks() {
296        let runtime = TaskRuntime::in_memory();
297
298        let task1 = create_working_task();
299        let task2 = create_working_task();
300
301        runtime.register_task(task1).await.unwrap();
302        runtime.register_task(task2).await.unwrap();
303
304        let page = runtime.list_tasks(None, None).await.unwrap();
305        assert_eq!(page.tasks.len(), 2);
306    }
307}