1use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
117use std::sync::{Arc, RwLock};
118
119use async_trait::async_trait;
120use tokio::sync::mpsc;
121
122use crate::error::{Error, Result};
123use crate::protocol::{
124 CallToolResult, CancelTaskParams, CreateMessageParams, CreateMessageResult, ElicitFormParams,
125 ElicitRequestParams, ElicitResult, ElicitUrlParams, GetTaskInfoParams, GetTaskResultParams,
126 ListTasksParams, ListTasksResult, LogLevel, LoggingMessageParams, ProgressParams,
127 ProgressToken, RequestId, TaskObject, TaskStatus,
128};
129
130#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub enum ServerNotification {
134 Progress(ProgressParams),
136 LogMessage(LoggingMessageParams),
138 ResourceUpdated {
140 uri: String,
142 },
143 ResourcesListChanged,
145 ToolsListChanged,
147 PromptsListChanged,
149 TaskStatusChanged(crate::protocol::TaskStatusParams),
151}
152
153pub type NotificationSender = mpsc::Sender<ServerNotification>;
155
156pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
158
159pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
161 mpsc::channel(buffer)
162}
163
164#[async_trait]
174pub trait ClientRequester: Send + Sync {
175 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
179
180 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
187
188 async fn request(
196 &self,
197 method: String,
198 params: serde_json::Value,
199 ) -> Result<serde_json::Value> {
200 let _ = (method, params);
201 Err(Error::Internal(
202 "ClientRequester does not support arbitrary requests".to_string(),
203 ))
204 }
205}
206
207pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
209
210#[derive(Debug)]
212pub struct OutgoingRequest {
213 pub id: RequestId,
215 pub method: String,
217 pub params: serde_json::Value,
219 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
221}
222
223pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
225
226pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
228
229pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
231 mpsc::channel(buffer)
232}
233
234#[derive(Clone)]
236pub struct ChannelClientRequester {
237 request_tx: OutgoingRequestSender,
238 next_id: Arc<AtomicI64>,
239}
240
241impl ChannelClientRequester {
242 pub fn new(request_tx: OutgoingRequestSender) -> Self {
244 Self {
245 request_tx,
246 next_id: Arc::new(AtomicI64::new(1)),
247 }
248 }
249
250 fn next_request_id(&self) -> RequestId {
251 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
252 RequestId::Number(id)
253 }
254}
255
256impl ChannelClientRequester {
257 async fn dispatch(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
258 let id = self.next_request_id();
259 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
260
261 let request = OutgoingRequest {
262 id,
263 method: method.to_string(),
264 params,
265 response_tx,
266 };
267
268 self.request_tx
269 .send(request)
270 .await
271 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
272
273 response_rx.await.map_err(|_| {
274 Error::Internal("Failed to receive response: channel closed".to_string())
275 })?
276 }
277}
278
279#[async_trait]
280impl ClientRequester for ChannelClientRequester {
281 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
282 let params_json = serde_json::to_value(¶ms)
283 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
284 let response = self.dispatch("sampling/createMessage", params_json).await?;
285 serde_json::from_value(response)
286 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
287 }
288
289 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
290 let params_json = serde_json::to_value(¶ms)
291 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
292 let response = self.dispatch("elicitation/create", params_json).await?;
293 serde_json::from_value(response)
294 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
295 }
296
297 async fn request(
298 &self,
299 method: String,
300 params: serde_json::Value,
301 ) -> Result<serde_json::Value> {
302 self.dispatch(&method, params).await
303 }
304}
305
306#[derive(Clone)]
308pub struct RequestContext {
309 request_id: RequestId,
311 progress_token: Option<ProgressToken>,
313 cancelled: Arc<AtomicBool>,
315 notification_tx: Option<NotificationSender>,
317 client_requester: Option<ClientRequesterHandle>,
319 extensions: Arc<Extensions>,
321 min_log_level: Option<Arc<RwLock<LogLevel>>>,
323}
324
325#[derive(Clone, Default)]
330pub struct Extensions {
331 map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
332}
333
334impl Extensions {
335 pub fn new() -> Self {
337 Self::default()
338 }
339
340 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
344 self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
345 }
346
347 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
351 self.map
352 .get(&std::any::TypeId::of::<T>())
353 .and_then(|val| val.downcast_ref::<T>())
354 }
355
356 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
358 self.map.contains_key(&std::any::TypeId::of::<T>())
359 }
360
361 pub fn merge(&mut self, other: &Extensions) {
365 for (k, v) in &other.map {
366 self.map.insert(*k, v.clone());
367 }
368 }
369
370 pub fn len(&self) -> usize {
372 self.map.len()
373 }
374
375 pub fn is_empty(&self) -> bool {
377 self.map.is_empty()
378 }
379}
380
381impl std::fmt::Debug for Extensions {
382 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 f.debug_struct("Extensions")
384 .field("len", &self.map.len())
385 .finish()
386 }
387}
388
389impl std::fmt::Debug for RequestContext {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 f.debug_struct("RequestContext")
392 .field("request_id", &self.request_id)
393 .field("progress_token", &self.progress_token)
394 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
395 .finish()
396 }
397}
398
399impl RequestContext {
400 pub fn new(request_id: RequestId) -> Self {
402 Self {
403 request_id,
404 progress_token: None,
405 cancelled: Arc::new(AtomicBool::new(false)),
406 notification_tx: None,
407 client_requester: None,
408 extensions: Arc::new(Extensions::new()),
409 min_log_level: None,
410 }
411 }
412
413 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
415 self.progress_token = Some(token);
416 self
417 }
418
419 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
421 self.notification_tx = Some(tx);
422 self
423 }
424
425 pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
430 self.min_log_level = Some(level);
431 self
432 }
433
434 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
436 self.client_requester = Some(requester);
437 self
438 }
439
440 pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
444 self.extensions = extensions;
445 self
446 }
447
448 pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
464 self.extensions.get::<T>()
465 }
466
467 pub fn extensions_mut(&mut self) -> &mut Extensions {
472 Arc::make_mut(&mut self.extensions)
473 }
474
475 pub fn extensions(&self) -> &Extensions {
477 &self.extensions
478 }
479
480 #[cfg(feature = "stateless")]
508 pub fn per_request_meta(&self) -> Option<&crate::stateless::StatelessRequestMeta> {
509 self.extension::<crate::stateless::StatelessRequestMeta>()
510 }
511
512 pub fn request_id(&self) -> &RequestId {
514 &self.request_id
515 }
516
517 pub fn progress_token(&self) -> Option<&ProgressToken> {
519 self.progress_token.as_ref()
520 }
521
522 pub fn is_cancelled(&self) -> bool {
524 self.cancelled.load(Ordering::Relaxed)
525 }
526
527 pub fn cancel(&self) {
529 self.cancelled.store(true, Ordering::Relaxed);
530 }
531
532 pub fn cancellation_token(&self) -> CancellationToken {
534 CancellationToken {
535 cancelled: self.cancelled.clone(),
536 }
537 }
538
539 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
543 let Some(token) = &self.progress_token else {
544 return;
545 };
546 let Some(tx) = &self.notification_tx else {
547 return;
548 };
549
550 let params = ProgressParams {
551 progress_token: token.clone(),
552 progress,
553 total,
554 message: message.map(|s| s.to_string()),
555 meta: None,
556 };
557
558 let _ = tx.try_send(ServerNotification::Progress(params));
560 }
561
562 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
566 let Some(token) = &self.progress_token else {
567 return;
568 };
569 let Some(tx) = &self.notification_tx else {
570 return;
571 };
572
573 let params = ProgressParams {
574 progress_token: token.clone(),
575 progress,
576 total,
577 message: message.map(|s| s.to_string()),
578 meta: None,
579 };
580
581 let _ = tx.try_send(ServerNotification::Progress(params));
582 }
583
584 pub fn send_log(&self, params: LoggingMessageParams) {
601 let Some(tx) = &self.notification_tx else {
602 return;
603 };
604
605 if let Some(min_level) = &self.min_log_level
610 && let Ok(min) = min_level.read()
611 && params.level > *min
612 {
613 return;
614 }
615
616 let _ = tx.try_send(ServerNotification::LogMessage(params));
617 }
618
619 pub fn can_sample(&self) -> bool {
624 self.client_requester.is_some()
625 }
626
627 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
651 let requester = self.client_requester.as_ref().ok_or_else(|| {
652 Error::Internal("Sampling not available: no client requester configured".to_string())
653 })?;
654
655 requester.sample(params).await
656 }
657
658 pub fn can_elicit(&self) -> bool {
664 self.client_requester.is_some()
665 }
666
667 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
699 let requester = self.client_requester.as_ref().ok_or_else(|| {
700 Error::Internal("Elicitation not available: no client requester configured".to_string())
701 })?;
702
703 requester.elicit(ElicitRequestParams::Form(params)).await
704 }
705
706 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
736 let requester = self.client_requester.as_ref().ok_or_else(|| {
737 Error::Internal("Elicitation not available: no client requester configured".to_string())
738 })?;
739
740 requester.elicit(ElicitRequestParams::Url(params)).await
741 }
742
743 pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
767 use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
768
769 let params = ElicitFormParams {
770 mode: Some(ElicitMode::Form),
771 message: message.into(),
772 requested_schema: ElicitFormSchema::new().boolean_field_with_default(
773 "confirm",
774 Some("Confirm this action"),
775 true,
776 false,
777 ),
778 meta: None,
779 };
780
781 let result = self.elicit_form(params).await?;
782 Ok(result.action == ElicitAction::Accept)
783 }
784
785 pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<ListTasksResult> {
795 let params = ListTasksParams {
796 status,
797 cursor: None,
798 meta: None,
799 };
800 let value = self
801 .request_raw("tasks/list", serde_json::to_value(¶ms)?)
802 .await?;
803 serde_json::from_value(value)
804 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/list: {e}")))
805 }
806
807 pub async fn get_task_info(&self, task_id: impl Into<String>) -> Result<TaskObject> {
812 let params = GetTaskInfoParams {
813 task_id: task_id.into(),
814 meta: None,
815 };
816 let value = self
817 .request_raw("tasks/get", serde_json::to_value(¶ms)?)
818 .await?;
819 serde_json::from_value(value)
820 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/get: {e}")))
821 }
822
823 pub async fn get_task_result(&self, task_id: impl Into<String>) -> Result<CallToolResult> {
831 let params = GetTaskResultParams {
832 task_id: task_id.into(),
833 meta: None,
834 };
835 let value = self
836 .request_raw("tasks/result", serde_json::to_value(¶ms)?)
837 .await?;
838 serde_json::from_value(value)
839 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/result: {e}")))
840 }
841
842 pub async fn cancel_task(
847 &self,
848 task_id: impl Into<String>,
849 reason: Option<String>,
850 ) -> Result<TaskObject> {
851 let params = CancelTaskParams {
852 task_id: task_id.into(),
853 reason,
854 meta: None,
855 };
856 let value = self
857 .request_raw("tasks/cancel", serde_json::to_value(¶ms)?)
858 .await?;
859 serde_json::from_value(value)
860 .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/cancel: {e}")))
861 }
862
863 pub async fn request_raw(
869 &self,
870 method: &str,
871 params: serde_json::Value,
872 ) -> Result<serde_json::Value> {
873 let requester = self.client_requester.as_ref().ok_or_else(|| {
874 Error::Internal(
875 "Client request not available: no client requester configured".to_string(),
876 )
877 })?;
878 requester.request(method.to_string(), params).await
879 }
880}
881
882#[derive(Clone, Debug)]
884pub struct CancellationToken {
885 cancelled: Arc<AtomicBool>,
886}
887
888impl CancellationToken {
889 pub fn is_cancelled(&self) -> bool {
891 self.cancelled.load(Ordering::Relaxed)
892 }
893
894 pub fn cancel(&self) {
896 self.cancelled.store(true, Ordering::Relaxed);
897 }
898}
899
900#[derive(Default)]
902pub struct RequestContextBuilder {
903 request_id: Option<RequestId>,
904 progress_token: Option<ProgressToken>,
905 notification_tx: Option<NotificationSender>,
906 client_requester: Option<ClientRequesterHandle>,
907 min_log_level: Option<Arc<RwLock<LogLevel>>>,
908}
909
910impl RequestContextBuilder {
911 pub fn new() -> Self {
913 Self::default()
914 }
915
916 pub fn request_id(mut self, id: RequestId) -> Self {
918 self.request_id = Some(id);
919 self
920 }
921
922 pub fn progress_token(mut self, token: ProgressToken) -> Self {
924 self.progress_token = Some(token);
925 self
926 }
927
928 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
930 self.notification_tx = Some(tx);
931 self
932 }
933
934 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
936 self.client_requester = Some(requester);
937 self
938 }
939
940 pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
942 self.min_log_level = Some(level);
943 self
944 }
945
946 pub fn build(self) -> RequestContext {
950 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
951 if let Some(token) = self.progress_token {
952 ctx = ctx.with_progress_token(token);
953 }
954 if let Some(tx) = self.notification_tx {
955 ctx = ctx.with_notification_sender(tx);
956 }
957 if let Some(requester) = self.client_requester {
958 ctx = ctx.with_client_requester(requester);
959 }
960 if let Some(level) = self.min_log_level {
961 ctx = ctx.with_min_log_level(level);
962 }
963 ctx
964 }
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970
971 #[test]
972 fn test_cancellation() {
973 let ctx = RequestContext::new(RequestId::Number(1));
974 assert!(!ctx.is_cancelled());
975
976 let token = ctx.cancellation_token();
977 assert!(!token.is_cancelled());
978
979 ctx.cancel();
980 assert!(ctx.is_cancelled());
981 assert!(token.is_cancelled());
982 }
983
984 #[tokio::test]
985 async fn test_progress_reporting() {
986 let (tx, mut rx) = notification_channel(10);
987
988 let ctx = RequestContext::new(RequestId::Number(1))
989 .with_progress_token(ProgressToken::Number(42))
990 .with_notification_sender(tx);
991
992 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
993 .await;
994
995 let notification = rx.recv().await.unwrap();
996 match notification {
997 ServerNotification::Progress(params) => {
998 assert_eq!(params.progress, 50.0);
999 assert_eq!(params.total, Some(100.0));
1000 assert_eq!(params.message.as_deref(), Some("Halfway"));
1001 }
1002 _ => panic!("Expected Progress notification"),
1003 }
1004 }
1005
1006 #[tokio::test]
1007 async fn test_progress_no_token() {
1008 let (tx, mut rx) = notification_channel(10);
1009
1010 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1012
1013 ctx.report_progress(50.0, Some(100.0), None).await;
1014
1015 assert!(rx.try_recv().is_err());
1017 }
1018
1019 #[test]
1020 fn test_builder() {
1021 let (tx, _rx) = notification_channel(10);
1022
1023 let ctx = RequestContextBuilder::new()
1024 .request_id(RequestId::String("req-1".to_string()))
1025 .progress_token(ProgressToken::String("prog-1".to_string()))
1026 .notification_sender(tx)
1027 .build();
1028
1029 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
1030 assert!(ctx.progress_token().is_some());
1031 }
1032
1033 #[test]
1034 fn test_can_sample_without_requester() {
1035 let ctx = RequestContext::new(RequestId::Number(1));
1036 assert!(!ctx.can_sample());
1037 }
1038
1039 #[test]
1040 fn test_can_sample_with_requester() {
1041 let (request_tx, _rx) = outgoing_request_channel(10);
1042 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1043
1044 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1045 assert!(ctx.can_sample());
1046 }
1047
1048 #[tokio::test]
1049 async fn test_sample_without_requester_fails() {
1050 use crate::protocol::{CreateMessageParams, SamplingMessage};
1051
1052 let ctx = RequestContext::new(RequestId::Number(1));
1053 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
1054
1055 let result = ctx.sample(params).await;
1056 assert!(result.is_err());
1057 assert!(
1058 result
1059 .unwrap_err()
1060 .to_string()
1061 .contains("Sampling not available")
1062 );
1063 }
1064
1065 #[test]
1066 fn test_builder_with_client_requester() {
1067 let (request_tx, _rx) = outgoing_request_channel(10);
1068 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1069
1070 let ctx = RequestContextBuilder::new()
1071 .request_id(RequestId::Number(1))
1072 .client_requester(requester)
1073 .build();
1074
1075 assert!(ctx.can_sample());
1076 }
1077
1078 #[test]
1079 fn test_can_elicit_without_requester() {
1080 let ctx = RequestContext::new(RequestId::Number(1));
1081 assert!(!ctx.can_elicit());
1082 }
1083
1084 #[test]
1085 fn test_can_elicit_with_requester() {
1086 let (request_tx, _rx) = outgoing_request_channel(10);
1087 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1088
1089 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1090 assert!(ctx.can_elicit());
1091 }
1092
1093 #[tokio::test]
1094 async fn test_elicit_form_without_requester_fails() {
1095 use crate::protocol::{ElicitFormSchema, ElicitMode};
1096
1097 let ctx = RequestContext::new(RequestId::Number(1));
1098 let params = ElicitFormParams {
1099 mode: Some(ElicitMode::Form),
1100 message: "Enter details".to_string(),
1101 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
1102 meta: None,
1103 };
1104
1105 let result = ctx.elicit_form(params).await;
1106 assert!(result.is_err());
1107 assert!(
1108 result
1109 .unwrap_err()
1110 .to_string()
1111 .contains("Elicitation not available")
1112 );
1113 }
1114
1115 #[tokio::test]
1116 async fn test_elicit_url_without_requester_fails() {
1117 use crate::protocol::ElicitMode;
1118
1119 let ctx = RequestContext::new(RequestId::Number(1));
1120 let params = ElicitUrlParams {
1121 mode: Some(ElicitMode::Url),
1122 elicitation_id: "test-123".to_string(),
1123 message: "Please authorize".to_string(),
1124 url: "https://example.com/auth".to_string(),
1125 meta: None,
1126 };
1127
1128 let result = ctx.elicit_url(params).await;
1129 assert!(result.is_err());
1130 assert!(
1131 result
1132 .unwrap_err()
1133 .to_string()
1134 .contains("Elicitation not available")
1135 );
1136 }
1137
1138 #[tokio::test]
1139 async fn test_confirm_without_requester_fails() {
1140 let ctx = RequestContext::new(RequestId::Number(1));
1141
1142 let result = ctx.confirm("Are you sure?").await;
1143 assert!(result.is_err());
1144 assert!(
1145 result
1146 .unwrap_err()
1147 .to_string()
1148 .contains("Elicitation not available")
1149 );
1150 }
1151
1152 #[tokio::test]
1153 async fn test_send_log_filtered_by_level() {
1154 let (tx, mut rx) = notification_channel(10);
1155 let min_level = Arc::new(RwLock::new(LogLevel::Warning));
1156
1157 let ctx = RequestContext::new(RequestId::Number(1))
1158 .with_notification_sender(tx)
1159 .with_min_log_level(min_level.clone());
1160
1161 ctx.send_log(LoggingMessageParams::new(
1163 LogLevel::Error,
1164 serde_json::Value::Null,
1165 ));
1166 let msg = rx.try_recv();
1167 assert!(msg.is_ok(), "Error should pass through Warning filter");
1168
1169 ctx.send_log(LoggingMessageParams::new(
1171 LogLevel::Warning,
1172 serde_json::Value::Null,
1173 ));
1174 let msg = rx.try_recv();
1175 assert!(msg.is_ok(), "Warning should pass through Warning filter");
1176
1177 ctx.send_log(LoggingMessageParams::new(
1179 LogLevel::Info,
1180 serde_json::Value::Null,
1181 ));
1182 let msg = rx.try_recv();
1183 assert!(msg.is_err(), "Info should be filtered by Warning filter");
1184
1185 ctx.send_log(LoggingMessageParams::new(
1187 LogLevel::Debug,
1188 serde_json::Value::Null,
1189 ));
1190 let msg = rx.try_recv();
1191 assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1192 }
1193
1194 #[tokio::test]
1195 async fn test_send_log_level_updates_dynamically() {
1196 let (tx, mut rx) = notification_channel(10);
1197 let min_level = Arc::new(RwLock::new(LogLevel::Error));
1198
1199 let ctx = RequestContext::new(RequestId::Number(1))
1200 .with_notification_sender(tx)
1201 .with_min_log_level(min_level.clone());
1202
1203 ctx.send_log(LoggingMessageParams::new(
1205 LogLevel::Info,
1206 serde_json::Value::Null,
1207 ));
1208 assert!(
1209 rx.try_recv().is_err(),
1210 "Info should be filtered at Error level"
1211 );
1212
1213 *min_level.write().unwrap() = LogLevel::Debug;
1215
1216 ctx.send_log(LoggingMessageParams::new(
1218 LogLevel::Info,
1219 serde_json::Value::Null,
1220 ));
1221 assert!(
1222 rx.try_recv().is_ok(),
1223 "Info should pass through after level changed to Debug"
1224 );
1225 }
1226
1227 #[tokio::test]
1228 async fn test_send_log_no_min_level_sends_all() {
1229 let (tx, mut rx) = notification_channel(10);
1230
1231 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1233
1234 ctx.send_log(LoggingMessageParams::new(
1235 LogLevel::Debug,
1236 serde_json::Value::Null,
1237 ));
1238 assert!(
1239 rx.try_recv().is_ok(),
1240 "Debug should pass when no min level is set"
1241 );
1242 }
1243
1244 fn make_task_object(id: &str, status: TaskStatus) -> serde_json::Value {
1245 serde_json::json!({
1246 "taskId": id,
1247 "status": status,
1248 "createdAt": "2026-04-24T00:00:00Z",
1249 "lastUpdatedAt": "2026-04-24T00:00:00Z",
1250 "ttl": null
1251 })
1252 }
1253
1254 fn spawn_mock_client(
1255 mut rx: OutgoingRequestReceiver,
1256 responder: impl Fn(&str, serde_json::Value) -> serde_json::Value + Send + 'static,
1257 ) {
1258 tokio::spawn(async move {
1259 while let Some(req) = rx.recv().await {
1260 let response = responder(&req.method, req.params);
1261 let _ = req.response_tx.send(Ok(response));
1262 }
1263 });
1264 }
1265
1266 #[tokio::test]
1267 async fn test_get_task_info_round_trips() {
1268 let (tx, rx) = outgoing_request_channel(10);
1269 spawn_mock_client(rx, |method, params| {
1270 assert_eq!(method, "tasks/get");
1271 let task_id = params["taskId"].as_str().unwrap().to_string();
1272 make_task_object(&task_id, TaskStatus::Working)
1273 });
1274 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1275 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1276
1277 let info = ctx.get_task_info("task-123").await.unwrap();
1278 assert_eq!(info.task_id, "task-123");
1279 assert!(matches!(info.status, TaskStatus::Working));
1280 }
1281
1282 #[tokio::test]
1283 async fn test_list_tasks_round_trips() {
1284 let (tx, rx) = outgoing_request_channel(10);
1285 spawn_mock_client(rx, |method, params| {
1286 assert_eq!(method, "tasks/list");
1287 assert_eq!(params["status"], serde_json::json!("working"));
1289 serde_json::json!({
1290 "tasks": [
1291 make_task_object("task-1", TaskStatus::Working),
1292 make_task_object("task-2", TaskStatus::Working),
1293 ]
1294 })
1295 });
1296 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1297 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1298
1299 let result = ctx.list_tasks(Some(TaskStatus::Working)).await.unwrap();
1300 assert_eq!(result.tasks.len(), 2);
1301 assert_eq!(result.tasks[0].task_id, "task-1");
1302 }
1303
1304 #[tokio::test]
1305 async fn test_cancel_task_forwards_reason() {
1306 let (tx, rx) = outgoing_request_channel(10);
1307 spawn_mock_client(rx, |method, params| {
1308 assert_eq!(method, "tasks/cancel");
1309 assert_eq!(params["reason"], serde_json::json!("user requested"));
1310 let task_id = params["taskId"].as_str().unwrap().to_string();
1311 make_task_object(&task_id, TaskStatus::Cancelled)
1312 });
1313 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1314 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1315
1316 let task = ctx
1317 .cancel_task("task-99", Some("user requested".into()))
1318 .await
1319 .unwrap();
1320 assert_eq!(task.task_id, "task-99");
1321 assert!(matches!(task.status, TaskStatus::Cancelled));
1322 }
1323
1324 #[tokio::test]
1325 async fn test_get_task_info_without_requester_fails() {
1326 let ctx = RequestContext::new(RequestId::Number(1));
1327 let result = ctx.get_task_info("task-1").await;
1328 assert!(result.is_err());
1329 assert!(
1330 result
1331 .unwrap_err()
1332 .to_string()
1333 .contains("Client request not available")
1334 );
1335 }
1336
1337 #[tokio::test]
1338 async fn test_default_request_impl_errors() {
1339 struct OnlySampleAndElicit;
1342
1343 #[async_trait]
1344 impl ClientRequester for OnlySampleAndElicit {
1345 async fn sample(&self, _: CreateMessageParams) -> Result<CreateMessageResult> {
1346 unreachable!()
1347 }
1348 async fn elicit(&self, _: ElicitRequestParams) -> Result<ElicitResult> {
1349 unreachable!()
1350 }
1351 }
1352
1353 let requester: ClientRequesterHandle = Arc::new(OnlySampleAndElicit);
1354 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1355
1356 let err = ctx.get_task_info("x").await.unwrap_err();
1357 assert!(err.to_string().contains("does not support arbitrary"));
1358 }
1359}