rust_mcp_sdk/
utils.rs

1use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
2use crate::schema::schema_utils::{ClientMessages, SdkError};
3use crate::schema::ProtocolVersion;
4use std::cmp::Ordering;
5
6/// A guard type that automatically aborts a Tokio task when dropped.
7///
8/// This ensures that the associated task does not outlive the scope
9/// of this struct, preventing runaway or leaked background tasks.
10///
11pub struct AbortTaskOnDrop {
12    /// The handle used to abort the spawned Tokio task.
13    pub handle: tokio::task::AbortHandle,
14}
15
16impl Drop for AbortTaskOnDrop {
17    fn drop(&mut self) {
18        // Automatically abort the associated task when this guard is dropped.
19        self.handle.abort();
20    }
21}
22
23/// Formats an assertion error message for unsupported capabilities.
24///
25/// Constructs a string describing that a specific entity (e.g., server or client) lacks
26/// support for a required capability, needed for a particular method.
27///
28/// # Arguments
29/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
30/// - `capability`: The name of the unsupported capability or tool.
31/// - `method_name`: The name of the method requiring the capability.
32///
33/// # Returns
34/// A formatted string detailing the unsupported capability error.
35///
36/// # Examples
37/// ```ignore
38/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
39/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
40/// ```
41pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
42    format!("{entity} does not support {capability} (required for {method_name})")
43}
44
45/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
46///
47/// This function compares the provided client and server protocol versions. If they are equal,
48/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
49/// lower or higher than the server version), it returns an error with details about the
50/// incompatible versions.
51///
52/// # Arguments
53///
54/// * `client_protocol_version` - A string slice representing the client's protocol version.
55/// * `server_protocol_version` - A string slice representing the server's protocol version.
56///
57/// # Returns
58///
59/// * `Ok(())` if the versions are equal.
60/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
61///   client and server versions as strings.
62///
63/// # Examples
64///
65/// ```
66/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
67/// use rust_mcp_sdk::error::McpSdkError;
68///
69/// // Compatible versions
70/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
71/// assert!(result.is_ok());
72///
73/// // Incompatible versions (requested < current)
74/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
75/// assert!(matches!(
76///     result,
77///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
78///     if requested == "2024_11_05" && current == "2025_03_26"
79/// ));
80///
81/// // Incompatible versions (requested > current)
82/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
83/// assert!(matches!(
84///     result,
85///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
86///     if requested == "2025_03_26" && current == "2024_11_05"
87/// ));
88/// ```
89#[allow(unused)]
90pub fn ensure_server_protocole_compatibility(
91    client_protocol_version: &str,
92    server_protocol_version: &str,
93) -> SdkResult<()> {
94    match client_protocol_version.cmp(server_protocol_version) {
95        Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
96            kind: ProtocolErrorKind::IncompatibleVersion {
97                requested: client_protocol_version.to_string(),
98                current: server_protocol_version.to_string(),
99            },
100        }),
101        Ordering::Equal => Ok(()),
102    }
103}
104
105/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
106///
107/// This function compares the client and server protocol versions. If the client version is
108/// higher than the server version, it returns an error indicating incompatibility. If the
109/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
110/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
111/// can use the client's version for compatibility.
112///
113/// # Arguments
114///
115/// * `client_protocol_version` - The client's protocol version.
116/// * `server_protocol_version` - The server's protocol version.
117///
118/// # Returns
119///
120/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
121/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
122///   version to use for compatibility.
123/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
124///   the client and server versions as strings.
125///
126/// # Examples
127///
128/// ```
129/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
130/// use rust_mcp_sdk::error::McpSdkError;
131///
132/// // Equal versions
133/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
134/// assert!(matches!(result, Ok(None)));
135///
136/// // Client version lower (downgrade allowed)
137/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
138/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
139///
140/// // Client version higher (incompatible)
141/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
142/// assert!(matches!(
143///     result,
144///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
145///     if requested == "2025_03_26" && current == "2024_11_05"
146/// ));
147/// ```
148#[allow(unused)]
149pub fn enforce_compatible_protocol_version(
150    client_protocol_version: &str,
151    server_protocol_version: &str,
152) -> SdkResult<Option<String>> {
153    match client_protocol_version.cmp(server_protocol_version) {
154        // if client protocol version is higher
155        Ordering::Greater => Err(McpSdkError::Protocol {
156            kind: ProtocolErrorKind::IncompatibleVersion {
157                requested: client_protocol_version.to_string(),
158                current: server_protocol_version.to_string(),
159            },
160        }),
161        Ordering::Equal => Ok(None),
162        Ordering::Less => {
163            // return the same version that was received from the client
164            Ok(Some(client_protocol_version.to_string()))
165        }
166    }
167}
168
169pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
170    let _mcp_protocol_version =
171        ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
172            kind: ProtocolErrorKind::ParseError(err),
173        })?;
174    Ok(())
175}
176
177/// Removes query string and hash fragment from a URL, returning the base path.
178///
179/// # Arguments
180/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
181///
182/// # Returns
183/// A String containing the base path without query parameters or fragment
184/// ```
185#[allow(unused)]
186pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
187    // Split off fragment (if any) and take the first part
188    let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
189
190    // Split off query string (if any) and take the first part
191    let without_query = without_fragment
192        .split_once('?')
193        .map_or(without_fragment, |(path, _)| path);
194
195    // Return the base path
196    if without_query.is_empty() {
197        "/".to_string()
198    } else {
199        without_query.to_string()
200    }
201}
202
203/// Checks if the input string is valid JSON and represents an "initialize" method request.
204pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
205    // Attempt to deserialize the input string into ClientMessages
206    let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
207        return Err(SdkError::bad_request()
208            .with_message("Bad Request: Session not found")
209            .into());
210    };
211
212    match request {
213        ClientMessages::Single(client_message) => {
214            if !client_message.is_initialize_request() {
215                return Err(SdkError::bad_request()
216                    .with_message("Bad Request: Session not found")
217                    .into());
218            }
219        }
220        ClientMessages::Batch(client_messages) => {
221            let count = client_messages
222                .iter()
223                .filter(|item| item.is_initialize_request())
224                .count();
225            if count > 1 {
226                return Err(SdkError::invalid_request()
227                    .with_message("Bad Request: Only one initialization request is allowed")
228                    .into());
229            }
230        }
231    };
232
233    Ok(())
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    #[test]
240    fn tets_remove_query_and_hash() {
241        assert_eq!(remove_query_and_hash("/messages"), "/messages");
242        assert_eq!(
243            remove_query_and_hash("/messages?foo=bar&baz=qux"),
244            "/messages"
245        );
246        assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
247        assert_eq!(
248            remove_query_and_hash("/messages?key=value#section2"),
249            "/messages"
250        );
251        assert_eq!(remove_query_and_hash("/"), "/");
252    }
253}