1use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86pub enum ServerNotification {
87 Progress(ProgressParams),
89 LogMessage(LoggingMessageParams),
91 ResourceUpdated {
93 uri: String,
95 },
96 ResourcesListChanged,
98}
99
100pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108 mpsc::channel(buffer)
109}
110
111#[async_trait]
121pub trait ClientRequester: Send + Sync {
122 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139#[derive(Debug)]
141pub struct OutgoingRequest {
142 pub id: RequestId,
144 pub method: String,
146 pub params: serde_json::Value,
148 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160 mpsc::channel(buffer)
161}
162
163#[derive(Clone)]
165pub struct ChannelClientRequester {
166 request_tx: OutgoingRequestSender,
167 next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171 pub fn new(request_tx: OutgoingRequestSender) -> Self {
173 Self {
174 request_tx,
175 next_id: Arc::new(AtomicI64::new(1)),
176 }
177 }
178
179 fn next_request_id(&self) -> RequestId {
180 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181 RequestId::Number(id)
182 }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188 let id = self.next_request_id();
189 let params_json = serde_json::to_value(¶ms)
190 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194 let request = OutgoingRequest {
195 id: id.clone(),
196 method: "sampling/createMessage".to_string(),
197 params: params_json,
198 response_tx,
199 };
200
201 self.request_tx
202 .send(request)
203 .await
204 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206 let response = response_rx.await.map_err(|_| {
207 Error::Internal("Failed to receive response: channel closed".to_string())
208 })??;
209
210 serde_json::from_value(response)
211 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212 }
213
214 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215 let id = self.next_request_id();
216 let params_json = serde_json::to_value(¶ms)
217 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221 let request = OutgoingRequest {
222 id: id.clone(),
223 method: "elicitation/create".to_string(),
224 params: params_json,
225 response_tx,
226 };
227
228 self.request_tx
229 .send(request)
230 .await
231 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233 let response = response_rx.await.map_err(|_| {
234 Error::Internal("Failed to receive response: channel closed".to_string())
235 })??;
236
237 serde_json::from_value(response)
238 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239 }
240}
241
242#[derive(Clone)]
244pub struct RequestContext {
245 request_id: RequestId,
247 progress_token: Option<ProgressToken>,
249 cancelled: Arc<AtomicBool>,
251 notification_tx: Option<NotificationSender>,
253 client_requester: Option<ClientRequesterHandle>,
255 extensions: Arc<Extensions>,
257}
258
259#[derive(Clone, Default)]
264pub struct Extensions {
265 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
266}
267
268impl Extensions {
269 pub fn new() -> Self {
271 Self::default()
272 }
273
274 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
278 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
279 }
280
281 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
285 self.map
286 .get(&std::any::TypeId::of::<T>())
287 .and_then(|val| val.downcast_ref::<T>())
288 }
289
290 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
292 self.map.contains_key(&std::any::TypeId::of::<T>())
293 }
294
295 pub fn merge(&mut self, other: &Extensions) {
299 for (k, v) in &other.map {
300 self.map.insert(*k, v.clone());
301 }
302 }
303}
304
305impl std::fmt::Debug for Extensions {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 f.debug_struct("Extensions")
308 .field("len", &self.map.len())
309 .finish()
310 }
311}
312
313impl std::fmt::Debug for RequestContext {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("RequestContext")
316 .field("request_id", &self.request_id)
317 .field("progress_token", &self.progress_token)
318 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
319 .finish()
320 }
321}
322
323impl RequestContext {
324 pub fn new(request_id: RequestId) -> Self {
326 Self {
327 request_id,
328 progress_token: None,
329 cancelled: Arc::new(AtomicBool::new(false)),
330 notification_tx: None,
331 client_requester: None,
332 extensions: Arc::new(Extensions::new()),
333 }
334 }
335
336 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
338 self.progress_token = Some(token);
339 self
340 }
341
342 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
344 self.notification_tx = Some(tx);
345 self
346 }
347
348 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
350 self.client_requester = Some(requester);
351 self
352 }
353
354 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
358 self.extensions = extensions;
359 self
360 }
361
362 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
378 self.extensions.get::<T>()
379 }
380
381 pub fn extensions_mut(&mut self) -> &mut Extensions {
386 Arc::make_mut(&mut self.extensions)
387 }
388
389 pub fn extensions(&self) -> &Extensions {
391 &self.extensions
392 }
393
394 pub fn request_id(&self) -> &RequestId {
396 &self.request_id
397 }
398
399 pub fn progress_token(&self) -> Option<&ProgressToken> {
401 self.progress_token.as_ref()
402 }
403
404 pub fn is_cancelled(&self) -> bool {
406 self.cancelled.load(Ordering::Relaxed)
407 }
408
409 pub fn cancel(&self) {
411 self.cancelled.store(true, Ordering::Relaxed);
412 }
413
414 pub fn cancellation_token(&self) -> CancellationToken {
416 CancellationToken {
417 cancelled: self.cancelled.clone(),
418 }
419 }
420
421 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
425 let Some(token) = &self.progress_token else {
426 return;
427 };
428 let Some(tx) = &self.notification_tx else {
429 return;
430 };
431
432 let params = ProgressParams {
433 progress_token: token.clone(),
434 progress,
435 total,
436 message: message.map(|s| s.to_string()),
437 };
438
439 let _ = tx.try_send(ServerNotification::Progress(params));
441 }
442
443 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
447 let Some(token) = &self.progress_token else {
448 return;
449 };
450 let Some(tx) = &self.notification_tx else {
451 return;
452 };
453
454 let params = ProgressParams {
455 progress_token: token.clone(),
456 progress,
457 total,
458 message: message.map(|s| s.to_string()),
459 };
460
461 let _ = tx.try_send(ServerNotification::Progress(params));
462 }
463
464 pub fn send_log(&self, params: LoggingMessageParams) {
482 let Some(tx) = &self.notification_tx else {
483 return;
484 };
485
486 let _ = tx.try_send(ServerNotification::LogMessage(params));
487 }
488
489 pub fn can_sample(&self) -> bool {
494 self.client_requester.is_some()
495 }
496
497 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
521 let requester = self.client_requester.as_ref().ok_or_else(|| {
522 Error::Internal("Sampling not available: no client requester configured".to_string())
523 })?;
524
525 requester.sample(params).await
526 }
527
528 pub fn can_elicit(&self) -> bool {
534 self.client_requester.is_some()
535 }
536
537 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
569 let requester = self.client_requester.as_ref().ok_or_else(|| {
570 Error::Internal("Elicitation not available: no client requester configured".to_string())
571 })?;
572
573 requester.elicit(ElicitRequestParams::Form(params)).await
574 }
575
576 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
606 let requester = self.client_requester.as_ref().ok_or_else(|| {
607 Error::Internal("Elicitation not available: no client requester configured".to_string())
608 })?;
609
610 requester.elicit(ElicitRequestParams::Url(params)).await
611 }
612
613 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
637 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
638
639 let params = ElicitFormParams {
640 mode: ElicitMode::Form,
641 message: message.into(),
642 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
643 "confirm",
644 Some("Confirm this action"),
645 true,
646 false,
647 ),
648 meta: None,
649 };
650
651 let result = self.elicit_form(params).await?;
652 Ok(result.action == ElicitAction::Accept)
653 }
654}
655
656#[derive(Clone, Debug)]
658pub struct CancellationToken {
659 cancelled: Arc<AtomicBool>,
660}
661
662impl CancellationToken {
663 pub fn is_cancelled(&self) -> bool {
665 self.cancelled.load(Ordering::Relaxed)
666 }
667
668 pub fn cancel(&self) {
670 self.cancelled.store(true, Ordering::Relaxed);
671 }
672}
673
674#[derive(Default)]
676pub struct RequestContextBuilder {
677 request_id: Option<RequestId>,
678 progress_token: Option<ProgressToken>,
679 notification_tx: Option<NotificationSender>,
680 client_requester: Option<ClientRequesterHandle>,
681}
682
683impl RequestContextBuilder {
684 pub fn new() -> Self {
686 Self::default()
687 }
688
689 pub fn request_id(mut self, id: RequestId) -> Self {
691 self.request_id = Some(id);
692 self
693 }
694
695 pub fn progress_token(mut self, token: ProgressToken) -> Self {
697 self.progress_token = Some(token);
698 self
699 }
700
701 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
703 self.notification_tx = Some(tx);
704 self
705 }
706
707 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
709 self.client_requester = Some(requester);
710 self
711 }
712
713 pub fn build(self) -> RequestContext {
717 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
718 if let Some(token) = self.progress_token {
719 ctx = ctx.with_progress_token(token);
720 }
721 if let Some(tx) = self.notification_tx {
722 ctx = ctx.with_notification_sender(tx);
723 }
724 if let Some(requester) = self.client_requester {
725 ctx = ctx.with_client_requester(requester);
726 }
727 ctx
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[test]
736 fn test_cancellation() {
737 let ctx = RequestContext::new(RequestId::Number(1));
738 assert!(!ctx.is_cancelled());
739
740 let token = ctx.cancellation_token();
741 assert!(!token.is_cancelled());
742
743 ctx.cancel();
744 assert!(ctx.is_cancelled());
745 assert!(token.is_cancelled());
746 }
747
748 #[tokio::test]
749 async fn test_progress_reporting() {
750 let (tx, mut rx) = notification_channel(10);
751
752 let ctx = RequestContext::new(RequestId::Number(1))
753 .with_progress_token(ProgressToken::Number(42))
754 .with_notification_sender(tx);
755
756 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
757 .await;
758
759 let notification = rx.recv().await.unwrap();
760 match notification {
761 ServerNotification::Progress(params) => {
762 assert_eq!(params.progress, 50.0);
763 assert_eq!(params.total, Some(100.0));
764 assert_eq!(params.message.as_deref(), Some("Halfway"));
765 }
766 _ => panic!("Expected Progress notification"),
767 }
768 }
769
770 #[tokio::test]
771 async fn test_progress_no_token() {
772 let (tx, mut rx) = notification_channel(10);
773
774 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
776
777 ctx.report_progress(50.0, Some(100.0), None).await;
778
779 assert!(rx.try_recv().is_err());
781 }
782
783 #[test]
784 fn test_builder() {
785 let (tx, _rx) = notification_channel(10);
786
787 let ctx = RequestContextBuilder::new()
788 .request_id(RequestId::String("req-1".to_string()))
789 .progress_token(ProgressToken::String("prog-1".to_string()))
790 .notification_sender(tx)
791 .build();
792
793 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
794 assert!(ctx.progress_token().is_some());
795 }
796
797 #[test]
798 fn test_can_sample_without_requester() {
799 let ctx = RequestContext::new(RequestId::Number(1));
800 assert!(!ctx.can_sample());
801 }
802
803 #[test]
804 fn test_can_sample_with_requester() {
805 let (request_tx, _rx) = outgoing_request_channel(10);
806 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
807
808 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
809 assert!(ctx.can_sample());
810 }
811
812 #[tokio::test]
813 async fn test_sample_without_requester_fails() {
814 use crate::protocol::{CreateMessageParams, SamplingMessage};
815
816 let ctx = RequestContext::new(RequestId::Number(1));
817 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
818
819 let result = ctx.sample(params).await;
820 assert!(result.is_err());
821 assert!(
822 result
823 .unwrap_err()
824 .to_string()
825 .contains("Sampling not available")
826 );
827 }
828
829 #[test]
830 fn test_builder_with_client_requester() {
831 let (request_tx, _rx) = outgoing_request_channel(10);
832 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
833
834 let ctx = RequestContextBuilder::new()
835 .request_id(RequestId::Number(1))
836 .client_requester(requester)
837 .build();
838
839 assert!(ctx.can_sample());
840 }
841
842 #[test]
843 fn test_can_elicit_without_requester() {
844 let ctx = RequestContext::new(RequestId::Number(1));
845 assert!(!ctx.can_elicit());
846 }
847
848 #[test]
849 fn test_can_elicit_with_requester() {
850 let (request_tx, _rx) = outgoing_request_channel(10);
851 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
852
853 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
854 assert!(ctx.can_elicit());
855 }
856
857 #[tokio::test]
858 async fn test_elicit_form_without_requester_fails() {
859 use crate::protocol::{ElicitFormSchema, ElicitMode};
860
861 let ctx = RequestContext::new(RequestId::Number(1));
862 let params = ElicitFormParams {
863 mode: ElicitMode::Form,
864 message: "Enter details".to_string(),
865 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
866 meta: None,
867 };
868
869 let result = ctx.elicit_form(params).await;
870 assert!(result.is_err());
871 assert!(
872 result
873 .unwrap_err()
874 .to_string()
875 .contains("Elicitation not available")
876 );
877 }
878
879 #[tokio::test]
880 async fn test_elicit_url_without_requester_fails() {
881 use crate::protocol::ElicitMode;
882
883 let ctx = RequestContext::new(RequestId::Number(1));
884 let params = ElicitUrlParams {
885 mode: ElicitMode::Url,
886 elicitation_id: "test-123".to_string(),
887 message: "Please authorize".to_string(),
888 url: "https://example.com/auth".to_string(),
889 meta: None,
890 };
891
892 let result = ctx.elicit_url(params).await;
893 assert!(result.is_err());
894 assert!(
895 result
896 .unwrap_err()
897 .to_string()
898 .contains("Elicitation not available")
899 );
900 }
901
902 #[tokio::test]
903 async fn test_confirm_without_requester_fails() {
904 let ctx = RequestContext::new(RequestId::Number(1));
905
906 let result = ctx.confirm("Are you sure?").await;
907 assert!(result.is_err());
908 assert!(
909 result
910 .unwrap_err()
911 .to_string()
912 .contains("Elicitation not available")
913 );
914 }
915}