1use alloc::string::{String, ToString};
20use alloc::sync::Arc;
21use alloc::vec::Vec;
22
23use hashbrown::HashMap as HashbrownMap;
24use serde_json::Value;
25
26use crate::auth::Principal;
27use crate::error::{McpError, McpResult};
28use crate::session::McpSession;
29
30#[cfg(feature = "std")]
31use crate::session::Cancellable;
32
33#[cfg(feature = "std")]
34use std::time::Instant;
35
36use turbomcp_types::{ClientCapabilities, CreateMessageRequest, CreateMessageResult, ElicitResult};
37
38#[derive(
45 Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize,
46)]
47#[serde(rename_all = "lowercase")]
48#[non_exhaustive]
49pub enum TransportType {
50 #[default]
52 Stdio,
53 Http,
55 WebSocket,
57 Tcp,
59 Unix,
61 Wasm,
63 Channel,
65 Unknown,
67}
68
69impl TransportType {
70 #[inline]
72 pub fn is_network(&self) -> bool {
73 matches!(self, Self::Http | Self::WebSocket | Self::Tcp)
74 }
75
76 #[inline]
78 pub fn is_local(&self) -> bool {
79 matches!(self, Self::Stdio | Self::Unix | Self::Channel)
80 }
81
82 pub fn as_str(&self) -> &'static str {
84 match self {
85 Self::Stdio => "stdio",
86 Self::Http => "http",
87 Self::WebSocket => "websocket",
88 Self::Tcp => "tcp",
89 Self::Unix => "unix",
90 Self::Wasm => "wasm",
91 Self::Channel => "channel",
92 Self::Unknown => "unknown",
93 }
94 }
95}
96
97impl core::fmt::Display for TransportType {
98 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
99 write!(f, "{}", self.as_str())
100 }
101}
102
103#[derive(Debug, Clone, Default)]
115pub struct RequestContext {
116 pub request_id: String,
118
119 pub transport: TransportType,
121
122 pub user_id: Option<String>,
124
125 pub session_id: Option<String>,
128
129 pub client_id: Option<String>,
131
132 pub metadata: HashbrownMap<String, Value>,
134
135 pub principal: Option<Principal>,
137
138 pub session: Option<Arc<dyn McpSession>>,
144
145 pub headers: Option<HashbrownMap<String, String>>,
150
151 #[cfg(feature = "std")]
155 pub start_time: Option<Instant>,
156
157 #[cfg(feature = "std")]
163 pub cancellation_token: Option<Arc<dyn Cancellable>>,
164}
165
166impl RequestContext {
171 pub fn new() -> Self {
176 #[cfg(feature = "std")]
177 {
178 Self {
179 request_id: uuid::Uuid::new_v4().to_string(),
180 ..Default::default()
181 }
182 }
183 #[cfg(not(feature = "std"))]
184 {
185 Self::default()
186 }
187 }
188
189 pub fn with_id_and_transport(request_id: impl Into<String>, transport: TransportType) -> Self {
191 Self {
192 request_id: request_id.into(),
193 transport,
194 ..Default::default()
195 }
196 }
197
198 pub fn with_id(request_id: impl Into<String>) -> Self {
200 Self {
201 request_id: request_id.into(),
202 ..Default::default()
203 }
204 }
205
206 #[inline]
208 pub fn stdio() -> Self {
209 Self::new().with_transport(TransportType::Stdio)
210 }
211
212 #[inline]
214 pub fn http() -> Self {
215 Self::new().with_transport(TransportType::Http)
216 }
217
218 #[inline]
220 pub fn websocket() -> Self {
221 Self::new().with_transport(TransportType::WebSocket)
222 }
223
224 #[inline]
226 pub fn tcp() -> Self {
227 Self::new().with_transport(TransportType::Tcp)
228 }
229
230 #[inline]
232 pub fn unix() -> Self {
233 Self::new().with_transport(TransportType::Unix)
234 }
235
236 #[inline]
238 pub fn wasm() -> Self {
239 Self::new().with_transport(TransportType::Wasm)
240 }
241
242 #[inline]
244 pub fn channel() -> Self {
245 Self::new().with_transport(TransportType::Channel)
246 }
247}
248
249impl RequestContext {
254 #[must_use]
256 pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
257 self.request_id = id.into();
258 self
259 }
260
261 #[must_use]
263 pub fn with_transport(mut self, transport: TransportType) -> Self {
264 self.transport = transport;
265 self
266 }
267
268 #[must_use]
270 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
271 self.user_id = Some(user_id.into());
272 self
273 }
274
275 #[must_use]
277 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
278 self.session_id = Some(session_id.into());
279 self
280 }
281
282 #[must_use]
284 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
285 self.client_id = Some(client_id.into());
286 self
287 }
288
289 #[must_use]
291 pub fn with_principal(mut self, principal: Principal) -> Self {
292 self.principal = Some(principal);
293 self
294 }
295
296 #[must_use]
301 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
302 self.metadata.insert(key.into(), value.into());
303 self
304 }
305
306 #[must_use]
308 pub fn with_session(mut self, session: Arc<dyn McpSession>) -> Self {
309 self.session = Some(session);
310 self
311 }
312
313 #[must_use]
318 pub fn with_headers(mut self, headers: HashbrownMap<String, String>) -> Self {
319 self.headers = Some(headers);
320 self
321 }
322
323 #[cfg(feature = "std")]
325 #[must_use]
326 pub fn with_start_time(mut self, start: Instant) -> Self {
327 self.start_time = Some(start);
328 self
329 }
330
331 #[cfg(feature = "std")]
333 #[must_use]
334 pub fn with_cancellation_token(mut self, token: Arc<dyn Cancellable>) -> Self {
335 self.cancellation_token = Some(token);
336 self
337 }
338}
339
340impl RequestContext {
345 pub fn insert_metadata(&mut self, key: impl Into<String>, value: impl Into<Value>) {
347 self.metadata.insert(key.into(), value.into());
348 }
349
350 pub fn set_principal(&mut self, principal: Principal) {
352 self.principal = Some(principal);
353 }
354
355 pub fn clear_principal(&mut self) {
357 self.principal = None;
358 }
359
360 pub fn set_session(&mut self, session: Arc<dyn McpSession>) {
362 self.session = Some(session);
363 }
364}
365
366impl RequestContext {
371 #[inline]
373 pub fn request_id(&self) -> &str {
374 &self.request_id
375 }
376
377 #[inline]
379 pub fn has_request_id(&self) -> bool {
380 !self.request_id.is_empty()
381 }
382
383 #[inline]
385 pub fn transport(&self) -> TransportType {
386 self.transport
387 }
388
389 #[inline]
391 pub fn user_id(&self) -> Option<&str> {
392 self.user_id.as_deref()
393 }
394
395 #[inline]
397 pub fn session_id(&self) -> Option<&str> {
398 self.session_id.as_deref()
399 }
400
401 #[inline]
403 pub fn client_id(&self) -> Option<&str> {
404 self.client_id.as_deref()
405 }
406
407 #[inline]
409 pub fn get_metadata(&self, key: &str) -> Option<&Value> {
410 self.metadata.get(key)
411 }
412
413 pub fn get_metadata_str(&self, key: &str) -> Option<&str> {
415 self.metadata.get(key).and_then(|v| v.as_str())
416 }
417
418 #[inline]
420 pub fn has_metadata(&self, key: &str) -> bool {
421 self.metadata.contains_key(key)
422 }
423
424 #[inline]
426 pub fn principal(&self) -> Option<&Principal> {
427 self.principal.as_ref()
428 }
429
430 pub fn is_authenticated(&self) -> bool {
436 self.principal.is_some() || self.user_id.is_some()
437 }
438
439 pub fn subject(&self) -> Option<&str> {
441 self.principal
442 .as_ref()
443 .map(|p| p.subject.as_str())
444 .or(self.user_id.as_deref())
445 }
446
447 #[inline]
449 pub fn session(&self) -> Option<&Arc<dyn McpSession>> {
450 self.session.as_ref()
451 }
452
453 #[inline]
455 pub fn has_session(&self) -> bool {
456 self.session.is_some()
457 }
458
459 #[inline]
461 pub fn headers(&self) -> Option<&HashbrownMap<String, String>> {
462 self.headers.as_ref()
463 }
464
465 pub fn header(&self, name: &str) -> Option<&str> {
467 let headers = self.headers.as_ref()?;
468 headers
469 .iter()
470 .find(|(k, _)| k.eq_ignore_ascii_case(name))
471 .map(|(_, v)| v.as_str())
472 }
473
474 #[cfg(feature = "std")]
476 pub fn elapsed(&self) -> Option<core::time::Duration> {
477 self.start_time.map(|t| t.elapsed())
478 }
479
480 #[cfg(feature = "std")]
482 pub fn is_cancelled(&self) -> bool {
483 self.cancellation_token
484 .as_ref()
485 .is_some_and(|c| c.is_cancelled())
486 }
487
488 pub fn roles(&self) -> Vec<String> {
492 if let Some(p) = &self.principal
493 && !p.roles.is_empty()
494 {
495 return p.roles.to_vec();
496 }
497
498 self.metadata
499 .get("auth")
500 .and_then(|auth| auth.get("roles"))
501 .and_then(|r| r.as_array())
502 .map(|arr| {
503 arr.iter()
504 .filter_map(|v| v.as_str().map(ToString::to_string))
505 .collect()
506 })
507 .unwrap_or_default()
508 }
509
510 pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
513 if required.is_empty() {
514 return true;
515 }
516 let roles = self.roles();
517 required
518 .iter()
519 .any(|need| roles.iter().any(|have| have == need.as_ref()))
520 }
521}
522
523impl RequestContext {
528 pub async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
533 let session = self.require_session("sampling/createMessage")?;
534 self.require_sampling_capability(session, &request).await?;
535 let params = serde_json::to_value(request).map_err(|e| {
536 McpError::invalid_params(alloc::format!("Failed to serialize sampling request: {e}"))
537 })?;
538 let result = session.call("sampling/createMessage", params).await?;
539 serde_json::from_value(result)
540 .map_err(|e| McpError::internal(alloc::format!("Failed to parse sampling result: {e}")))
541 }
542
543 pub async fn elicit_form(
545 &self,
546 message: impl Into<String>,
547 schema: Value,
548 ) -> McpResult<ElicitResult> {
549 let session = self.require_session("elicitation/create")?;
550 self.require_elicitation_capability(session, "form").await?;
551 let params = serde_json::json!({
552 "mode": "form",
553 "message": message.into(),
554 "requestedSchema": schema,
555 });
556 let result = session.call("elicitation/create", params).await?;
557 serde_json::from_value(result).map_err(|e| {
558 McpError::internal(alloc::format!("Failed to parse elicitation result: {e}"))
559 })
560 }
561
562 pub async fn elicit_url(
564 &self,
565 message: impl Into<String>,
566 url: impl Into<String>,
567 elicitation_id: impl Into<String>,
568 ) -> McpResult<ElicitResult> {
569 let session = self.require_session("elicitation/create")?;
570 self.require_elicitation_capability(session, "url").await?;
571 let params = serde_json::json!({
572 "mode": "url",
573 "message": message.into(),
574 "url": url.into(),
575 "elicitationId": elicitation_id.into(),
576 });
577 let result = session.call("elicitation/create", params).await?;
578 serde_json::from_value(result).map_err(|e| {
579 McpError::internal(alloc::format!("Failed to parse elicitation result: {e}"))
580 })
581 }
582
583 pub async fn notify_client(&self, method: impl AsRef<str>, params: Value) -> McpResult<()> {
585 let session = self.require_session(method.as_ref())?;
586 session.notify(method.as_ref(), params).await
587 }
588
589 fn require_session(&self, op: &str) -> McpResult<&Arc<dyn McpSession>> {
590 self.session.as_ref().ok_or_else(|| {
591 McpError::capability_not_supported(alloc::format!(
592 "Bidirectional session required for {op} but transport does not support it"
593 ))
594 })
595 }
596
597 async fn require_sampling_capability(
598 &self,
599 session: &Arc<dyn McpSession>,
600 request: &CreateMessageRequest,
601 ) -> McpResult<()> {
602 let Some(caps) = session.client_capabilities().await? else {
603 return Ok(());
604 };
605
606 let Some(sampling) = caps.sampling.as_ref() else {
607 return Err(McpError::capability_not_supported(
608 "client sampling capability required for sampling/createMessage",
609 ));
610 };
611
612 if (request.tools.is_some() || request.tool_choice.is_some()) && sampling.tools.is_none() {
613 return Err(McpError::capability_not_supported(
614 "client sampling.tools capability required for tool-enabled sampling/createMessage",
615 ));
616 }
617
618 if request.task.is_some() && !client_supports_task_sampling(&caps) {
619 return Err(McpError::capability_not_supported(
620 "client tasks.requests.sampling.createMessage capability required for task-augmented sampling/createMessage",
621 ));
622 }
623
624 Ok(())
625 }
626
627 async fn require_elicitation_capability(
628 &self,
629 session: &Arc<dyn McpSession>,
630 mode: &str,
631 ) -> McpResult<()> {
632 let Some(caps) = session.client_capabilities().await? else {
633 return Ok(());
634 };
635
636 let Some(elicitation) = caps.elicitation.as_ref() else {
637 return Err(McpError::capability_not_supported(
638 "client elicitation capability required for elicitation/create",
639 ));
640 };
641
642 let supported = match mode {
643 "form" => elicitation.supports_form(),
644 "url" => elicitation.supports_url(),
645 _ => false,
646 };
647
648 if supported {
649 Ok(())
650 } else {
651 Err(McpError::capability_not_supported(alloc::format!(
652 "client elicitation.{mode} capability required for elicitation/create"
653 )))
654 }
655 }
656}
657
658fn client_supports_task_sampling(caps: &ClientCapabilities) -> bool {
659 caps.tasks
660 .as_ref()
661 .and_then(|tasks| tasks.requests.as_ref())
662 .and_then(|requests| requests.sampling.as_ref())
663 .and_then(|sampling| sampling.create_message.as_ref())
664 .is_some()
665}
666
667#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_transport_type_display() {
677 assert_eq!(TransportType::Stdio.to_string(), "stdio");
678 assert_eq!(TransportType::Http.to_string(), "http");
679 assert_eq!(TransportType::WebSocket.to_string(), "websocket");
680 assert_eq!(TransportType::Tcp.to_string(), "tcp");
681 assert_eq!(TransportType::Unix.to_string(), "unix");
682 assert_eq!(TransportType::Wasm.to_string(), "wasm");
683 assert_eq!(TransportType::Channel.to_string(), "channel");
684 assert_eq!(TransportType::Unknown.to_string(), "unknown");
685 }
686
687 #[test]
688 fn test_transport_type_classification() {
689 assert!(TransportType::Http.is_network());
690 assert!(TransportType::WebSocket.is_network());
691 assert!(TransportType::Tcp.is_network());
692 assert!(!TransportType::Stdio.is_network());
693
694 assert!(TransportType::Stdio.is_local());
695 assert!(TransportType::Unix.is_local());
696 assert!(TransportType::Channel.is_local());
697 assert!(!TransportType::Http.is_local());
698 }
699
700 #[test]
701 fn test_request_context_new() {
702 let ctx = RequestContext::with_id_and_transport("test-123", TransportType::Http);
703 assert_eq!(ctx.request_id(), "test-123");
704 assert_eq!(ctx.transport(), TransportType::Http);
705 assert!(ctx.metadata.is_empty());
706 assert!(!ctx.has_session());
707 }
708
709 #[test]
710 fn test_request_context_factory_methods() {
711 assert_eq!(RequestContext::stdio().transport(), TransportType::Stdio);
712 assert_eq!(RequestContext::http().transport(), TransportType::Http);
713 assert_eq!(
714 RequestContext::websocket().transport(),
715 TransportType::WebSocket
716 );
717 assert_eq!(RequestContext::tcp().transport(), TransportType::Tcp);
718 assert_eq!(RequestContext::unix().transport(), TransportType::Unix);
719 assert_eq!(RequestContext::wasm().transport(), TransportType::Wasm);
720 assert_eq!(
721 RequestContext::channel().transport(),
722 TransportType::Channel
723 );
724 }
725
726 #[test]
727 fn test_request_context_metadata() {
728 let ctx = RequestContext::with_id_and_transport("1", TransportType::Http)
729 .with_metadata("key1", "value1")
730 .with_metadata("count", 42);
731
732 assert_eq!(ctx.get_metadata_str("key1"), Some("value1"));
733 assert_eq!(ctx.get_metadata("count"), Some(&serde_json::json!(42)));
734 assert_eq!(ctx.get_metadata("key3"), None);
735
736 assert!(ctx.has_metadata("key1"));
737 assert!(!ctx.has_metadata("key3"));
738 }
739
740 #[test]
741 fn test_request_context_ids() {
742 let ctx = RequestContext::with_id_and_transport("r", TransportType::Http)
743 .with_user_id("u")
744 .with_session_id("s")
745 .with_client_id("c");
746
747 assert_eq!(ctx.user_id(), Some("u"));
748 assert_eq!(ctx.session_id(), Some("s"));
749 assert_eq!(ctx.client_id(), Some("c"));
750 assert!(ctx.is_authenticated());
751 }
752
753 #[test]
754 fn test_request_context_principal() {
755 let ctx = RequestContext::with_id_and_transport("1", TransportType::Http);
756 assert!(!ctx.is_authenticated());
757 assert!(ctx.principal().is_none());
758 assert!(ctx.subject().is_none());
759
760 let principal = Principal::new("user-123")
761 .with_email("user@example.com")
762 .with_role("admin");
763
764 let ctx = ctx.with_principal(principal);
765 assert!(ctx.is_authenticated());
766 assert_eq!(ctx.subject(), Some("user-123"));
767 assert!(ctx.principal().unwrap().has_role("admin"));
768 assert_eq!(ctx.roles(), alloc::vec![String::from("admin")]);
769 assert!(ctx.has_any_role(&["admin"]));
770 assert!(!ctx.has_any_role(&["root"]));
771 }
772
773 #[test]
774 fn test_request_context_default() {
775 let ctx = RequestContext::default();
776 assert!(ctx.request_id.is_empty());
777 assert_eq!(ctx.transport, TransportType::Stdio);
778 assert!(ctx.metadata.is_empty());
779 assert!(!ctx.has_session());
780 }
781
782 #[test]
783 fn test_request_context_headers() {
784 let mut headers: HashbrownMap<String, String> = HashbrownMap::new();
785 headers.insert("User-Agent".into(), "Test/1.0".into());
786 let ctx =
787 RequestContext::with_id_and_transport("1", TransportType::Http).with_headers(headers);
788
789 assert_eq!(ctx.header("user-agent"), Some("Test/1.0"));
790 assert_eq!(ctx.header("USER-AGENT"), Some("Test/1.0"));
791 assert_eq!(ctx.header("missing"), None);
792 }
793
794 #[cfg(feature = "std")]
795 #[tokio::test]
796 async fn test_sampling_without_session_fails() {
797 use turbomcp_types::CreateMessageRequest;
798 let ctx = RequestContext::stdio();
799 let err = ctx
800 .sample(CreateMessageRequest::default())
801 .await
802 .unwrap_err();
803 assert_eq!(err.kind, crate::error::ErrorKind::CapabilityNotSupported);
804 }
805}