Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: Some(ElicitMode::Form),
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CallToolResult, CancelTaskParams, CreateMessageParams, CreateMessageResult, ElicitFormParams,
81    ElicitRequestParams, ElicitResult, ElicitUrlParams, GetTaskInfoParams, GetTaskResultParams,
82    ListTasksParams, ListTasksResult, LogLevel, LoggingMessageParams, ProgressParams,
83    ProgressToken, RequestId, TaskObject, TaskStatus,
84};
85
86/// A notification to be sent to the client
87#[derive(Debug, Clone)]
88#[non_exhaustive]
89pub enum ServerNotification {
90    /// Progress update for a request
91    Progress(ProgressParams),
92    /// Log message notification
93    LogMessage(LoggingMessageParams),
94    /// A subscribed resource has been updated
95    ResourceUpdated {
96        /// The URI of the updated resource
97        uri: String,
98    },
99    /// The list of available resources has changed
100    ResourcesListChanged,
101    /// The list of available tools has changed
102    ToolsListChanged,
103    /// The list of available prompts has changed
104    PromptsListChanged,
105    /// Task status has changed
106    TaskStatusChanged(crate::protocol::TaskStatusParams),
107}
108
109/// Sender for server notifications
110pub type NotificationSender = mpsc::Sender<ServerNotification>;
111
112/// Receiver for server notifications
113pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
114
115/// Create a new notification channel
116pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
117    mpsc::channel(buffer)
118}
119
120// =============================================================================
121// Client Requests (Server -> Client)
122// =============================================================================
123
124/// Trait for sending requests from server to client
125///
126/// This enables bidirectional communication where the server can request
127/// actions from the client, such as sampling (LLM requests), elicitation
128/// (user input requests), and task polling (per SEP-1686).
129#[async_trait]
130pub trait ClientRequester: Send + Sync {
131    /// Send a sampling request to the client
132    ///
133    /// Returns the LLM completion result from the client.
134    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
135
136    /// Send an elicitation request to the client
137    ///
138    /// This requests user input from the client. The request can be either
139    /// form-based (structured input) or URL-based (redirect to external URL).
140    ///
141    /// Returns the elicitation result with the user's action and any submitted data.
142    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
143
144    /// Send a generic JSON-RPC request to the client.
145    ///
146    /// Used by typed helpers ([`RequestContext::get_task_info`] etc.) to
147    /// dispatch arbitrary request methods. The default implementation returns
148    /// an error so existing custom implementations of this trait keep
149    /// compiling; they only need to override this if they want to support
150    /// methods beyond `sample` and `elicit`.
151    async fn request(
152        &self,
153        method: String,
154        params: serde_json::Value,
155    ) -> Result<serde_json::Value> {
156        let _ = (method, params);
157        Err(Error::Internal(
158            "ClientRequester does not support arbitrary requests".to_string(),
159        ))
160    }
161}
162
163/// A clonable handle to a client requester
164pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
165
166/// Outgoing request to be sent to the client
167#[derive(Debug)]
168pub struct OutgoingRequest {
169    /// The JSON-RPC request ID
170    pub id: RequestId,
171    /// The method name
172    pub method: String,
173    /// The request parameters as JSON
174    pub params: serde_json::Value,
175    /// Channel to send the response back
176    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
177}
178
179/// Sender for outgoing requests to the client
180pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
181
182/// Receiver for outgoing requests (used by transport)
183pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
184
185/// Create a new outgoing request channel
186pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
187    mpsc::channel(buffer)
188}
189
190/// A client requester implementation that sends requests through a channel
191#[derive(Clone)]
192pub struct ChannelClientRequester {
193    request_tx: OutgoingRequestSender,
194    next_id: Arc<AtomicI64>,
195}
196
197impl ChannelClientRequester {
198    /// Create a new channel-based client requester
199    pub fn new(request_tx: OutgoingRequestSender) -> Self {
200        Self {
201            request_tx,
202            next_id: Arc::new(AtomicI64::new(1)),
203        }
204    }
205
206    fn next_request_id(&self) -> RequestId {
207        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
208        RequestId::Number(id)
209    }
210}
211
212impl ChannelClientRequester {
213    async fn dispatch(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
214        let id = self.next_request_id();
215        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
216
217        let request = OutgoingRequest {
218            id,
219            method: method.to_string(),
220            params,
221            response_tx,
222        };
223
224        self.request_tx
225            .send(request)
226            .await
227            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
228
229        response_rx.await.map_err(|_| {
230            Error::Internal("Failed to receive response: channel closed".to_string())
231        })?
232    }
233}
234
235#[async_trait]
236impl ClientRequester for ChannelClientRequester {
237    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
238        let params_json = serde_json::to_value(&params)
239            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
240        let response = self.dispatch("sampling/createMessage", params_json).await?;
241        serde_json::from_value(response)
242            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
243    }
244
245    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
246        let params_json = serde_json::to_value(&params)
247            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
248        let response = self.dispatch("elicitation/create", params_json).await?;
249        serde_json::from_value(response)
250            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
251    }
252
253    async fn request(
254        &self,
255        method: String,
256        params: serde_json::Value,
257    ) -> Result<serde_json::Value> {
258        self.dispatch(&method, params).await
259    }
260}
261
262/// Context for a request, providing progress, cancellation, and client request support
263#[derive(Clone)]
264pub struct RequestContext {
265    /// The request ID
266    request_id: RequestId,
267    /// Progress token (if provided by client)
268    progress_token: Option<ProgressToken>,
269    /// Cancellation flag
270    cancelled: Arc<AtomicBool>,
271    /// Channel for sending notifications
272    notification_tx: Option<NotificationSender>,
273    /// Handle for sending requests to the client (for sampling, etc.)
274    client_requester: Option<ClientRequesterHandle>,
275    /// Extensions for passing data from router/middleware to handlers
276    extensions: Arc<Extensions>,
277    /// Minimum log level set by the client (shared with router for dynamic updates)
278    min_log_level: Option<Arc<RwLock<LogLevel>>>,
279}
280
281/// Type-erased extensions map for passing data to handlers.
282///
283/// Extensions allow router-level state and middleware-injected data to flow
284/// to tool handlers via the `Extension<T>` extractor.
285#[derive(Clone, Default)]
286pub struct Extensions {
287    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
288}
289
290impl Extensions {
291    /// Create an empty extensions map.
292    pub fn new() -> Self {
293        Self::default()
294    }
295
296    /// Insert a value into the extensions map.
297    ///
298    /// If a value of the same type already exists, it is replaced.
299    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
300        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
301    }
302
303    /// Get a reference to a value in the extensions map.
304    ///
305    /// Returns `None` if no value of the given type has been inserted.
306    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
307        self.map
308            .get(&std::any::TypeId::of::<T>())
309            .and_then(|val| val.downcast_ref::<T>())
310    }
311
312    /// Check if the extensions map contains a value of the given type.
313    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
314        self.map.contains_key(&std::any::TypeId::of::<T>())
315    }
316
317    /// Merge another extensions map into this one.
318    ///
319    /// Values from `other` will overwrite existing values of the same type.
320    pub fn merge(&mut self, other: &Extensions) {
321        for (k, v) in &other.map {
322            self.map.insert(*k, v.clone());
323        }
324    }
325
326    /// Returns the number of entries in the extensions map.
327    pub fn len(&self) -> usize {
328        self.map.len()
329    }
330
331    /// Returns `true` if the extensions map contains no entries.
332    pub fn is_empty(&self) -> bool {
333        self.map.is_empty()
334    }
335}
336
337impl std::fmt::Debug for Extensions {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.debug_struct("Extensions")
340            .field("len", &self.map.len())
341            .finish()
342    }
343}
344
345impl std::fmt::Debug for RequestContext {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        f.debug_struct("RequestContext")
348            .field("request_id", &self.request_id)
349            .field("progress_token", &self.progress_token)
350            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
351            .finish()
352    }
353}
354
355impl RequestContext {
356    /// Create a new request context
357    pub fn new(request_id: RequestId) -> Self {
358        Self {
359            request_id,
360            progress_token: None,
361            cancelled: Arc::new(AtomicBool::new(false)),
362            notification_tx: None,
363            client_requester: None,
364            extensions: Arc::new(Extensions::new()),
365            min_log_level: None,
366        }
367    }
368
369    /// Set the progress token
370    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
371        self.progress_token = Some(token);
372        self
373    }
374
375    /// Set the notification sender
376    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
377        self.notification_tx = Some(tx);
378        self
379    }
380
381    /// Set the minimum log level for filtering outgoing log notifications
382    ///
383    /// This is shared with the router so that `logging/setLevel` updates
384    /// are immediately visible to all request contexts.
385    pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
386        self.min_log_level = Some(level);
387        self
388    }
389
390    /// Set the client requester for server-to-client requests
391    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
392        self.client_requester = Some(requester);
393        self
394    }
395
396    /// Set the extensions for this request context.
397    ///
398    /// Extensions allow router-level state and middleware data to flow to handlers.
399    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
400        self.extensions = extensions;
401        self
402    }
403
404    /// Get a reference to a value from the extensions map.
405    ///
406    /// Returns `None` if no value of the given type has been inserted.
407    ///
408    /// # Example
409    ///
410    /// ```rust,ignore
411    /// #[derive(Clone)]
412    /// struct CurrentUser { id: String }
413    ///
414    /// // In a handler:
415    /// if let Some(user) = ctx.extension::<CurrentUser>() {
416    ///     println!("User: {}", user.id);
417    /// }
418    /// ```
419    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
420        self.extensions.get::<T>()
421    }
422
423    /// Get a mutable reference to the extensions.
424    ///
425    /// This allows middleware to insert data that handlers can access via
426    /// the `Extension<T>` extractor.
427    pub fn extensions_mut(&mut self) -> &mut Extensions {
428        Arc::make_mut(&mut self.extensions)
429    }
430
431    /// Get a reference to the extensions.
432    pub fn extensions(&self) -> &Extensions {
433        &self.extensions
434    }
435
436    /// Get the request ID
437    pub fn request_id(&self) -> &RequestId {
438        &self.request_id
439    }
440
441    /// Get the progress token (if any)
442    pub fn progress_token(&self) -> Option<&ProgressToken> {
443        self.progress_token.as_ref()
444    }
445
446    /// Check if the request has been cancelled
447    pub fn is_cancelled(&self) -> bool {
448        self.cancelled.load(Ordering::Relaxed)
449    }
450
451    /// Mark the request as cancelled
452    pub fn cancel(&self) {
453        self.cancelled.store(true, Ordering::Relaxed);
454    }
455
456    /// Get a cancellation token that can be shared
457    pub fn cancellation_token(&self) -> CancellationToken {
458        CancellationToken {
459            cancelled: self.cancelled.clone(),
460        }
461    }
462
463    /// Report progress to the client
464    ///
465    /// This is a no-op if no progress token was provided or no notification sender is configured.
466    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
467        let Some(token) = &self.progress_token else {
468            return;
469        };
470        let Some(tx) = &self.notification_tx else {
471            return;
472        };
473
474        let params = ProgressParams {
475            progress_token: token.clone(),
476            progress,
477            total,
478            message: message.map(|s| s.to_string()),
479            meta: None,
480        };
481
482        // Best effort - don't block if channel is full
483        let _ = tx.try_send(ServerNotification::Progress(params));
484    }
485
486    /// Report progress synchronously (non-async version)
487    ///
488    /// This is a no-op if no progress token was provided or no notification sender is configured.
489    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
490        let Some(token) = &self.progress_token else {
491            return;
492        };
493        let Some(tx) = &self.notification_tx else {
494            return;
495        };
496
497        let params = ProgressParams {
498            progress_token: token.clone(),
499            progress,
500            total,
501            message: message.map(|s| s.to_string()),
502            meta: None,
503        };
504
505        let _ = tx.try_send(ServerNotification::Progress(params));
506    }
507
508    /// Send a log message notification to the client
509    ///
510    /// This is a no-op if no notification sender is configured.
511    ///
512    /// # Example
513    ///
514    /// ```rust,ignore
515    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
516    ///
517    /// async fn my_tool(ctx: RequestContext) {
518    ///     ctx.send_log(
519    ///         LoggingMessageParams::new(LogLevel::Info, serde_json::json!("Processing..."))
520    ///             .with_logger("my-tool")
521    ///     );
522    /// }
523    /// ```
524    pub fn send_log(&self, params: LoggingMessageParams) {
525        let Some(tx) = &self.notification_tx else {
526            return;
527        };
528
529        // Filter by minimum log level set via logging/setLevel
530        // LogLevel derives Ord with Emergency < Alert < ... < Debug,
531        // so a message passes if its severity is at least the minimum
532        // (i.e., its ordinal is <= the minimum level's ordinal).
533        if let Some(min_level) = &self.min_log_level
534            && let Ok(min) = min_level.read()
535            && params.level > *min
536        {
537            return;
538        }
539
540        let _ = tx.try_send(ServerNotification::LogMessage(params));
541    }
542
543    /// Check if sampling is available
544    ///
545    /// Returns true if a client requester is configured and the transport
546    /// supports bidirectional communication.
547    pub fn can_sample(&self) -> bool {
548        self.client_requester.is_some()
549    }
550
551    /// Request an LLM completion from the client
552    ///
553    /// This sends a `sampling/createMessage` request to the client and waits
554    /// for the response. The client is expected to forward this to an LLM
555    /// and return the result.
556    ///
557    /// Returns an error if sampling is not available (no client requester configured).
558    ///
559    /// # Example
560    ///
561    /// ```rust,ignore
562    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
563    ///
564    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
565    ///     let params = CreateMessageParams::new(
566    ///         vec![SamplingMessage::user("Summarize: ...")],
567    ///         500,
568    ///     );
569    ///
570    ///     let result = ctx.sample(params).await?;
571    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
572    /// }
573    /// ```
574    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
575        let requester = self.client_requester.as_ref().ok_or_else(|| {
576            Error::Internal("Sampling not available: no client requester configured".to_string())
577        })?;
578
579        requester.sample(params).await
580    }
581
582    /// Check if elicitation is available
583    ///
584    /// Returns true if a client requester is configured and the transport
585    /// supports bidirectional communication. Note that this only checks if
586    /// the mechanism is available, not whether the client supports elicitation.
587    pub fn can_elicit(&self) -> bool {
588        self.client_requester.is_some()
589    }
590
591    /// Request user input via a form from the client
592    ///
593    /// This sends an `elicitation/create` request to the client with a form schema.
594    /// The client renders the form to the user and returns their response.
595    ///
596    /// Returns an error if elicitation is not available (no client requester configured).
597    ///
598    /// # Example
599    ///
600    /// ```rust,ignore
601    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
602    ///
603    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
604    ///     let params = ElicitFormParams {
605    ///         mode: Some(ElicitMode::Form),
606    ///         message: "Please enter your details".to_string(),
607    ///         requested_schema: ElicitFormSchema::new()
608    ///             .string_field("name", Some("Your name"), true),
609    ///         meta: None,
610    ///     };
611    ///
612    ///     let result = ctx.elicit_form(params).await?;
613    ///     match result.action {
614    ///         ElicitAction::Accept => {
615    ///             // Use result.content
616    ///             Ok(CallToolResult::text("Got your input!"))
617    ///         }
618    ///         _ => Ok(CallToolResult::text("User declined"))
619    ///     }
620    /// }
621    /// ```
622    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
623        let requester = self.client_requester.as_ref().ok_or_else(|| {
624            Error::Internal("Elicitation not available: no client requester configured".to_string())
625        })?;
626
627        requester.elicit(ElicitRequestParams::Form(params)).await
628    }
629
630    /// Request user input via URL redirect from the client
631    ///
632    /// This sends an `elicitation/create` request to the client with a URL.
633    /// The client directs the user to the URL for out-of-band input collection.
634    /// The server receives the result via a callback notification.
635    ///
636    /// Returns an error if elicitation is not available (no client requester configured).
637    ///
638    /// # Example
639    ///
640    /// ```rust,ignore
641    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
642    ///
643    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
644    ///     let params = ElicitUrlParams {
645    ///         mode: Some(ElicitMode::Url),
646    ///         elicitation_id: "unique-id-123".to_string(),
647    ///         message: "Please authorize via the link".to_string(),
648    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
649    ///         meta: None,
650    ///     };
651    ///
652    ///     let result = ctx.elicit_url(params).await?;
653    ///     match result.action {
654    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
655    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
656    ///     }
657    /// }
658    /// ```
659    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
660        let requester = self.client_requester.as_ref().ok_or_else(|| {
661            Error::Internal("Elicitation not available: no client requester configured".to_string())
662        })?;
663
664        requester.elicit(ElicitRequestParams::Url(params)).await
665    }
666
667    /// Request simple confirmation from the user.
668    ///
669    /// This is a convenience method for simple yes/no confirmation dialogs.
670    /// It creates an elicitation form with a single boolean "confirm" field
671    /// and returns `true` if the user accepts, `false` otherwise.
672    ///
673    /// Returns an error if elicitation is not available (no client requester configured).
674    ///
675    /// # Example
676    ///
677    /// ```rust,ignore
678    /// use tower_mcp::{RequestContext, CallToolResult};
679    ///
680    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
681    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
682    ///     if confirmed {
683    ///         // Perform deletion
684    ///         Ok(CallToolResult::text("Item deleted"))
685    ///     } else {
686    ///         Ok(CallToolResult::text("Deletion cancelled"))
687    ///     }
688    /// }
689    /// ```
690    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
691        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
692
693        let params = ElicitFormParams {
694            mode: Some(ElicitMode::Form),
695            message: message.into(),
696            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
697                "confirm",
698                Some("Confirm this action"),
699                true,
700                false,
701            ),
702            meta: None,
703        };
704
705        let result = self.elicit_form(params).await?;
706        Ok(result.action == ElicitAction::Accept)
707    }
708
709    /// List tasks tracked by the connected client (SEP-1686).
710    ///
711    /// Sends a `tasks/list` request to the client and returns the result.
712    /// Pass `Some(status)` to filter to a single status, or `None` for all
713    /// tasks. Pagination is exposed via [`ListTasksResult::next_cursor`];
714    /// use [`request_raw`](Self::request_raw) for cursor-driven calls.
715    ///
716    /// Returns an error if no client requester is configured or the client
717    /// does not advertise task support.
718    pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<ListTasksResult> {
719        let params = ListTasksParams {
720            status,
721            cursor: None,
722            meta: None,
723        };
724        let value = self
725            .request_raw("tasks/list", serde_json::to_value(&params)?)
726            .await?;
727        serde_json::from_value(value)
728            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/list: {e}")))
729    }
730
731    /// Fetch metadata for a single task tracked by the client (SEP-1686).
732    ///
733    /// Sends a `tasks/get` request and returns the task object, including
734    /// the current status, timestamps, and TTL.
735    pub async fn get_task_info(&self, task_id: impl Into<String>) -> Result<TaskObject> {
736        let params = GetTaskInfoParams {
737            task_id: task_id.into(),
738            meta: None,
739        };
740        let value = self
741            .request_raw("tasks/get", serde_json::to_value(&params)?)
742            .await?;
743        serde_json::from_value(value)
744            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/get: {e}")))
745    }
746
747    /// Fetch the terminal result for a task tracked by the client (SEP-1686).
748    ///
749    /// Sends a `tasks/result` request. The client is expected to block until
750    /// the task reaches a terminal state and then return the underlying
751    /// `CallToolResult`. For long-running tasks, prefer polling with
752    /// [`get_task_info`](Self::get_task_info) and only call this once the
753    /// status is terminal.
754    pub async fn get_task_result(&self, task_id: impl Into<String>) -> Result<CallToolResult> {
755        let params = GetTaskResultParams {
756            task_id: task_id.into(),
757            meta: None,
758        };
759        let value = self
760            .request_raw("tasks/result", serde_json::to_value(&params)?)
761            .await?;
762        serde_json::from_value(value)
763            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/result: {e}")))
764    }
765
766    /// Cancel a task tracked by the client (SEP-1686).
767    ///
768    /// Sends a `tasks/cancel` request and returns the resulting task object,
769    /// which will reflect the cancelled status.
770    pub async fn cancel_task(
771        &self,
772        task_id: impl Into<String>,
773        reason: Option<String>,
774    ) -> Result<TaskObject> {
775        let params = CancelTaskParams {
776            task_id: task_id.into(),
777            reason,
778            meta: None,
779        };
780        let value = self
781            .request_raw("tasks/cancel", serde_json::to_value(&params)?)
782            .await?;
783        serde_json::from_value(value)
784            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/cancel: {e}")))
785    }
786
787    /// Send an arbitrary JSON-RPC request to the client.
788    ///
789    /// Escape hatch for methods not covered by the typed helpers (e.g. when
790    /// a `tasks/list` cursor needs to be passed). Most callers should prefer
791    /// the typed methods.
792    pub async fn request_raw(
793        &self,
794        method: &str,
795        params: serde_json::Value,
796    ) -> Result<serde_json::Value> {
797        let requester = self.client_requester.as_ref().ok_or_else(|| {
798            Error::Internal(
799                "Client request not available: no client requester configured".to_string(),
800            )
801        })?;
802        requester.request(method.to_string(), params).await
803    }
804}
805
806/// A token that can be used to check for cancellation
807#[derive(Clone, Debug)]
808pub struct CancellationToken {
809    cancelled: Arc<AtomicBool>,
810}
811
812impl CancellationToken {
813    /// Check if cancellation has been requested
814    pub fn is_cancelled(&self) -> bool {
815        self.cancelled.load(Ordering::Relaxed)
816    }
817
818    /// Request cancellation
819    pub fn cancel(&self) {
820        self.cancelled.store(true, Ordering::Relaxed);
821    }
822}
823
824/// Builder for creating request contexts
825#[derive(Default)]
826pub struct RequestContextBuilder {
827    request_id: Option<RequestId>,
828    progress_token: Option<ProgressToken>,
829    notification_tx: Option<NotificationSender>,
830    client_requester: Option<ClientRequesterHandle>,
831    min_log_level: Option<Arc<RwLock<LogLevel>>>,
832}
833
834impl RequestContextBuilder {
835    /// Create a new builder
836    pub fn new() -> Self {
837        Self::default()
838    }
839
840    /// Set the request ID
841    pub fn request_id(mut self, id: RequestId) -> Self {
842        self.request_id = Some(id);
843        self
844    }
845
846    /// Set the progress token
847    pub fn progress_token(mut self, token: ProgressToken) -> Self {
848        self.progress_token = Some(token);
849        self
850    }
851
852    /// Set the notification sender
853    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
854        self.notification_tx = Some(tx);
855        self
856    }
857
858    /// Set the client requester for server-to-client requests
859    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
860        self.client_requester = Some(requester);
861        self
862    }
863
864    /// Set the minimum log level for filtering
865    pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
866        self.min_log_level = Some(level);
867        self
868    }
869
870    /// Build the request context
871    ///
872    /// Panics if request_id is not set.
873    pub fn build(self) -> RequestContext {
874        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
875        if let Some(token) = self.progress_token {
876            ctx = ctx.with_progress_token(token);
877        }
878        if let Some(tx) = self.notification_tx {
879            ctx = ctx.with_notification_sender(tx);
880        }
881        if let Some(requester) = self.client_requester {
882            ctx = ctx.with_client_requester(requester);
883        }
884        if let Some(level) = self.min_log_level {
885            ctx = ctx.with_min_log_level(level);
886        }
887        ctx
888    }
889}
890
891#[cfg(test)]
892mod tests {
893    use super::*;
894
895    #[test]
896    fn test_cancellation() {
897        let ctx = RequestContext::new(RequestId::Number(1));
898        assert!(!ctx.is_cancelled());
899
900        let token = ctx.cancellation_token();
901        assert!(!token.is_cancelled());
902
903        ctx.cancel();
904        assert!(ctx.is_cancelled());
905        assert!(token.is_cancelled());
906    }
907
908    #[tokio::test]
909    async fn test_progress_reporting() {
910        let (tx, mut rx) = notification_channel(10);
911
912        let ctx = RequestContext::new(RequestId::Number(1))
913            .with_progress_token(ProgressToken::Number(42))
914            .with_notification_sender(tx);
915
916        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
917            .await;
918
919        let notification = rx.recv().await.unwrap();
920        match notification {
921            ServerNotification::Progress(params) => {
922                assert_eq!(params.progress, 50.0);
923                assert_eq!(params.total, Some(100.0));
924                assert_eq!(params.message.as_deref(), Some("Halfway"));
925            }
926            _ => panic!("Expected Progress notification"),
927        }
928    }
929
930    #[tokio::test]
931    async fn test_progress_no_token() {
932        let (tx, mut rx) = notification_channel(10);
933
934        // No progress token - should be a no-op
935        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
936
937        ctx.report_progress(50.0, Some(100.0), None).await;
938
939        // Channel should be empty
940        assert!(rx.try_recv().is_err());
941    }
942
943    #[test]
944    fn test_builder() {
945        let (tx, _rx) = notification_channel(10);
946
947        let ctx = RequestContextBuilder::new()
948            .request_id(RequestId::String("req-1".to_string()))
949            .progress_token(ProgressToken::String("prog-1".to_string()))
950            .notification_sender(tx)
951            .build();
952
953        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
954        assert!(ctx.progress_token().is_some());
955    }
956
957    #[test]
958    fn test_can_sample_without_requester() {
959        let ctx = RequestContext::new(RequestId::Number(1));
960        assert!(!ctx.can_sample());
961    }
962
963    #[test]
964    fn test_can_sample_with_requester() {
965        let (request_tx, _rx) = outgoing_request_channel(10);
966        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
967
968        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
969        assert!(ctx.can_sample());
970    }
971
972    #[tokio::test]
973    async fn test_sample_without_requester_fails() {
974        use crate::protocol::{CreateMessageParams, SamplingMessage};
975
976        let ctx = RequestContext::new(RequestId::Number(1));
977        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
978
979        let result = ctx.sample(params).await;
980        assert!(result.is_err());
981        assert!(
982            result
983                .unwrap_err()
984                .to_string()
985                .contains("Sampling not available")
986        );
987    }
988
989    #[test]
990    fn test_builder_with_client_requester() {
991        let (request_tx, _rx) = outgoing_request_channel(10);
992        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
993
994        let ctx = RequestContextBuilder::new()
995            .request_id(RequestId::Number(1))
996            .client_requester(requester)
997            .build();
998
999        assert!(ctx.can_sample());
1000    }
1001
1002    #[test]
1003    fn test_can_elicit_without_requester() {
1004        let ctx = RequestContext::new(RequestId::Number(1));
1005        assert!(!ctx.can_elicit());
1006    }
1007
1008    #[test]
1009    fn test_can_elicit_with_requester() {
1010        let (request_tx, _rx) = outgoing_request_channel(10);
1011        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1012
1013        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1014        assert!(ctx.can_elicit());
1015    }
1016
1017    #[tokio::test]
1018    async fn test_elicit_form_without_requester_fails() {
1019        use crate::protocol::{ElicitFormSchema, ElicitMode};
1020
1021        let ctx = RequestContext::new(RequestId::Number(1));
1022        let params = ElicitFormParams {
1023            mode: Some(ElicitMode::Form),
1024            message: "Enter details".to_string(),
1025            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
1026            meta: None,
1027        };
1028
1029        let result = ctx.elicit_form(params).await;
1030        assert!(result.is_err());
1031        assert!(
1032            result
1033                .unwrap_err()
1034                .to_string()
1035                .contains("Elicitation not available")
1036        );
1037    }
1038
1039    #[tokio::test]
1040    async fn test_elicit_url_without_requester_fails() {
1041        use crate::protocol::ElicitMode;
1042
1043        let ctx = RequestContext::new(RequestId::Number(1));
1044        let params = ElicitUrlParams {
1045            mode: Some(ElicitMode::Url),
1046            elicitation_id: "test-123".to_string(),
1047            message: "Please authorize".to_string(),
1048            url: "https://example.com/auth".to_string(),
1049            meta: None,
1050        };
1051
1052        let result = ctx.elicit_url(params).await;
1053        assert!(result.is_err());
1054        assert!(
1055            result
1056                .unwrap_err()
1057                .to_string()
1058                .contains("Elicitation not available")
1059        );
1060    }
1061
1062    #[tokio::test]
1063    async fn test_confirm_without_requester_fails() {
1064        let ctx = RequestContext::new(RequestId::Number(1));
1065
1066        let result = ctx.confirm("Are you sure?").await;
1067        assert!(result.is_err());
1068        assert!(
1069            result
1070                .unwrap_err()
1071                .to_string()
1072                .contains("Elicitation not available")
1073        );
1074    }
1075
1076    #[tokio::test]
1077    async fn test_send_log_filtered_by_level() {
1078        let (tx, mut rx) = notification_channel(10);
1079        let min_level = Arc::new(RwLock::new(LogLevel::Warning));
1080
1081        let ctx = RequestContext::new(RequestId::Number(1))
1082            .with_notification_sender(tx)
1083            .with_min_log_level(min_level.clone());
1084
1085        // Error is more severe than Warning — should pass through
1086        ctx.send_log(LoggingMessageParams::new(
1087            LogLevel::Error,
1088            serde_json::Value::Null,
1089        ));
1090        let msg = rx.try_recv();
1091        assert!(msg.is_ok(), "Error should pass through Warning filter");
1092
1093        // Warning is equal to min level — should pass through
1094        ctx.send_log(LoggingMessageParams::new(
1095            LogLevel::Warning,
1096            serde_json::Value::Null,
1097        ));
1098        let msg = rx.try_recv();
1099        assert!(msg.is_ok(), "Warning should pass through Warning filter");
1100
1101        // Info is less severe than Warning — should be filtered
1102        ctx.send_log(LoggingMessageParams::new(
1103            LogLevel::Info,
1104            serde_json::Value::Null,
1105        ));
1106        let msg = rx.try_recv();
1107        assert!(msg.is_err(), "Info should be filtered by Warning filter");
1108
1109        // Debug is less severe than Warning — should be filtered
1110        ctx.send_log(LoggingMessageParams::new(
1111            LogLevel::Debug,
1112            serde_json::Value::Null,
1113        ));
1114        let msg = rx.try_recv();
1115        assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1116    }
1117
1118    #[tokio::test]
1119    async fn test_send_log_level_updates_dynamically() {
1120        let (tx, mut rx) = notification_channel(10);
1121        let min_level = Arc::new(RwLock::new(LogLevel::Error));
1122
1123        let ctx = RequestContext::new(RequestId::Number(1))
1124            .with_notification_sender(tx)
1125            .with_min_log_level(min_level.clone());
1126
1127        // Info should be filtered at Error level
1128        ctx.send_log(LoggingMessageParams::new(
1129            LogLevel::Info,
1130            serde_json::Value::Null,
1131        ));
1132        assert!(
1133            rx.try_recv().is_err(),
1134            "Info should be filtered at Error level"
1135        );
1136
1137        // Dynamically update to Debug (most permissive)
1138        *min_level.write().unwrap() = LogLevel::Debug;
1139
1140        // Now Info should pass through
1141        ctx.send_log(LoggingMessageParams::new(
1142            LogLevel::Info,
1143            serde_json::Value::Null,
1144        ));
1145        assert!(
1146            rx.try_recv().is_ok(),
1147            "Info should pass through after level changed to Debug"
1148        );
1149    }
1150
1151    #[tokio::test]
1152    async fn test_send_log_no_min_level_sends_all() {
1153        let (tx, mut rx) = notification_channel(10);
1154
1155        // No min_log_level set — all messages should pass through
1156        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1157
1158        ctx.send_log(LoggingMessageParams::new(
1159            LogLevel::Debug,
1160            serde_json::Value::Null,
1161        ));
1162        assert!(
1163            rx.try_recv().is_ok(),
1164            "Debug should pass when no min level is set"
1165        );
1166    }
1167
1168    fn make_task_object(id: &str, status: TaskStatus) -> serde_json::Value {
1169        serde_json::json!({
1170            "taskId": id,
1171            "status": status,
1172            "createdAt": "2026-04-24T00:00:00Z",
1173            "lastUpdatedAt": "2026-04-24T00:00:00Z",
1174            "ttl": null
1175        })
1176    }
1177
1178    fn spawn_mock_client(
1179        mut rx: OutgoingRequestReceiver,
1180        responder: impl Fn(&str, serde_json::Value) -> serde_json::Value + Send + 'static,
1181    ) {
1182        tokio::spawn(async move {
1183            while let Some(req) = rx.recv().await {
1184                let response = responder(&req.method, req.params);
1185                let _ = req.response_tx.send(Ok(response));
1186            }
1187        });
1188    }
1189
1190    #[tokio::test]
1191    async fn test_get_task_info_round_trips() {
1192        let (tx, rx) = outgoing_request_channel(10);
1193        spawn_mock_client(rx, |method, params| {
1194            assert_eq!(method, "tasks/get");
1195            let task_id = params["taskId"].as_str().unwrap().to_string();
1196            make_task_object(&task_id, TaskStatus::Working)
1197        });
1198        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1199        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1200
1201        let info = ctx.get_task_info("task-123").await.unwrap();
1202        assert_eq!(info.task_id, "task-123");
1203        assert!(matches!(info.status, TaskStatus::Working));
1204    }
1205
1206    #[tokio::test]
1207    async fn test_list_tasks_round_trips() {
1208        let (tx, rx) = outgoing_request_channel(10);
1209        spawn_mock_client(rx, |method, params| {
1210            assert_eq!(method, "tasks/list");
1211            // Status filter should be forwarded
1212            assert_eq!(params["status"], serde_json::json!("working"));
1213            serde_json::json!({
1214                "tasks": [
1215                    make_task_object("task-1", TaskStatus::Working),
1216                    make_task_object("task-2", TaskStatus::Working),
1217                ]
1218            })
1219        });
1220        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1221        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1222
1223        let result = ctx.list_tasks(Some(TaskStatus::Working)).await.unwrap();
1224        assert_eq!(result.tasks.len(), 2);
1225        assert_eq!(result.tasks[0].task_id, "task-1");
1226    }
1227
1228    #[tokio::test]
1229    async fn test_cancel_task_forwards_reason() {
1230        let (tx, rx) = outgoing_request_channel(10);
1231        spawn_mock_client(rx, |method, params| {
1232            assert_eq!(method, "tasks/cancel");
1233            assert_eq!(params["reason"], serde_json::json!("user requested"));
1234            let task_id = params["taskId"].as_str().unwrap().to_string();
1235            make_task_object(&task_id, TaskStatus::Cancelled)
1236        });
1237        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1238        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1239
1240        let task = ctx
1241            .cancel_task("task-99", Some("user requested".into()))
1242            .await
1243            .unwrap();
1244        assert_eq!(task.task_id, "task-99");
1245        assert!(matches!(task.status, TaskStatus::Cancelled));
1246    }
1247
1248    #[tokio::test]
1249    async fn test_get_task_info_without_requester_fails() {
1250        let ctx = RequestContext::new(RequestId::Number(1));
1251        let result = ctx.get_task_info("task-1").await;
1252        assert!(result.is_err());
1253        assert!(
1254            result
1255                .unwrap_err()
1256                .to_string()
1257                .contains("Client request not available")
1258        );
1259    }
1260
1261    #[tokio::test]
1262    async fn test_default_request_impl_errors() {
1263        // A custom requester that only implements sample/elicit (not request)
1264        // should reject task helpers.
1265        struct OnlySampleAndElicit;
1266
1267        #[async_trait]
1268        impl ClientRequester for OnlySampleAndElicit {
1269            async fn sample(&self, _: CreateMessageParams) -> Result<CreateMessageResult> {
1270                unreachable!()
1271            }
1272            async fn elicit(&self, _: ElicitRequestParams) -> Result<ElicitResult> {
1273                unreachable!()
1274            }
1275        }
1276
1277        let requester: ClientRequesterHandle = Arc::new(OnlySampleAndElicit);
1278        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1279
1280        let err = ctx.get_task_info("x").await.unwrap_err();
1281        assert!(err.to_string().contains("does not support arbitrary"));
1282    }
1283}