rust_mcp_sdk/utils.rs
1use std::cmp::Ordering;
2
3use crate::error::{McpSdkError, SdkResult};
4
5/// Formats an assertion error message for unsupported capabilities.
6///
7/// Constructs a string describing that a specific entity (e.g., server or client) lacks
8/// support for a required capability, needed for a particular method.
9///
10/// # Arguments
11/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
12/// - `capability`: The name of the unsupported capability or tool.
13/// - `method_name`: The name of the method requiring the capability.
14///
15/// # Returns
16/// A formatted string detailing the unsupported capability error.
17///
18/// # Examples
19/// ```ignore
20/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
21/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
22/// ```
23pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
24 format!("{entity} does not support {capability} (required for {method_name})")
25}
26
27/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
28///
29/// This function compares the provided client and server protocol versions. If they are equal,
30/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
31/// lower or higher than the server version), it returns an error with details about the
32/// incompatible versions.
33///
34/// # Arguments
35///
36/// * `client_protocol_version` - A string slice representing the client's protocol version.
37/// * `server_protocol_version` - A string slice representing the server's protocol version.
38///
39/// # Returns
40///
41/// * `Ok(())` if the versions are equal.
42/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
43/// client and server versions as strings.
44///
45/// # Examples
46///
47/// ```
48/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
49/// use rust_mcp_sdk::error::McpSdkError;
50///
51/// // Compatible versions
52/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
53/// assert!(result.is_ok());
54///
55/// // Incompatible versions (client < server)
56/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
57/// assert!(matches!(
58/// result,
59/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
60/// if client == "2024_11_05" && server == "2025_03_26"
61/// ));
62///
63/// // Incompatible versions (client > server)
64/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
65/// assert!(matches!(
66/// result,
67/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
68/// if client == "2025_03_26" && server == "2024_11_05"
69/// ));
70/// ```
71#[allow(unused)]
72pub fn ensure_server_protocole_compatibility(
73 client_protocol_version: &str,
74 server_protocol_version: &str,
75) -> SdkResult<()> {
76 match client_protocol_version.cmp(server_protocol_version) {
77 Ordering::Less | Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
78 client_protocol_version.to_string(),
79 server_protocol_version.to_string(),
80 )),
81 Ordering::Equal => Ok(()),
82 }
83}
84
85/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
86///
87/// This function compares the client and server protocol versions. If the client version is
88/// higher than the server version, it returns an error indicating incompatibility. If the
89/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
90/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
91/// can use the client's version for compatibility.
92///
93/// # Arguments
94///
95/// * `client_protocol_version` - The client's protocol version.
96/// * `server_protocol_version` - The server's protocol version.
97///
98/// # Returns
99///
100/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
101/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
102/// version to use for compatibility.
103/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
104/// the client and server versions as strings.
105///
106/// # Examples
107///
108/// ```
109/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
110/// use rust_mcp_sdk::error::McpSdkError;
111///
112/// // Equal versions
113/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
114/// assert!(matches!(result, Ok(None)));
115///
116/// // Client version lower (downgrade allowed)
117/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
118/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
119///
120/// // Client version higher (incompatible)
121/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
122/// assert!(matches!(
123/// result,
124/// Err(McpSdkError::IncompatibleProtocolVersion(client, server))
125/// if client == "2025_03_26" && server == "2024_11_05"
126/// ));
127/// ```
128#[allow(unused)]
129pub fn enforce_compatible_protocol_version(
130 client_protocol_version: &str,
131 server_protocol_version: &str,
132) -> SdkResult<Option<String>> {
133 match client_protocol_version.cmp(server_protocol_version) {
134 // if client protocol version is higher
135 Ordering::Greater => Err(McpSdkError::IncompatibleProtocolVersion(
136 client_protocol_version.to_string(),
137 server_protocol_version.to_string(),
138 )),
139 Ordering::Equal => Ok(None),
140 Ordering::Less => {
141 // return the same version that was received from the client
142 Ok(Some(client_protocol_version.to_string()))
143 }
144 }
145}
146
147/// Removes query string and hash fragment from a URL, returning the base path.
148///
149/// # Arguments
150/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
151///
152/// # Returns
153/// A String containing the base path without query parameters or fragment
154/// ```
155#[allow(unused)]
156pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
157 // Split off fragment (if any) and take the first part
158 let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
159
160 // Split off query string (if any) and take the first part
161 let without_query = without_fragment
162 .split_once('?')
163 .map_or(without_fragment, |(path, _)| path);
164
165 // Return the base path
166 if without_query.is_empty() {
167 "/".to_string()
168 } else {
169 without_query.to_string()
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 #[test]
177 fn tets_remove_query_and_hash() {
178 assert_eq!(remove_query_and_hash("/messages"), "/messages");
179 assert_eq!(
180 remove_query_and_hash("/messages?foo=bar&baz=qux"),
181 "/messages"
182 );
183 assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
184 assert_eq!(
185 remove_query_and_hash("/messages?key=value#section2"),
186 "/messages"
187 );
188 assert_eq!(remove_query_and_hash("/"), "/");
189 }
190}