1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86#[non_exhaustive]
87pub enum ServerNotification {
88 Progress(ProgressParams),
90 LogMessage(LoggingMessageParams),
92 ResourceUpdated {
94 uri: String,
96 },
97 ResourcesListChanged,
99 ToolsListChanged,
101 PromptsListChanged,
103 TaskStatusChanged(crate::protocol::TaskStatusParams),
105}
106
107pub type NotificationSender = mpsc::Sender<ServerNotification>;
109
110pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
112
113pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
115 mpsc::channel(buffer)
116}
117
118#[async_trait]
128pub trait ClientRequester: Send + Sync {
129 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
133
134 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
141}
142
143pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
145
146#[derive(Debug)]
148pub struct OutgoingRequest {
149 pub id: RequestId,
151 pub method: String,
153 pub params: serde_json::Value,
155 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
157}
158
159pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
161
162pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
164
165pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
167 mpsc::channel(buffer)
168}
169
170#[derive(Clone)]
172pub struct ChannelClientRequester {
173 request_tx: OutgoingRequestSender,
174 next_id: Arc<AtomicI64>,
175}
176
177impl ChannelClientRequester {
178 pub fn new(request_tx: OutgoingRequestSender) -> Self {
180 Self {
181 request_tx,
182 next_id: Arc::new(AtomicI64::new(1)),
183 }
184 }
185
186 fn next_request_id(&self) -> RequestId {
187 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
188 RequestId::Number(id)
189 }
190}
191
192#[async_trait]
193impl ClientRequester for ChannelClientRequester {
194 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
195 let id = self.next_request_id();
196 let params_json = serde_json::to_value(¶ms)
197 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
198
199 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
200
201 let request = OutgoingRequest {
202 id: id.clone(),
203 method: "sampling/createMessage".to_string(),
204 params: params_json,
205 response_tx,
206 };
207
208 self.request_tx
209 .send(request)
210 .await
211 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
212
213 let response = response_rx.await.map_err(|_| {
214 Error::Internal("Failed to receive response: channel closed".to_string())
215 })??;
216
217 serde_json::from_value(response)
218 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
219 }
220
221 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
222 let id = self.next_request_id();
223 let params_json = serde_json::to_value(¶ms)
224 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
225
226 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
227
228 let request = OutgoingRequest {
229 id: id.clone(),
230 method: "elicitation/create".to_string(),
231 params: params_json,
232 response_tx,
233 };
234
235 self.request_tx
236 .send(request)
237 .await
238 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
239
240 let response = response_rx.await.map_err(|_| {
241 Error::Internal("Failed to receive response: channel closed".to_string())
242 })??;
243
244 serde_json::from_value(response)
245 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
246 }
247}
248
249#[derive(Clone)]
251pub struct RequestContext {
252 request_id: RequestId,
254 progress_token: Option<ProgressToken>,
256 cancelled: Arc<AtomicBool>,
258 notification_tx: Option<NotificationSender>,
260 client_requester: Option<ClientRequesterHandle>,
262 extensions: Arc<Extensions>,
264 min_log_level: Option<Arc<RwLock<LogLevel>>>,
266}
267
268#[derive(Clone, Default)]
273pub struct Extensions {
274 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
275}
276
277impl Extensions {
278 pub fn new() -> Self {
280 Self::default()
281 }
282
283 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
287 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
288 }
289
290 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
294 self.map
295 .get(&std::any::TypeId::of::<T>())
296 .and_then(|val| val.downcast_ref::<T>())
297 }
298
299 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
301 self.map.contains_key(&std::any::TypeId::of::<T>())
302 }
303
304 pub fn merge(&mut self, other: &Extensions) {
308 for (k, v) in &other.map {
309 self.map.insert(*k, v.clone());
310 }
311 }
312}
313
314impl std::fmt::Debug for Extensions {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 f.debug_struct("Extensions")
317 .field("len", &self.map.len())
318 .finish()
319 }
320}
321
322impl std::fmt::Debug for RequestContext {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 f.debug_struct("RequestContext")
325 .field("request_id", &self.request_id)
326 .field("progress_token", &self.progress_token)
327 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
328 .finish()
329 }
330}
331
332impl RequestContext {
333 pub fn new(request_id: RequestId) -> Self {
335 Self {
336 request_id,
337 progress_token: None,
338 cancelled: Arc::new(AtomicBool::new(false)),
339 notification_tx: None,
340 client_requester: None,
341 extensions: Arc::new(Extensions::new()),
342 min_log_level: None,
343 }
344 }
345
346 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
348 self.progress_token = Some(token);
349 self
350 }
351
352 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
354 self.notification_tx = Some(tx);
355 self
356 }
357
358 pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
363 self.min_log_level = Some(level);
364 self
365 }
366
367 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
369 self.client_requester = Some(requester);
370 self
371 }
372
373 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
377 self.extensions = extensions;
378 self
379 }
380
381 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
397 self.extensions.get::<T>()
398 }
399
400 pub fn extensions_mut(&mut self) -> &mut Extensions {
405 Arc::make_mut(&mut self.extensions)
406 }
407
408 pub fn extensions(&self) -> &Extensions {
410 &self.extensions
411 }
412
413 pub fn request_id(&self) -> &RequestId {
415 &self.request_id
416 }
417
418 pub fn progress_token(&self) -> Option<&ProgressToken> {
420 self.progress_token.as_ref()
421 }
422
423 pub fn is_cancelled(&self) -> bool {
425 self.cancelled.load(Ordering::Relaxed)
426 }
427
428 pub fn cancel(&self) {
430 self.cancelled.store(true, Ordering::Relaxed);
431 }
432
433 pub fn cancellation_token(&self) -> CancellationToken {
435 CancellationToken {
436 cancelled: self.cancelled.clone(),
437 }
438 }
439
440 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
444 let Some(token) = &self.progress_token else {
445 return;
446 };
447 let Some(tx) = &self.notification_tx else {
448 return;
449 };
450
451 let params = ProgressParams {
452 progress_token: token.clone(),
453 progress,
454 total,
455 message: message.map(|s| s.to_string()),
456 meta: None,
457 };
458
459 let _ = tx.try_send(ServerNotification::Progress(params));
461 }
462
463 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
467 let Some(token) = &self.progress_token else {
468 return;
469 };
470 let Some(tx) = &self.notification_tx else {
471 return;
472 };
473
474 let params = ProgressParams {
475 progress_token: token.clone(),
476 progress,
477 total,
478 message: message.map(|s| s.to_string()),
479 meta: None,
480 };
481
482 let _ = tx.try_send(ServerNotification::Progress(params));
483 }
484
485 pub fn send_log(&self, params: LoggingMessageParams) {
502 let Some(tx) = &self.notification_tx else {
503 return;
504 };
505
506 if let Some(min_level) = &self.min_log_level
511 && let Ok(min) = min_level.read()
512 && params.level > *min
513 {
514 return;
515 }
516
517 let _ = tx.try_send(ServerNotification::LogMessage(params));
518 }
519
520 pub fn can_sample(&self) -> bool {
525 self.client_requester.is_some()
526 }
527
528 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
552 let requester = self.client_requester.as_ref().ok_or_else(|| {
553 Error::Internal("Sampling not available: no client requester configured".to_string())
554 })?;
555
556 requester.sample(params).await
557 }
558
559 pub fn can_elicit(&self) -> bool {
565 self.client_requester.is_some()
566 }
567
568 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
600 let requester = self.client_requester.as_ref().ok_or_else(|| {
601 Error::Internal("Elicitation not available: no client requester configured".to_string())
602 })?;
603
604 requester.elicit(ElicitRequestParams::Form(params)).await
605 }
606
607 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
637 let requester = self.client_requester.as_ref().ok_or_else(|| {
638 Error::Internal("Elicitation not available: no client requester configured".to_string())
639 })?;
640
641 requester.elicit(ElicitRequestParams::Url(params)).await
642 }
643
644 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
668 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
669
670 let params = ElicitFormParams {
671 mode: Some(ElicitMode::Form),
672 message: message.into(),
673 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
674 "confirm",
675 Some("Confirm this action"),
676 true,
677 false,
678 ),
679 meta: None,
680 };
681
682 let result = self.elicit_form(params).await?;
683 Ok(result.action == ElicitAction::Accept)
684 }
685}
686
687#[derive(Clone, Debug)]
689pub struct CancellationToken {
690 cancelled: Arc<AtomicBool>,
691}
692
693impl CancellationToken {
694 pub fn is_cancelled(&self) -> bool {
696 self.cancelled.load(Ordering::Relaxed)
697 }
698
699 pub fn cancel(&self) {
701 self.cancelled.store(true, Ordering::Relaxed);
702 }
703}
704
705#[derive(Default)]
707pub struct RequestContextBuilder {
708 request_id: Option<RequestId>,
709 progress_token: Option<ProgressToken>,
710 notification_tx: Option<NotificationSender>,
711 client_requester: Option<ClientRequesterHandle>,
712 min_log_level: Option<Arc<RwLock<LogLevel>>>,
713}
714
715impl RequestContextBuilder {
716 pub fn new() -> Self {
718 Self::default()
719 }
720
721 pub fn request_id(mut self, id: RequestId) -> Self {
723 self.request_id = Some(id);
724 self
725 }
726
727 pub fn progress_token(mut self, token: ProgressToken) -> Self {
729 self.progress_token = Some(token);
730 self
731 }
732
733 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
735 self.notification_tx = Some(tx);
736 self
737 }
738
739 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
741 self.client_requester = Some(requester);
742 self
743 }
744
745 pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
747 self.min_log_level = Some(level);
748 self
749 }
750
751 pub fn build(self) -> RequestContext {
755 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
756 if let Some(token) = self.progress_token {
757 ctx = ctx.with_progress_token(token);
758 }
759 if let Some(tx) = self.notification_tx {
760 ctx = ctx.with_notification_sender(tx);
761 }
762 if let Some(requester) = self.client_requester {
763 ctx = ctx.with_client_requester(requester);
764 }
765 if let Some(level) = self.min_log_level {
766 ctx = ctx.with_min_log_level(level);
767 }
768 ctx
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775
776 #[test]
777 fn test_cancellation() {
778 let ctx = RequestContext::new(RequestId::Number(1));
779 assert!(!ctx.is_cancelled());
780
781 let token = ctx.cancellation_token();
782 assert!(!token.is_cancelled());
783
784 ctx.cancel();
785 assert!(ctx.is_cancelled());
786 assert!(token.is_cancelled());
787 }
788
789 #[tokio::test]
790 async fn test_progress_reporting() {
791 let (tx, mut rx) = notification_channel(10);
792
793 let ctx = RequestContext::new(RequestId::Number(1))
794 .with_progress_token(ProgressToken::Number(42))
795 .with_notification_sender(tx);
796
797 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
798 .await;
799
800 let notification = rx.recv().await.unwrap();
801 match notification {
802 ServerNotification::Progress(params) => {
803 assert_eq!(params.progress, 50.0);
804 assert_eq!(params.total, Some(100.0));
805 assert_eq!(params.message.as_deref(), Some("Halfway"));
806 }
807 _ => panic!("Expected Progress notification"),
808 }
809 }
810
811 #[tokio::test]
812 async fn test_progress_no_token() {
813 let (tx, mut rx) = notification_channel(10);
814
815 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
817
818 ctx.report_progress(50.0, Some(100.0), None).await;
819
820 assert!(rx.try_recv().is_err());
822 }
823
824 #[test]
825 fn test_builder() {
826 let (tx, _rx) = notification_channel(10);
827
828 let ctx = RequestContextBuilder::new()
829 .request_id(RequestId::String("req-1".to_string()))
830 .progress_token(ProgressToken::String("prog-1".to_string()))
831 .notification_sender(tx)
832 .build();
833
834 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
835 assert!(ctx.progress_token().is_some());
836 }
837
838 #[test]
839 fn test_can_sample_without_requester() {
840 let ctx = RequestContext::new(RequestId::Number(1));
841 assert!(!ctx.can_sample());
842 }
843
844 #[test]
845 fn test_can_sample_with_requester() {
846 let (request_tx, _rx) = outgoing_request_channel(10);
847 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
848
849 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
850 assert!(ctx.can_sample());
851 }
852
853 #[tokio::test]
854 async fn test_sample_without_requester_fails() {
855 use crate::protocol::{CreateMessageParams, SamplingMessage};
856
857 let ctx = RequestContext::new(RequestId::Number(1));
858 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
859
860 let result = ctx.sample(params).await;
861 assert!(result.is_err());
862 assert!(
863 result
864 .unwrap_err()
865 .to_string()
866 .contains("Sampling not available")
867 );
868 }
869
870 #[test]
871 fn test_builder_with_client_requester() {
872 let (request_tx, _rx) = outgoing_request_channel(10);
873 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
874
875 let ctx = RequestContextBuilder::new()
876 .request_id(RequestId::Number(1))
877 .client_requester(requester)
878 .build();
879
880 assert!(ctx.can_sample());
881 }
882
883 #[test]
884 fn test_can_elicit_without_requester() {
885 let ctx = RequestContext::new(RequestId::Number(1));
886 assert!(!ctx.can_elicit());
887 }
888
889 #[test]
890 fn test_can_elicit_with_requester() {
891 let (request_tx, _rx) = outgoing_request_channel(10);
892 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
893
894 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
895 assert!(ctx.can_elicit());
896 }
897
898 #[tokio::test]
899 async fn test_elicit_form_without_requester_fails() {
900 use crate::protocol::{ElicitFormSchema, ElicitMode};
901
902 let ctx = RequestContext::new(RequestId::Number(1));
903 let params = ElicitFormParams {
904 mode: Some(ElicitMode::Form),
905 message: "Enter details".to_string(),
906 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
907 meta: None,
908 };
909
910 let result = ctx.elicit_form(params).await;
911 assert!(result.is_err());
912 assert!(
913 result
914 .unwrap_err()
915 .to_string()
916 .contains("Elicitation not available")
917 );
918 }
919
920 #[tokio::test]
921 async fn test_elicit_url_without_requester_fails() {
922 use crate::protocol::ElicitMode;
923
924 let ctx = RequestContext::new(RequestId::Number(1));
925 let params = ElicitUrlParams {
926 mode: Some(ElicitMode::Url),
927 elicitation_id: "test-123".to_string(),
928 message: "Please authorize".to_string(),
929 url: "https://example.com/auth".to_string(),
930 meta: None,
931 };
932
933 let result = ctx.elicit_url(params).await;
934 assert!(result.is_err());
935 assert!(
936 result
937 .unwrap_err()
938 .to_string()
939 .contains("Elicitation not available")
940 );
941 }
942
943 #[tokio::test]
944 async fn test_confirm_without_requester_fails() {
945 let ctx = RequestContext::new(RequestId::Number(1));
946
947 let result = ctx.confirm("Are you sure?").await;
948 assert!(result.is_err());
949 assert!(
950 result
951 .unwrap_err()
952 .to_string()
953 .contains("Elicitation not available")
954 );
955 }
956
957 #[tokio::test]
958 async fn test_send_log_filtered_by_level() {
959 let (tx, mut rx) = notification_channel(10);
960 let min_level = Arc::new(RwLock::new(LogLevel::Warning));
961
962 let ctx = RequestContext::new(RequestId::Number(1))
963 .with_notification_sender(tx)
964 .with_min_log_level(min_level.clone());
965
966 ctx.send_log(LoggingMessageParams::new(
968 LogLevel::Error,
969 serde_json::Value::Null,
970 ));
971 let msg = rx.try_recv();
972 assert!(msg.is_ok(), "Error should pass through Warning filter");
973
974 ctx.send_log(LoggingMessageParams::new(
976 LogLevel::Warning,
977 serde_json::Value::Null,
978 ));
979 let msg = rx.try_recv();
980 assert!(msg.is_ok(), "Warning should pass through Warning filter");
981
982 ctx.send_log(LoggingMessageParams::new(
984 LogLevel::Info,
985 serde_json::Value::Null,
986 ));
987 let msg = rx.try_recv();
988 assert!(msg.is_err(), "Info should be filtered by Warning filter");
989
990 ctx.send_log(LoggingMessageParams::new(
992 LogLevel::Debug,
993 serde_json::Value::Null,
994 ));
995 let msg = rx.try_recv();
996 assert!(msg.is_err(), "Debug should be filtered by Warning filter");
997 }
998
999 #[tokio::test]
1000 async fn test_send_log_level_updates_dynamically() {
1001 let (tx, mut rx) = notification_channel(10);
1002 let min_level = Arc::new(RwLock::new(LogLevel::Error));
1003
1004 let ctx = RequestContext::new(RequestId::Number(1))
1005 .with_notification_sender(tx)
1006 .with_min_log_level(min_level.clone());
1007
1008 ctx.send_log(LoggingMessageParams::new(
1010 LogLevel::Info,
1011 serde_json::Value::Null,
1012 ));
1013 assert!(
1014 rx.try_recv().is_err(),
1015 "Info should be filtered at Error level"
1016 );
1017
1018 *min_level.write().unwrap() = LogLevel::Debug;
1020
1021 ctx.send_log(LoggingMessageParams::new(
1023 LogLevel::Info,
1024 serde_json::Value::Null,
1025 ));
1026 assert!(
1027 rx.try_recv().is_ok(),
1028 "Info should pass through after level changed to Debug"
1029 );
1030 }
1031
1032 #[tokio::test]
1033 async fn test_send_log_no_min_level_sends_all() {
1034 let (tx, mut rx) = notification_channel(10);
1035
1036 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1038
1039 ctx.send_log(LoggingMessageParams::new(
1040 LogLevel::Debug,
1041 serde_json::Value::Null,
1042 ));
1043 assert!(
1044 rx.try_recv().is_ok(),
1045 "Debug should pass when no min level is set"
1046 );
1047 }
1048}