rust_mcp_sdk/
utils.rs

1use crate::schema::schema_utils::{ClientMessages, SdkError};
2
3use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
4use crate::schema::ProtocolVersion;
5use std::cmp::Ordering;
6
7/// A guard type that automatically aborts a Tokio task when dropped.
8///
9/// This ensures that the associated task does not outlive the scope
10/// of this struct, preventing runaway or leaked background tasks.
11///
12pub struct AbortTaskOnDrop {
13    /// The handle used to abort the spawned Tokio task.
14    pub handle: tokio::task::AbortHandle,
15}
16
17impl Drop for AbortTaskOnDrop {
18    fn drop(&mut self) {
19        // Automatically abort the associated task when this guard is dropped.
20        self.handle.abort();
21    }
22}
23
24/// Formats an assertion error message for unsupported capabilities.
25///
26/// Constructs a string describing that a specific entity (e.g., server or client) lacks
27/// support for a required capability, needed for a particular method.
28///
29/// # Arguments
30/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
31/// - `capability`: The name of the unsupported capability or tool.
32/// - `method_name`: The name of the method requiring the capability.
33///
34/// # Returns
35/// A formatted string detailing the unsupported capability error.
36///
37/// # Examples
38/// ```ignore
39/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
40/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
41/// ```
42pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
43    format!("{entity} does not support {capability} (required for {method_name})")
44}
45
46/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
47///
48/// This function compares the provided client and server protocol versions. If they are equal,
49/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
50/// lower or higher than the server version), it returns an error with details about the
51/// incompatible versions.
52///
53/// # Arguments
54///
55/// * `client_protocol_version` - A string slice representing the client's protocol version.
56/// * `server_protocol_version` - A string slice representing the server's protocol version.
57///
58/// # Returns
59///
60/// * `Ok(())` if the versions are equal.
61/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
62///   client and server versions as strings.
63///
64/// # Examples
65///
66/// ```
67/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
68/// use rust_mcp_sdk::error::McpSdkError;
69///
70/// // Compatible versions
71/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
72/// assert!(result.is_ok());
73///
74/// // Incompatible versions (requested < current)
75/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
76/// assert!(matches!(
77///     result,
78///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
79///     if requested == "2024_11_05" && current == "2025_03_26"
80/// ));
81///
82/// // Incompatible versions (requested > current)
83/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
84/// assert!(matches!(
85///     result,
86///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
87///     if requested == "2025_03_26" && current == "2024_11_05"
88/// ));
89/// ```
90#[allow(unused)]
91pub fn ensure_server_protocole_compatibility(
92    client_protocol_version: &str,
93    server_protocol_version: &str,
94) -> SdkResult<()> {
95    match client_protocol_version.cmp(server_protocol_version) {
96        Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
97            kind: ProtocolErrorKind::IncompatibleVersion {
98                requested: client_protocol_version.to_string(),
99                current: server_protocol_version.to_string(),
100            },
101        }),
102        Ordering::Equal => Ok(()),
103    }
104}
105
106/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
107///
108/// This function compares the client and server protocol versions. If the client version is
109/// higher than the server version, it returns an error indicating incompatibility. If the
110/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
111/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
112/// can use the client's version for compatibility.
113///
114/// # Arguments
115///
116/// * `client_protocol_version` - The client's protocol version.
117/// * `server_protocol_version` - The server's protocol version.
118///
119/// # Returns
120///
121/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
122/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
123///   version to use for compatibility.
124/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
125///   the client and server versions as strings.
126///
127/// # Examples
128///
129/// ```
130/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
131/// use rust_mcp_sdk::error::McpSdkError;
132///
133/// // Equal versions
134/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
135/// assert!(matches!(result, Ok(None)));
136///
137/// // Client version lower (downgrade allowed)
138/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
139/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
140///
141/// // Client version higher (incompatible)
142/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
143/// assert!(matches!(
144///     result,
145///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
146///     if requested == "2025_03_26" && current == "2024_11_05"
147/// ));
148/// ```
149#[allow(unused)]
150pub fn enforce_compatible_protocol_version(
151    client_protocol_version: &str,
152    server_protocol_version: &str,
153) -> SdkResult<Option<String>> {
154    match client_protocol_version.cmp(server_protocol_version) {
155        // if client protocol version is higher
156        Ordering::Greater => Err(McpSdkError::Protocol {
157            kind: ProtocolErrorKind::IncompatibleVersion {
158                requested: client_protocol_version.to_string(),
159                current: server_protocol_version.to_string(),
160            },
161        }),
162        Ordering::Equal => Ok(None),
163        Ordering::Less => {
164            // return the same version that was received from the client
165            Ok(Some(client_protocol_version.to_string()))
166        }
167    }
168}
169
170pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
171    let _mcp_protocol_version =
172        ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
173            kind: ProtocolErrorKind::ParseError(err),
174        })?;
175    Ok(())
176}
177
178/// Removes query string and hash fragment from a URL, returning the base path.
179///
180/// # Arguments
181/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
182///
183/// # Returns
184/// A String containing the base path without query parameters or fragment
185/// ```
186#[allow(unused)]
187pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
188    // Split off fragment (if any) and take the first part
189    let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
190
191    // Split off query string (if any) and take the first part
192    let without_query = without_fragment
193        .split_once('?')
194        .map_or(without_fragment, |(path, _)| path);
195
196    // Return the base path
197    if without_query.is_empty() {
198        "/".to_string()
199    } else {
200        without_query.to_string()
201    }
202}
203
204/// Checks if the input string is valid JSON and represents an "initialize" method request.
205pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
206    // Attempt to deserialize the input string into ClientMessages
207    let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
208        return Err(SdkError::bad_request()
209            .with_message("Bad Request: Session not found")
210            .into());
211    };
212
213    match request {
214        ClientMessages::Single(client_message) => {
215            if !client_message.is_initialize_request() {
216                return Err(SdkError::bad_request()
217                    .with_message("Bad Request: Session not found")
218                    .into());
219            }
220        }
221        ClientMessages::Batch(client_messages) => {
222            let count = client_messages
223                .iter()
224                .filter(|item| item.is_initialize_request())
225                .count();
226            if count > 1 {
227                return Err(SdkError::invalid_request()
228                    .with_message("Bad Request: Only one initialization request is allowed")
229                    .into());
230            }
231        }
232    };
233
234    Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    #[test]
241    fn tets_remove_query_and_hash() {
242        assert_eq!(remove_query_and_hash("/messages"), "/messages");
243        assert_eq!(
244            remove_query_and_hash("/messages?foo=bar&baz=qux"),
245            "/messages"
246        );
247        assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
248        assert_eq!(
249            remove_query_and_hash("/messages?key=value#section2"),
250            "/messages"
251        );
252        assert_eq!(remove_query_and_hash("/"), "/");
253    }
254}