rust_mcp_sdk/utils.rs
1use crate::schema::schema_utils::{ClientMessages, SdkError};
2
3use crate::error::{McpSdkError, SdkResult};
4use crate::schema::ProtocolVersion;
5use std::cmp::Ordering;
6
7/// A guard type that automatically aborts a Tokio task when dropped.
8///
9/// This ensures that the associated task does not outlive the scope
10/// of this struct, preventing runaway or leaked background tasks.
11///
12pub struct AbortTaskOnDrop {
13 /// The handle used to abort the spawned Tokio task.
14 pub handle: tokio::task::AbortHandle,
15}
16
17impl Drop for AbortTaskOnDrop {
18 fn drop(&mut self) {
19 // Automatically abort the associated task when this guard is dropped.
20 self.handle.abort();
21 }
22}
23
24/// Formats an assertion error message for unsupported capabilities.
25///
26/// Constructs a string describing that a specific entity (e.g., server or client) lacks
27/// support for a required capability, needed for a particular method.
28///
29/// # Arguments
30/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
31/// - `capability`: The name of the unsupported capability or tool.
32/// - `method_name`: The name of the method requiring the capability.
33///
34/// # Returns
35/// A formatted string detailing the unsupported capability error.
36///
37/// # Examples
38/// ```ignore
39/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
40/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
41/// ```
42pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
43 format!("{entity} does not support {capability} (required for {method_name})")
44}
45
46/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
47///
48/// This function compares the provided client and server protocol versions. If they are equal,
49/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
50/// lower or higher than the server version), it returns an error with details about the
51/// incompatible versions.
52///
53/// # Arguments
54///
55/// * `client_protocol_version` - A string slice representing the client's protocol version.
56/// * `server_protocol_version` - A string slice representing the server's protocol version.
57///
58/// # Returns
59///
60/// * `Ok(())` if the versions are equal.
61/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
62/// client and server versions as strings.
63///
64/// # Examples
65///
66/// ```
67/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
68/// use rust_mcp_sdk::error::McpSdkError;
69///
70/// // Compatible versions
71/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
72/// assert!(result.is_ok());
73///
74/// // Incompatible versions (client < server)
75/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
76/// assert!(matches!(
77/// result,
78/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
79/// if client == "2024_11_05" && server == "2025_03_26"
80/// ));
81///
82/// // Incompatible versions (client > server)
83/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
84/// assert!(matches!(
85/// result,
86/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
87/// if client == "2025_03_26" && server == "2024_11_05"
88/// ));
89/// ```
90#[allow(unused)]
91pub fn ensure_server_protocole_compatibility(
92 client_protocol_version: &str,
93 server_protocol_version: &str,
94) -> SdkResult<()> {
95 match client_protocol_version.cmp(server_protocol_version) {
96 Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
97 client_protocol_version.to_string(),
98 server_protocol_version.to_string(),
99 )),
100 Ordering::Equal => Ok(()),
101 }
102}
103
104/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
105///
106/// This function compares the client and server protocol versions. If the client version is
107/// higher than the server version, it returns an error indicating incompatibility. If the
108/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
109/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
110/// can use the client's version for compatibility.
111///
112/// # Arguments
113///
114/// * `client_protocol_version` - The client's protocol version.
115/// * `server_protocol_version` - The server's protocol version.
116///
117/// # Returns
118///
119/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
120/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
121/// version to use for compatibility.
122/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
123/// the client and server versions as strings.
124///
125/// # Examples
126///
127/// ```
128/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
129/// use rust_mcp_sdk::error::McpSdkError;
130///
131/// // Equal versions
132/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
133/// assert!(matches!(result, Ok(None)));
134///
135/// // Client version lower (downgrade allowed)
136/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
137/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
138///
139/// // Client version higher (incompatible)
140/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
141/// assert!(matches!(
142/// result,
143/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
144/// if client == "2025_03_26" && server == "2024_11_05"
145/// ));
146/// ```
147#[allow(unused)]
148pub fn enforce_compatible_protocol_version(
149 client_protocol_version: &str,
150 server_protocol_version: &str,
151) -> SdkResult<Option<String>> {
152 match client_protocol_version.cmp(server_protocol_version) {
153 // if client protocol version is higher
154 Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
155 client_protocol_version.to_string(),
156 server_protocol_version.to_string(),
157 )),
158 Ordering::Equal => Ok(None),
159 Ordering::Less => {
160 // return the same version that was received from the client
161 Ok(Some(client_protocol_version.to_string()))
162 }
163 }
164}
165
166pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
167 let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?;
168 Ok(())
169}
170
171/// Removes query string and hash fragment from a URL, returning the base path.
172///
173/// # Arguments
174/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
175///
176/// # Returns
177/// A String containing the base path without query parameters or fragment
178/// ```
179#[allow(unused)]
180pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
181 // Split off fragment (if any) and take the first part
182 let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
183
184 // Split off query string (if any) and take the first part
185 let without_query = without_fragment
186 .split_once('?')
187 .map_or(without_fragment, |(path, _)| path);
188
189 // Return the base path
190 if without_query.is_empty() {
191 "/".to_string()
192 } else {
193 without_query.to_string()
194 }
195}
196
197/// Checks if the input string is valid JSON and represents an "initialize" method request.
198pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
199 // Attempt to deserialize the input string into ClientMessages
200 let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
201 return Err(SdkError::bad_request()
202 .with_message("Bad Request: Session not found")
203 .into());
204 };
205
206 match request {
207 ClientMessages::Single(client_message) => {
208 if !client_message.is_initialize_request() {
209 return Err(SdkError::bad_request()
210 .with_message("Bad Request: Session not found")
211 .into());
212 }
213 }
214 ClientMessages::Batch(client_messages) => {
215 let count = client_messages
216 .iter()
217 .filter(|item| item.is_initialize_request())
218 .count();
219 if count > 1 {
220 return Err(SdkError::invalid_request()
221 .with_message("Bad Request: Only one initialization request is allowed")
222 .into());
223 }
224 }
225 };
226
227 Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 #[test]
234 fn tets_remove_query_and_hash() {
235 assert_eq!(remove_query_and_hash("/messages"), "/messages");
236 assert_eq!(
237 remove_query_and_hash("/messages?foo=bar&baz=qux"),
238 "/messages"
239 );
240 assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
241 assert_eq!(
242 remove_query_and_hash("/messages?key=value#section2"),
243 "/messages"
244 );
245 assert_eq!(remove_query_and_hash("/"), "/");
246 }
247}