1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio_util::sync::CancellationToken;
11use turbomcp_core::error::McpResult;
12use turbomcp_types::{CreateMessageRequest, CreateMessageResult, ElicitResult};
13use uuid::Uuid;
14
15pub use turbomcp_core::context::TransportType;
17
18#[async_trait::async_trait]
20pub trait McpSession: Send + Sync + std::fmt::Debug {
21 async fn call(&self, method: &str, params: serde_json::Value) -> McpResult<serde_json::Value>;
23 async fn notify(&self, method: &str, params: serde_json::Value) -> McpResult<()>;
25}
26
27#[derive(Debug, Clone)]
29pub struct RequestContext {
30 request_id: String,
32 transport: TransportType,
34 start_time: Instant,
36 headers: Option<HashMap<String, String>>,
38 user_id: Option<String>,
40 session_id: Option<String>,
42 client_id: Option<String>,
44 metadata: HashMap<String, serde_json::Value>,
46 cancellation_token: Option<Arc<CancellationToken>>,
48 session: Option<Arc<dyn McpSession>>,
50}
51
52impl Default for RequestContext {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl RequestContext {
59 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 request_id: Uuid::new_v4().to_string(),
64 transport: TransportType::Stdio,
65 start_time: Instant::now(),
66 headers: None,
67 user_id: None,
68 session_id: None,
69 client_id: None,
70 metadata: HashMap::new(),
71 cancellation_token: None,
72 session: None,
73 }
74 }
75
76 #[must_use]
78 pub fn with_session(mut self, session: Arc<dyn McpSession>) -> Self {
79 self.session = Some(session);
80 self
81 }
82
83 pub async fn elicit_form(
85 &self,
86 message: impl Into<String>,
87 schema: serde_json::Value,
88 ) -> McpResult<ElicitResult> {
89 let session = self.session.as_ref().ok_or_else(|| {
90 turbomcp_core::error::McpError::capability_not_supported(
91 "Server-to-client requests not available on this transport",
92 )
93 })?;
94
95 let params = serde_json::json!({
96 "mode": "form",
97 "message": message.into(),
98 "requestedSchema": schema
99 });
100
101 let result = session.call("elicitation/create", params).await?;
102 serde_json::from_value(result).map_err(|e| {
103 turbomcp_core::error::McpError::internal(format!(
104 "Failed to parse elicit result: {}",
105 e
106 ))
107 })
108 }
109
110 pub async fn elicit_url(
112 &self,
113 message: impl Into<String>,
114 url: impl Into<String>,
115 elicitation_id: impl Into<String>,
116 ) -> McpResult<ElicitResult> {
117 let session = self.session.as_ref().ok_or_else(|| {
118 turbomcp_core::error::McpError::capability_not_supported(
119 "Server-to-client requests not available on this transport",
120 )
121 })?;
122
123 let params = serde_json::json!({
124 "mode": "url",
125 "message": message.into(),
126 "url": url.into(),
127 "elicitationId": elicitation_id.into()
128 });
129
130 let result = session.call("elicitation/create", params).await?;
131 serde_json::from_value(result).map_err(|e| {
132 turbomcp_core::error::McpError::internal(format!(
133 "Failed to parse elicit result: {}",
134 e
135 ))
136 })
137 }
138
139 pub async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
141 let session = self.session.as_ref().ok_or_else(|| {
142 turbomcp_core::error::McpError::capability_not_supported(
143 "Server-to-client requests not available on this transport",
144 )
145 })?;
146
147 let params = serde_json::to_value(request).map_err(|e| {
148 turbomcp_core::error::McpError::invalid_params(format!(
149 "Failed to serialize sampling request: {}",
150 e
151 ))
152 })?;
153
154 let result = session.call("sampling/createMessage", params).await?;
155 serde_json::from_value(result).map_err(|e| {
156 turbomcp_core::error::McpError::internal(format!(
157 "Failed to parse sampling result: {}",
158 e
159 ))
160 })
161 }
162
163 #[must_use]
165 pub fn stdio() -> Self {
166 Self::new().with_transport(TransportType::Stdio)
167 }
168
169 #[must_use]
171 pub fn http() -> Self {
172 Self::new().with_transport(TransportType::Http)
173 }
174
175 #[must_use]
177 pub fn websocket() -> Self {
178 Self::new().with_transport(TransportType::WebSocket)
179 }
180
181 #[must_use]
183 pub fn tcp() -> Self {
184 Self::new().with_transport(TransportType::Tcp)
185 }
186
187 #[must_use]
189 pub fn unix() -> Self {
190 Self::new().with_transport(TransportType::Unix)
191 }
192
193 #[must_use]
195 pub fn wasm() -> Self {
196 Self::new().with_transport(TransportType::Wasm)
197 }
198
199 #[must_use]
201 pub fn channel() -> Self {
202 Self::new().with_transport(TransportType::Channel)
203 }
204
205 #[must_use]
207 pub fn with_id(id: impl Into<String>) -> Self {
208 Self {
209 request_id: id.into(),
210 ..Self::new()
211 }
212 }
213
214 #[must_use]
216 pub fn with_transport(mut self, transport: TransportType) -> Self {
217 self.transport = transport;
218 self
219 }
220
221 #[must_use]
223 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
224 self.headers = Some(headers);
225 self
226 }
227
228 #[must_use]
230 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
231 self.user_id = Some(user_id.into());
232 self
233 }
234
235 #[must_use]
237 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
238 self.session_id = Some(session_id.into());
239 self
240 }
241
242 #[must_use]
244 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
245 self.client_id = Some(client_id.into());
246 self
247 }
248
249 #[must_use]
251 pub fn with_metadata(
252 mut self,
253 key: impl Into<String>,
254 value: impl Into<serde_json::Value>,
255 ) -> Self {
256 self.metadata.insert(key.into(), value.into());
257 self
258 }
259
260 #[must_use]
262 pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
263 self.cancellation_token = Some(token);
264 self
265 }
266
267 #[must_use]
269 pub fn request_id(&self) -> &str {
270 &self.request_id
271 }
272
273 #[must_use]
275 pub fn transport(&self) -> TransportType {
276 self.transport
277 }
278
279 #[must_use]
281 pub fn headers(&self) -> Option<&HashMap<String, String>> {
282 self.headers.as_ref()
283 }
284
285 #[must_use]
287 pub fn header(&self, name: &str) -> Option<&str> {
288 let headers = self.headers.as_ref()?;
289 let name_lower = name.to_lowercase();
290 headers
291 .iter()
292 .find(|(key, _)| key.to_lowercase() == name_lower)
293 .map(|(_, value)| value.as_str())
294 }
295
296 #[must_use]
298 pub fn user_id(&self) -> Option<&str> {
299 self.user_id.as_deref()
300 }
301
302 #[must_use]
304 pub fn session_id(&self) -> Option<&str> {
305 self.session_id.as_deref()
306 }
307
308 #[must_use]
310 pub fn client_id(&self) -> Option<&str> {
311 self.client_id.as_deref()
312 }
313
314 #[must_use]
316 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
317 self.metadata.get(key)
318 }
319
320 #[must_use]
322 pub fn elapsed(&self) -> std::time::Duration {
323 self.start_time.elapsed()
324 }
325
326 #[must_use]
328 pub fn is_cancelled(&self) -> bool {
329 self.cancellation_token
330 .as_ref()
331 .is_some_and(|t| t.is_cancelled())
332 }
333
334 #[must_use]
336 pub fn is_authenticated(&self) -> bool {
337 self.user_id.is_some()
338 }
339
340 #[must_use]
346 pub fn to_core_context(&self) -> turbomcp_core::context::RequestContext {
347 let mut core_ctx =
349 turbomcp_core::context::RequestContext::new(&self.request_id, self.transport);
350
351 if let Some(ref headers) = self.headers {
353 for (key, value) in headers {
354 core_ctx.insert_metadata(format!("header:{key}"), value.clone());
355 }
356 }
357
358 if let Some(ref user_id) = self.user_id {
360 core_ctx.insert_metadata("user_id", user_id.clone());
361 }
362 if let Some(ref session_id) = self.session_id {
363 core_ctx.insert_metadata("session_id", session_id.clone());
364 }
365 if let Some(ref client_id) = self.client_id {
366 core_ctx.insert_metadata("client_id", client_id.clone());
367 }
368
369 core_ctx
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_new_context() {
379 let ctx = RequestContext::new();
380 assert!(!ctx.request_id().is_empty());
381 assert_eq!(ctx.transport(), TransportType::Stdio);
382 assert!(!ctx.is_cancelled());
383 }
384
385 #[test]
386 fn test_with_id() {
387 let ctx = RequestContext::with_id("test-123");
388 assert_eq!(ctx.request_id(), "test-123");
389 }
390
391 #[test]
392 fn test_transport_types() {
393 let ctx = RequestContext::new().with_transport(TransportType::Http);
394 assert_eq!(ctx.transport(), TransportType::Http);
395 assert_eq!(ctx.transport().as_str(), "http");
396 }
397
398 #[test]
399 fn test_headers() {
400 let mut headers = HashMap::new();
401 headers.insert("User-Agent".to_string(), "Test/1.0".to_string());
402 headers.insert("Content-Type".to_string(), "application/json".to_string());
403
404 let ctx = RequestContext::new().with_headers(headers);
405
406 assert!(ctx.headers().is_some());
407 assert_eq!(ctx.header("user-agent"), Some("Test/1.0"));
409 assert_eq!(ctx.header("USER-AGENT"), Some("Test/1.0"));
410 assert_eq!(ctx.header("content-type"), Some("application/json"));
411 assert_eq!(ctx.header("x-custom"), None);
412 }
413
414 #[test]
415 fn test_user_id() {
416 let ctx = RequestContext::new().with_user_id("user-123");
417 assert_eq!(ctx.user_id(), Some("user-123"));
418 assert!(ctx.is_authenticated());
419 }
420
421 #[test]
422 fn test_metadata() {
423 let ctx = RequestContext::new()
424 .with_metadata("key1", "value1")
425 .with_metadata("key2", serde_json::json!(42));
426
427 assert_eq!(
428 ctx.get_metadata("key1"),
429 Some(&serde_json::Value::String("value1".to_string()))
430 );
431 assert_eq!(ctx.get_metadata("key2"), Some(&serde_json::json!(42)));
432 assert_eq!(ctx.get_metadata("key3"), None);
433 }
434
435 #[test]
436 fn test_cancellation() {
437 let token = Arc::new(CancellationToken::new());
438 let ctx = RequestContext::new().with_cancellation_token(token.clone());
439
440 assert!(!ctx.is_cancelled());
441 token.cancel();
442 assert!(ctx.is_cancelled());
443 }
444}