1use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80 CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81 ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84#[derive(Debug, Clone)]
86pub enum ServerNotification {
87 Progress(ProgressParams),
89 LogMessage(LoggingMessageParams),
91 ResourceUpdated {
93 uri: String,
95 },
96 ResourcesListChanged,
98}
99
100pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108 mpsc::channel(buffer)
109}
110
111#[async_trait]
121pub trait ClientRequester: Send + Sync {
122 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139#[derive(Debug)]
141pub struct OutgoingRequest {
142 pub id: RequestId,
144 pub method: String,
146 pub params: serde_json::Value,
148 pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160 mpsc::channel(buffer)
161}
162
163#[derive(Clone)]
165pub struct ChannelClientRequester {
166 request_tx: OutgoingRequestSender,
167 next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171 pub fn new(request_tx: OutgoingRequestSender) -> Self {
173 Self {
174 request_tx,
175 next_id: Arc::new(AtomicI64::new(1)),
176 }
177 }
178
179 fn next_request_id(&self) -> RequestId {
180 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181 RequestId::Number(id)
182 }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187 async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188 let id = self.next_request_id();
189 let params_json = serde_json::to_value(¶ms)
190 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194 let request = OutgoingRequest {
195 id: id.clone(),
196 method: "sampling/createMessage".to_string(),
197 params: params_json,
198 response_tx,
199 };
200
201 self.request_tx
202 .send(request)
203 .await
204 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206 let response = response_rx.await.map_err(|_| {
207 Error::Internal("Failed to receive response: channel closed".to_string())
208 })??;
209
210 serde_json::from_value(response)
211 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212 }
213
214 async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215 let id = self.next_request_id();
216 let params_json = serde_json::to_value(¶ms)
217 .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221 let request = OutgoingRequest {
222 id: id.clone(),
223 method: "elicitation/create".to_string(),
224 params: params_json,
225 response_tx,
226 };
227
228 self.request_tx
229 .send(request)
230 .await
231 .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233 let response = response_rx.await.map_err(|_| {
234 Error::Internal("Failed to receive response: channel closed".to_string())
235 })??;
236
237 serde_json::from_value(response)
238 .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239 }
240}
241
242#[derive(Clone)]
244pub struct RequestContext {
245 request_id: RequestId,
247 progress_token: Option<ProgressToken>,
249 cancelled: Arc<AtomicBool>,
251 notification_tx: Option<NotificationSender>,
253 client_requester: Option<ClientRequesterHandle>,
255}
256
257impl std::fmt::Debug for RequestContext {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("RequestContext")
260 .field("request_id", &self.request_id)
261 .field("progress_token", &self.progress_token)
262 .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
263 .finish()
264 }
265}
266
267impl RequestContext {
268 pub fn new(request_id: RequestId) -> Self {
270 Self {
271 request_id,
272 progress_token: None,
273 cancelled: Arc::new(AtomicBool::new(false)),
274 notification_tx: None,
275 client_requester: None,
276 }
277 }
278
279 pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
281 self.progress_token = Some(token);
282 self
283 }
284
285 pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
287 self.notification_tx = Some(tx);
288 self
289 }
290
291 pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
293 self.client_requester = Some(requester);
294 self
295 }
296
297 pub fn request_id(&self) -> &RequestId {
299 &self.request_id
300 }
301
302 pub fn progress_token(&self) -> Option<&ProgressToken> {
304 self.progress_token.as_ref()
305 }
306
307 pub fn is_cancelled(&self) -> bool {
309 self.cancelled.load(Ordering::Relaxed)
310 }
311
312 pub fn cancel(&self) {
314 self.cancelled.store(true, Ordering::Relaxed);
315 }
316
317 pub fn cancellation_token(&self) -> CancellationToken {
319 CancellationToken {
320 cancelled: self.cancelled.clone(),
321 }
322 }
323
324 pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
328 let Some(token) = &self.progress_token else {
329 return;
330 };
331 let Some(tx) = &self.notification_tx else {
332 return;
333 };
334
335 let params = ProgressParams {
336 progress_token: token.clone(),
337 progress,
338 total,
339 message: message.map(|s| s.to_string()),
340 };
341
342 let _ = tx.try_send(ServerNotification::Progress(params));
344 }
345
346 pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
350 let Some(token) = &self.progress_token else {
351 return;
352 };
353 let Some(tx) = &self.notification_tx else {
354 return;
355 };
356
357 let params = ProgressParams {
358 progress_token: token.clone(),
359 progress,
360 total,
361 message: message.map(|s| s.to_string()),
362 };
363
364 let _ = tx.try_send(ServerNotification::Progress(params));
365 }
366
367 pub fn can_sample(&self) -> bool {
372 self.client_requester.is_some()
373 }
374
375 pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
399 let requester = self.client_requester.as_ref().ok_or_else(|| {
400 Error::Internal("Sampling not available: no client requester configured".to_string())
401 })?;
402
403 requester.sample(params).await
404 }
405
406 pub fn can_elicit(&self) -> bool {
412 self.client_requester.is_some()
413 }
414
415 pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
447 let requester = self.client_requester.as_ref().ok_or_else(|| {
448 Error::Internal("Elicitation not available: no client requester configured".to_string())
449 })?;
450
451 requester.elicit(ElicitRequestParams::Form(params)).await
452 }
453
454 pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
484 let requester = self.client_requester.as_ref().ok_or_else(|| {
485 Error::Internal("Elicitation not available: no client requester configured".to_string())
486 })?;
487
488 requester.elicit(ElicitRequestParams::Url(params)).await
489 }
490}
491
492#[derive(Clone, Debug)]
494pub struct CancellationToken {
495 cancelled: Arc<AtomicBool>,
496}
497
498impl CancellationToken {
499 pub fn is_cancelled(&self) -> bool {
501 self.cancelled.load(Ordering::Relaxed)
502 }
503
504 pub fn cancel(&self) {
506 self.cancelled.store(true, Ordering::Relaxed);
507 }
508}
509
510#[derive(Default)]
512pub struct RequestContextBuilder {
513 request_id: Option<RequestId>,
514 progress_token: Option<ProgressToken>,
515 notification_tx: Option<NotificationSender>,
516 client_requester: Option<ClientRequesterHandle>,
517}
518
519impl RequestContextBuilder {
520 pub fn new() -> Self {
522 Self::default()
523 }
524
525 pub fn request_id(mut self, id: RequestId) -> Self {
527 self.request_id = Some(id);
528 self
529 }
530
531 pub fn progress_token(mut self, token: ProgressToken) -> Self {
533 self.progress_token = Some(token);
534 self
535 }
536
537 pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
539 self.notification_tx = Some(tx);
540 self
541 }
542
543 pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
545 self.client_requester = Some(requester);
546 self
547 }
548
549 pub fn build(self) -> RequestContext {
553 let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
554 if let Some(token) = self.progress_token {
555 ctx = ctx.with_progress_token(token);
556 }
557 if let Some(tx) = self.notification_tx {
558 ctx = ctx.with_notification_sender(tx);
559 }
560 if let Some(requester) = self.client_requester {
561 ctx = ctx.with_client_requester(requester);
562 }
563 ctx
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn test_cancellation() {
573 let ctx = RequestContext::new(RequestId::Number(1));
574 assert!(!ctx.is_cancelled());
575
576 let token = ctx.cancellation_token();
577 assert!(!token.is_cancelled());
578
579 ctx.cancel();
580 assert!(ctx.is_cancelled());
581 assert!(token.is_cancelled());
582 }
583
584 #[tokio::test]
585 async fn test_progress_reporting() {
586 let (tx, mut rx) = notification_channel(10);
587
588 let ctx = RequestContext::new(RequestId::Number(1))
589 .with_progress_token(ProgressToken::Number(42))
590 .with_notification_sender(tx);
591
592 ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
593 .await;
594
595 let notification = rx.recv().await.unwrap();
596 match notification {
597 ServerNotification::Progress(params) => {
598 assert_eq!(params.progress, 50.0);
599 assert_eq!(params.total, Some(100.0));
600 assert_eq!(params.message.as_deref(), Some("Halfway"));
601 }
602 _ => panic!("Expected Progress notification"),
603 }
604 }
605
606 #[tokio::test]
607 async fn test_progress_no_token() {
608 let (tx, mut rx) = notification_channel(10);
609
610 let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
612
613 ctx.report_progress(50.0, Some(100.0), None).await;
614
615 assert!(rx.try_recv().is_err());
617 }
618
619 #[test]
620 fn test_builder() {
621 let (tx, _rx) = notification_channel(10);
622
623 let ctx = RequestContextBuilder::new()
624 .request_id(RequestId::String("req-1".to_string()))
625 .progress_token(ProgressToken::String("prog-1".to_string()))
626 .notification_sender(tx)
627 .build();
628
629 assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
630 assert!(ctx.progress_token().is_some());
631 }
632
633 #[test]
634 fn test_can_sample_without_requester() {
635 let ctx = RequestContext::new(RequestId::Number(1));
636 assert!(!ctx.can_sample());
637 }
638
639 #[test]
640 fn test_can_sample_with_requester() {
641 let (request_tx, _rx) = outgoing_request_channel(10);
642 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
643
644 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
645 assert!(ctx.can_sample());
646 }
647
648 #[tokio::test]
649 async fn test_sample_without_requester_fails() {
650 use crate::protocol::{CreateMessageParams, SamplingMessage};
651
652 let ctx = RequestContext::new(RequestId::Number(1));
653 let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
654
655 let result = ctx.sample(params).await;
656 assert!(result.is_err());
657 assert!(
658 result
659 .unwrap_err()
660 .to_string()
661 .contains("Sampling not available")
662 );
663 }
664
665 #[test]
666 fn test_builder_with_client_requester() {
667 let (request_tx, _rx) = outgoing_request_channel(10);
668 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
669
670 let ctx = RequestContextBuilder::new()
671 .request_id(RequestId::Number(1))
672 .client_requester(requester)
673 .build();
674
675 assert!(ctx.can_sample());
676 }
677
678 #[test]
679 fn test_can_elicit_without_requester() {
680 let ctx = RequestContext::new(RequestId::Number(1));
681 assert!(!ctx.can_elicit());
682 }
683
684 #[test]
685 fn test_can_elicit_with_requester() {
686 let (request_tx, _rx) = outgoing_request_channel(10);
687 let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
688
689 let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
690 assert!(ctx.can_elicit());
691 }
692
693 #[tokio::test]
694 async fn test_elicit_form_without_requester_fails() {
695 use crate::protocol::{ElicitFormSchema, ElicitMode};
696
697 let ctx = RequestContext::new(RequestId::Number(1));
698 let params = ElicitFormParams {
699 mode: ElicitMode::Form,
700 message: "Enter details".to_string(),
701 requested_schema: ElicitFormSchema::new().string_field("name", None, true),
702 meta: None,
703 };
704
705 let result = ctx.elicit_form(params).await;
706 assert!(result.is_err());
707 assert!(
708 result
709 .unwrap_err()
710 .to_string()
711 .contains("Elicitation not available")
712 );
713 }
714
715 #[tokio::test]
716 async fn test_elicit_url_without_requester_fails() {
717 use crate::protocol::ElicitMode;
718
719 let ctx = RequestContext::new(RequestId::Number(1));
720 let params = ElicitUrlParams {
721 mode: ElicitMode::Url,
722 elicitation_id: "test-123".to_string(),
723 message: "Please authorize".to_string(),
724 url: "https://example.com/auth".to_string(),
725 meta: None,
726 };
727
728 let result = ctx.elicit_url(params).await;
729 assert!(result.is_err());
730 assert!(
731 result
732 .unwrap_err()
733 .to_string()
734 .contains("Elicitation not available")
735 );
736 }
737}