rust_mcp_sdk/utils.rs
1use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
2use crate::schema::schema_utils::{ClientMessages, SdkError};
3use crate::schema::ProtocolVersion;
4use std::cmp::Ordering;
5
6/// A guard type that automatically aborts a Tokio task when dropped.
7///
8/// This ensures that the associated task does not outlive the scope
9/// of this struct, preventing runaway or leaked background tasks.
10///
11pub struct AbortTaskOnDrop {
12 /// The handle used to abort the spawned Tokio task.
13 pub handle: tokio::task::AbortHandle,
14}
15
16impl Drop for AbortTaskOnDrop {
17 fn drop(&mut self) {
18 // Automatically abort the associated task when this guard is dropped.
19 self.handle.abort();
20 }
21}
22
23/// Formats an assertion error message for unsupported capabilities.
24///
25/// Constructs a string describing that a specific entity (e.g., server or client) lacks
26/// support for a required capability, needed for a particular method.
27///
28/// # Arguments
29/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
30/// - `capability`: The name of the unsupported capability or tool.
31/// - `method_name`: The name of the method requiring the capability.
32///
33/// # Returns
34/// A formatted string detailing the unsupported capability error.
35///
36/// # Examples
37/// ```ignore
38/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
39/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
40/// ```
41pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
42 format!("{entity} does not support {capability} (required for {method_name})")
43}
44
45/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
46///
47/// This function compares the provided client and server protocol versions. If they are equal,
48/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
49/// lower or higher than the server version), it returns an error with details about the
50/// incompatible versions.
51///
52/// # Arguments
53///
54/// * `client_protocol_version` - A string slice representing the client's protocol version.
55/// * `server_protocol_version` - A string slice representing the server's protocol version.
56///
57/// # Returns
58///
59/// * `Ok(())` if the versions are equal.
60/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
61/// client and server versions as strings.
62///
63/// # Examples
64///
65/// ```
66/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
67/// use rust_mcp_sdk::error::McpSdkError;
68///
69/// // Compatible versions
70/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
71/// assert!(result.is_ok());
72///
73/// // Incompatible versions (requested < current)
74/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
75/// assert!(matches!(
76/// result,
77/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
78/// if requested == "2024_11_05" && current == "2025_03_26"
79/// ));
80///
81/// // Incompatible versions (requested > current)
82/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
83/// assert!(matches!(
84/// result,
85/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
86/// if requested == "2025_03_26" && current == "2024_11_05"
87/// ));
88/// ```
89#[allow(unused)]
90pub fn ensure_server_protocole_compatibility(
91 client_protocol_version: &str,
92 server_protocol_version: &str,
93) -> SdkResult<()> {
94 match client_protocol_version.cmp(server_protocol_version) {
95 Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
96 kind: ProtocolErrorKind::IncompatibleVersion {
97 requested: client_protocol_version.to_string(),
98 current: server_protocol_version.to_string(),
99 },
100 }),
101 Ordering::Equal => Ok(()),
102 }
103}
104
105/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
106///
107/// This function compares the client and server protocol versions. If the client version is
108/// higher than the server version, it returns an error indicating incompatibility. If the
109/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
110/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
111/// can use the client's version for compatibility.
112///
113/// # Arguments
114///
115/// * `client_protocol_version` - The client's protocol version.
116/// * `server_protocol_version` - The server's protocol version.
117///
118/// # Returns
119///
120/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
121/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
122/// version to use for compatibility.
123/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
124/// the client and server versions as strings.
125///
126/// # Examples
127///
128/// ```
129/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
130/// use rust_mcp_sdk::error::McpSdkError;
131///
132/// // Equal versions
133/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
134/// assert!(matches!(result, Ok(None)));
135///
136/// // Client version lower (downgrade allowed)
137/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
138/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
139///
140/// // Client version higher (incompatible)
141/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
142/// assert!(matches!(
143/// result,
144/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
145/// if requested == "2025_03_26" && current == "2024_11_05"
146/// ));
147/// ```
148#[allow(unused)]
149pub fn enforce_compatible_protocol_version(
150 client_protocol_version: &str,
151 server_protocol_version: &str,
152) -> SdkResult<Option<String>> {
153 match client_protocol_version.cmp(server_protocol_version) {
154 // if client protocol version is higher
155 Ordering::Greater => Err(McpSdkError::Protocol {
156 kind: ProtocolErrorKind::IncompatibleVersion {
157 requested: client_protocol_version.to_string(),
158 current: server_protocol_version.to_string(),
159 },
160 }),
161 Ordering::Equal => Ok(None),
162 Ordering::Less => {
163 // return the same version that was received from the client
164 Ok(Some(client_protocol_version.to_string()))
165 }
166 }
167}
168
169pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
170 let _mcp_protocol_version =
171 ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
172 kind: ProtocolErrorKind::ParseError(err),
173 })?;
174 Ok(())
175}
176
177/// Removes query string and hash fragment from a URL, returning the base path.
178///
179/// # Arguments
180/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
181///
182/// # Returns
183/// A String containing the base path without query parameters or fragment
184/// ```
185#[allow(unused)]
186pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
187 // Split off fragment (if any) and take the first part
188 let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
189
190 // Split off query string (if any) and take the first part
191 let without_query = without_fragment
192 .split_once('?')
193 .map_or(without_fragment, |(path, _)| path);
194
195 // Return the base path
196 if without_query.is_empty() {
197 "/".to_string()
198 } else {
199 without_query.to_string()
200 }
201}
202
203/// Checks if the input string is valid JSON and represents an "initialize" method request.
204pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
205 // Attempt to deserialize the input string into ClientMessages
206 let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
207 return Err(SdkError::bad_request()
208 .with_message("Bad Request: Session not found")
209 .into());
210 };
211
212 match request {
213 ClientMessages::Single(client_message) => {
214 if !client_message.is_initialize_request() {
215 return Err(SdkError::bad_request()
216 .with_message("Bad Request: Session not found")
217 .into());
218 }
219 }
220 ClientMessages::Batch(client_messages) => {
221 let count = client_messages
222 .iter()
223 .filter(|item| item.is_initialize_request())
224 .count();
225 if count > 1 {
226 return Err(SdkError::invalid_request()
227 .with_message("Bad Request: Only one initialization request is allowed")
228 .into());
229 }
230 }
231 };
232
233 Ok(())
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 #[test]
240 fn tets_remove_query_and_hash() {
241 assert_eq!(remove_query_and_hash("/messages"), "/messages");
242 assert_eq!(
243 remove_query_and_hash("/messages?foo=bar&baz=qux"),
244 "/messages"
245 );
246 assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
247 assert_eq!(
248 remove_query_and_hash("/messages?key=value#section2"),
249 "/messages"
250 );
251 assert_eq!(remove_query_and_hash("/"), "/");
252 }
253}