rust_mcp_sdk/utils.rs
1use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult};
2use crate::schema::schema_utils::{ClientMessages, SdkError};
3use crate::schema::ProtocolVersion;
4use std::cmp::Ordering;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6#[cfg(feature = "auth")]
7use url::Url;
8
9/// A guard type that automatically aborts a Tokio task when dropped.
10///
11/// This ensures that the associated task does not outlive the scope
12/// of this struct, preventing runaway or leaked background tasks.
13///
14pub struct AbortTaskOnDrop {
15 /// The handle used to abort the spawned Tokio task.
16 pub handle: tokio::task::AbortHandle,
17}
18
19impl Drop for AbortTaskOnDrop {
20 fn drop(&mut self) {
21 // Automatically abort the associated task when this guard is dropped.
22 self.handle.abort();
23 }
24}
25
26/// Formats an assertion error message for unsupported capabilities.
27///
28/// Constructs a string describing that a specific entity (e.g., server or client) lacks
29/// support for a required capability, needed for a particular method.
30///
31/// # Arguments
32/// - `entity`: The name of the entity (e.g., "Server" or "Client") that lacks support.
33/// - `capability`: The name of the unsupported capability or tool.
34/// - `method_name`: The name of the method requiring the capability.
35///
36/// # Returns
37/// A formatted string detailing the unsupported capability error.
38///
39/// # Examples
40/// ```ignore
41/// let msg = format_assertion_message("Server", "tools", rust_mcp_schema::ListResourcesRequest::method_name());
42/// assert_eq!(msg, "Server does not support resources (required for resources/list)");
43/// ```
44pub fn format_assertion_message(entity: &str, capability: &str, method_name: &str) -> String {
45 format!("{entity} does not support {capability} (required for {method_name})")
46}
47
48// Function to convert Unix timestamp to SystemTime
49pub fn unix_timestamp_to_systemtime(timestamp: u64) -> SystemTime {
50 UNIX_EPOCH + Duration::from_secs(timestamp)
51}
52
53/// Checks if the client and server protocol versions are compatible by ensuring they are equal.
54///
55/// This function compares the provided client and server protocol versions. If they are equal,
56/// it returns `Ok(())`, indicating compatibility. If they differ (either the client version is
57/// lower or higher than the server version), it returns an error with details about the
58/// incompatible versions.
59///
60/// # Arguments
61///
62/// * `client_protocol_version` - A string slice representing the client's protocol version.
63/// * `server_protocol_version` - A string slice representing the server's protocol version.
64///
65/// # Returns
66///
67/// * `Ok(())` if the versions are equal.
68/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the versions differ, containing the
69/// client and server versions as strings.
70///
71/// # Examples
72///
73/// ```
74/// use rust_mcp_sdk::mcp_client::ensure_server_protocole_compatibility;
75/// use rust_mcp_sdk::error::McpSdkError;
76///
77/// // Compatible versions
78/// let result = ensure_server_protocole_compatibility("2024_11_05", "2024_11_05");
79/// assert!(result.is_ok());
80///
81/// // Incompatible versions (requested < current)
82/// let result = ensure_server_protocole_compatibility("2024_11_05", "2025_03_26");
83/// assert!(matches!(
84/// result,
85/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
86/// if requested == "2024_11_05" && current == "2025_03_26"
87/// ));
88///
89/// // Incompatible versions (requested > current)
90/// let result = ensure_server_protocole_compatibility("2025_03_26", "2024_11_05");
91/// assert!(matches!(
92/// result,
93/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
94/// if requested == "2025_03_26" && current == "2024_11_05"
95/// ));
96/// ```
97#[allow(unused)]
98pub fn ensure_server_protocole_compatibility(
99 client_protocol_version: &str,
100 server_protocol_version: &str,
101) -> SdkResult<()> {
102 match client_protocol_version.cmp(server_protocol_version) {
103 Ordering::Less | Ordering::Greater => Err(McpSdkError::Protocol {
104 kind: ProtocolErrorKind::IncompatibleVersion {
105 requested: client_protocol_version.to_string(),
106 current: server_protocol_version.to_string(),
107 },
108 }),
109 Ordering::Equal => Ok(()),
110 }
111}
112
113/// Enforces protocol version compatibility on for MCP Server , allowing the client to use a lower or equal version.
114///
115/// This function compares the client and server protocol versions. If the client version is
116/// higher than the server version, it returns an error indicating incompatibility. If the
117/// versions are equal, it returns `Ok(None)`, indicating no downgrade is needed. If the client
118/// version is lower, it returns `Ok(Some(client_protocol_version))`, suggesting the server
119/// can use the client's version for compatibility.
120///
121/// # Arguments
122///
123/// * `client_protocol_version` - The client's protocol version.
124/// * `server_protocol_version` - The server's protocol version.
125///
126/// # Returns
127///
128/// * `Ok(None)` if the versions are equal, indicating no downgrade is needed.
129/// * `Ok(Some(client_protocol_version))` if the client version is lower, returning the client
130/// version to use for compatibility.
131/// * `Err(McpSdkError::IncompatibleProtocolVersion)` if the client version is higher, containing
132/// the client and server versions as strings.
133///
134/// # Examples
135///
136/// ```
137/// use rust_mcp_sdk::mcp_server::enforce_compatible_protocol_version;
138/// use rust_mcp_sdk::error::McpSdkError;
139///
140/// // Equal versions
141/// let result = enforce_compatible_protocol_version("2024_11_05", "2024_11_05");
142/// assert!(matches!(result, Ok(None)));
143///
144/// // Client version lower (downgrade allowed)
145/// let result = enforce_compatible_protocol_version("2024_11_05", "2025_03_26");
146/// assert!(matches!(result, Ok(Some(ref v)) if v == "2024_11_05"));
147///
148/// // Client version higher (incompatible)
149/// let result = enforce_compatible_protocol_version("2025_03_26", "2024_11_05");
150/// assert!(matches!(
151/// result,
152/// Err(McpSdkError::Protocol{kind: rust_mcp_sdk::error::ProtocolErrorKind::IncompatibleVersion {requested, current}})
153/// if requested == "2025_03_26" && current == "2024_11_05"
154/// ));
155/// ```
156#[allow(unused)]
157pub fn enforce_compatible_protocol_version(
158 client_protocol_version: &str,
159 server_protocol_version: &str,
160) -> SdkResult<Option<String>> {
161 match client_protocol_version.cmp(server_protocol_version) {
162 // if client protocol version is higher
163 Ordering::Greater => Err(McpSdkError::Protocol {
164 kind: ProtocolErrorKind::IncompatibleVersion {
165 requested: client_protocol_version.to_string(),
166 current: server_protocol_version.to_string(),
167 },
168 }),
169 Ordering::Equal => Ok(None),
170 Ordering::Less => {
171 // return the same version that was received from the client
172 Ok(Some(client_protocol_version.to_string()))
173 }
174 }
175}
176
177pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> {
178 let _mcp_protocol_version =
179 ProtocolVersion::try_from(mcp_protocol_version).map_err(|err| McpSdkError::Protocol {
180 kind: ProtocolErrorKind::ParseError(err),
181 })?;
182 Ok(())
183}
184
185/// Removes query string and hash fragment from a URL, returning the base path.
186///
187/// # Arguments
188/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
189///
190/// # Returns
191/// A String containing the base path without query parameters or fragment
192/// ```
193#[allow(unused)]
194pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
195 // Split off fragment (if any) and take the first part
196 let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
197
198 // Split off query string (if any) and take the first part
199 let without_query = without_fragment
200 .split_once('?')
201 .map_or(without_fragment, |(path, _)| path);
202
203 // Return the base path
204 if without_query.is_empty() {
205 "/".to_string()
206 } else {
207 without_query.to_string()
208 }
209}
210
211/// Checks if the input string is valid JSON and represents an "initialize" method request.
212pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> {
213 // Attempt to deserialize the input string into ClientMessages
214 let Ok(request) = serde_json::from_str::<ClientMessages>(json_str) else {
215 return Err(SdkError::bad_request()
216 .with_message("Bad Request: Session not found")
217 .into());
218 };
219
220 match request {
221 ClientMessages::Single(client_message) => {
222 if !client_message.is_initialize_request() {
223 return Err(SdkError::bad_request()
224 .with_message("Bad Request: Session not found")
225 .into());
226 }
227 }
228 ClientMessages::Batch(client_messages) => {
229 let count = client_messages
230 .iter()
231 .filter(|item| item.is_initialize_request())
232 .count();
233 if count > 1 {
234 return Err(SdkError::invalid_request()
235 .with_message("Bad Request: Only one initialization request is allowed")
236 .into());
237 }
238 }
239 };
240
241 Ok(())
242}
243
244#[cfg(feature = "auth")]
245pub fn join_url(base: &Url, segment: &str) -> Result<Url, url::ParseError> {
246 // Fast early check - Url must be absolute
247 if base.cannot_be_a_base() {
248 return Err(url::ParseError::RelativeUrlWithoutBase);
249 }
250
251 // We have to clone - there is no way around this when taking &Url
252 let mut url = base.clone();
253
254 // This is the official, safe, and correct way
255 url.path_segments_mut()
256 .map_err(|_| url::ParseError::RelativeUrlWithoutBase)?
257 .pop_if_empty() // makes it act like a directory
258 .extend(
259 segment
260 .trim_start_matches('/')
261 .split('/')
262 .filter(|s| !s.is_empty()),
263 );
264
265 Ok(url)
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 #[test]
272 fn tets_remove_query_and_hash() {
273 assert_eq!(remove_query_and_hash("/messages"), "/messages");
274 assert_eq!(
275 remove_query_and_hash("/messages?foo=bar&baz=qux"),
276 "/messages"
277 );
278 assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
279 assert_eq!(
280 remove_query_and_hash("/messages?key=value#section2"),
281 "/messages"
282 );
283 assert_eq!(remove_query_and_hash("/"), "/");
284 }
285
286 #[test]
287 fn test_join_url() {
288 let expect = "http://example.com/api/user/userinfo";
289 let result = join_url(
290 &Url::parse("http://example.com/api").unwrap(),
291 "/user/userinfo",
292 )
293 .unwrap();
294 assert_eq!(result.to_string(), expect);
295
296 let result = join_url(
297 &Url::parse("http://example.com/api").unwrap(),
298 "user/userinfo",
299 )
300 .unwrap();
301 assert_eq!(result.to_string(), expect);
302
303 let result = join_url(
304 &Url::parse("http://example.com/api/").unwrap(),
305 "/user/userinfo",
306 )
307 .unwrap();
308 assert_eq!(result.to_string(), expect);
309
310 let result = join_url(
311 &Url::parse("http://example.com/api/").unwrap(),
312 "user/userinfo",
313 )
314 .unwrap();
315 assert_eq!(result.to_string(), expect);
316 }
317}