Skip to main content

turbomcp_server/
context.rs

1//! Request context for MCP handlers.
2//!
3//! This module provides a server-specific request context extending the core
4//! context with runtime features like cancellation, timing, and structured headers.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio_util::sync::CancellationToken;
11use turbomcp_core::error::McpResult;
12use turbomcp_types::{CreateMessageRequest, CreateMessageResult, ElicitResult};
13use uuid::Uuid;
14
15// Re-export TransportType from core for unified type system (DRY)
16pub use turbomcp_core::context::TransportType;
17
18/// Trait for bidirectional session communication.
19#[async_trait::async_trait]
20pub trait McpSession: Send + Sync + std::fmt::Debug {
21    /// Send a request to the client and wait for a response.
22    async fn call(&self, method: &str, params: serde_json::Value) -> McpResult<serde_json::Value>;
23    /// Send a notification to the client.
24    async fn notify(&self, method: &str, params: serde_json::Value) -> McpResult<()>;
25}
26
27/// Context information for an MCP request.
28#[derive(Debug, Clone)]
29pub struct RequestContext {
30    /// Unique request identifier
31    request_id: String,
32    /// Transport type used for this request
33    transport: TransportType,
34    /// Time when request processing started
35    start_time: Instant,
36    /// HTTP headers (if applicable)
37    headers: Option<HashMap<String, String>>,
38    /// User ID (if authenticated)
39    user_id: Option<String>,
40    /// Session ID
41    session_id: Option<String>,
42    /// Client ID
43    client_id: Option<String>,
44    /// Custom metadata
45    metadata: HashMap<String, serde_json::Value>,
46    /// Cancellation token for cooperative cancellation
47    cancellation_token: Option<Arc<CancellationToken>>,
48    /// Session handle for bidirectional communication
49    session: Option<Arc<dyn McpSession>>,
50}
51
52impl Default for RequestContext {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl RequestContext {
59    /// Create a new request context with a generated UUID.
60    #[must_use]
61    pub fn new() -> Self {
62        Self {
63            request_id: Uuid::new_v4().to_string(),
64            transport: TransportType::Stdio,
65            start_time: Instant::now(),
66            headers: None,
67            user_id: None,
68            session_id: None,
69            client_id: None,
70            metadata: HashMap::new(),
71            cancellation_token: None,
72            session: None,
73        }
74    }
75
76    /// Set the session handle for bidirectional communication.
77    #[must_use]
78    pub fn with_session(mut self, session: Arc<dyn McpSession>) -> Self {
79        self.session = Some(session);
80        self
81    }
82
83    /// Request user input via a form.
84    pub async fn elicit_form(
85        &self,
86        message: impl Into<String>,
87        schema: serde_json::Value,
88    ) -> McpResult<ElicitResult> {
89        let session = self.session.as_ref().ok_or_else(|| {
90            turbomcp_core::error::McpError::capability_not_supported(
91                "Server-to-client requests not available on this transport",
92            )
93        })?;
94
95        let params = serde_json::json!({
96            "mode": "form",
97            "message": message.into(),
98            "requestedSchema": schema
99        });
100
101        let result = session.call("elicitation/create", params).await?;
102        serde_json::from_value(result).map_err(|e| {
103            turbomcp_core::error::McpError::internal(format!(
104                "Failed to parse elicit result: {}",
105                e
106            ))
107        })
108    }
109
110    /// Request user action via a URL.
111    pub async fn elicit_url(
112        &self,
113        message: impl Into<String>,
114        url: impl Into<String>,
115        elicitation_id: impl Into<String>,
116    ) -> McpResult<ElicitResult> {
117        let session = self.session.as_ref().ok_or_else(|| {
118            turbomcp_core::error::McpError::capability_not_supported(
119                "Server-to-client requests not available on this transport",
120            )
121        })?;
122
123        let params = serde_json::json!({
124            "mode": "url",
125            "message": message.into(),
126            "url": url.into(),
127            "elicitationId": elicitation_id.into()
128        });
129
130        let result = session.call("elicitation/create", params).await?;
131        serde_json::from_value(result).map_err(|e| {
132            turbomcp_core::error::McpError::internal(format!(
133                "Failed to parse elicit result: {}",
134                e
135            ))
136        })
137    }
138
139    /// Request LLM sampling from the client.
140    pub async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
141        let session = self.session.as_ref().ok_or_else(|| {
142            turbomcp_core::error::McpError::capability_not_supported(
143                "Server-to-client requests not available on this transport",
144            )
145        })?;
146
147        let params = serde_json::to_value(request).map_err(|e| {
148            turbomcp_core::error::McpError::invalid_params(format!(
149                "Failed to serialize sampling request: {}",
150                e
151            ))
152        })?;
153
154        let result = session.call("sampling/createMessage", params).await?;
155        serde_json::from_value(result).map_err(|e| {
156            turbomcp_core::error::McpError::internal(format!(
157                "Failed to parse sampling result: {}",
158                e
159            ))
160        })
161    }
162
163    /// Create a new request context for STDIO transport.
164    #[must_use]
165    pub fn stdio() -> Self {
166        Self::new().with_transport(TransportType::Stdio)
167    }
168
169    /// Create a new request context for HTTP transport.
170    #[must_use]
171    pub fn http() -> Self {
172        Self::new().with_transport(TransportType::Http)
173    }
174
175    /// Create a new request context for WebSocket transport.
176    #[must_use]
177    pub fn websocket() -> Self {
178        Self::new().with_transport(TransportType::WebSocket)
179    }
180
181    /// Create a new request context for TCP transport.
182    #[must_use]
183    pub fn tcp() -> Self {
184        Self::new().with_transport(TransportType::Tcp)
185    }
186
187    /// Create a new request context for Unix socket transport.
188    #[must_use]
189    pub fn unix() -> Self {
190        Self::new().with_transport(TransportType::Unix)
191    }
192
193    /// Create a new request context for WASM transport.
194    #[must_use]
195    pub fn wasm() -> Self {
196        Self::new().with_transport(TransportType::Wasm)
197    }
198
199    /// Create a new request context for in-process channel transport.
200    #[must_use]
201    pub fn channel() -> Self {
202        Self::new().with_transport(TransportType::Channel)
203    }
204
205    /// Create a new request context with a specific request ID.
206    #[must_use]
207    pub fn with_id(id: impl Into<String>) -> Self {
208        Self {
209            request_id: id.into(),
210            ..Self::new()
211        }
212    }
213
214    /// Set the transport type.
215    #[must_use]
216    pub fn with_transport(mut self, transport: TransportType) -> Self {
217        self.transport = transport;
218        self
219    }
220
221    /// Set the HTTP headers.
222    #[must_use]
223    pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
224        self.headers = Some(headers);
225        self
226    }
227
228    /// Set the user ID.
229    #[must_use]
230    pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
231        self.user_id = Some(user_id.into());
232        self
233    }
234
235    /// Set the session ID.
236    #[must_use]
237    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
238        self.session_id = Some(session_id.into());
239        self
240    }
241
242    /// Set the client ID.
243    #[must_use]
244    pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
245        self.client_id = Some(client_id.into());
246        self
247    }
248
249    /// Add a metadata key-value pair.
250    #[must_use]
251    pub fn with_metadata(
252        mut self,
253        key: impl Into<String>,
254        value: impl Into<serde_json::Value>,
255    ) -> Self {
256        self.metadata.insert(key.into(), value.into());
257        self
258    }
259
260    /// Set the cancellation token.
261    #[must_use]
262    pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
263        self.cancellation_token = Some(token);
264        self
265    }
266
267    /// Get the request ID.
268    #[must_use]
269    pub fn request_id(&self) -> &str {
270        &self.request_id
271    }
272
273    /// Get the transport type.
274    #[must_use]
275    pub fn transport(&self) -> TransportType {
276        self.transport
277    }
278
279    /// Get all HTTP headers.
280    #[must_use]
281    pub fn headers(&self) -> Option<&HashMap<String, String>> {
282        self.headers.as_ref()
283    }
284
285    /// Get a specific HTTP header (case-insensitive).
286    #[must_use]
287    pub fn header(&self, name: &str) -> Option<&str> {
288        let headers = self.headers.as_ref()?;
289        let name_lower = name.to_lowercase();
290        headers
291            .iter()
292            .find(|(key, _)| key.to_lowercase() == name_lower)
293            .map(|(_, value)| value.as_str())
294    }
295
296    /// Get the user ID.
297    #[must_use]
298    pub fn user_id(&self) -> Option<&str> {
299        self.user_id.as_deref()
300    }
301
302    /// Get the session ID.
303    #[must_use]
304    pub fn session_id(&self) -> Option<&str> {
305        self.session_id.as_deref()
306    }
307
308    /// Get the client ID.
309    #[must_use]
310    pub fn client_id(&self) -> Option<&str> {
311        self.client_id.as_deref()
312    }
313
314    /// Get a metadata value.
315    #[must_use]
316    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
317        self.metadata.get(key)
318    }
319
320    /// Get the elapsed time since request processing started.
321    #[must_use]
322    pub fn elapsed(&self) -> std::time::Duration {
323        self.start_time.elapsed()
324    }
325
326    /// Check if the request has been cancelled.
327    #[must_use]
328    pub fn is_cancelled(&self) -> bool {
329        self.cancellation_token
330            .as_ref()
331            .is_some_and(|t| t.is_cancelled())
332    }
333
334    /// Check if the user is authenticated.
335    #[must_use]
336    pub fn is_authenticated(&self) -> bool {
337        self.user_id.is_some()
338    }
339
340    /// Convert to the core RequestContext type.
341    ///
342    /// This method creates a minimal context compatible with the unified
343    /// `turbomcp_core::McpHandler` trait. Headers and auth fields are
344    /// encoded as metadata with standard prefixes.
345    #[must_use]
346    pub fn to_core_context(&self) -> turbomcp_core::context::RequestContext {
347        // TransportType is re-exported from core, so no conversion needed
348        let mut core_ctx =
349            turbomcp_core::context::RequestContext::new(&self.request_id, self.transport);
350
351        // Copy headers as metadata with "header:" prefix
352        if let Some(ref headers) = self.headers {
353            for (key, value) in headers {
354                core_ctx.insert_metadata(format!("header:{key}"), value.clone());
355            }
356        }
357
358        // Copy auth/session fields as metadata
359        if let Some(ref user_id) = self.user_id {
360            core_ctx.insert_metadata("user_id", user_id.clone());
361        }
362        if let Some(ref session_id) = self.session_id {
363            core_ctx.insert_metadata("session_id", session_id.clone());
364        }
365        if let Some(ref client_id) = self.client_id {
366            core_ctx.insert_metadata("client_id", client_id.clone());
367        }
368
369        core_ctx
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_new_context() {
379        let ctx = RequestContext::new();
380        assert!(!ctx.request_id().is_empty());
381        assert_eq!(ctx.transport(), TransportType::Stdio);
382        assert!(!ctx.is_cancelled());
383    }
384
385    #[test]
386    fn test_with_id() {
387        let ctx = RequestContext::with_id("test-123");
388        assert_eq!(ctx.request_id(), "test-123");
389    }
390
391    #[test]
392    fn test_transport_types() {
393        let ctx = RequestContext::new().with_transport(TransportType::Http);
394        assert_eq!(ctx.transport(), TransportType::Http);
395        assert_eq!(ctx.transport().as_str(), "http");
396    }
397
398    #[test]
399    fn test_headers() {
400        let mut headers = HashMap::new();
401        headers.insert("User-Agent".to_string(), "Test/1.0".to_string());
402        headers.insert("Content-Type".to_string(), "application/json".to_string());
403
404        let ctx = RequestContext::new().with_headers(headers);
405
406        assert!(ctx.headers().is_some());
407        // Case-insensitive lookup
408        assert_eq!(ctx.header("user-agent"), Some("Test/1.0"));
409        assert_eq!(ctx.header("USER-AGENT"), Some("Test/1.0"));
410        assert_eq!(ctx.header("content-type"), Some("application/json"));
411        assert_eq!(ctx.header("x-custom"), None);
412    }
413
414    #[test]
415    fn test_user_id() {
416        let ctx = RequestContext::new().with_user_id("user-123");
417        assert_eq!(ctx.user_id(), Some("user-123"));
418        assert!(ctx.is_authenticated());
419    }
420
421    #[test]
422    fn test_metadata() {
423        let ctx = RequestContext::new()
424            .with_metadata("key1", "value1")
425            .with_metadata("key2", serde_json::json!(42));
426
427        assert_eq!(
428            ctx.get_metadata("key1"),
429            Some(&serde_json::Value::String("value1".to_string()))
430        );
431        assert_eq!(ctx.get_metadata("key2"), Some(&serde_json::json!(42)));
432        assert_eq!(ctx.get_metadata("key3"), None);
433    }
434
435    #[test]
436    fn test_cancellation() {
437        let token = Arc::new(CancellationToken::new());
438        let ctx = RequestContext::new().with_cancellation_token(token.clone());
439
440        assert!(!ctx.is_cancelled());
441        token.cancel();
442        assert!(ctx.is_cancelled());
443    }
444}