Skip to main content

rust_mcp_sdk/
task_store.rs

1mod in_memory_task_store;
2use async_trait::async_trait;
3use futures::Stream;
4pub use in_memory_task_store::*;
5use rust_mcp_schema::{
6    schema_utils::{
7        ClientJsonrpcRequest, ResultFromClient, ResultFromServer, ServerJsonrpcRequest,
8    },
9    ListTasksResult, RequestId, Task, TaskStatus, TaskStatusNotificationParams,
10};
11use std::{fmt::Debug, pin::Pin, sync::Arc};
12
13use crate::error::SdkResult;
14
15/// A stream of task status notifications, where each item contains the notification parameters
16/// and an optional session_id
17pub type TaskStatusStream =
18    Pin<Box<dyn Stream<Item = (TaskStatusNotificationParams, Option<String>)> + Send + 'static>>;
19
20#[async_trait]
21pub trait TaskStatusSignal: Send + Sync + 'static {
22    /// Publish a status change event
23    async fn publish_status_change(
24        &self,
25        event: TaskStatusNotificationParams,
26        session_id: Option<&String>,
27    );
28    /// Return a new independent stream of events
29    fn subscribe(&self) -> Option<TaskStatusStream> {
30        None
31    }
32}
33
34pub type TaskStatusCallback = Box<dyn Fn(&Task, Option<&String>) + Send + Sync + 'static>;
35
36pub struct CreateTaskOptions {
37    ///Actual retention duration from creation in milliseconds, None for unlimited.
38    pub ttl: Option<i64>,
39    pub poll_interval: ::std::option::Option<i64>,
40    ///Additional context to pass to the task store.
41    pub meta: Option<serde_json::Map<String, serde_json::Value>>,
42    // pub context: Option<HashMap<String, Box<dyn Any + Send>>>,
43}
44
45pub struct TaskCreator<Req, Res>
46where
47    Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
48    Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
49{
50    pub request_id: RequestId,
51    pub request: Req,
52    pub session_id: Option<String>,
53    pub task_store: Arc<dyn TaskStore<Req, Res>>,
54}
55
56impl<Req, Res> TaskCreator<Req, Res>
57where
58    Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize + 'static,
59    Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize + 'static,
60{
61    pub async fn create_task(self, task_params: CreateTaskOptions) -> Task {
62        self.task_store
63            .create_task(task_params, self.request_id, self.request, self.session_id)
64            .await
65    }
66}
67
68/// A trait for storing and managing long-running tasks, storing and retrieving task state and results.
69/// Tasks were introduced in MCP Protocol version 2025-11-25.
70/// For more details, see: <https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks>
71#[async_trait]
72pub trait TaskStore<Req, Res>: Send + Sync + TaskStatusSignal
73where
74    Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
75    Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
76{
77    /// Creates a new task with the given creation parameters and original request.
78    /// The implementation must generate a unique taskId and createdAt timestamp.
79    ///
80    /// TTL Management:
81    /// - The implementation receives the TTL suggested by the requestor via taskParams.ttl
82    /// - The implementation MAY override the requested TTL (e.g., to enforce limits)
83    /// - The actual TTL used MUST be returned in the Task object
84    /// - Null TTL indicates unlimited task lifetime (no automatic cleanup)
85    /// - Cleanup SHOULD occur automatically after TTL expires, regardless of task status
86    ///
87    /// # Arguments
88    /// * `task_params` - The task creation parameters from the request (ttl, pollInterval)
89    /// * `request_id` - The JSON-RPC request ID
90    /// * `request` - The original request that triggered task creation
91    /// * `session_id` - Optional session ID for binding the task to a specific session
92    ///
93    /// # Returns
94    /// The created task object
95    async fn create_task(
96        &self,
97        task_params: CreateTaskOptions,
98        request_id: RequestId,
99        request: Req,
100        session_id: Option<String>,
101    ) -> Task;
102
103    /// Begins active polling for task status updates in requestor mode.
104    /// This method spawns a long-running background task that drives the polling
105    /// schedule for all tasks managed by this store. It repeatedly invokes the
106    /// provided `get_task_callback` to query the **receiver** for the current status
107    /// of pending tasks.
108    ///
109    /// The polling loop should respect the `pollInterval` suggested by the receiver and
110    /// dynamically adjusts accordingly. Each task is polled until it reaches a
111    /// terminal state (`Completed`, `Failed`, or `Cancelled`), at which point it
112    /// is removed from the active polling schedule.
113    ///
114    /// This mechanism is used when the local side acts as the **requestor** in the
115    /// Model Context Protocol task flow — i.e., when a task-augmented request has
116    /// been sent to the remote side (the receiver) and the local side needs to
117    /// actively monitor progress via repeated `tasks/get` calls.
118    fn start_task_polling(&self, get_task_callback: TaskStatusPoller) -> SdkResult<()>;
119
120    /// Waits asynchronously for the result of a task.
121    ///
122    /// # Arguments
123    ///
124    /// * `task_id` - The unique identifier of the task whose result is awaited.
125    /// * `session_id` - Optional session identifier used to disambiguate or scope the task.
126    ///
127    /// # Returns
128    ///
129    /// * `Ok(Res)` if the task completes successfully and sends its result.
130    /// * `Err(SdkError)` if:
131    ///   - the task does not exist,
132    ///   - the task result channel is dropped before sending,
133    ///   - or an internal error occurs.
134    ///
135    /// # Errors
136    ///
137    /// Returns an internal RPC error if the task does not exist or if the sender
138    /// side of the oneshot channel is dropped before producing a result.
139    async fn wait_for_task_result(
140        &self,
141        task_id: &str,
142        session_id: Option<String>,
143    ) -> SdkResult<(TaskStatus, Option<Res>)>;
144
145    /// Gets the current status of a task.
146    ///
147    /// # Arguments
148    /// * `task_id` - The task identifier
149    /// * `session_id` - Optional session ID for binding the query to a specific session
150    ///
151    /// # Returns
152    /// The task object, or None if it does not exist
153    async fn get_task(&self, task_id: &str, session_id: Option<String>) -> Option<Task>;
154
155    /// Stores the result of a task and sets its final status.
156    ///
157    /// # Arguments
158    /// * `task_id` - The task identifier
159    /// * `status` - The final status: 'completed' for success, 'failed' for errors
160    /// * `result` - The result to store
161    /// * `session_id` - Optional session ID for binding the operation to a specific session
162    async fn store_task_result(
163        &self,
164        task_id: &str,
165        status: TaskStatus,
166        result: Res,
167        session_id: Option<&String>,
168    ) -> ();
169
170    /// Retrieves the stored result of a task.
171    ///
172    /// # Arguments
173    /// * `task_id` - The task identifier
174    /// * `session_id` - Optional session ID for binding the query to a specific session
175    ///
176    /// # Returns
177    /// The stored result
178    async fn get_task_result(&self, task_id: &str, session_id: Option<String>) -> Option<Res>;
179
180    /// Updates a task's status (e.g., to 'cancelled', 'failed', 'completed').
181    ///
182    /// # Arguments
183    /// * `task_id` - The task identifier
184    /// * `status` - The new status
185    /// * `status_message` - Optional diagnostic message for failed tasks or other status information
186    /// * `session_id` - Optional session ID for binding the operation to a specific session
187    async fn update_task_status(
188        &self,
189        task_id: &str,
190        status: TaskStatus,
191        status_message: Option<String>,
192        session_id: Option<String>,
193    ) -> ();
194
195    /// Lists tasks, optionally starting from a pagination cursor.
196    ///
197    /// # Arguments
198    /// * `cursor` - Optional cursor for pagination
199    /// * `session_id` - Optional session ID for binding the query to a specific session
200    ///
201    /// # Returns
202    /// An object containing the tasks array and an optional nextCursor
203    async fn list_tasks(
204        &self,
205        cursor: Option<String>,
206        session_id: Option<String>,
207    ) -> ListTasksResult;
208}
209
210pub type ServerTaskCreator = TaskCreator<ClientJsonrpcRequest, ResultFromServer>;
211pub type ClientTaskCreator = TaskCreator<ServerJsonrpcRequest, ResultFromClient>;
212
213pub type ServerTaskStore = dyn TaskStore<ClientJsonrpcRequest, ResultFromServer>;
214pub type ClientTaskStore = dyn TaskStore<ServerJsonrpcRequest, ResultFromClient>;