rust_mcp_sdk/
utils.rs

1use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
2use crate::schema::schema_utils::{ClientMessages, SdkError};
3use crate::schema::ProtocolVersion;
4use std::cmp::Ordering;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6#[cfg(feature = "auth")]
7use url::Url;
8
9/// A guard type that automatically aborts a Tokio task when dropped.
10///
11/// This ensures that the associated task does not outlive the scope
12/// of this struct, preventing runaway or leaked background tasks.
13///
14pub struct AbortTaskOnDrop {
15    /// The handle used to abort the spawned Tokio task.
16    pub handle: tokio::task::AbortHandle,
17}
18
19impl Drop for AbortTaskOnDrop {
20    fn drop(&mut self) {
21        // Automatically abort the associated task when this guard is dropped.
22        self.handle.abort();
23    }
24}
25
26/// Formats an assertion error message for unsupported capabilities.
27///
28/// Constructs a string describing that a specific entity (e.g., server or client) lacks
29/// support for a required capability, needed for a particular method.
30///
31/// # Arguments
32/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
33/// - `capability`: The name of the unsupported capability or tool.
34/// - `method_name`: The name of the method requiring the capability.
35///
36/// # Returns
37/// A formatted string detailing the unsupported capability error.
38///
39/// # Examples
40/// ```ignore
41/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
42/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
43/// ```
44pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
45    format!("{entity} does not support {capability} (required for {method_name})")
46}
47
48// Function to convert Unix timestamp to SystemTime
49pub fn unix_timestamp_to_systemtime(timestamp: u64) -> SystemTime {
50    UNIX_EPOCH + Duration::from_secs(timestamp)
51}
52
53/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
54///
55/// This function compares the provided client and server protocol versions. If they are equal,
56/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
57/// lower or higher than the server version), it returns an error with details about the
58/// incompatible versions.
59///
60/// # Arguments
61///
62/// * `client_protocol_version` - A string slice representing the client's protocol version.
63/// * `server_protocol_version` - A string slice representing the server's protocol version.
64///
65/// # Returns
66///
67/// * `Ok(())` if the versions are equal.
68/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
69///   client and server versions as strings.
70///
71/// # Examples
72///
73/// ```
74/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
75/// use rust_mcp_sdk::error::McpSdkError;
76///
77/// // Compatible versions
78/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
79/// assert!(result.is_ok());
80///
81/// // Incompatible versions (requested < current)
82/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
83/// assert!(matches!(
84///     result,
85///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
86///     if requested == "2024_11_05" && current == "2025_03_26"
87/// ));
88///
89/// // Incompatible versions (requested > current)
90/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
91/// assert!(matches!(
92///     result,
93///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
94///     if requested == "2025_03_26" && current == "2024_11_05"
95/// ));
96/// ```
97#[allow(unused)]
98pub fn ensure_server_protocole_compatibility(
99    client_protocol_version: &str,
100    server_protocol_version: &str,
101) -> SdkResult<()> {
102    match client_protocol_version.cmp(server_protocol_version) {
103        Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
104            kind: ProtocolErrorKind::IncompatibleVersion {
105                requested: client_protocol_version.to_string(),
106                current: server_protocol_version.to_string(),
107            },
108        }),
109        Ordering::Equal => Ok(()),
110    }
111}
112
113/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
114///
115/// This function compares the client and server protocol versions. If the client version is
116/// higher than the server version, it returns an error indicating incompatibility. If the
117/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
118/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
119/// can use the client's version for compatibility.
120///
121/// # Arguments
122///
123/// * `client_protocol_version` - The client's protocol version.
124/// * `server_protocol_version` - The server's protocol version.
125///
126/// # Returns
127///
128/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
129/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
130///   version to use for compatibility.
131/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
132///   the client and server versions as strings.
133///
134/// # Examples
135///
136/// ```
137/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
138/// use rust_mcp_sdk::error::McpSdkError;
139///
140/// // Equal versions
141/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
142/// assert!(matches!(result, Ok(None)));
143///
144/// // Client version lower (downgrade allowed)
145/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
146/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
147///
148/// // Client version higher (incompatible)
149/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
150/// assert!(matches!(
151///     result,
152///     Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
153///     if requested == "2025_03_26" && current == "2024_11_05"
154/// ));
155/// ```
156#[allow(unused)]
157pub fn enforce_compatible_protocol_version(
158    client_protocol_version: &str,
159    server_protocol_version: &str,
160) -> SdkResult<Option<String>> {
161    match client_protocol_version.cmp(server_protocol_version) {
162        // if client protocol version is higher
163        Ordering::Greater => Err(McpSdkError::Protocol {
164            kind: ProtocolErrorKind::IncompatibleVersion {
165                requested: client_protocol_version.to_string(),
166                current: server_protocol_version.to_string(),
167            },
168        }),
169        Ordering::Equal => Ok(None),
170        Ordering::Less => {
171            // return the same version that was received from the client
172            Ok(Some(client_protocol_version.to_string()))
173        }
174    }
175}
176
177pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
178    let _mcp_protocol_version =
179        ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
180            kind: ProtocolErrorKind::ParseError(err),
181        })?;
182    Ok(())
183}
184
185/// Removes query string and hash fragment from a URL, returning the base path.
186///
187/// # Arguments
188/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
189///
190/// # Returns
191/// A String containing the base path without query parameters or fragment
192/// ```
193#[allow(unused)]
194pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
195    // Split off fragment (if any) and take the first part
196    let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
197
198    // Split off query string (if any) and take the first part
199    let without_query = without_fragment
200        .split_once('?')
201        .map_or(without_fragment, |(path, _)| path);
202
203    // Return the base path
204    if without_query.is_empty() {
205        "/".to_string()
206    } else {
207        without_query.to_string()
208    }
209}
210
211/// Checks if the input string is valid JSON and represents an "initialize" method request.
212pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
213    // Attempt to deserialize the input string into ClientMessages
214    let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
215        return Err(SdkError::bad_request()
216            .with_message("Bad Request: Session not found")
217            .into());
218    };
219
220    match request {
221        ClientMessages::Single(client_message) => {
222            if !client_message.is_initialize_request() {
223                return Err(SdkError::bad_request()
224                    .with_message("Bad Request: Session not found")
225                    .into());
226            }
227        }
228        ClientMessages::Batch(client_messages) => {
229            let count = client_messages
230                .iter()
231                .filter(|item| item.is_initialize_request())
232                .count();
233            if count > 1 {
234                return Err(SdkError::invalid_request()
235                    .with_message("Bad Request: Only one initialization request is allowed")
236                    .into());
237            }
238        }
239    };
240
241    Ok(())
242}
243
244#[cfg(feature = "auth")]
245pub fn join_url(base: &Url, segment: &str) -> Result<Url, url::ParseError> {
246    // Fast early check - Url must be absolute
247    if base.cannot_be_a_base() {
248        return Err(url::ParseError::RelativeUrlWithoutBase);
249    }
250
251    // We have to clone - there is no way around this when taking &Url
252    let mut url = base.clone();
253
254    // This is the official, safe, and correct way
255    url.path_segments_mut()
256        .map_err(|_| url::ParseError::RelativeUrlWithoutBase)?
257        .pop_if_empty() // makes it act like a directory
258        .extend(
259            segment
260                .trim_start_matches('/')
261                .split('/')
262                .filter(|s| !s.is_empty()),
263        );
264
265    Ok(url)
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    #[test]
272    fn tets_remove_query_and_hash() {
273        assert_eq!(remove_query_and_hash("/messages"), "/messages");
274        assert_eq!(
275            remove_query_and_hash("/messages?foo=bar&baz=qux"),
276            "/messages"
277        );
278        assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
279        assert_eq!(
280            remove_query_and_hash("/messages?key=value#section2"),
281            "/messages"
282        );
283        assert_eq!(remove_query_and_hash("/"), "/");
284    }
285
286    #[test]
287    fn test_join_url() {
288        let expect = "http://example.com/api/user/userinfo";
289        let result = join_url(
290            &Url::parse("http://example.com/api").unwrap(),
291            "/user/userinfo",
292        )
293        .unwrap();
294        assert_eq!(result.to_string(), expect);
295
296        let result = join_url(
297            &Url::parse("http://example.com/api").unwrap(),
298            "user/userinfo",
299        )
300        .unwrap();
301        assert_eq!(result.to_string(), expect);
302
303        let result = join_url(
304            &Url::parse("http://example.com/api/").unwrap(),
305            "/user/userinfo",
306        )
307        .unwrap();
308        assert_eq!(result.to_string(), expect);
309
310        let result = join_url(
311            &Url::parse("http://example.com/api/").unwrap(),
312            "user/userinfo",
313        )
314        .unwrap();
315        assert_eq!(result.to_string(), expect);
316    }
317}