1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CallToolResult, CancelTaskParams, CreateMessageParams, CreateMessageResult, ElicitFormParams,
81 ElicitRequestParams, ElicitResult, ElicitUrlParams, GetTaskInfoParams, GetTaskResultParams,
82 ListTasksParams, ListTasksResult, LogLevel, LoggingMessageParams, ProgressParams,
83 ProgressToken, RequestId, TaskObject, TaskStatus,
84};
85
86#[derive(Debug, Clone)]
88#[non_exhaustive]
89pub enum ServerNotification {
90 Progress(ProgressParams),
92 LogMessage(LoggingMessageParams),
94 ResourceUpdated {
96 uri: String,
98 },
99 ResourcesListChanged,
101 ToolsListChanged,
103 PromptsListChanged,
105 TaskStatusChanged(crate::protocol::TaskStatusParams),
107}
108
109pub type NotificationSender = mpsc::Sender<ServerNotification>;
111
112pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
114
115pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
117 mpsc::channel(buffer)
118}
119
120#[async_trait]
130pub trait ClientRequester: Send + Sync {
131 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
135
136 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
143
144 async fn request(
152 &self,
153 method: String,
154 params: serde_json::Value,
155 ) -> Result<serde_json::Value> {
156 let _ = (method, params);
157 Err(Error::Internal(
158 "ClientRequester does not support arbitrary requests".to_string(),
159 ))
160 }
161}
162
163pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
165
166#[derive(Debug)]
168pub struct OutgoingRequest {
169 pub id: RequestId,
171 pub method: String,
173 pub params: serde_json::Value,
175 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
177}
178
179pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
181
182pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
184
185pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
187 mpsc::channel(buffer)
188}
189
190#[derive(Clone)]
192pub struct ChannelClientRequester {
193 request_tx: OutgoingRequestSender,
194 next_id: Arc<AtomicI64>,
195}
196
197impl ChannelClientRequester {
198 pub fn new(request_tx: OutgoingRequestSender) -> Self {
200 Self {
201 request_tx,
202 next_id: Arc::new(AtomicI64::new(1)),
203 }
204 }
205
206 fn next_request_id(&self) -> RequestId {
207 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
208 RequestId::Number(id)
209 }
210}
211
212impl ChannelClientRequester {
213 async fn dispatch(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
214 let id = self.next_request_id();
215 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
216
217 let request = OutgoingRequest {
218 id,
219 method: method.to_string(),
220 params,
221 response_tx,
222 };
223
224 self.request_tx
225 .send(request)
226 .await
227 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
228
229 response_rx.await.map_err(|_| {
230 Error::Internal("Failed to receive response: channel closed".to_string())
231 })?
232 }
233}
234
235#[async_trait]
236impl ClientRequester for ChannelClientRequester {
237 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
238 let params_json = serde_json::to_value(¶ms)
239 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
240 let response = self.dispatch("sampling/createMessage", params_json).await?;
241 serde_json::from_value(response)
242 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
243 }
244
245 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
246 let params_json = serde_json::to_value(¶ms)
247 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
248 let response = self.dispatch("elicitation/create", params_json).await?;
249 serde_json::from_value(response)
250 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
251 }
252
253 async fn request(
254 &self,
255 method: String,
256 params: serde_json::Value,
257 ) -> Result<serde_json::Value> {
258 self.dispatch(&method, params).await
259 }
260}
261
262#[derive(Clone)]
264pub struct RequestContext {
265 request_id: RequestId,
267 progress_token: Option<ProgressToken>,
269 cancelled: Arc<AtomicBool>,
271 notification_tx: Option<NotificationSender>,
273 client_requester: Option<ClientRequesterHandle>,
275 extensions: Arc<Extensions>,
277 min_log_level: Option<Arc<RwLock<LogLevel>>>,
279}
280
281#[derive(Clone, Default)]
286pub struct Extensions {
287 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
288}
289
290impl Extensions {
291 pub fn new() -> Self {
293 Self::default()
294 }
295
296 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
300 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
301 }
302
303 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
307 self.map
308 .get(&std::any::TypeId::of::<T>())
309 .and_then(|val| val.downcast_ref::<T>())
310 }
311
312 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
314 self.map.contains_key(&std::any::TypeId::of::<T>())
315 }
316
317 pub fn merge(&mut self, other: &Extensions) {
321 for (k, v) in &other.map {
322 self.map.insert(*k, v.clone());
323 }
324 }
325
326 pub fn len(&self) -> usize {
328 self.map.len()
329 }
330
331 pub fn is_empty(&self) -> bool {
333 self.map.is_empty()
334 }
335}
336
337impl std::fmt::Debug for Extensions {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.debug_struct("Extensions")
340 .field("len", &self.map.len())
341 .finish()
342 }
343}
344
345impl std::fmt::Debug for RequestContext {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 f.debug_struct("RequestContext")
348 .field("request_id", &self.request_id)
349 .field("progress_token", &self.progress_token)
350 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
351 .finish()
352 }
353}
354
355impl RequestContext {
356 pub fn new(request_id: RequestId) -> Self {
358 Self {
359 request_id,
360 progress_token: None,
361 cancelled: Arc::new(AtomicBool::new(false)),
362 notification_tx: None,
363 client_requester: None,
364 extensions: Arc::new(Extensions::new()),
365 min_log_level: None,
366 }
367 }
368
369 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
371 self.progress_token = Some(token);
372 self
373 }
374
375 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
377 self.notification_tx = Some(tx);
378 self
379 }
380
381 pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
386 self.min_log_level = Some(level);
387 self
388 }
389
390 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
392 self.client_requester = Some(requester);
393 self
394 }
395
396 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
400 self.extensions = extensions;
401 self
402 }
403
404 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
420 self.extensions.get::<T>()
421 }
422
423 pub fn extensions_mut(&mut self) -> &mut Extensions {
428 Arc::make_mut(&mut self.extensions)
429 }
430
431 pub fn extensions(&self) -> &Extensions {
433 &self.extensions
434 }
435
436 pub fn request_id(&self) -> &RequestId {
438 &self.request_id
439 }
440
441 pub fn progress_token(&self) -> Option<&ProgressToken> {
443 self.progress_token.as_ref()
444 }
445
446 pub fn is_cancelled(&self) -> bool {
448 self.cancelled.load(Ordering::Relaxed)
449 }
450
451 pub fn cancel(&self) {
453 self.cancelled.store(true, Ordering::Relaxed);
454 }
455
456 pub fn cancellation_token(&self) -> CancellationToken {
458 CancellationToken {
459 cancelled: self.cancelled.clone(),
460 }
461 }
462
463 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
467 let Some(token) = &self.progress_token else {
468 return;
469 };
470 let Some(tx) = &self.notification_tx else {
471 return;
472 };
473
474 let params = ProgressParams {
475 progress_token: token.clone(),
476 progress,
477 total,
478 message: message.map(|s| s.to_string()),
479 meta: None,
480 };
481
482 let _ = tx.try_send(ServerNotification::Progress(params));
484 }
485
486 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
490 let Some(token) = &self.progress_token else {
491 return;
492 };
493 let Some(tx) = &self.notification_tx else {
494 return;
495 };
496
497 let params = ProgressParams {
498 progress_token: token.clone(),
499 progress,
500 total,
501 message: message.map(|s| s.to_string()),
502 meta: None,
503 };
504
505 let _ = tx.try_send(ServerNotification::Progress(params));
506 }
507
508 pub fn send_log(&self, params: LoggingMessageParams) {
525 let Some(tx) = &self.notification_tx else {
526 return;
527 };
528
529 if let Some(min_level) = &self.min_log_level
534 && let Ok(min) = min_level.read()
535 && params.level > *min
536 {
537 return;
538 }
539
540 let _ = tx.try_send(ServerNotification::LogMessage(params));
541 }
542
543 pub fn can_sample(&self) -> bool {
548 self.client_requester.is_some()
549 }
550
551 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
575 let requester = self.client_requester.as_ref().ok_or_else(|| {
576 Error::Internal("Sampling not available: no client requester configured".to_string())
577 })?;
578
579 requester.sample(params).await
580 }
581
582 pub fn can_elicit(&self) -> bool {
588 self.client_requester.is_some()
589 }
590
591 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
623 let requester = self.client_requester.as_ref().ok_or_else(|| {
624 Error::Internal("Elicitation not available: no client requester configured".to_string())
625 })?;
626
627 requester.elicit(ElicitRequestParams::Form(params)).await
628 }
629
630 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
660 let requester = self.client_requester.as_ref().ok_or_else(|| {
661 Error::Internal("Elicitation not available: no client requester configured".to_string())
662 })?;
663
664 requester.elicit(ElicitRequestParams::Url(params)).await
665 }
666
667 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
691 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
692
693 let params = ElicitFormParams {
694 mode: Some(ElicitMode::Form),
695 message: message.into(),
696 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
697 "confirm",
698 Some("Confirm this action"),
699 true,
700 false,
701 ),
702 meta: None,
703 };
704
705 let result = self.elicit_form(params).await?;
706 Ok(result.action == ElicitAction::Accept)
707 }
708
709 pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<ListTasksResult> {
719 let params = ListTasksParams {
720 status,
721 cursor: None,
722 meta: None,
723 };
724 let value = self
725 .request_raw("tasks/list", serde_json::to_value(¶ms)?)
726 .await?;
727 serde_json::from_value(value)
728 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/list: {e}")))
729 }
730
731 pub async fn get_task_info(&self, task_id: impl Into<String>) -> Result<TaskObject> {
736 let params = GetTaskInfoParams {
737 task_id: task_id.into(),
738 meta: None,
739 };
740 let value = self
741 .request_raw("tasks/get", serde_json::to_value(¶ms)?)
742 .await?;
743 serde_json::from_value(value)
744 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/get: {e}")))
745 }
746
747 pub async fn get_task_result(&self, task_id: impl Into<String>) -> Result<CallToolResult> {
755 let params = GetTaskResultParams {
756 task_id: task_id.into(),
757 meta: None,
758 };
759 let value = self
760 .request_raw("tasks/result", serde_json::to_value(¶ms)?)
761 .await?;
762 serde_json::from_value(value)
763 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/result: {e}")))
764 }
765
766 pub async fn cancel_task(
771 &self,
772 task_id: impl Into<String>,
773 reason: Option<String>,
774 ) -> Result<TaskObject> {
775 let params = CancelTaskParams {
776 task_id: task_id.into(),
777 reason,
778 meta: None,
779 };
780 let value = self
781 .request_raw("tasks/cancel", serde_json::to_value(¶ms)?)
782 .await?;
783 serde_json::from_value(value)
784 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/cancel: {e}")))
785 }
786
787 pub async fn request_raw(
793 &self,
794 method: &str,
795 params: serde_json::Value,
796 ) -> Result<serde_json::Value> {
797 let requester = self.client_requester.as_ref().ok_or_else(|| {
798 Error::Internal(
799 "Client request not available: no client requester configured".to_string(),
800 )
801 })?;
802 requester.request(method.to_string(), params).await
803 }
804}
805
806#[derive(Clone, Debug)]
808pub struct CancellationToken {
809 cancelled: Arc<AtomicBool>,
810}
811
812impl CancellationToken {
813 pub fn is_cancelled(&self) -> bool {
815 self.cancelled.load(Ordering::Relaxed)
816 }
817
818 pub fn cancel(&self) {
820 self.cancelled.store(true, Ordering::Relaxed);
821 }
822}
823
824#[derive(Default)]
826pub struct RequestContextBuilder {
827 request_id: Option<RequestId>,
828 progress_token: Option<ProgressToken>,
829 notification_tx: Option<NotificationSender>,
830 client_requester: Option<ClientRequesterHandle>,
831 min_log_level: Option<Arc<RwLock<LogLevel>>>,
832}
833
834impl RequestContextBuilder {
835 pub fn new() -> Self {
837 Self::default()
838 }
839
840 pub fn request_id(mut self, id: RequestId) -> Self {
842 self.request_id = Some(id);
843 self
844 }
845
846 pub fn progress_token(mut self, token: ProgressToken) -> Self {
848 self.progress_token = Some(token);
849 self
850 }
851
852 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
854 self.notification_tx = Some(tx);
855 self
856 }
857
858 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
860 self.client_requester = Some(requester);
861 self
862 }
863
864 pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
866 self.min_log_level = Some(level);
867 self
868 }
869
870 pub fn build(self) -> RequestContext {
874 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
875 if let Some(token) = self.progress_token {
876 ctx = ctx.with_progress_token(token);
877 }
878 if let Some(tx) = self.notification_tx {
879 ctx = ctx.with_notification_sender(tx);
880 }
881 if let Some(requester) = self.client_requester {
882 ctx = ctx.with_client_requester(requester);
883 }
884 if let Some(level) = self.min_log_level {
885 ctx = ctx.with_min_log_level(level);
886 }
887 ctx
888 }
889}
890
891#[cfg(test)]
892mod tests {
893 use super::*;
894
895 #[test]
896 fn test_cancellation() {
897 let ctx = RequestContext::new(RequestId::Number(1));
898 assert!(!ctx.is_cancelled());
899
900 let token = ctx.cancellation_token();
901 assert!(!token.is_cancelled());
902
903 ctx.cancel();
904 assert!(ctx.is_cancelled());
905 assert!(token.is_cancelled());
906 }
907
908 #[tokio::test]
909 async fn test_progress_reporting() {
910 let (tx, mut rx) = notification_channel(10);
911
912 let ctx = RequestContext::new(RequestId::Number(1))
913 .with_progress_token(ProgressToken::Number(42))
914 .with_notification_sender(tx);
915
916 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
917 .await;
918
919 let notification = rx.recv().await.unwrap();
920 match notification {
921 ServerNotification::Progress(params) => {
922 assert_eq!(params.progress, 50.0);
923 assert_eq!(params.total, Some(100.0));
924 assert_eq!(params.message.as_deref(), Some("Halfway"));
925 }
926 _ => panic!("Expected Progress notification"),
927 }
928 }
929
930 #[tokio::test]
931 async fn test_progress_no_token() {
932 let (tx, mut rx) = notification_channel(10);
933
934 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
936
937 ctx.report_progress(50.0, Some(100.0), None).await;
938
939 assert!(rx.try_recv().is_err());
941 }
942
943 #[test]
944 fn test_builder() {
945 let (tx, _rx) = notification_channel(10);
946
947 let ctx = RequestContextBuilder::new()
948 .request_id(RequestId::String("req-1".to_string()))
949 .progress_token(ProgressToken::String("prog-1".to_string()))
950 .notification_sender(tx)
951 .build();
952
953 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
954 assert!(ctx.progress_token().is_some());
955 }
956
957 #[test]
958 fn test_can_sample_without_requester() {
959 let ctx = RequestContext::new(RequestId::Number(1));
960 assert!(!ctx.can_sample());
961 }
962
963 #[test]
964 fn test_can_sample_with_requester() {
965 let (request_tx, _rx) = outgoing_request_channel(10);
966 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
967
968 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
969 assert!(ctx.can_sample());
970 }
971
972 #[tokio::test]
973 async fn test_sample_without_requester_fails() {
974 use crate::protocol::{CreateMessageParams, SamplingMessage};
975
976 let ctx = RequestContext::new(RequestId::Number(1));
977 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
978
979 let result = ctx.sample(params).await;
980 assert!(result.is_err());
981 assert!(
982 result
983 .unwrap_err()
984 .to_string()
985 .contains("Sampling not available")
986 );
987 }
988
989 #[test]
990 fn test_builder_with_client_requester() {
991 let (request_tx, _rx) = outgoing_request_channel(10);
992 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
993
994 let ctx = RequestContextBuilder::new()
995 .request_id(RequestId::Number(1))
996 .client_requester(requester)
997 .build();
998
999 assert!(ctx.can_sample());
1000 }
1001
1002 #[test]
1003 fn test_can_elicit_without_requester() {
1004 let ctx = RequestContext::new(RequestId::Number(1));
1005 assert!(!ctx.can_elicit());
1006 }
1007
1008 #[test]
1009 fn test_can_elicit_with_requester() {
1010 let (request_tx, _rx) = outgoing_request_channel(10);
1011 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1012
1013 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1014 assert!(ctx.can_elicit());
1015 }
1016
1017 #[tokio::test]
1018 async fn test_elicit_form_without_requester_fails() {
1019 use crate::protocol::{ElicitFormSchema, ElicitMode};
1020
1021 let ctx = RequestContext::new(RequestId::Number(1));
1022 let params = ElicitFormParams {
1023 mode: Some(ElicitMode::Form),
1024 message: "Enter details".to_string(),
1025 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
1026 meta: None,
1027 };
1028
1029 let result = ctx.elicit_form(params).await;
1030 assert!(result.is_err());
1031 assert!(
1032 result
1033 .unwrap_err()
1034 .to_string()
1035 .contains("Elicitation not available")
1036 );
1037 }
1038
1039 #[tokio::test]
1040 async fn test_elicit_url_without_requester_fails() {
1041 use crate::protocol::ElicitMode;
1042
1043 let ctx = RequestContext::new(RequestId::Number(1));
1044 let params = ElicitUrlParams {
1045 mode: Some(ElicitMode::Url),
1046 elicitation_id: "test-123".to_string(),
1047 message: "Please authorize".to_string(),
1048 url: "https://example.com/auth".to_string(),
1049 meta: None,
1050 };
1051
1052 let result = ctx.elicit_url(params).await;
1053 assert!(result.is_err());
1054 assert!(
1055 result
1056 .unwrap_err()
1057 .to_string()
1058 .contains("Elicitation not available")
1059 );
1060 }
1061
1062 #[tokio::test]
1063 async fn test_confirm_without_requester_fails() {
1064 let ctx = RequestContext::new(RequestId::Number(1));
1065
1066 let result = ctx.confirm("Are you sure?").await;
1067 assert!(result.is_err());
1068 assert!(
1069 result
1070 .unwrap_err()
1071 .to_string()
1072 .contains("Elicitation not available")
1073 );
1074 }
1075
1076 #[tokio::test]
1077 async fn test_send_log_filtered_by_level() {
1078 let (tx, mut rx) = notification_channel(10);
1079 let min_level = Arc::new(RwLock::new(LogLevel::Warning));
1080
1081 let ctx = RequestContext::new(RequestId::Number(1))
1082 .with_notification_sender(tx)
1083 .with_min_log_level(min_level.clone());
1084
1085 ctx.send_log(LoggingMessageParams::new(
1087 LogLevel::Error,
1088 serde_json::Value::Null,
1089 ));
1090 let msg = rx.try_recv();
1091 assert!(msg.is_ok(), "Error should pass through Warning filter");
1092
1093 ctx.send_log(LoggingMessageParams::new(
1095 LogLevel::Warning,
1096 serde_json::Value::Null,
1097 ));
1098 let msg = rx.try_recv();
1099 assert!(msg.is_ok(), "Warning should pass through Warning filter");
1100
1101 ctx.send_log(LoggingMessageParams::new(
1103 LogLevel::Info,
1104 serde_json::Value::Null,
1105 ));
1106 let msg = rx.try_recv();
1107 assert!(msg.is_err(), "Info should be filtered by Warning filter");
1108
1109 ctx.send_log(LoggingMessageParams::new(
1111 LogLevel::Debug,
1112 serde_json::Value::Null,
1113 ));
1114 let msg = rx.try_recv();
1115 assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1116 }
1117
1118 #[tokio::test]
1119 async fn test_send_log_level_updates_dynamically() {
1120 let (tx, mut rx) = notification_channel(10);
1121 let min_level = Arc::new(RwLock::new(LogLevel::Error));
1122
1123 let ctx = RequestContext::new(RequestId::Number(1))
1124 .with_notification_sender(tx)
1125 .with_min_log_level(min_level.clone());
1126
1127 ctx.send_log(LoggingMessageParams::new(
1129 LogLevel::Info,
1130 serde_json::Value::Null,
1131 ));
1132 assert!(
1133 rx.try_recv().is_err(),
1134 "Info should be filtered at Error level"
1135 );
1136
1137 *min_level.write().unwrap() = LogLevel::Debug;
1139
1140 ctx.send_log(LoggingMessageParams::new(
1142 LogLevel::Info,
1143 serde_json::Value::Null,
1144 ));
1145 assert!(
1146 rx.try_recv().is_ok(),
1147 "Info should pass through after level changed to Debug"
1148 );
1149 }
1150
1151 #[tokio::test]
1152 async fn test_send_log_no_min_level_sends_all() {
1153 let (tx, mut rx) = notification_channel(10);
1154
1155 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1157
1158 ctx.send_log(LoggingMessageParams::new(
1159 LogLevel::Debug,
1160 serde_json::Value::Null,
1161 ));
1162 assert!(
1163 rx.try_recv().is_ok(),
1164 "Debug should pass when no min level is set"
1165 );
1166 }
1167
1168 fn make_task_object(id: &str, status: TaskStatus) -> serde_json::Value {
1169 serde_json::json!({
1170 "taskId": id,
1171 "status": status,
1172 "createdAt": "2026-04-24T00:00:00Z",
1173 "lastUpdatedAt": "2026-04-24T00:00:00Z",
1174 "ttl": null
1175 })
1176 }
1177
1178 fn spawn_mock_client(
1179 mut rx: OutgoingRequestReceiver,
1180 responder: impl Fn(&str, serde_json::Value) -> serde_json::Value + Send + 'static,
1181 ) {
1182 tokio::spawn(async move {
1183 while let Some(req) = rx.recv().await {
1184 let response = responder(&req.method, req.params);
1185 let _ = req.response_tx.send(Ok(response));
1186 }
1187 });
1188 }
1189
1190 #[tokio::test]
1191 async fn test_get_task_info_round_trips() {
1192 let (tx, rx) = outgoing_request_channel(10);
1193 spawn_mock_client(rx, |method, params| {
1194 assert_eq!(method, "tasks/get");
1195 let task_id = params["taskId"].as_str().unwrap().to_string();
1196 make_task_object(&task_id, TaskStatus::Working)
1197 });
1198 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1199 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1200
1201 let info = ctx.get_task_info("task-123").await.unwrap();
1202 assert_eq!(info.task_id, "task-123");
1203 assert!(matches!(info.status, TaskStatus::Working));
1204 }
1205
1206 #[tokio::test]
1207 async fn test_list_tasks_round_trips() {
1208 let (tx, rx) = outgoing_request_channel(10);
1209 spawn_mock_client(rx, |method, params| {
1210 assert_eq!(method, "tasks/list");
1211 assert_eq!(params["status"], serde_json::json!("working"));
1213 serde_json::json!({
1214 "tasks": [
1215 make_task_object("task-1", TaskStatus::Working),
1216 make_task_object("task-2", TaskStatus::Working),
1217 ]
1218 })
1219 });
1220 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1221 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1222
1223 let result = ctx.list_tasks(Some(TaskStatus::Working)).await.unwrap();
1224 assert_eq!(result.tasks.len(), 2);
1225 assert_eq!(result.tasks[0].task_id, "task-1");
1226 }
1227
1228 #[tokio::test]
1229 async fn test_cancel_task_forwards_reason() {
1230 let (tx, rx) = outgoing_request_channel(10);
1231 spawn_mock_client(rx, |method, params| {
1232 assert_eq!(method, "tasks/cancel");
1233 assert_eq!(params["reason"], serde_json::json!("user requested"));
1234 let task_id = params["taskId"].as_str().unwrap().to_string();
1235 make_task_object(&task_id, TaskStatus::Cancelled)
1236 });
1237 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1238 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1239
1240 let task = ctx
1241 .cancel_task("task-99", Some("user requested".into()))
1242 .await
1243 .unwrap();
1244 assert_eq!(task.task_id, "task-99");
1245 assert!(matches!(task.status, TaskStatus::Cancelled));
1246 }
1247
1248 #[tokio::test]
1249 async fn test_get_task_info_without_requester_fails() {
1250 let ctx = RequestContext::new(RequestId::Number(1));
1251 let result = ctx.get_task_info("task-1").await;
1252 assert!(result.is_err());
1253 assert!(
1254 result
1255 .unwrap_err()
1256 .to_string()
1257 .contains("Client request not available")
1258 );
1259 }
1260
1261 #[tokio::test]
1262 async fn test_default_request_impl_errors() {
1263 struct OnlySampleAndElicit;
1266
1267 #[async_trait]
1268 impl ClientRequester for OnlySampleAndElicit {
1269 async fn sample(&self, _: CreateMessageParams) -> Result<CreateMessageResult> {
1270 unreachable!()
1271 }
1272 async fn elicit(&self, _: ElicitRequestParams) -> Result<ElicitResult> {
1273 unreachable!()
1274 }
1275 }
1276
1277 let requester: ClientRequesterHandle = Arc::new(OnlySampleAndElicit);
1278 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1279
1280 let err = ctx.get_task_info("x").await.unwrap_err();
1281 assert!(err.to_string().contains("does not support arbitrary"));
1282 }
1283}