Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: Some(ElicitMode::Form),
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71//!
72//! # Stateless mode: per-request metadata (`stateless` feature)
73//!
74//! With the 2026-07-28 protocol, clients do not run an initialize handshake.
75//! Instead, every request carries the client's protocol version, identity, and
76//! capabilities in a `_meta` object. The HTTP transport extracts these fields
77//! and stashes them as a [`StatelessRequestMeta`](crate::stateless::StatelessRequestMeta)
78//! extension on the [`RequestContext`], accessible via
79//! [`ctx.per_request_meta()`](RequestContext::per_request_meta).
80//!
81//! `per_request_meta()` returns `Some` when:
82//! - The `stateless` feature is compiled in, AND
83//! - The request was dispatched by the HTTP transport, AND
84//! - The request's `_meta` contained at least one recognized field.
85//!
86//! It returns `None` for stdio/WebSocket transports, for 2025-11-25 session-based
87//! requests, and when the request carried no `_meta`.
88//!
89//! The [`StatelessRequestMeta`](crate::stateless::StatelessRequestMeta) struct
90//! provides:
91//!
92//! - `protocol_version` -- the `io.modelcontextprotocol/protocolVersion` field
93//! - `client_info` -- the `io.modelcontextprotocol/clientInfo` field (name, version)
94//! - `client_capabilities` -- the `io.modelcontextprotocol/clientCapabilities` field
95//! - `log_level` -- optional per-request log level override
96//! - `progress_token` -- optional progress token for progress notifications
97//!
98//! ```rust,ignore
99//! // Requires feature = ["stateless"]
100//! use tower_mcp::context::RequestContext;
101//!
102//! async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
103//!     if let Some(meta) = ctx.per_request_meta() {
104//!         // Available for 2026-07-28+ clients on the HTTP transport
105//!         if let Some(ref info) = meta.client_info {
106//!             tracing::info!(client = %info.name, version = %info.version, "request from");
107//!         }
108//!         if let Some(ref version) = meta.protocol_version {
109//!             tracing::debug!(protocol_version = %version);
110//!         }
111//!     }
112//!     Ok(CallToolResult::text("ok"))
113//! }
114//! ```
115
116use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
117use std::sync::{Arc, RwLock};
118
119use async_trait::async_trait;
120use tokio::sync::mpsc;
121
122use crate::error::{Error, Result};
123use crate::protocol::{
124    CallToolResult, CancelTaskParams, CreateMessageParams, CreateMessageResult, ElicitFormParams,
125    ElicitRequestParams, ElicitResult, ElicitUrlParams, GetTaskInfoParams, GetTaskResultParams,
126    ListTasksParams, ListTasksResult, LogLevel, LoggingMessageParams, ProgressParams,
127    ProgressToken, RequestId, TaskObject, TaskStatus,
128};
129
130/// A notification to be sent to the client
131#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub enum ServerNotification {
134    /// Progress update for a request
135    Progress(ProgressParams),
136    /// Log message notification
137    LogMessage(LoggingMessageParams),
138    /// A subscribed resource has been updated
139    ResourceUpdated {
140        /// The URI of the updated resource
141        uri: String,
142    },
143    /// The list of available resources has changed
144    ResourcesListChanged,
145    /// The list of available tools has changed
146    ToolsListChanged,
147    /// The list of available prompts has changed
148    PromptsListChanged,
149    /// Task status has changed
150    TaskStatusChanged(crate::protocol::TaskStatusParams),
151}
152
153/// Sender for server notifications
154pub type NotificationSender = mpsc::Sender<ServerNotification>;
155
156/// Receiver for server notifications
157pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
158
159/// Create a new notification channel
160pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
161    mpsc::channel(buffer)
162}
163
164// =============================================================================
165// Client Requests (Server -> Client)
166// =============================================================================
167
168/// Trait for sending requests from server to client
169///
170/// This enables bidirectional communication where the server can request
171/// actions from the client, such as sampling (LLM requests), elicitation
172/// (user input requests), and task polling (per SEP-1686).
173#[async_trait]
174pub trait ClientRequester: Send + Sync {
175    /// Send a sampling request to the client
176    ///
177    /// Returns the LLM completion result from the client.
178    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
179
180    /// Send an elicitation request to the client
181    ///
182    /// This requests user input from the client. The request can be either
183    /// form-based (structured input) or URL-based (redirect to external URL).
184    ///
185    /// Returns the elicitation result with the user's action and any submitted data.
186    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
187
188    /// Send a generic JSON-RPC request to the client.
189    ///
190    /// Used by typed helpers ([`RequestContext::get_task_info`] etc.) to
191    /// dispatch arbitrary request methods. The default implementation returns
192    /// an error so existing custom implementations of this trait keep
193    /// compiling; they only need to override this if they want to support
194    /// methods beyond `sample` and `elicit`.
195    async fn request(
196        &self,
197        method: String,
198        params: serde_json::Value,
199    ) -> Result<serde_json::Value> {
200        let _ = (method, params);
201        Err(Error::Internal(
202            "ClientRequester does not support arbitrary requests".to_string(),
203        ))
204    }
205}
206
207/// A clonable handle to a client requester
208pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
209
210/// Outgoing request to be sent to the client
211#[derive(Debug)]
212pub struct OutgoingRequest {
213    /// The JSON-RPC request ID
214    pub id: RequestId,
215    /// The method name
216    pub method: String,
217    /// The request parameters as JSON
218    pub params: serde_json::Value,
219    /// Channel to send the response back
220    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
221}
222
223/// Sender for outgoing requests to the client
224pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
225
226/// Receiver for outgoing requests (used by transport)
227pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
228
229/// Create a new outgoing request channel
230pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
231    mpsc::channel(buffer)
232}
233
234/// A client requester implementation that sends requests through a channel
235#[derive(Clone)]
236pub struct ChannelClientRequester {
237    request_tx: OutgoingRequestSender,
238    next_id: Arc<AtomicI64>,
239}
240
241impl ChannelClientRequester {
242    /// Create a new channel-based client requester
243    pub fn new(request_tx: OutgoingRequestSender) -> Self {
244        Self {
245            request_tx,
246            next_id: Arc::new(AtomicI64::new(1)),
247        }
248    }
249
250    fn next_request_id(&self) -> RequestId {
251        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
252        RequestId::Number(id)
253    }
254}
255
256impl ChannelClientRequester {
257    async fn dispatch(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
258        let id = self.next_request_id();
259        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
260
261        let request = OutgoingRequest {
262            id,
263            method: method.to_string(),
264            params,
265            response_tx,
266        };
267
268        self.request_tx
269            .send(request)
270            .await
271            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
272
273        response_rx.await.map_err(|_| {
274            Error::Internal("Failed to receive response: channel closed".to_string())
275        })?
276    }
277}
278
279#[async_trait]
280impl ClientRequester for ChannelClientRequester {
281    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
282        let params_json = serde_json::to_value(&params)
283            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
284        let response = self.dispatch("sampling/createMessage", params_json).await?;
285        serde_json::from_value(response)
286            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
287    }
288
289    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
290        let params_json = serde_json::to_value(&params)
291            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
292        let response = self.dispatch("elicitation/create", params_json).await?;
293        serde_json::from_value(response)
294            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
295    }
296
297    async fn request(
298        &self,
299        method: String,
300        params: serde_json::Value,
301    ) -> Result<serde_json::Value> {
302        self.dispatch(&method, params).await
303    }
304}
305
306/// Context for a request, providing progress, cancellation, and client request support
307#[derive(Clone)]
308pub struct RequestContext {
309    /// The request ID
310    request_id: RequestId,
311    /// Progress token (if provided by client)
312    progress_token: Option<ProgressToken>,
313    /// Cancellation flag
314    cancelled: Arc<AtomicBool>,
315    /// Channel for sending notifications
316    notification_tx: Option<NotificationSender>,
317    /// Handle for sending requests to the client (for sampling, etc.)
318    client_requester: Option<ClientRequesterHandle>,
319    /// Extensions for passing data from router/middleware to handlers
320    extensions: Arc<Extensions>,
321    /// Minimum log level set by the client (shared with router for dynamic updates)
322    min_log_level: Option<Arc<RwLock<LogLevel>>>,
323}
324
325/// Type-erased extensions map for passing data to handlers.
326///
327/// Extensions allow router-level state and middleware-injected data to flow
328/// to tool handlers via the `Extension<T>` extractor.
329#[derive(Clone, Default)]
330pub struct Extensions {
331    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
332}
333
334impl Extensions {
335    /// Create an empty extensions map.
336    pub fn new() -> Self {
337        Self::default()
338    }
339
340    /// Insert a value into the extensions map.
341    ///
342    /// If a value of the same type already exists, it is replaced.
343    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
344        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
345    }
346
347    /// Get a reference to a value in the extensions map.
348    ///
349    /// Returns `None` if no value of the given type has been inserted.
350    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
351        self.map
352            .get(&std::any::TypeId::of::<T>())
353            .and_then(|val| val.downcast_ref::<T>())
354    }
355
356    /// Check if the extensions map contains a value of the given type.
357    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
358        self.map.contains_key(&std::any::TypeId::of::<T>())
359    }
360
361    /// Merge another extensions map into this one.
362    ///
363    /// Values from `other` will overwrite existing values of the same type.
364    pub fn merge(&mut self, other: &Extensions) {
365        for (k, v) in &other.map {
366            self.map.insert(*k, v.clone());
367        }
368    }
369
370    /// Returns the number of entries in the extensions map.
371    pub fn len(&self) -> usize {
372        self.map.len()
373    }
374
375    /// Returns `true` if the extensions map contains no entries.
376    pub fn is_empty(&self) -> bool {
377        self.map.is_empty()
378    }
379}
380
381impl std::fmt::Debug for Extensions {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        f.debug_struct("Extensions")
384            .field("len", &self.map.len())
385            .finish()
386    }
387}
388
389impl std::fmt::Debug for RequestContext {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        f.debug_struct("RequestContext")
392            .field("request_id", &self.request_id)
393            .field("progress_token", &self.progress_token)
394            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
395            .finish()
396    }
397}
398
399impl RequestContext {
400    /// Create a new request context
401    pub fn new(request_id: RequestId) -> Self {
402        Self {
403            request_id,
404            progress_token: None,
405            cancelled: Arc::new(AtomicBool::new(false)),
406            notification_tx: None,
407            client_requester: None,
408            extensions: Arc::new(Extensions::new()),
409            min_log_level: None,
410        }
411    }
412
413    /// Set the progress token
414    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
415        self.progress_token = Some(token);
416        self
417    }
418
419    /// Set the notification sender
420    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
421        self.notification_tx = Some(tx);
422        self
423    }
424
425    /// Set the minimum log level for filtering outgoing log notifications
426    ///
427    /// This is shared with the router so that `logging/setLevel` updates
428    /// are immediately visible to all request contexts.
429    pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
430        self.min_log_level = Some(level);
431        self
432    }
433
434    /// Set the client requester for server-to-client requests
435    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
436        self.client_requester = Some(requester);
437        self
438    }
439
440    /// Set the extensions for this request context.
441    ///
442    /// Extensions allow router-level state and middleware data to flow to handlers.
443    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
444        self.extensions = extensions;
445        self
446    }
447
448    /// Get a reference to a value from the extensions map.
449    ///
450    /// Returns `None` if no value of the given type has been inserted.
451    ///
452    /// # Example
453    ///
454    /// ```rust,ignore
455    /// #[derive(Clone)]
456    /// struct CurrentUser { id: String }
457    ///
458    /// // In a handler:
459    /// if let Some(user) = ctx.extension::<CurrentUser>() {
460    ///     println!("User: {}", user.id);
461    /// }
462    /// ```
463    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
464        self.extensions.get::<T>()
465    }
466
467    /// Get a mutable reference to the extensions.
468    ///
469    /// This allows middleware to insert data that handlers can access via
470    /// the `Extension<T>` extractor.
471    pub fn extensions_mut(&mut self) -> &mut Extensions {
472        Arc::make_mut(&mut self.extensions)
473    }
474
475    /// Get a reference to the extensions.
476    pub fn extensions(&self) -> &Extensions {
477        &self.extensions
478    }
479
480    /// SEP-2575 per-request `_meta` (protocol version, client info, client
481    /// capabilities, log level) if the transport extracted it.
482    ///
483    /// Returns `Some` for 2026-07-28+ clients on the HTTP transport when the
484    /// request carried a `_meta` object with recognized fields. Returns `None`
485    /// when:
486    /// - The request had no `_meta` field, or
487    /// - The transport does not stash per-request metadata (only the HTTP
488    ///   transport currently does this), or
489    /// - The `stateless` feature is not compiled in.
490    ///
491    /// # Example
492    ///
493    /// ```rust,ignore
494    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
495    ///     if let Some(meta) = ctx.per_request_meta() {
496    ///         // protocol_version, client_info, client_capabilities are all Option<_>
497    ///         if let Some(ref version) = meta.protocol_version {
498    ///             tracing::debug!(protocol_version = %version);
499    ///         }
500    ///         if let Some(ref info) = meta.client_info {
501    ///             tracing::info!(client = %info.name, version = %info.version);
502    ///         }
503    ///     }
504    ///     Ok(CallToolResult::text("ok"))
505    /// }
506    /// ```
507    #[cfg(feature = "stateless")]
508    pub fn per_request_meta(&self) -> Option<&crate::stateless::StatelessRequestMeta> {
509        self.extension::<crate::stateless::StatelessRequestMeta>()
510    }
511
512    /// Get the request ID
513    pub fn request_id(&self) -> &RequestId {
514        &self.request_id
515    }
516
517    /// Get the progress token (if any)
518    pub fn progress_token(&self) -> Option<&ProgressToken> {
519        self.progress_token.as_ref()
520    }
521
522    /// Check if the request has been cancelled
523    pub fn is_cancelled(&self) -> bool {
524        self.cancelled.load(Ordering::Relaxed)
525    }
526
527    /// Mark the request as cancelled
528    pub fn cancel(&self) {
529        self.cancelled.store(true, Ordering::Relaxed);
530    }
531
532    /// Get a cancellation token that can be shared
533    pub fn cancellation_token(&self) -> CancellationToken {
534        CancellationToken {
535            cancelled: self.cancelled.clone(),
536        }
537    }
538
539    /// Report progress to the client
540    ///
541    /// This is a no-op if no progress token was provided or no notification sender is configured.
542    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
543        let Some(token) = &self.progress_token else {
544            return;
545        };
546        let Some(tx) = &self.notification_tx else {
547            return;
548        };
549
550        let params = ProgressParams {
551            progress_token: token.clone(),
552            progress,
553            total,
554            message: message.map(|s| s.to_string()),
555            meta: None,
556        };
557
558        // Best effort - don't block if channel is full
559        let _ = tx.try_send(ServerNotification::Progress(params));
560    }
561
562    /// Report progress synchronously (non-async version)
563    ///
564    /// This is a no-op if no progress token was provided or no notification sender is configured.
565    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
566        let Some(token) = &self.progress_token else {
567            return;
568        };
569        let Some(tx) = &self.notification_tx else {
570            return;
571        };
572
573        let params = ProgressParams {
574            progress_token: token.clone(),
575            progress,
576            total,
577            message: message.map(|s| s.to_string()),
578            meta: None,
579        };
580
581        let _ = tx.try_send(ServerNotification::Progress(params));
582    }
583
584    /// Send a log message notification to the client
585    ///
586    /// This is a no-op if no notification sender is configured.
587    ///
588    /// # Example
589    ///
590    /// ```rust,ignore
591    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
592    ///
593    /// async fn my_tool(ctx: RequestContext) {
594    ///     ctx.send_log(
595    ///         LoggingMessageParams::new(LogLevel::Info, serde_json::json!("Processing..."))
596    ///             .with_logger("my-tool")
597    ///     );
598    /// }
599    /// ```
600    pub fn send_log(&self, params: LoggingMessageParams) {
601        let Some(tx) = &self.notification_tx else {
602            return;
603        };
604
605        // Filter by minimum log level set via logging/setLevel
606        // LogLevel derives Ord with Emergency < Alert < ... < Debug,
607        // so a message passes if its severity is at least the minimum
608        // (i.e., its ordinal is <= the minimum level's ordinal).
609        if let Some(min_level) = &self.min_log_level
610            && let Ok(min) = min_level.read()
611            && params.level > *min
612        {
613            return;
614        }
615
616        let _ = tx.try_send(ServerNotification::LogMessage(params));
617    }
618
619    /// Check if sampling is available
620    ///
621    /// Returns true if a client requester is configured and the transport
622    /// supports bidirectional communication.
623    pub fn can_sample(&self) -> bool {
624        self.client_requester.is_some()
625    }
626
627    /// Request an LLM completion from the client
628    ///
629    /// This sends a `sampling/createMessage` request to the client and waits
630    /// for the response. The client is expected to forward this to an LLM
631    /// and return the result.
632    ///
633    /// Returns an error if sampling is not available (no client requester configured).
634    ///
635    /// # Example
636    ///
637    /// ```rust,ignore
638    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
639    ///
640    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
641    ///     let params = CreateMessageParams::new(
642    ///         vec![SamplingMessage::user("Summarize: ...")],
643    ///         500,
644    ///     );
645    ///
646    ///     let result = ctx.sample(params).await?;
647    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
648    /// }
649    /// ```
650    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
651        let requester = self.client_requester.as_ref().ok_or_else(|| {
652            Error::Internal("Sampling not available: no client requester configured".to_string())
653        })?;
654
655        requester.sample(params).await
656    }
657
658    /// Check if elicitation is available
659    ///
660    /// Returns true if a client requester is configured and the transport
661    /// supports bidirectional communication. Note that this only checks if
662    /// the mechanism is available, not whether the client supports elicitation.
663    pub fn can_elicit(&self) -> bool {
664        self.client_requester.is_some()
665    }
666
667    /// Request user input via a form from the client
668    ///
669    /// This sends an `elicitation/create` request to the client with a form schema.
670    /// The client renders the form to the user and returns their response.
671    ///
672    /// Returns an error if elicitation is not available (no client requester configured).
673    ///
674    /// # Example
675    ///
676    /// ```rust,ignore
677    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
678    ///
679    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
680    ///     let params = ElicitFormParams {
681    ///         mode: Some(ElicitMode::Form),
682    ///         message: "Please enter your details".to_string(),
683    ///         requested_schema: ElicitFormSchema::new()
684    ///             .string_field("name", Some("Your name"), true),
685    ///         meta: None,
686    ///     };
687    ///
688    ///     let result = ctx.elicit_form(params).await?;
689    ///     match result.action {
690    ///         ElicitAction::Accept => {
691    ///             // Use result.content
692    ///             Ok(CallToolResult::text("Got your input!"))
693    ///         }
694    ///         _ => Ok(CallToolResult::text("User declined"))
695    ///     }
696    /// }
697    /// ```
698    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
699        let requester = self.client_requester.as_ref().ok_or_else(|| {
700            Error::Internal("Elicitation not available: no client requester configured".to_string())
701        })?;
702
703        requester.elicit(ElicitRequestParams::Form(params)).await
704    }
705
706    /// Request user input via URL redirect from the client
707    ///
708    /// This sends an `elicitation/create` request to the client with a URL.
709    /// The client directs the user to the URL for out-of-band input collection.
710    /// The server receives the result via a callback notification.
711    ///
712    /// Returns an error if elicitation is not available (no client requester configured).
713    ///
714    /// # Example
715    ///
716    /// ```rust,ignore
717    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
718    ///
719    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
720    ///     let params = ElicitUrlParams {
721    ///         mode: Some(ElicitMode::Url),
722    ///         elicitation_id: "unique-id-123".to_string(),
723    ///         message: "Please authorize via the link".to_string(),
724    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
725    ///         meta: None,
726    ///     };
727    ///
728    ///     let result = ctx.elicit_url(params).await?;
729    ///     match result.action {
730    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
731    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
732    ///     }
733    /// }
734    /// ```
735    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
736        let requester = self.client_requester.as_ref().ok_or_else(|| {
737            Error::Internal("Elicitation not available: no client requester configured".to_string())
738        })?;
739
740        requester.elicit(ElicitRequestParams::Url(params)).await
741    }
742
743    /// Request simple confirmation from the user.
744    ///
745    /// This is a convenience method for simple yes/no confirmation dialogs.
746    /// It creates an elicitation form with a single boolean "confirm" field
747    /// and returns `true` if the user accepts, `false` otherwise.
748    ///
749    /// Returns an error if elicitation is not available (no client requester configured).
750    ///
751    /// # Example
752    ///
753    /// ```rust,ignore
754    /// use tower_mcp::{RequestContext, CallToolResult};
755    ///
756    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
757    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
758    ///     if confirmed {
759    ///         // Perform deletion
760    ///         Ok(CallToolResult::text("Item deleted"))
761    ///     } else {
762    ///         Ok(CallToolResult::text("Deletion cancelled"))
763    ///     }
764    /// }
765    /// ```
766    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
767        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
768
769        let params = ElicitFormParams {
770            mode: Some(ElicitMode::Form),
771            message: message.into(),
772            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
773                "confirm",
774                Some("Confirm this action"),
775                true,
776                false,
777            ),
778            meta: None,
779        };
780
781        let result = self.elicit_form(params).await?;
782        Ok(result.action == ElicitAction::Accept)
783    }
784
785    /// List tasks tracked by the connected client (SEP-1686).
786    ///
787    /// Sends a `tasks/list` request to the client and returns the result.
788    /// Pass `Some(status)` to filter to a single status, or `None` for all
789    /// tasks. Pagination is exposed via [`ListTasksResult::next_cursor`];
790    /// use [`request_raw`](Self::request_raw) for cursor-driven calls.
791    ///
792    /// Returns an error if no client requester is configured or the client
793    /// does not advertise task support.
794    pub async fn list_tasks(&self, status: Option<TaskStatus>) -> Result<ListTasksResult> {
795        let params = ListTasksParams {
796            status,
797            cursor: None,
798            meta: None,
799        };
800        let value = self
801            .request_raw("tasks/list", serde_json::to_value(&params)?)
802            .await?;
803        serde_json::from_value(value)
804            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/list: {e}")))
805    }
806
807    /// Fetch metadata for a single task tracked by the client (SEP-1686).
808    ///
809    /// Sends a `tasks/get` request and returns the task object, including
810    /// the current status, timestamps, and TTL.
811    pub async fn get_task_info(&self, task_id: impl Into<String>) -> Result<TaskObject> {
812        let params = GetTaskInfoParams {
813            task_id: task_id.into(),
814            meta: None,
815        };
816        let value = self
817            .request_raw("tasks/get", serde_json::to_value(&params)?)
818            .await?;
819        serde_json::from_value(value)
820            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/get: {e}")))
821    }
822
823    /// Fetch the terminal result for a task tracked by the client (SEP-1686).
824    ///
825    /// Sends a `tasks/result` request. The client is expected to block until
826    /// the task reaches a terminal state and then return the underlying
827    /// `CallToolResult`. For long-running tasks, prefer polling with
828    /// [`get_task_info`](Self::get_task_info) and only call this once the
829    /// status is terminal.
830    pub async fn get_task_result(&self, task_id: impl Into<String>) -> Result<CallToolResult> {
831        let params = GetTaskResultParams {
832            task_id: task_id.into(),
833            meta: None,
834        };
835        let value = self
836            .request_raw("tasks/result", serde_json::to_value(&params)?)
837            .await?;
838        serde_json::from_value(value)
839            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/result: {e}")))
840    }
841
842    /// Cancel a task tracked by the client (SEP-1686).
843    ///
844    /// Sends a `tasks/cancel` request and returns the resulting task object,
845    /// which will reflect the cancelled status.
846    pub async fn cancel_task(
847        &self,
848        task_id: impl Into<String>,
849        reason: Option<String>,
850    ) -> Result<TaskObject> {
851        let params = CancelTaskParams {
852            task_id: task_id.into(),
853            reason,
854            meta: None,
855        };
856        let value = self
857            .request_raw("tasks/cancel", serde_json::to_value(&params)?)
858            .await?;
859        serde_json::from_value(value)
860            .map_err(|e| Error::Internal(format!("Failed to deserialize tasks/cancel: {e}")))
861    }
862
863    /// Send an arbitrary JSON-RPC request to the client.
864    ///
865    /// Escape hatch for methods not covered by the typed helpers (e.g. when
866    /// a `tasks/list` cursor needs to be passed). Most callers should prefer
867    /// the typed methods.
868    pub async fn request_raw(
869        &self,
870        method: &str,
871        params: serde_json::Value,
872    ) -> Result<serde_json::Value> {
873        let requester = self.client_requester.as_ref().ok_or_else(|| {
874            Error::Internal(
875                "Client request not available: no client requester configured".to_string(),
876            )
877        })?;
878        requester.request(method.to_string(), params).await
879    }
880}
881
882/// A token that can be used to check for cancellation
883#[derive(Clone, Debug)]
884pub struct CancellationToken {
885    cancelled: Arc<AtomicBool>,
886}
887
888impl CancellationToken {
889    /// Check if cancellation has been requested
890    pub fn is_cancelled(&self) -> bool {
891        self.cancelled.load(Ordering::Relaxed)
892    }
893
894    /// Request cancellation
895    pub fn cancel(&self) {
896        self.cancelled.store(true, Ordering::Relaxed);
897    }
898}
899
900/// Builder for creating request contexts
901#[derive(Default)]
902pub struct RequestContextBuilder {
903    request_id: Option<RequestId>,
904    progress_token: Option<ProgressToken>,
905    notification_tx: Option<NotificationSender>,
906    client_requester: Option<ClientRequesterHandle>,
907    min_log_level: Option<Arc<RwLock<LogLevel>>>,
908}
909
910impl RequestContextBuilder {
911    /// Create a new builder
912    pub fn new() -> Self {
913        Self::default()
914    }
915
916    /// Set the request ID
917    pub fn request_id(mut self, id: RequestId) -> Self {
918        self.request_id = Some(id);
919        self
920    }
921
922    /// Set the progress token
923    pub fn progress_token(mut self, token: ProgressToken) -> Self {
924        self.progress_token = Some(token);
925        self
926    }
927
928    /// Set the notification sender
929    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
930        self.notification_tx = Some(tx);
931        self
932    }
933
934    /// Set the client requester for server-to-client requests
935    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
936        self.client_requester = Some(requester);
937        self
938    }
939
940    /// Set the minimum log level for filtering
941    pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
942        self.min_log_level = Some(level);
943        self
944    }
945
946    /// Build the request context
947    ///
948    /// Panics if request_id is not set.
949    pub fn build(self) -> RequestContext {
950        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
951        if let Some(token) = self.progress_token {
952            ctx = ctx.with_progress_token(token);
953        }
954        if let Some(tx) = self.notification_tx {
955            ctx = ctx.with_notification_sender(tx);
956        }
957        if let Some(requester) = self.client_requester {
958            ctx = ctx.with_client_requester(requester);
959        }
960        if let Some(level) = self.min_log_level {
961            ctx = ctx.with_min_log_level(level);
962        }
963        ctx
964    }
965}
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970
971    #[test]
972    fn test_cancellation() {
973        let ctx = RequestContext::new(RequestId::Number(1));
974        assert!(!ctx.is_cancelled());
975
976        let token = ctx.cancellation_token();
977        assert!(!token.is_cancelled());
978
979        ctx.cancel();
980        assert!(ctx.is_cancelled());
981        assert!(token.is_cancelled());
982    }
983
984    #[tokio::test]
985    async fn test_progress_reporting() {
986        let (tx, mut rx) = notification_channel(10);
987
988        let ctx = RequestContext::new(RequestId::Number(1))
989            .with_progress_token(ProgressToken::Number(42))
990            .with_notification_sender(tx);
991
992        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
993            .await;
994
995        let notification = rx.recv().await.unwrap();
996        match notification {
997            ServerNotification::Progress(params) => {
998                assert_eq!(params.progress, 50.0);
999                assert_eq!(params.total, Some(100.0));
1000                assert_eq!(params.message.as_deref(), Some("Halfway"));
1001            }
1002            _ => panic!("Expected Progress notification"),
1003        }
1004    }
1005
1006    #[tokio::test]
1007    async fn test_progress_no_token() {
1008        let (tx, mut rx) = notification_channel(10);
1009
1010        // No progress token - should be a no-op
1011        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1012
1013        ctx.report_progress(50.0, Some(100.0), None).await;
1014
1015        // Channel should be empty
1016        assert!(rx.try_recv().is_err());
1017    }
1018
1019    #[test]
1020    fn test_builder() {
1021        let (tx, _rx) = notification_channel(10);
1022
1023        let ctx = RequestContextBuilder::new()
1024            .request_id(RequestId::String("req-1".to_string()))
1025            .progress_token(ProgressToken::String("prog-1".to_string()))
1026            .notification_sender(tx)
1027            .build();
1028
1029        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
1030        assert!(ctx.progress_token().is_some());
1031    }
1032
1033    #[test]
1034    fn test_can_sample_without_requester() {
1035        let ctx = RequestContext::new(RequestId::Number(1));
1036        assert!(!ctx.can_sample());
1037    }
1038
1039    #[test]
1040    fn test_can_sample_with_requester() {
1041        let (request_tx, _rx) = outgoing_request_channel(10);
1042        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1043
1044        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1045        assert!(ctx.can_sample());
1046    }
1047
1048    #[tokio::test]
1049    async fn test_sample_without_requester_fails() {
1050        use crate::protocol::{CreateMessageParams, SamplingMessage};
1051
1052        let ctx = RequestContext::new(RequestId::Number(1));
1053        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
1054
1055        let result = ctx.sample(params).await;
1056        assert!(result.is_err());
1057        assert!(
1058            result
1059                .unwrap_err()
1060                .to_string()
1061                .contains("Sampling not available")
1062        );
1063    }
1064
1065    #[test]
1066    fn test_builder_with_client_requester() {
1067        let (request_tx, _rx) = outgoing_request_channel(10);
1068        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1069
1070        let ctx = RequestContextBuilder::new()
1071            .request_id(RequestId::Number(1))
1072            .client_requester(requester)
1073            .build();
1074
1075        assert!(ctx.can_sample());
1076    }
1077
1078    #[test]
1079    fn test_can_elicit_without_requester() {
1080        let ctx = RequestContext::new(RequestId::Number(1));
1081        assert!(!ctx.can_elicit());
1082    }
1083
1084    #[test]
1085    fn test_can_elicit_with_requester() {
1086        let (request_tx, _rx) = outgoing_request_channel(10);
1087        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
1088
1089        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1090        assert!(ctx.can_elicit());
1091    }
1092
1093    #[tokio::test]
1094    async fn test_elicit_form_without_requester_fails() {
1095        use crate::protocol::{ElicitFormSchema, ElicitMode};
1096
1097        let ctx = RequestContext::new(RequestId::Number(1));
1098        let params = ElicitFormParams {
1099            mode: Some(ElicitMode::Form),
1100            message: "Enter details".to_string(),
1101            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
1102            meta: None,
1103        };
1104
1105        let result = ctx.elicit_form(params).await;
1106        assert!(result.is_err());
1107        assert!(
1108            result
1109                .unwrap_err()
1110                .to_string()
1111                .contains("Elicitation not available")
1112        );
1113    }
1114
1115    #[tokio::test]
1116    async fn test_elicit_url_without_requester_fails() {
1117        use crate::protocol::ElicitMode;
1118
1119        let ctx = RequestContext::new(RequestId::Number(1));
1120        let params = ElicitUrlParams {
1121            mode: Some(ElicitMode::Url),
1122            elicitation_id: "test-123".to_string(),
1123            message: "Please authorize".to_string(),
1124            url: "https://example.com/auth".to_string(),
1125            meta: None,
1126        };
1127
1128        let result = ctx.elicit_url(params).await;
1129        assert!(result.is_err());
1130        assert!(
1131            result
1132                .unwrap_err()
1133                .to_string()
1134                .contains("Elicitation not available")
1135        );
1136    }
1137
1138    #[tokio::test]
1139    async fn test_confirm_without_requester_fails() {
1140        let ctx = RequestContext::new(RequestId::Number(1));
1141
1142        let result = ctx.confirm("Are you sure?").await;
1143        assert!(result.is_err());
1144        assert!(
1145            result
1146                .unwrap_err()
1147                .to_string()
1148                .contains("Elicitation not available")
1149        );
1150    }
1151
1152    #[tokio::test]
1153    async fn test_send_log_filtered_by_level() {
1154        let (tx, mut rx) = notification_channel(10);
1155        let min_level = Arc::new(RwLock::new(LogLevel::Warning));
1156
1157        let ctx = RequestContext::new(RequestId::Number(1))
1158            .with_notification_sender(tx)
1159            .with_min_log_level(min_level.clone());
1160
1161        // Error is more severe than Warning — should pass through
1162        ctx.send_log(LoggingMessageParams::new(
1163            LogLevel::Error,
1164            serde_json::Value::Null,
1165        ));
1166        let msg = rx.try_recv();
1167        assert!(msg.is_ok(), "Error should pass through Warning filter");
1168
1169        // Warning is equal to min level — should pass through
1170        ctx.send_log(LoggingMessageParams::new(
1171            LogLevel::Warning,
1172            serde_json::Value::Null,
1173        ));
1174        let msg = rx.try_recv();
1175        assert!(msg.is_ok(), "Warning should pass through Warning filter");
1176
1177        // Info is less severe than Warning — should be filtered
1178        ctx.send_log(LoggingMessageParams::new(
1179            LogLevel::Info,
1180            serde_json::Value::Null,
1181        ));
1182        let msg = rx.try_recv();
1183        assert!(msg.is_err(), "Info should be filtered by Warning filter");
1184
1185        // Debug is less severe than Warning — should be filtered
1186        ctx.send_log(LoggingMessageParams::new(
1187            LogLevel::Debug,
1188            serde_json::Value::Null,
1189        ));
1190        let msg = rx.try_recv();
1191        assert!(msg.is_err(), "Debug should be filtered by Warning filter");
1192    }
1193
1194    #[tokio::test]
1195    async fn test_send_log_level_updates_dynamically() {
1196        let (tx, mut rx) = notification_channel(10);
1197        let min_level = Arc::new(RwLock::new(LogLevel::Error));
1198
1199        let ctx = RequestContext::new(RequestId::Number(1))
1200            .with_notification_sender(tx)
1201            .with_min_log_level(min_level.clone());
1202
1203        // Info should be filtered at Error level
1204        ctx.send_log(LoggingMessageParams::new(
1205            LogLevel::Info,
1206            serde_json::Value::Null,
1207        ));
1208        assert!(
1209            rx.try_recv().is_err(),
1210            "Info should be filtered at Error level"
1211        );
1212
1213        // Dynamically update to Debug (most permissive)
1214        *min_level.write().unwrap() = LogLevel::Debug;
1215
1216        // Now Info should pass through
1217        ctx.send_log(LoggingMessageParams::new(
1218            LogLevel::Info,
1219            serde_json::Value::Null,
1220        ));
1221        assert!(
1222            rx.try_recv().is_ok(),
1223            "Info should pass through after level changed to Debug"
1224        );
1225    }
1226
1227    #[tokio::test]
1228    async fn test_send_log_no_min_level_sends_all() {
1229        let (tx, mut rx) = notification_channel(10);
1230
1231        // No min_log_level set — all messages should pass through
1232        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1233
1234        ctx.send_log(LoggingMessageParams::new(
1235            LogLevel::Debug,
1236            serde_json::Value::Null,
1237        ));
1238        assert!(
1239            rx.try_recv().is_ok(),
1240            "Debug should pass when no min level is set"
1241        );
1242    }
1243
1244    fn make_task_object(id: &str, status: TaskStatus) -> serde_json::Value {
1245        serde_json::json!({
1246            "taskId": id,
1247            "status": status,
1248            "createdAt": "2026-04-24T00:00:00Z",
1249            "lastUpdatedAt": "2026-04-24T00:00:00Z",
1250            "ttl": null
1251        })
1252    }
1253
1254    fn spawn_mock_client(
1255        mut rx: OutgoingRequestReceiver,
1256        responder: impl Fn(&str, serde_json::Value) -> serde_json::Value + Send + 'static,
1257    ) {
1258        tokio::spawn(async move {
1259            while let Some(req) = rx.recv().await {
1260                let response = responder(&req.method, req.params);
1261                let _ = req.response_tx.send(Ok(response));
1262            }
1263        });
1264    }
1265
1266    #[tokio::test]
1267    async fn test_get_task_info_round_trips() {
1268        let (tx, rx) = outgoing_request_channel(10);
1269        spawn_mock_client(rx, |method, params| {
1270            assert_eq!(method, "tasks/get");
1271            let task_id = params["taskId"].as_str().unwrap().to_string();
1272            make_task_object(&task_id, TaskStatus::Working)
1273        });
1274        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1275        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1276
1277        let info = ctx.get_task_info("task-123").await.unwrap();
1278        assert_eq!(info.task_id, "task-123");
1279        assert!(matches!(info.status, TaskStatus::Working));
1280    }
1281
1282    #[tokio::test]
1283    async fn test_list_tasks_round_trips() {
1284        let (tx, rx) = outgoing_request_channel(10);
1285        spawn_mock_client(rx, |method, params| {
1286            assert_eq!(method, "tasks/list");
1287            // Status filter should be forwarded
1288            assert_eq!(params["status"], serde_json::json!("working"));
1289            serde_json::json!({
1290                "tasks": [
1291                    make_task_object("task-1", TaskStatus::Working),
1292                    make_task_object("task-2", TaskStatus::Working),
1293                ]
1294            })
1295        });
1296        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1297        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1298
1299        let result = ctx.list_tasks(Some(TaskStatus::Working)).await.unwrap();
1300        assert_eq!(result.tasks.len(), 2);
1301        assert_eq!(result.tasks[0].task_id, "task-1");
1302    }
1303
1304    #[tokio::test]
1305    async fn test_cancel_task_forwards_reason() {
1306        let (tx, rx) = outgoing_request_channel(10);
1307        spawn_mock_client(rx, |method, params| {
1308            assert_eq!(method, "tasks/cancel");
1309            assert_eq!(params["reason"], serde_json::json!("user requested"));
1310            let task_id = params["taskId"].as_str().unwrap().to_string();
1311            make_task_object(&task_id, TaskStatus::Cancelled)
1312        });
1313        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(tx));
1314        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1315
1316        let task = ctx
1317            .cancel_task("task-99", Some("user requested".into()))
1318            .await
1319            .unwrap();
1320        assert_eq!(task.task_id, "task-99");
1321        assert!(matches!(task.status, TaskStatus::Cancelled));
1322    }
1323
1324    #[tokio::test]
1325    async fn test_get_task_info_without_requester_fails() {
1326        let ctx = RequestContext::new(RequestId::Number(1));
1327        let result = ctx.get_task_info("task-1").await;
1328        assert!(result.is_err());
1329        assert!(
1330            result
1331                .unwrap_err()
1332                .to_string()
1333                .contains("Client request not available")
1334        );
1335    }
1336
1337    #[tokio::test]
1338    async fn test_default_request_impl_errors() {
1339        // A custom requester that only implements sample/elicit (not request)
1340        // should reject task helpers.
1341        struct OnlySampleAndElicit;
1342
1343        #[async_trait]
1344        impl ClientRequester for OnlySampleAndElicit {
1345            async fn sample(&self, _: CreateMessageParams) -> Result<CreateMessageResult> {
1346                unreachable!()
1347            }
1348            async fn elicit(&self, _: ElicitRequestParams) -> Result<ElicitResult> {
1349                unreachable!()
1350            }
1351        }
1352
1353        let requester: ClientRequesterHandle = Arc::new(OnlySampleAndElicit);
1354        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
1355
1356        let err = ctx.get_task_info("x").await.unwrap_err();
1357        assert!(err.to_string().contains("does not support arbitrary"));
1358    }
1359}