turbomcp_protocol/context/
request.rs

1//! Request and response context types for MCP request handling.
2//!
3//! This module contains the core context types used throughout the MCP protocol
4//! implementation for tracking request metadata, response information, and analytics.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9use std::time::Instant;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use super::capabilities::ServerToClientRequests;
17use crate::types::Timestamp;
18
19/// Context information for a single MCP request, carried through its entire lifecycle.
20///
21/// This struct contains essential metadata for processing, logging, and tracing a request,
22/// including unique identifiers, authentication information, and mechanisms for
23/// cancellation and server-initiated communication.
24#[derive(Clone)]
25pub struct RequestContext {
26    /// A unique identifier for the request, typically a UUID.
27    pub request_id: String,
28
29    /// The identifier for the user making the request, if authenticated.
30    pub user_id: Option<String>,
31
32    /// The identifier for the session to which this request belongs.
33    pub session_id: Option<String>,
34
35    /// The identifier for the client application making the request.
36    pub client_id: Option<String>,
37
38    /// The timestamp when the request was received.
39    pub timestamp: Timestamp,
40
41    /// The `Instant` when request processing started, used for performance tracking.
42    pub start_time: Instant,
43
44    /// A collection of custom metadata for application-specific use cases.
45    pub metadata: Arc<HashMap<String, serde_json::Value>>,
46
47    /// The tracing span associated with this request for observability.
48    #[cfg(feature = "tracing")]
49    pub span: Option<tracing::Span>,
50
51    /// A token that can be used to signal cancellation of the request.
52    pub cancellation_token: Option<Arc<CancellationToken>>,
53
54    /// An interface for making server-initiated requests back to the client (e.g., sampling, elicitation).
55    /// This is hidden from public docs as it's an internal detail injected by the server.
56    #[doc(hidden)]
57    pub(crate) server_to_client: Option<Arc<dyn ServerToClientRequests>>,
58}
59
60impl fmt::Debug for RequestContext {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("RequestContext")
63            .field("request_id", &self.request_id)
64            .field("user_id", &self.user_id)
65            .field("session_id", &self.session_id)
66            .field("client_id", &self.client_id)
67            .field("timestamp", &self.timestamp)
68            .field("metadata", &self.metadata)
69            .field("server_to_client", &self.server_to_client.is_some())
70            .finish()
71    }
72}
73
74/// Context information generated after processing a request, containing response details.
75#[derive(Debug, Clone)]
76pub struct ResponseContext {
77    /// The ID of the original request this response is for.
78    pub request_id: String,
79
80    /// The timestamp when the response was generated.
81    pub timestamp: Timestamp,
82
83    /// The total time taken to process the request.
84    pub duration: std::time::Duration,
85
86    /// The status of the response (e.g., Success, Error).
87    pub status: ResponseStatus,
88
89    /// A collection of custom metadata for the response.
90    pub metadata: Arc<HashMap<String, serde_json::Value>>,
91}
92
93/// Represents the status of an MCP response.
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum ResponseStatus {
96    /// The request was processed successfully.
97    Success,
98    /// An error occurred during request processing.
99    Error {
100        /// A numeric code indicating the error type.
101        code: i32,
102        /// A human-readable message describing the error.
103        message: String,
104    },
105    /// The response is partial, indicating more data will follow (for streaming).
106    Partial,
107    /// The request was cancelled before completion.
108    Cancelled,
109}
110
111/// Contains analytics information for a single request, used for monitoring and debugging.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RequestInfo {
114    /// The timestamp when the request was received.
115    pub timestamp: DateTime<Utc>,
116    /// The identifier of the client that made the request.
117    pub client_id: String,
118    /// The name of the tool or method that was called.
119    pub method_name: String,
120    /// The parameters provided in the request, potentially sanitized for privacy.
121    pub parameters: serde_json::Value,
122    /// The total time taken to generate a response, in milliseconds.
123    pub response_time_ms: Option<u64>,
124    /// A boolean indicating whether the request was successful.
125    pub success: bool,
126    /// The error message, if the request failed.
127    pub error_message: Option<String>,
128    /// The HTTP status code, if the request was handled over HTTP.
129    pub status_code: Option<u16>,
130    /// Additional custom metadata for analytics.
131    pub metadata: HashMap<String, serde_json::Value>,
132}
133
134impl RequestContext {
135    /// Creates a new `RequestContext` with a generated UUIDv4 as the request ID.
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            request_id: Uuid::new_v4().to_string(),
140            user_id: None,
141            session_id: None,
142            client_id: None,
143            timestamp: Timestamp::now(),
144            start_time: Instant::now(),
145            metadata: Arc::new(HashMap::new()),
146            #[cfg(feature = "tracing")]
147            span: None,
148            cancellation_token: None,
149            server_to_client: None,
150        }
151    }
152
153    /// Creates a new `RequestContext` with a specific request ID.
154    pub fn with_id(id: impl Into<String>) -> Self {
155        Self {
156            request_id: id.into(),
157            ..Self::new()
158        }
159    }
160
161    /// Sets the user ID for this context, returning the modified context.
162    ///
163    /// # Example
164    /// ```
165    /// # use turbomcp_protocol::context::RequestContext;
166    /// let ctx = RequestContext::new().with_user_id("user-123");
167    /// assert_eq!(ctx.user_id, Some("user-123".to_string()));
168    /// ```
169    #[must_use]
170    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
171        self.user_id = Some(user_id.into());
172        self
173    }
174
175    /// Sets the session ID for this context, returning the modified context.
176    #[must_use]
177    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
178        self.session_id = Some(session_id.into());
179        self
180    }
181
182    /// Sets the client ID for this context, returning the modified context.
183    #[must_use]
184    pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
185        self.client_id = Some(client_id.into());
186        self
187    }
188
189    /// Adds a key-value pair to the metadata, returning the modified context.
190    ///
191    /// # Example
192    /// ```
193    /// # use turbomcp_protocol::context::RequestContext;
194    /// # use serde_json::json;
195    /// let ctx = RequestContext::new().with_metadata("tenant", json!("acme-corp"));
196    /// assert_eq!(ctx.get_metadata("tenant"), Some(&json!("acme-corp")));
197    /// ```
198    #[must_use]
199    pub fn with_metadata(
200        mut self,
201        key: impl Into<String>,
202        value: impl Into<serde_json::Value>,
203    ) -> Self {
204        Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
205        self
206    }
207
208    /// Retrieves a value from the metadata by key.
209    #[must_use]
210    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
211        self.metadata.get(key)
212    }
213
214    /// Returns the elapsed time since the request processing started.
215    #[must_use]
216    pub fn elapsed(&self) -> std::time::Duration {
217        self.start_time.elapsed()
218    }
219
220    /// Checks if the request has been marked for cancellation.
221    #[must_use]
222    pub fn is_cancelled(&self) -> bool {
223        self.cancellation_token
224            .as_ref()
225            .is_some_and(|token| token.is_cancelled())
226    }
227
228    /// Sets the server-to-client requests interface for this context.
229    ///
230    /// This enables tools to make server-initiated requests (sampling, elicitation, roots)
231    /// with full context propagation for tracing and attribution. This is typically called
232    /// by the server implementation.
233    #[must_use]
234    pub fn with_server_to_client(mut self, capabilities: Arc<dyn ServerToClientRequests>) -> Self {
235        self.server_to_client = Some(capabilities);
236        self
237    }
238
239    /// Sets the cancellation token for cooperative cancellation.
240    /// This is typically called by the server implementation.
241    #[must_use]
242    pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
243        self.cancellation_token = Some(token);
244        self
245    }
246
247    /// Returns the user ID from the request context, if available.
248    #[must_use]
249    pub fn user(&self) -> Option<&str> {
250        self.user_id.as_deref()
251    }
252
253    /// Checks if the request is from an authenticated client.
254    /// This is determined by metadata set during the authentication process.
255    #[must_use]
256    pub fn is_authenticated(&self) -> bool {
257        self.get_metadata("client_authenticated")
258            .and_then(|v| v.as_bool())
259            .unwrap_or(false)
260    }
261
262    /// Returns the user roles from the request context, if available.
263    /// Roles are typically populated from an authentication token.
264    #[must_use]
265    pub fn roles(&self) -> Vec<String> {
266        self.get_metadata("auth")
267            .and_then(|auth| auth.get("roles"))
268            .and_then(|roles| roles.as_array())
269            .map(|roles| {
270                roles
271                    .iter()
272                    .filter_map(|role| role.as_str().map(ToString::to_string))
273                    .collect()
274            })
275            .unwrap_or_default()
276    }
277
278    /// Checks if the user has any of the specified roles.
279    /// Returns `true` if the required roles list is empty or if the user has at least one of the roles.
280    pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
281        if required.is_empty() {
282            return true; // Empty requirement always passes
283        }
284
285        let user_roles = self.roles();
286        required
287            .iter()
288            .any(|required_role| user_roles.contains(&required_role.as_ref().to_string()))
289    }
290
291    /// Gets the server-to-client requests interface.
292    ///
293    /// Returns `None` if not configured (e.g., for unidirectional transports).
294    /// This is hidden from public docs as it's an internal detail for use by server tools.
295    #[doc(hidden)]
296    pub fn server_to_client(&self) -> Option<&Arc<dyn ServerToClientRequests>> {
297        self.server_to_client.as_ref()
298    }
299}
300
301impl Default for RequestContext {
302    fn default() -> Self {
303        Self::new()
304    }
305}
306
307impl ResponseContext {
308    /// Creates a new `ResponseContext` for a successful operation.
309    pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
310        Self {
311            request_id: request_id.into(),
312            timestamp: Timestamp::now(),
313            duration,
314            status: ResponseStatus::Success,
315            metadata: Arc::new(HashMap::new()),
316        }
317    }
318
319    /// Creates a new `ResponseContext` for a failed operation.
320    pub fn error(
321        request_id: impl Into<String>,
322        duration: std::time::Duration,
323        code: i32,
324        message: impl Into<String>,
325    ) -> Self {
326        Self {
327            request_id: request_id.into(),
328            timestamp: Timestamp::now(),
329            duration,
330            status: ResponseStatus::Error {
331                code,
332                message: message.into(),
333            },
334            metadata: Arc::new(HashMap::new()),
335        }
336    }
337}
338
339impl RequestInfo {
340    /// Creates a new `RequestInfo` for analytics.
341    #[must_use]
342    pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
343        Self {
344            timestamp: Utc::now(),
345            client_id,
346            method_name,
347            parameters,
348            response_time_ms: None,
349            success: false,
350            error_message: None,
351            status_code: None,
352            metadata: HashMap::new(),
353        }
354    }
355
356    /// Marks the request as completed successfully and records the response time.
357    #[must_use]
358    pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
359        self.response_time_ms = Some(response_time_ms);
360        self.success = true;
361        self.status_code = Some(200);
362        self
363    }
364
365    /// Marks the request as failed and records the response time and error message.
366    #[must_use]
367    pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
368        self.response_time_ms = Some(response_time_ms);
369        self.success = false;
370        self.error_message = Some(error);
371        self.status_code = Some(500);
372        self
373    }
374
375    /// Sets the HTTP status code for this request.
376    #[must_use]
377    pub const fn with_status_code(mut self, code: u16) -> Self {
378        self.status_code = Some(code);
379        self
380    }
381
382    /// Adds a key-value pair to the analytics metadata.
383    #[must_use]
384    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
385        self.metadata.insert(key, value);
386        self
387    }
388}
389
390/// An extension trait for `RequestContext` providing enhanced client ID handling.
391pub trait RequestContextExt {
392    /// Sets the client ID using the structured `ClientId` enum, which includes the method of identification.
393    #[must_use]
394    fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self;
395
396    /// Extracts a client ID from headers or query parameters and sets it on the context.
397    #[must_use]
398    fn extract_client_id(
399        self,
400        extractor: &super::client::ClientIdExtractor,
401        headers: Option<&HashMap<String, String>>,
402        query_params: Option<&HashMap<String, String>>,
403    ) -> Self;
404
405    /// Gets the structured `ClientId` enum from the context, if available.
406    fn get_enhanced_client_id(&self) -> Option<super::client::ClientId>;
407}
408
409impl RequestContextExt for RequestContext {
410    fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self {
411        self.with_client_id(client_id.as_str())
412            .with_metadata(
413                "client_id_method".to_string(),
414                serde_json::Value::String(client_id.auth_method().to_string()),
415            )
416            .with_metadata(
417                "client_authenticated".to_string(),
418                serde_json::Value::Bool(client_id.is_authenticated()),
419            )
420    }
421
422    fn extract_client_id(
423        self,
424        extractor: &super::client::ClientIdExtractor,
425        headers: Option<&HashMap<String, String>>,
426        query_params: Option<&HashMap<String, String>>,
427    ) -> Self {
428        let client_id = extractor.extract_client_id(headers, query_params);
429        self.with_enhanced_client_id(client_id)
430    }
431
432    fn get_enhanced_client_id(&self) -> Option<super::client::ClientId> {
433        self.client_id.as_ref().map(|id| {
434            let method = self
435                .get_metadata("client_id_method")
436                .and_then(|v| v.as_str())
437                .unwrap_or("header");
438
439            match method {
440                "bearer_token" => super::client::ClientId::Token(id.clone()),
441                "session_cookie" => super::client::ClientId::Session(id.clone()),
442                "query_param" => super::client::ClientId::QueryParam(id.clone()),
443                "user_agent" => super::client::ClientId::UserAgent(id.clone()),
444                "anonymous" => super::client::ClientId::Anonymous,
445                _ => super::client::ClientId::Header(id.clone()), // Default to header for "header" and unknown methods
446            }
447        })
448    }
449}