turbomcp_protocol/context/
request.rs

1//! Request and response context types for MCP request handling.
2//!
3//! This module contains the core context types used throughout the MCP protocol
4//! implementation for tracking request metadata, response information, and analytics.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9use std::time::Instant;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use super::capabilities::ServerToClientRequests;
17use crate::types::Timestamp;
18
19/// Context information for a single MCP request, carried through its entire lifecycle.
20///
21/// This struct contains essential metadata for processing, logging, and tracing a request,
22/// including unique identifiers, authentication information, and mechanisms for
23/// cancellation and server-initiated communication.
24#[derive(Clone)]
25pub struct RequestContext {
26    /// A unique identifier for the request, typically a UUID.
27    pub request_id: String,
28
29    /// The identifier for the user making the request, if authenticated.
30    pub user_id: Option<String>,
31
32    /// The identifier for the session to which this request belongs.
33    pub session_id: Option<String>,
34
35    /// The identifier for the client application making the request.
36    pub client_id: Option<String>,
37
38    /// The timestamp when the request was received.
39    pub timestamp: Timestamp,
40
41    /// The `Instant` when request processing started, used for performance tracking.
42    pub start_time: Instant,
43
44    /// A collection of custom metadata for application-specific use cases.
45    pub metadata: Arc<HashMap<String, serde_json::Value>>,
46
47    /// The tracing span associated with this request for observability.
48    #[cfg(feature = "tracing")]
49    pub span: Option<tracing::Span>,
50
51    /// A token that can be used to signal cancellation of the request.
52    pub cancellation_token: Option<Arc<CancellationToken>>,
53
54    /// An interface for making server-initiated requests back to the client (e.g., sampling, elicitation).
55    /// This is hidden from public docs as it's an internal detail injected by the server.
56    #[doc(hidden)]
57    pub(crate) server_to_client: Option<Arc<dyn ServerToClientRequests>>,
58}
59
60impl fmt::Debug for RequestContext {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("RequestContext")
63            .field("request_id", &self.request_id)
64            .field("user_id", &self.user_id)
65            .field("session_id", &self.session_id)
66            .field("client_id", &self.client_id)
67            .field("timestamp", &self.timestamp)
68            .field("metadata", &self.metadata)
69            .field("server_to_client", &self.server_to_client.is_some())
70            .finish()
71    }
72}
73
74/// Context information generated after processing a request, containing response details.
75#[derive(Debug, Clone)]
76pub struct ResponseContext {
77    /// The ID of the original request this response is for.
78    pub request_id: String,
79
80    /// The timestamp when the response was generated.
81    pub timestamp: Timestamp,
82
83    /// The total time taken to process the request.
84    pub duration: std::time::Duration,
85
86    /// The status of the response (e.g., Success, Error).
87    pub status: ResponseStatus,
88
89    /// A collection of custom metadata for the response.
90    pub metadata: Arc<HashMap<String, serde_json::Value>>,
91}
92
93/// Represents the status of an MCP response.
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum ResponseStatus {
96    /// The request was processed successfully.
97    Success,
98    /// An error occurred during request processing.
99    Error {
100        /// A numeric code indicating the error type.
101        code: i32,
102        /// A human-readable message describing the error.
103        message: String,
104    },
105    /// The response is partial, indicating more data will follow (for streaming).
106    Partial,
107    /// The request was cancelled before completion.
108    Cancelled,
109}
110
111/// Contains analytics information for a single request, used for monitoring and debugging.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RequestInfo {
114    /// The timestamp when the request was received.
115    pub timestamp: DateTime<Utc>,
116    /// The identifier of the client that made the request.
117    pub client_id: String,
118    /// The name of the tool or method that was called.
119    pub method_name: String,
120    /// The parameters provided in the request, potentially sanitized for privacy.
121    pub parameters: serde_json::Value,
122    /// The total time taken to generate a response, in milliseconds.
123    pub response_time_ms: Option<u64>,
124    /// A boolean indicating whether the request was successful.
125    pub success: bool,
126    /// The error message, if the request failed.
127    pub error_message: Option<String>,
128    /// The HTTP status code, if the request was handled over HTTP.
129    pub status_code: Option<u16>,
130    /// Additional custom metadata for analytics.
131    pub metadata: HashMap<String, serde_json::Value>,
132}
133
134impl RequestContext {
135    /// Creates a new `RequestContext` with a generated UUIDv4 as the request ID.
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            request_id: Uuid::new_v4().to_string(),
140            user_id: None,
141            session_id: None,
142            client_id: None,
143            timestamp: Timestamp::now(),
144            start_time: Instant::now(),
145            metadata: Arc::new(HashMap::new()),
146            #[cfg(feature = "tracing")]
147            span: None,
148            cancellation_token: None,
149            server_to_client: None,
150        }
151    }
152
153    /// Creates a new `RequestContext` with a specific request ID.
154    pub fn with_id(id: impl Into<String>) -> Self {
155        Self {
156            request_id: id.into(),
157            ..Self::new()
158        }
159    }
160
161    /// Sets the user ID for this context, returning the modified context.
162    ///
163    /// # Example
164    /// ```
165    /// # use turbomcp_protocol::context::RequestContext;
166    /// let ctx = RequestContext::new().with_user_id("user-123");
167    /// assert_eq!(ctx.user_id, Some("user-123".to_string()));
168    /// ```
169    #[must_use]
170    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
171        self.user_id = Some(user_id.into());
172        self
173    }
174
175    /// Sets the session ID for this context, returning the modified context.
176    #[must_use]
177    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
178        self.session_id = Some(session_id.into());
179        self
180    }
181
182    /// Sets the client ID for this context, returning the modified context.
183    #[must_use]
184    pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
185        self.client_id = Some(client_id.into());
186        self
187    }
188
189    /// Adds a key-value pair to the metadata, returning the modified context.
190    ///
191    /// # Example
192    /// ```
193    /// # use turbomcp_protocol::context::RequestContext;
194    /// # use serde_json::json;
195    /// let ctx = RequestContext::new().with_metadata("tenant", json!("acme-corp"));
196    /// assert_eq!(ctx.get_metadata("tenant"), Some(&json!("acme-corp")));
197    /// ```
198    #[must_use]
199    pub fn with_metadata(
200        mut self,
201        key: impl Into<String>,
202        value: impl Into<serde_json::Value>,
203    ) -> Self {
204        Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
205        self
206    }
207
208    /// Retrieves a value from the metadata by key.
209    #[must_use]
210    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
211        self.metadata.get(key)
212    }
213
214    /// Returns the elapsed time since the request processing started.
215    #[must_use]
216    pub fn elapsed(&self) -> std::time::Duration {
217        self.start_time.elapsed()
218    }
219
220    /// Checks if the request has been marked for cancellation.
221    #[must_use]
222    pub fn is_cancelled(&self) -> bool {
223        self.cancellation_token
224            .as_ref()
225            .is_some_and(|token| token.is_cancelled())
226    }
227
228    /// Sets the server-to-client requests interface for this context.
229    ///
230    /// This enables tools to make server-initiated requests (sampling, elicitation, roots)
231    /// with full context propagation for tracing and attribution. This is typically called
232    /// by the server implementation.
233    #[must_use]
234    pub fn with_server_to_client(mut self, capabilities: Arc<dyn ServerToClientRequests>) -> Self {
235        self.server_to_client = Some(capabilities);
236        self
237    }
238
239    /// Sets the cancellation token for cooperative cancellation.
240    /// This is typically called by the server implementation.
241    #[must_use]
242    pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
243        self.cancellation_token = Some(token);
244        self
245    }
246
247    /// Returns the user ID from the request context, if available.
248    #[must_use]
249    pub fn user(&self) -> Option<&str> {
250        self.user_id.as_deref()
251    }
252
253    /// Checks if the request is from an authenticated client.
254    /// This is determined by metadata set during the authentication process.
255    #[must_use]
256    pub fn is_authenticated(&self) -> bool {
257        self.get_metadata("client_authenticated")
258            .and_then(|v| v.as_bool())
259            .unwrap_or(false)
260    }
261
262    /// Returns the user roles from the request context, if available.
263    /// Roles are typically populated from an authentication token.
264    #[must_use]
265    pub fn roles(&self) -> Vec<String> {
266        self.get_metadata("auth")
267            .and_then(|auth| auth.get("roles"))
268            .and_then(|roles| roles.as_array())
269            .map(|roles| {
270                roles
271                    .iter()
272                    .filter_map(|role| role.as_str().map(ToString::to_string))
273                    .collect()
274            })
275            .unwrap_or_default()
276    }
277
278    /// Checks if the user has any of the specified roles.
279    /// Returns `true` if the required roles list is empty or if the user has at least one of the roles.
280    pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
281        if required.is_empty() {
282            return true; // Empty requirement always passes
283        }
284
285        let user_roles = self.roles();
286        required
287            .iter()
288            .any(|required_role| user_roles.contains(&required_role.as_ref().to_string()))
289    }
290
291    /// Gets the server-to-client requests interface.
292    ///
293    /// Returns `None` if not configured (e.g., for unidirectional transports).
294    /// This is hidden from public docs as it's an internal detail for use by server tools.
295    #[doc(hidden)]
296    pub fn server_to_client(&self) -> Option<&Arc<dyn ServerToClientRequests>> {
297        self.server_to_client.as_ref()
298    }
299
300    /// Returns all HTTP headers from the request, if available.
301    ///
302    /// Headers are automatically extracted by HTTP and WebSocket transports and stored
303    /// in the context metadata. Returns `None` if not using an HTTP-based transport
304    /// or if headers were not extracted.
305    ///
306    /// # Example
307    /// ```
308    /// # use turbomcp_protocol::RequestContext;
309    /// # let ctx = RequestContext::new();
310    /// if let Some(headers) = ctx.headers() {
311    ///     for (name, value) in headers.iter() {
312    ///         println!("{}: {}", name, value);
313    ///     }
314    /// }
315    /// ```
316    #[must_use]
317    pub fn headers(&self) -> Option<HashMap<String, String>> {
318        self.get_metadata("http_headers")
319            .and_then(|v| serde_json::from_value(v.clone()).ok())
320    }
321
322    /// Returns a specific HTTP header value by name (case-insensitive).
323    ///
324    /// This method performs case-insensitive header lookup, as per HTTP specification.
325    /// Returns `None` if the header is not present or if not using an HTTP-based transport.
326    ///
327    /// # Example
328    /// ```
329    /// # use turbomcp_protocol::RequestContext;
330    /// # let ctx = RequestContext::new();
331    /// if let Some(user_agent) = ctx.header("user-agent") {
332    ///     println!("User-Agent: {}", user_agent);
333    /// }
334    /// ```
335    #[must_use]
336    pub fn header(&self, name: &str) -> Option<String> {
337        let headers = self.headers()?;
338        let name_lower = name.to_lowercase();
339
340        // HTTP headers are case-insensitive, so we need to search with lowercase comparison
341        headers
342            .iter()
343            .find(|(key, _)| key.to_lowercase() == name_lower)
344            .map(|(_, value)| value.clone())
345    }
346
347    /// Returns the transport type used for this request.
348    ///
349    /// Common transport types include: "http", "websocket", "stdio", "tcp", "unix".
350    /// Returns `None` if transport metadata is not set.
351    ///
352    /// # Example
353    /// ```
354    /// # use turbomcp_protocol::RequestContext;
355    /// # let ctx = RequestContext::new();
356    /// if let Some(transport) = ctx.transport() {
357    ///     println!("Request received via: {}", transport);
358    /// }
359    /// ```
360    #[must_use]
361    pub fn transport(&self) -> Option<String> {
362        self.get_metadata("transport")
363            .and_then(|v| v.as_str())
364            .map(|s| s.to_string())
365    }
366}
367
368impl Default for RequestContext {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374impl ResponseContext {
375    /// Creates a new `ResponseContext` for a successful operation.
376    pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
377        Self {
378            request_id: request_id.into(),
379            timestamp: Timestamp::now(),
380            duration,
381            status: ResponseStatus::Success,
382            metadata: Arc::new(HashMap::new()),
383        }
384    }
385
386    /// Creates a new `ResponseContext` for a failed operation.
387    pub fn error(
388        request_id: impl Into<String>,
389        duration: std::time::Duration,
390        code: i32,
391        message: impl Into<String>,
392    ) -> Self {
393        Self {
394            request_id: request_id.into(),
395            timestamp: Timestamp::now(),
396            duration,
397            status: ResponseStatus::Error {
398                code,
399                message: message.into(),
400            },
401            metadata: Arc::new(HashMap::new()),
402        }
403    }
404}
405
406impl RequestInfo {
407    /// Creates a new `RequestInfo` for analytics.
408    #[must_use]
409    pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
410        Self {
411            timestamp: Utc::now(),
412            client_id,
413            method_name,
414            parameters,
415            response_time_ms: None,
416            success: false,
417            error_message: None,
418            status_code: None,
419            metadata: HashMap::new(),
420        }
421    }
422
423    /// Marks the request as completed successfully and records the response time.
424    #[must_use]
425    pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
426        self.response_time_ms = Some(response_time_ms);
427        self.success = true;
428        self.status_code = Some(200);
429        self
430    }
431
432    /// Marks the request as failed and records the response time and error message.
433    #[must_use]
434    pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
435        self.response_time_ms = Some(response_time_ms);
436        self.success = false;
437        self.error_message = Some(error);
438        self.status_code = Some(500);
439        self
440    }
441
442    /// Sets the HTTP status code for this request.
443    #[must_use]
444    pub const fn with_status_code(mut self, code: u16) -> Self {
445        self.status_code = Some(code);
446        self
447    }
448
449    /// Adds a key-value pair to the analytics metadata.
450    #[must_use]
451    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
452        self.metadata.insert(key, value);
453        self
454    }
455}
456
457/// An extension trait for `RequestContext` providing enhanced client ID handling.
458pub trait RequestContextExt {
459    /// Sets the client ID using the structured `ClientId` enum, which includes the method of identification.
460    #[must_use]
461    fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self;
462
463    /// Extracts a client ID from headers or query parameters and sets it on the context.
464    #[must_use]
465    fn extract_client_id(
466        self,
467        extractor: &super::client::ClientIdExtractor,
468        headers: Option<&HashMap<String, String>>,
469        query_params: Option<&HashMap<String, String>>,
470    ) -> Self;
471
472    /// Gets the structured `ClientId` enum from the context, if available.
473    fn get_enhanced_client_id(&self) -> Option<super::client::ClientId>;
474}
475
476impl RequestContextExt for RequestContext {
477    fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self {
478        self.with_client_id(client_id.as_str())
479            .with_metadata(
480                "client_id_method".to_string(),
481                serde_json::Value::String(client_id.auth_method().to_string()),
482            )
483            .with_metadata(
484                "client_authenticated".to_string(),
485                serde_json::Value::Bool(client_id.is_authenticated()),
486            )
487    }
488
489    fn extract_client_id(
490        self,
491        extractor: &super::client::ClientIdExtractor,
492        headers: Option<&HashMap<String, String>>,
493        query_params: Option<&HashMap<String, String>>,
494    ) -> Self {
495        let client_id = extractor.extract_client_id(headers, query_params);
496        self.with_enhanced_client_id(client_id)
497    }
498
499    fn get_enhanced_client_id(&self) -> Option<super::client::ClientId> {
500        self.client_id.as_ref().map(|id| {
501            let method = self
502                .get_metadata("client_id_method")
503                .and_then(|v| v.as_str())
504                .unwrap_or("header");
505
506            match method {
507                "bearer_token" => super::client::ClientId::Token(id.clone()),
508                "session_cookie" => super::client::ClientId::Session(id.clone()),
509                "query_param" => super::client::ClientId::QueryParam(id.clone()),
510                "user_agent" => super::client::ClientId::UserAgent(id.clone()),
511                "anonymous" => super::client::ClientId::Anonymous,
512                _ => super::client::ClientId::Header(id.clone()), // Default to header for "header" and unknown methods
513            }
514        })
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_headers_returns_none_when_not_set() {
524        let ctx = RequestContext::new();
525        assert!(ctx.headers().is_none());
526    }
527
528    #[test]
529    fn test_headers_returns_headers_when_set() {
530        let mut headers_map = HashMap::new();
531        headers_map.insert("user-agent".to_string(), "Test-Agent/1.0".to_string());
532        headers_map.insert("content-type".to_string(), "application/json".to_string());
533
534        let headers_json = serde_json::to_value(&headers_map).unwrap();
535        let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
536
537        let headers = ctx.headers();
538        assert!(headers.is_some());
539
540        let headers = headers.unwrap();
541        assert_eq!(headers.len(), 2);
542        assert_eq!(
543            headers.get("user-agent"),
544            Some(&"Test-Agent/1.0".to_string())
545        );
546        assert_eq!(
547            headers.get("content-type"),
548            Some(&"application/json".to_string())
549        );
550    }
551
552    #[test]
553    fn test_header_case_insensitive_lookup() {
554        let mut headers_map = HashMap::new();
555        headers_map.insert("User-Agent".to_string(), "Test-Agent/1.0".to_string());
556        headers_map.insert("Content-Type".to_string(), "application/json".to_string());
557
558        let headers_json = serde_json::to_value(&headers_map).unwrap();
559        let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
560
561        // Test case-insensitive lookup
562        assert_eq!(ctx.header("user-agent"), Some("Test-Agent/1.0".to_string()));
563        assert_eq!(ctx.header("USER-AGENT"), Some("Test-Agent/1.0".to_string()));
564        assert_eq!(ctx.header("User-Agent"), Some("Test-Agent/1.0".to_string()));
565        assert_eq!(
566            ctx.header("content-type"),
567            Some("application/json".to_string())
568        );
569        assert_eq!(
570            ctx.header("CONTENT-TYPE"),
571            Some("application/json".to_string())
572        );
573    }
574
575    #[test]
576    fn test_header_returns_none_when_not_found() {
577        let mut headers_map = HashMap::new();
578        headers_map.insert("user-agent".to_string(), "Test-Agent/1.0".to_string());
579
580        let headers_json = serde_json::to_value(&headers_map).unwrap();
581        let ctx = RequestContext::new().with_metadata("http_headers", headers_json);
582
583        assert_eq!(ctx.header("x-custom-header"), None);
584    }
585
586    #[test]
587    fn test_header_returns_none_when_headers_not_set() {
588        let ctx = RequestContext::new();
589        assert_eq!(ctx.header("user-agent"), None);
590    }
591
592    #[test]
593    fn test_transport_returns_none_when_not_set() {
594        let ctx = RequestContext::new();
595        assert!(ctx.transport().is_none());
596    }
597
598    #[test]
599    fn test_transport_returns_transport_type() {
600        let ctx = RequestContext::new().with_metadata("transport", "http");
601
602        assert_eq!(ctx.transport(), Some("http".to_string()));
603    }
604
605    #[test]
606    fn test_multiple_transport_types() {
607        let http_ctx = RequestContext::new().with_metadata("transport", "http");
608        assert_eq!(http_ctx.transport(), Some("http".to_string()));
609
610        let ws_ctx = RequestContext::new().with_metadata("transport", "websocket");
611        assert_eq!(ws_ctx.transport(), Some("websocket".to_string()));
612
613        let stdio_ctx = RequestContext::new().with_metadata("transport", "stdio");
614        assert_eq!(stdio_ctx.transport(), Some("stdio".to_string()));
615    }
616}