rust_mcp_sdk/task_store.rs
1mod in_memory_task_store;
2use async_trait::async_trait;
3use futures::Stream;
4pub use in_memory_task_store::*;
5use rust_mcp_schema::{
6 schema_utils::{
7 ClientJsonrpcRequest, ResultFromClient, ResultFromServer, ServerJsonrpcRequest,
8 },
9 ListTasksResult, RequestId, Task, TaskStatus, TaskStatusNotificationParams,
10};
11use std::{fmt::Debug, pin::Pin, sync::Arc};
12
13use crate::error::SdkResult;
14
15/// A stream of task status notifications, where each item contains the notification parameters
16/// and an optional session_id
17pub type TaskStatusStream =
18 Pin<Box<dyn Stream<Item = (TaskStatusNotificationParams, Option<String>)> + Send + 'static>>;
19
20#[async_trait]
21pub trait TaskStatusSignal: Send + Sync + 'static {
22 /// Publish a status change event
23 async fn publish_status_change(
24 &self,
25 event: TaskStatusNotificationParams,
26 session_id: Option<&String>,
27 );
28 /// Return a new independent stream of events
29 fn subscribe(&self) -> Option<TaskStatusStream> {
30 None
31 }
32}
33
34pub type TaskStatusCallback = Box<dyn Fn(&Task, Option<&String>) + Send + Sync + 'static>;
35
36pub struct CreateTaskOptions {
37 ///Actual retention duration from creation in milliseconds, None for unlimited.
38 pub ttl: Option<i64>,
39 pub poll_interval: ::std::option::Option<i64>,
40 ///Additional context to pass to the task store.
41 pub meta: Option<serde_json::Map<String, serde_json::Value>>,
42 // pub context: Option<HashMap<String, Box<dyn Any + Send>>>,
43}
44
45pub struct TaskCreator<Req, Res>
46where
47 Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
48 Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
49{
50 pub request_id: RequestId,
51 pub request: Req,
52 pub session_id: Option<String>,
53 pub task_store: Arc<dyn TaskStore<Req, Res>>,
54}
55
56impl<Req, Res> TaskCreator<Req, Res>
57where
58 Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize + 'static,
59 Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize + 'static,
60{
61 pub async fn create_task(self, task_params: CreateTaskOptions) -> Task {
62 self.task_store
63 .create_task(task_params, self.request_id, self.request, self.session_id)
64 .await
65 }
66}
67
68/// A trait for storing and managing long-running tasks, storing and retrieving task state and results.
69/// Tasks were introduced in MCP Protocol version 2025-11-25.
70/// For more details, see: <https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks>
71#[async_trait]
72pub trait TaskStore<Req, Res>: Send + Sync + TaskStatusSignal
73where
74 Req: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
75 Res: Debug + Clone + serde::Deserialize<'static> + serde::Serialize,
76{
77 /// Creates a new task with the given creation parameters and original request.
78 /// The implementation must generate a unique taskId and createdAt timestamp.
79 ///
80 /// TTL Management:
81 /// - The implementation receives the TTL suggested by the requestor via taskParams.ttl
82 /// - The implementation MAY override the requested TTL (e.g., to enforce limits)
83 /// - The actual TTL used MUST be returned in the Task object
84 /// - Null TTL indicates unlimited task lifetime (no automatic cleanup)
85 /// - Cleanup SHOULD occur automatically after TTL expires, regardless of task status
86 ///
87 /// # Arguments
88 /// * `task_params` - The task creation parameters from the request (ttl, pollInterval)
89 /// * `request_id` - The JSON-RPC request ID
90 /// * `request` - The original request that triggered task creation
91 /// * `session_id` - Optional session ID for binding the task to a specific session
92 ///
93 /// # Returns
94 /// The created task object
95 async fn create_task(
96 &self,
97 task_params: CreateTaskOptions,
98 request_id: RequestId,
99 request: Req,
100 session_id: Option<String>,
101 ) -> Task;
102
103 /// Begins active polling for task status updates in requestor mode.
104 /// This method spawns a long-running background task that drives the polling
105 /// schedule for all tasks managed by this store. It repeatedly invokes the
106 /// provided `get_task_callback` to query the **receiver** for the current status
107 /// of pending tasks.
108 ///
109 /// The polling loop should respect the `pollInterval` suggested by the receiver and
110 /// dynamically adjusts accordingly. Each task is polled until it reaches a
111 /// terminal state (`Completed`, `Failed`, or `Cancelled`), at which point it
112 /// is removed from the active polling schedule.
113 ///
114 /// This mechanism is used when the local side acts as the **requestor** in the
115 /// Model Context Protocol task flow — i.e., when a task-augmented request has
116 /// been sent to the remote side (the receiver) and the local side needs to
117 /// actively monitor progress via repeated `tasks/get` calls.
118 fn start_task_polling(&self, get_task_callback: TaskStatusPoller) -> SdkResult<()>;
119
120 /// Waits asynchronously for the result of a task.
121 ///
122 /// # Arguments
123 ///
124 /// * `task_id` - The unique identifier of the task whose result is awaited.
125 /// * `session_id` - Optional session identifier used to disambiguate or scope the task.
126 ///
127 /// # Returns
128 ///
129 /// * `Ok(Res)` if the task completes successfully and sends its result.
130 /// * `Err(SdkError)` if:
131 /// - the task does not exist,
132 /// - the task result channel is dropped before sending,
133 /// - or an internal error occurs.
134 ///
135 /// # Errors
136 ///
137 /// Returns an internal RPC error if the task does not exist or if the sender
138 /// side of the oneshot channel is dropped before producing a result.
139 async fn wait_for_task_result(
140 &self,
141 task_id: &str,
142 session_id: Option<String>,
143 ) -> SdkResult<(TaskStatus, Option<Res>)>;
144
145 /// Gets the current status of a task.
146 ///
147 /// # Arguments
148 /// * `task_id` - The task identifier
149 /// * `session_id` - Optional session ID for binding the query to a specific session
150 ///
151 /// # Returns
152 /// The task object, or None if it does not exist
153 async fn get_task(&self, task_id: &str, session_id: Option<String>) -> Option<Task>;
154
155 /// Stores the result of a task and sets its final status.
156 ///
157 /// # Arguments
158 /// * `task_id` - The task identifier
159 /// * `status` - The final status: 'completed' for success, 'failed' for errors
160 /// * `result` - The result to store
161 /// * `session_id` - Optional session ID for binding the operation to a specific session
162 async fn store_task_result(
163 &self,
164 task_id: &str,
165 status: TaskStatus,
166 result: Res,
167 session_id: Option<&String>,
168 ) -> ();
169
170 /// Retrieves the stored result of a task.
171 ///
172 /// # Arguments
173 /// * `task_id` - The task identifier
174 /// * `session_id` - Optional session ID for binding the query to a specific session
175 ///
176 /// # Returns
177 /// The stored result
178 async fn get_task_result(&self, task_id: &str, session_id: Option<String>) -> Option<Res>;
179
180 /// Updates a task's status (e.g., to 'cancelled', 'failed', 'completed').
181 ///
182 /// # Arguments
183 /// * `task_id` - The task identifier
184 /// * `status` - The new status
185 /// * `status_message` - Optional diagnostic message for failed tasks or other status information
186 /// * `session_id` - Optional session ID for binding the operation to a specific session
187 async fn update_task_status(
188 &self,
189 task_id: &str,
190 status: TaskStatus,
191 status_message: Option<String>,
192 session_id: Option<String>,
193 ) -> ();
194
195 /// Lists tasks, optionally starting from a pagination cursor.
196 ///
197 /// # Arguments
198 /// * `cursor` - Optional cursor for pagination
199 /// * `session_id` - Optional session ID for binding the query to a specific session
200 ///
201 /// # Returns
202 /// An object containing the tasks array and an optional nextCursor
203 async fn list_tasks(
204 &self,
205 cursor: Option<String>,
206 session_id: Option<String>,
207 ) -> ListTasksResult;
208}
209
210pub type ServerTaskCreator = TaskCreator<ClientJsonrpcRequest, ResultFromServer>;
211pub type ClientTaskCreator = TaskCreator<ServerJsonrpcRequest, ResultFromClient>;
212
213pub type ServerTaskStore = dyn TaskStore<ClientJsonrpcRequest, ResultFromServer>;
214pub type ClientTaskStore = dyn TaskStore<ServerJsonrpcRequest, ResultFromClient>;