Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: ElicitMode::Form,
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::Arc;
73use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86pub enum ServerNotification {
87    /// Progress update for a request
88    Progress(ProgressParams),
89    /// Log message notification
90    LogMessage(LoggingMessageParams),
91    /// A subscribed resource has been updated
92    ResourceUpdated {
93        /// The URI of the updated resource
94        uri: String,
95    },
96    /// The list of available resources has changed
97    ResourcesListChanged,
98}
99
100/// Sender for server notifications
101pub type NotificationSender = mpsc::Sender<ServerNotification>;
102
103/// Receiver for server notifications
104pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
105
106/// Create a new notification channel
107pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
108    mpsc::channel(buffer)
109}
110
111// =============================================================================
112// Client Requests (Server -> Client)
113// =============================================================================
114
115/// Trait for sending requests from server to client
116///
117/// This enables bidirectional communication where the server can request
118/// actions from the client, such as sampling (LLM requests) and elicitation
119/// (user input requests).
120#[async_trait]
121pub trait ClientRequester: Send + Sync {
122    /// Send a sampling request to the client
123    ///
124    /// Returns the LLM completion result from the client.
125    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
126
127    /// Send an elicitation request to the client
128    ///
129    /// This requests user input from the client. The request can be either
130    /// form-based (structured input) or URL-based (redirect to external URL).
131    ///
132    /// Returns the elicitation result with the user's action and any submitted data.
133    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
134}
135
136/// A clonable handle to a client requester
137pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
138
139/// Outgoing request to be sent to the client
140#[derive(Debug)]
141pub struct OutgoingRequest {
142    /// The JSON-RPC request ID
143    pub id: RequestId,
144    /// The method name
145    pub method: String,
146    /// The request parameters as JSON
147    pub params: serde_json::Value,
148    /// Channel to send the response back
149    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
150}
151
152/// Sender for outgoing requests to the client
153pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
154
155/// Receiver for outgoing requests (used by transport)
156pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
157
158/// Create a new outgoing request channel
159pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
160    mpsc::channel(buffer)
161}
162
163/// A client requester implementation that sends requests through a channel
164#[derive(Clone)]
165pub struct ChannelClientRequester {
166    request_tx: OutgoingRequestSender,
167    next_id: Arc<AtomicI64>,
168}
169
170impl ChannelClientRequester {
171    /// Create a new channel-based client requester
172    pub fn new(request_tx: OutgoingRequestSender) -> Self {
173        Self {
174            request_tx,
175            next_id: Arc::new(AtomicI64::new(1)),
176        }
177    }
178
179    fn next_request_id(&self) -> RequestId {
180        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
181        RequestId::Number(id)
182    }
183}
184
185#[async_trait]
186impl ClientRequester for ChannelClientRequester {
187    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
188        let id = self.next_request_id();
189        let params_json = serde_json::to_value(&params)
190            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
191
192        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
193
194        let request = OutgoingRequest {
195            id: id.clone(),
196            method: "sampling/createMessage".to_string(),
197            params: params_json,
198            response_tx,
199        };
200
201        self.request_tx
202            .send(request)
203            .await
204            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
205
206        let response = response_rx.await.map_err(|_| {
207            Error::Internal("Failed to receive response: channel closed".to_string())
208        })??;
209
210        serde_json::from_value(response)
211            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
212    }
213
214    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
215        let id = self.next_request_id();
216        let params_json = serde_json::to_value(&params)
217            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
218
219        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
220
221        let request = OutgoingRequest {
222            id: id.clone(),
223            method: "elicitation/create".to_string(),
224            params: params_json,
225            response_tx,
226        };
227
228        self.request_tx
229            .send(request)
230            .await
231            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
232
233        let response = response_rx.await.map_err(|_| {
234            Error::Internal("Failed to receive response: channel closed".to_string())
235        })??;
236
237        serde_json::from_value(response)
238            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
239    }
240}
241
242/// Context for a request, providing progress, cancellation, and client request support
243#[derive(Clone)]
244pub struct RequestContext {
245    /// The request ID
246    request_id: RequestId,
247    /// Progress token (if provided by client)
248    progress_token: Option<ProgressToken>,
249    /// Cancellation flag
250    cancelled: Arc<AtomicBool>,
251    /// Channel for sending notifications
252    notification_tx: Option<NotificationSender>,
253    /// Handle for sending requests to the client (for sampling, etc.)
254    client_requester: Option<ClientRequesterHandle>,
255}
256
257impl std::fmt::Debug for RequestContext {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("RequestContext")
260            .field("request_id", &self.request_id)
261            .field("progress_token", &self.progress_token)
262            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
263            .finish()
264    }
265}
266
267impl RequestContext {
268    /// Create a new request context
269    pub fn new(request_id: RequestId) -> Self {
270        Self {
271            request_id,
272            progress_token: None,
273            cancelled: Arc::new(AtomicBool::new(false)),
274            notification_tx: None,
275            client_requester: None,
276        }
277    }
278
279    /// Set the progress token
280    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
281        self.progress_token = Some(token);
282        self
283    }
284
285    /// Set the notification sender
286    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
287        self.notification_tx = Some(tx);
288        self
289    }
290
291    /// Set the client requester for server-to-client requests
292    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
293        self.client_requester = Some(requester);
294        self
295    }
296
297    /// Get the request ID
298    pub fn request_id(&self) -> &RequestId {
299        &self.request_id
300    }
301
302    /// Get the progress token (if any)
303    pub fn progress_token(&self) -> Option<&ProgressToken> {
304        self.progress_token.as_ref()
305    }
306
307    /// Check if the request has been cancelled
308    pub fn is_cancelled(&self) -> bool {
309        self.cancelled.load(Ordering::Relaxed)
310    }
311
312    /// Mark the request as cancelled
313    pub fn cancel(&self) {
314        self.cancelled.store(true, Ordering::Relaxed);
315    }
316
317    /// Get a cancellation token that can be shared
318    pub fn cancellation_token(&self) -> CancellationToken {
319        CancellationToken {
320            cancelled: self.cancelled.clone(),
321        }
322    }
323
324    /// Report progress to the client
325    ///
326    /// This is a no-op if no progress token was provided or no notification sender is configured.
327    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
328        let Some(token) = &self.progress_token else {
329            return;
330        };
331        let Some(tx) = &self.notification_tx else {
332            return;
333        };
334
335        let params = ProgressParams {
336            progress_token: token.clone(),
337            progress,
338            total,
339            message: message.map(|s| s.to_string()),
340        };
341
342        // Best effort - don't block if channel is full
343        let _ = tx.try_send(ServerNotification::Progress(params));
344    }
345
346    /// Report progress synchronously (non-async version)
347    ///
348    /// This is a no-op if no progress token was provided or no notification sender is configured.
349    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
350        let Some(token) = &self.progress_token else {
351            return;
352        };
353        let Some(tx) = &self.notification_tx else {
354            return;
355        };
356
357        let params = ProgressParams {
358            progress_token: token.clone(),
359            progress,
360            total,
361            message: message.map(|s| s.to_string()),
362        };
363
364        let _ = tx.try_send(ServerNotification::Progress(params));
365    }
366
367    /// Check if sampling is available
368    ///
369    /// Returns true if a client requester is configured and the transport
370    /// supports bidirectional communication.
371    pub fn can_sample(&self) -> bool {
372        self.client_requester.is_some()
373    }
374
375    /// Request an LLM completion from the client
376    ///
377    /// This sends a `sampling/createMessage` request to the client and waits
378    /// for the response. The client is expected to forward this to an LLM
379    /// and return the result.
380    ///
381    /// Returns an error if sampling is not available (no client requester configured).
382    ///
383    /// # Example
384    ///
385    /// ```rust,ignore
386    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
387    ///
388    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
389    ///     let params = CreateMessageParams::new(
390    ///         vec![SamplingMessage::user("Summarize: ...")],
391    ///         500,
392    ///     );
393    ///
394    ///     let result = ctx.sample(params).await?;
395    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
396    /// }
397    /// ```
398    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
399        let requester = self.client_requester.as_ref().ok_or_else(|| {
400            Error::Internal("Sampling not available: no client requester configured".to_string())
401        })?;
402
403        requester.sample(params).await
404    }
405
406    /// Check if elicitation is available
407    ///
408    /// Returns true if a client requester is configured and the transport
409    /// supports bidirectional communication. Note that this only checks if
410    /// the mechanism is available, not whether the client supports elicitation.
411    pub fn can_elicit(&self) -> bool {
412        self.client_requester.is_some()
413    }
414
415    /// Request user input via a form from the client
416    ///
417    /// This sends an `elicitation/create` request to the client with a form schema.
418    /// The client renders the form to the user and returns their response.
419    ///
420    /// Returns an error if elicitation is not available (no client requester configured).
421    ///
422    /// # Example
423    ///
424    /// ```rust,ignore
425    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
426    ///
427    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
428    ///     let params = ElicitFormParams {
429    ///         mode: ElicitMode::Form,
430    ///         message: "Please enter your details".to_string(),
431    ///         requested_schema: ElicitFormSchema::new()
432    ///             .string_field("name", Some("Your name"), true),
433    ///         meta: None,
434    ///     };
435    ///
436    ///     let result = ctx.elicit_form(params).await?;
437    ///     match result.action {
438    ///         ElicitAction::Accept => {
439    ///             // Use result.content
440    ///             Ok(CallToolResult::text("Got your input!"))
441    ///         }
442    ///         _ => Ok(CallToolResult::text("User declined"))
443    ///     }
444    /// }
445    /// ```
446    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
447        let requester = self.client_requester.as_ref().ok_or_else(|| {
448            Error::Internal("Elicitation not available: no client requester configured".to_string())
449        })?;
450
451        requester.elicit(ElicitRequestParams::Form(params)).await
452    }
453
454    /// Request user input via URL redirect from the client
455    ///
456    /// This sends an `elicitation/create` request to the client with a URL.
457    /// The client directs the user to the URL for out-of-band input collection.
458    /// The server receives the result via a callback notification.
459    ///
460    /// Returns an error if elicitation is not available (no client requester configured).
461    ///
462    /// # Example
463    ///
464    /// ```rust,ignore
465    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
466    ///
467    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
468    ///     let params = ElicitUrlParams {
469    ///         mode: ElicitMode::Url,
470    ///         elicitation_id: "unique-id-123".to_string(),
471    ///         message: "Please authorize via the link".to_string(),
472    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
473    ///         meta: None,
474    ///     };
475    ///
476    ///     let result = ctx.elicit_url(params).await?;
477    ///     match result.action {
478    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
479    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
480    ///     }
481    /// }
482    /// ```
483    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
484        let requester = self.client_requester.as_ref().ok_or_else(|| {
485            Error::Internal("Elicitation not available: no client requester configured".to_string())
486        })?;
487
488        requester.elicit(ElicitRequestParams::Url(params)).await
489    }
490}
491
492/// A token that can be used to check for cancellation
493#[derive(Clone, Debug)]
494pub struct CancellationToken {
495    cancelled: Arc<AtomicBool>,
496}
497
498impl CancellationToken {
499    /// Check if cancellation has been requested
500    pub fn is_cancelled(&self) -> bool {
501        self.cancelled.load(Ordering::Relaxed)
502    }
503
504    /// Request cancellation
505    pub fn cancel(&self) {
506        self.cancelled.store(true, Ordering::Relaxed);
507    }
508}
509
510/// Builder for creating request contexts
511#[derive(Default)]
512pub struct RequestContextBuilder {
513    request_id: Option<RequestId>,
514    progress_token: Option<ProgressToken>,
515    notification_tx: Option<NotificationSender>,
516    client_requester: Option<ClientRequesterHandle>,
517}
518
519impl RequestContextBuilder {
520    /// Create a new builder
521    pub fn new() -> Self {
522        Self::default()
523    }
524
525    /// Set the request ID
526    pub fn request_id(mut self, id: RequestId) -> Self {
527        self.request_id = Some(id);
528        self
529    }
530
531    /// Set the progress token
532    pub fn progress_token(mut self, token: ProgressToken) -> Self {
533        self.progress_token = Some(token);
534        self
535    }
536
537    /// Set the notification sender
538    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
539        self.notification_tx = Some(tx);
540        self
541    }
542
543    /// Set the client requester for server-to-client requests
544    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
545        self.client_requester = Some(requester);
546        self
547    }
548
549    /// Build the request context
550    ///
551    /// Panics if request_id is not set.
552    pub fn build(self) -> RequestContext {
553        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
554        if let Some(token) = self.progress_token {
555            ctx = ctx.with_progress_token(token);
556        }
557        if let Some(tx) = self.notification_tx {
558            ctx = ctx.with_notification_sender(tx);
559        }
560        if let Some(requester) = self.client_requester {
561            ctx = ctx.with_client_requester(requester);
562        }
563        ctx
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn test_cancellation() {
573        let ctx = RequestContext::new(RequestId::Number(1));
574        assert!(!ctx.is_cancelled());
575
576        let token = ctx.cancellation_token();
577        assert!(!token.is_cancelled());
578
579        ctx.cancel();
580        assert!(ctx.is_cancelled());
581        assert!(token.is_cancelled());
582    }
583
584    #[tokio::test]
585    async fn test_progress_reporting() {
586        let (tx, mut rx) = notification_channel(10);
587
588        let ctx = RequestContext::new(RequestId::Number(1))
589            .with_progress_token(ProgressToken::Number(42))
590            .with_notification_sender(tx);
591
592        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
593            .await;
594
595        let notification = rx.recv().await.unwrap();
596        match notification {
597            ServerNotification::Progress(params) => {
598                assert_eq!(params.progress, 50.0);
599                assert_eq!(params.total, Some(100.0));
600                assert_eq!(params.message.as_deref(), Some("Halfway"));
601            }
602            _ => panic!("Expected Progress notification"),
603        }
604    }
605
606    #[tokio::test]
607    async fn test_progress_no_token() {
608        let (tx, mut rx) = notification_channel(10);
609
610        // No progress token - should be a no-op
611        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
612
613        ctx.report_progress(50.0, Some(100.0), None).await;
614
615        // Channel should be empty
616        assert!(rx.try_recv().is_err());
617    }
618
619    #[test]
620    fn test_builder() {
621        let (tx, _rx) = notification_channel(10);
622
623        let ctx = RequestContextBuilder::new()
624            .request_id(RequestId::String("req-1".to_string()))
625            .progress_token(ProgressToken::String("prog-1".to_string()))
626            .notification_sender(tx)
627            .build();
628
629        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
630        assert!(ctx.progress_token().is_some());
631    }
632
633    #[test]
634    fn test_can_sample_without_requester() {
635        let ctx = RequestContext::new(RequestId::Number(1));
636        assert!(!ctx.can_sample());
637    }
638
639    #[test]
640    fn test_can_sample_with_requester() {
641        let (request_tx, _rx) = outgoing_request_channel(10);
642        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
643
644        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
645        assert!(ctx.can_sample());
646    }
647
648    #[tokio::test]
649    async fn test_sample_without_requester_fails() {
650        use crate::protocol::{CreateMessageParams, SamplingMessage};
651
652        let ctx = RequestContext::new(RequestId::Number(1));
653        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
654
655        let result = ctx.sample(params).await;
656        assert!(result.is_err());
657        assert!(
658            result
659                .unwrap_err()
660                .to_string()
661                .contains("Sampling not available")
662        );
663    }
664
665    #[test]
666    fn test_builder_with_client_requester() {
667        let (request_tx, _rx) = outgoing_request_channel(10);
668        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
669
670        let ctx = RequestContextBuilder::new()
671            .request_id(RequestId::Number(1))
672            .client_requester(requester)
673            .build();
674
675        assert!(ctx.can_sample());
676    }
677
678    #[test]
679    fn test_can_elicit_without_requester() {
680        let ctx = RequestContext::new(RequestId::Number(1));
681        assert!(!ctx.can_elicit());
682    }
683
684    #[test]
685    fn test_can_elicit_with_requester() {
686        let (request_tx, _rx) = outgoing_request_channel(10);
687        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
688
689        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
690        assert!(ctx.can_elicit());
691    }
692
693    #[tokio::test]
694    async fn test_elicit_form_without_requester_fails() {
695        use crate::protocol::{ElicitFormSchema, ElicitMode};
696
697        let ctx = RequestContext::new(RequestId::Number(1));
698        let params = ElicitFormParams {
699            mode: ElicitMode::Form,
700            message: "Enter details".to_string(),
701            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
702            meta: None,
703        };
704
705        let result = ctx.elicit_form(params).await;
706        assert!(result.is_err());
707        assert!(
708            result
709                .unwrap_err()
710                .to_string()
711                .contains("Elicitation not available")
712        );
713    }
714
715    #[tokio::test]
716    async fn test_elicit_url_without_requester_fails() {
717        use crate::protocol::ElicitMode;
718
719        let ctx = RequestContext::new(RequestId::Number(1));
720        let params = ElicitUrlParams {
721            mode: ElicitMode::Url,
722            elicitation_id: "test-123".to_string(),
723            message: "Please authorize".to_string(),
724            url: "https://example.com/auth".to_string(),
725            meta: None,
726        };
727
728        let result = ctx.elicit_url(params).await;
729        assert!(result.is_err());
730        assert!(
731            result
732                .unwrap_err()
733                .to_string()
734                .contains("Elicitation not available")
735        );
736    }
737}