rust_mcp_sdk/
utils.rs

1use crate::schema::schema_utils::{ClientMessages, SdkError};
2
3use crate::error::{McpSdkError, SdkResult};
4use crate::schema::ProtocolVersion;
5use std::cmp::Ordering;
6
7/// A guard type that automatically aborts a Tokio task when dropped.
8///
9/// This ensures that the associated task does not outlive the scope
10/// of this struct, preventing runaway or leaked background tasks.
11///
12pub struct AbortTaskOnDrop {
13    /// The handle used to abort the spawned Tokio task.
14    pub handle: tokio::task::AbortHandle,
15}
16
17impl Drop for AbortTaskOnDrop {
18    fn drop(&mut self) {
19        // Automatically abort the associated task when this guard is dropped.
20        self.handle.abort();
21    }
22}
23
24/// Formats an assertion error message for unsupported capabilities.
25///
26/// Constructs a string describing that a specific entity (e.g., server or client) lacks
27/// support for a required capability, needed for a particular method.
28///
29/// # Arguments
30/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
31/// - `capability`: The name of the unsupported capability or tool.
32/// - `method_name`: The name of the method requiring the capability.
33///
34/// # Returns
35/// A formatted string detailing the unsupported capability error.
36///
37/// # Examples
38/// ```ignore
39/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
40/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
41/// ```
42pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
43    format!("{entity} does not support {capability} (required for {method_name})")
44}
45
46/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
47///
48/// This function compares the provided client and server protocol versions. If they are equal,
49/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
50/// lower or higher than the server version), it returns an error with details about the
51/// incompatible versions.
52///
53/// # Arguments
54///
55/// * `client_protocol_version` - A string slice representing the client's protocol version.
56/// * `server_protocol_version` - A string slice representing the server's protocol version.
57///
58/// # Returns
59///
60/// * `Ok(())` if the versions are equal.
61/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
62///   client and server versions as strings.
63///
64/// # Examples
65///
66/// ```
67/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
68/// use rust_mcp_sdk::error::McpSdkError;
69///
70/// // Compatible versions
71/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
72/// assert!(result.is_ok());
73///
74/// // Incompatible versions (client < server)
75/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
76/// assert!(matches!(
77///     result,
78///     Err(McpSdkError::IncompatibleProtocolVersion(client, server))
79///     if client == "2024_11_05" && server == "2025_03_26"
80/// ));
81///
82/// // Incompatible versions (client > server)
83/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
84/// assert!(matches!(
85///     result,
86///     Err(McpSdkError::IncompatibleProtocolVersion(client, server))
87///     if client == "2025_03_26" && server == "2024_11_05"
88/// ));
89/// ```
90#[allow(unused)]
91pub fn ensure_server_protocole_compatibility(
92    client_protocol_version: &str,
93    server_protocol_version: &str,
94) -> SdkResult<()> {
95    match client_protocol_version.cmp(server_protocol_version) {
96        Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
97            client_protocol_version.to_string(),
98            server_protocol_version.to_string(),
99        )),
100        Ordering::Equal => Ok(()),
101    }
102}
103
104/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
105///
106/// This function compares the client and server protocol versions. If the client version is
107/// higher than the server version, it returns an error indicating incompatibility. If the
108/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
109/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
110/// can use the client's version for compatibility.
111///
112/// # Arguments
113///
114/// * `client_protocol_version` - The client's protocol version.
115/// * `server_protocol_version` - The server's protocol version.
116///
117/// # Returns
118///
119/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
120/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
121///   version to use for compatibility.
122/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
123///   the client and server versions as strings.
124///
125/// # Examples
126///
127/// ```
128/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
129/// use rust_mcp_sdk::error::McpSdkError;
130///
131/// // Equal versions
132/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
133/// assert!(matches!(result, Ok(None)));
134///
135/// // Client version lower (downgrade allowed)
136/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
137/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
138///
139/// // Client version higher (incompatible)
140/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
141/// assert!(matches!(
142///     result,
143///     Err(McpSdkError::IncompatibleProtocolVersion(client, server))
144///     if client == "2025_03_26" && server == "2024_11_05"
145/// ));
146/// ```
147#[allow(unused)]
148pub fn enforce_compatible_protocol_version(
149    client_protocol_version: &str,
150    server_protocol_version: &str,
151) -> SdkResult<Option<String>> {
152    match client_protocol_version.cmp(server_protocol_version) {
153        // if client protocol version is higher
154        Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
155            client_protocol_version.to_string(),
156            server_protocol_version.to_string(),
157        )),
158        Ordering::Equal => Ok(None),
159        Ordering::Less => {
160            // return the same version that was received from the client
161            Ok(Some(client_protocol_version.to_string()))
162        }
163    }
164}
165
166pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
167    let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?;
168    Ok(())
169}
170
171/// Removes query string and hash fragment from a URL, returning the base path.
172///
173/// # Arguments
174/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
175///
176/// # Returns
177/// A String containing the base path without query parameters or fragment
178/// ```
179#[allow(unused)]
180pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
181    // Split off fragment (if any) and take the first part
182    let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
183
184    // Split off query string (if any) and take the first part
185    let without_query = without_fragment
186        .split_once('?')
187        .map_or(without_fragment, |(path, _)| path);
188
189    // Return the base path
190    if without_query.is_empty() {
191        "/".to_string()
192    } else {
193        without_query.to_string()
194    }
195}
196
197/// Checks if the input string is valid JSON and represents an "initialize" method request.
198pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
199    // Attempt to deserialize the input string into ClientMessages
200    let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
201        return Err(SdkError::bad_request()
202            .with_message("Bad Request: Session not found")
203            .into());
204    };
205
206    match request {
207        ClientMessages::Single(client_message) => {
208            if !client_message.is_initialize_request() {
209                return Err(SdkError::bad_request()
210                    .with_message("Bad Request: Session not found")
211                    .into());
212            }
213        }
214        ClientMessages::Batch(client_messages) => {
215            let count = client_messages
216                .iter()
217                .filter(|item| item.is_initialize_request())
218                .count();
219            if count > 1 {
220                return Err(SdkError::invalid_request()
221                    .with_message("Bad Request: Only one initialization request is allowed")
222                    .into());
223            }
224        }
225    };
226
227    Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    #[test]
234    fn tets_remove_query_and_hash() {
235        assert_eq!(remove_query_and_hash("/messages"), "/messages");
236        assert_eq!(
237            remove_query_and_hash("/messages?foo=bar&baz=qux"),
238            "/messages"
239        );
240        assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
241        assert_eq!(
242            remove_query_and_hash("/messages?key=value#section2"),
243            "/messages"
244        );
245        assert_eq!(remove_query_and_hash("/"), "/");
246    }
247}