Skip to main content

tower_mcp/
context.rs

1//! Request context for MCP handlers
2//!
3//! Provides progress reporting, cancellation support, and client request capabilities
4//! for long-running operations.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use tower_mcp::context::RequestContext;
10//!
11//! async fn long_running_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
12//!     for i in 0..100 {
13//!         // Check if cancelled
14//!         if ctx.is_cancelled() {
15//!             return Err(Error::tool("Operation cancelled"));
16//!         }
17//!
18//!         // Report progress
19//!         ctx.report_progress(i as f64, Some(100.0), Some("Processing...")).await;
20//!
21//!         do_work(i).await;
22//!     }
23//!     Ok(CallToolResult::text("Done!"))
24//! }
25//! ```
26//!
27//! # Sampling (LLM requests to client)
28//!
29//! ```rust,ignore
30//! use tower_mcp::context::RequestContext;
31//! use tower_mcp::{CreateMessageParams, SamplingMessage};
32//!
33//! async fn ai_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
34//!     // Request LLM completion from the client
35//!     let params = CreateMessageParams::new(
36//!         vec![SamplingMessage::user("Summarize this text...")],
37//!         500,
38//!     );
39//!
40//!     let result = ctx.sample(params).await?;
41//!     Ok(CallToolResult::text(format!("Summary: {:?}", result.content)))
42//! }
43//! ```
44//!
45//! # Elicitation (requesting user input)
46//!
47//! ```rust,ignore
48//! use tower_mcp::context::RequestContext;
49//! use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
50//!
51//! async fn interactive_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
52//!     // Request user input via form
53//!     let params = ElicitFormParams {
54//!         mode: Some(ElicitMode::Form),
55//!         message: "Please provide additional details".to_string(),
56//!         requested_schema: ElicitFormSchema::new()
57//!             .string_field("name", Some("Your name"), true)
58//!             .number_field("age", Some("Your age"), false),
59//!         meta: None,
60//!     };
61//!
62//!     let result = ctx.elicit_form(params).await?;
63//!     if result.action == ElicitAction::Accept {
64//!         // Use the form data
65//!         Ok(CallToolResult::text(format!("Got: {:?}", result.content)))
66//!     } else {
67//!         Ok(CallToolResult::text("User declined"))
68//!     }
69//! }
70//! ```
71
72use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
73use std::sync::{Arc, RwLock};
74
75use async_trait::async_trait;
76use tokio::sync::mpsc;
77
78use crate::error::{Error, Result};
79use crate::protocol::{
80    CreateMessageParams, CreateMessageResult, ElicitFormParams, ElicitRequestParams, ElicitResult,
81    ElicitUrlParams, LogLevel, LoggingMessageParams, ProgressParams, ProgressToken, RequestId,
82};
83
84/// A notification to be sent to the client
85#[derive(Debug, Clone)]
86#[non_exhaustive]
87pub enum ServerNotification {
88    /// Progress update for a request
89    Progress(ProgressParams),
90    /// Log message notification
91    LogMessage(LoggingMessageParams),
92    /// A subscribed resource has been updated
93    ResourceUpdated {
94        /// The URI of the updated resource
95        uri: String,
96    },
97    /// The list of available resources has changed
98    ResourcesListChanged,
99    /// The list of available tools has changed
100    ToolsListChanged,
101    /// The list of available prompts has changed
102    PromptsListChanged,
103    /// Task status has changed
104    TaskStatusChanged(crate::protocol::TaskStatusParams),
105}
106
107/// Sender for server notifications
108pub type NotificationSender = mpsc::Sender<ServerNotification>;
109
110/// Receiver for server notifications
111pub type NotificationReceiver = mpsc::Receiver<ServerNotification>;
112
113/// Create a new notification channel
114pub fn notification_channel(buffer: usize) -> (NotificationSender, NotificationReceiver) {
115    mpsc::channel(buffer)
116}
117
118// =============================================================================
119// Client Requests (Server -> Client)
120// =============================================================================
121
122/// Trait for sending requests from server to client
123///
124/// This enables bidirectional communication where the server can request
125/// actions from the client, such as sampling (LLM requests) and elicitation
126/// (user input requests).
127#[async_trait]
128pub trait ClientRequester: Send + Sync {
129    /// Send a sampling request to the client
130    ///
131    /// Returns the LLM completion result from the client.
132    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult>;
133
134    /// Send an elicitation request to the client
135    ///
136    /// This requests user input from the client. The request can be either
137    /// form-based (structured input) or URL-based (redirect to external URL).
138    ///
139    /// Returns the elicitation result with the user's action and any submitted data.
140    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult>;
141}
142
143/// A clonable handle to a client requester
144pub type ClientRequesterHandle = Arc<dyn ClientRequester>;
145
146/// Outgoing request to be sent to the client
147#[derive(Debug)]
148pub struct OutgoingRequest {
149    /// The JSON-RPC request ID
150    pub id: RequestId,
151    /// The method name
152    pub method: String,
153    /// The request parameters as JSON
154    pub params: serde_json::Value,
155    /// Channel to send the response back
156    pub response_tx: tokio::sync::oneshot::Sender<Result<serde_json::Value>>,
157}
158
159/// Sender for outgoing requests to the client
160pub type OutgoingRequestSender = mpsc::Sender<OutgoingRequest>;
161
162/// Receiver for outgoing requests (used by transport)
163pub type OutgoingRequestReceiver = mpsc::Receiver<OutgoingRequest>;
164
165/// Create a new outgoing request channel
166pub fn outgoing_request_channel(buffer: usize) -> (OutgoingRequestSender, OutgoingRequestReceiver) {
167    mpsc::channel(buffer)
168}
169
170/// A client requester implementation that sends requests through a channel
171#[derive(Clone)]
172pub struct ChannelClientRequester {
173    request_tx: OutgoingRequestSender,
174    next_id: Arc<AtomicI64>,
175}
176
177impl ChannelClientRequester {
178    /// Create a new channel-based client requester
179    pub fn new(request_tx: OutgoingRequestSender) -> Self {
180        Self {
181            request_tx,
182            next_id: Arc::new(AtomicI64::new(1)),
183        }
184    }
185
186    fn next_request_id(&self) -> RequestId {
187        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
188        RequestId::Number(id)
189    }
190}
191
192#[async_trait]
193impl ClientRequester for ChannelClientRequester {
194    async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
195        let id = self.next_request_id();
196        let params_json = serde_json::to_value(&params)
197            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
198
199        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
200
201        let request = OutgoingRequest {
202            id: id.clone(),
203            method: "sampling/createMessage".to_string(),
204            params: params_json,
205            response_tx,
206        };
207
208        self.request_tx
209            .send(request)
210            .await
211            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
212
213        let response = response_rx.await.map_err(|_| {
214            Error::Internal("Failed to receive response: channel closed".to_string())
215        })??;
216
217        serde_json::from_value(response)
218            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
219    }
220
221    async fn elicit(&self, params: ElicitRequestParams) -> Result<ElicitResult> {
222        let id = self.next_request_id();
223        let params_json = serde_json::to_value(&params)
224            .map_err(|e| Error::Internal(format!("Failed to serialize params: {}", e)))?;
225
226        let (response_tx, response_rx) = tokio::sync::oneshot::channel();
227
228        let request = OutgoingRequest {
229            id: id.clone(),
230            method: "elicitation/create".to_string(),
231            params: params_json,
232            response_tx,
233        };
234
235        self.request_tx
236            .send(request)
237            .await
238            .map_err(|_| Error::Internal("Failed to send request: channel closed".to_string()))?;
239
240        let response = response_rx.await.map_err(|_| {
241            Error::Internal("Failed to receive response: channel closed".to_string())
242        })??;
243
244        serde_json::from_value(response)
245            .map_err(|e| Error::Internal(format!("Failed to deserialize response: {}", e)))
246    }
247}
248
249/// Context for a request, providing progress, cancellation, and client request support
250#[derive(Clone)]
251pub struct RequestContext {
252    /// The request ID
253    request_id: RequestId,
254    /// Progress token (if provided by client)
255    progress_token: Option<ProgressToken>,
256    /// Cancellation flag
257    cancelled: Arc<AtomicBool>,
258    /// Channel for sending notifications
259    notification_tx: Option<NotificationSender>,
260    /// Handle for sending requests to the client (for sampling, etc.)
261    client_requester: Option<ClientRequesterHandle>,
262    /// Extensions for passing data from router/middleware to handlers
263    extensions: Arc<Extensions>,
264    /// Minimum log level set by the client (shared with router for dynamic updates)
265    min_log_level: Option<Arc<RwLock<LogLevel>>>,
266}
267
268/// Type-erased extensions map for passing data to handlers.
269///
270/// Extensions allow router-level state and middleware-injected data to flow
271/// to tool handlers via the `Extension<T>` extractor.
272#[derive(Clone, Default)]
273pub struct Extensions {
274    map: std::collections::HashMap<std::any::TypeId, Arc<dyn std::any::Any + Send + Sync>>,
275}
276
277impl Extensions {
278    /// Create an empty extensions map.
279    pub fn new() -> Self {
280        Self::default()
281    }
282
283    /// Insert a value into the extensions map.
284    ///
285    /// If a value of the same type already exists, it is replaced.
286    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
287        self.map.insert(std::any::TypeId::of::<T>(), Arc::new(val));
288    }
289
290    /// Get a reference to a value in the extensions map.
291    ///
292    /// Returns `None` if no value of the given type has been inserted.
293    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
294        self.map
295            .get(&std::any::TypeId::of::<T>())
296            .and_then(|val| val.downcast_ref::<T>())
297    }
298
299    /// Check if the extensions map contains a value of the given type.
300    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
301        self.map.contains_key(&std::any::TypeId::of::<T>())
302    }
303
304    /// Merge another extensions map into this one.
305    ///
306    /// Values from `other` will overwrite existing values of the same type.
307    pub fn merge(&mut self, other: &Extensions) {
308        for (k, v) in &other.map {
309            self.map.insert(*k, v.clone());
310        }
311    }
312}
313
314impl std::fmt::Debug for Extensions {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("Extensions")
317            .field("len", &self.map.len())
318            .finish()
319    }
320}
321
322impl std::fmt::Debug for RequestContext {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("RequestContext")
325            .field("request_id", &self.request_id)
326            .field("progress_token", &self.progress_token)
327            .field("cancelled", &self.cancelled.load(Ordering::Relaxed))
328            .finish()
329    }
330}
331
332impl RequestContext {
333    /// Create a new request context
334    pub fn new(request_id: RequestId) -> Self {
335        Self {
336            request_id,
337            progress_token: None,
338            cancelled: Arc::new(AtomicBool::new(false)),
339            notification_tx: None,
340            client_requester: None,
341            extensions: Arc::new(Extensions::new()),
342            min_log_level: None,
343        }
344    }
345
346    /// Set the progress token
347    pub fn with_progress_token(mut self, token: ProgressToken) -> Self {
348        self.progress_token = Some(token);
349        self
350    }
351
352    /// Set the notification sender
353    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
354        self.notification_tx = Some(tx);
355        self
356    }
357
358    /// Set the minimum log level for filtering outgoing log notifications
359    ///
360    /// This is shared with the router so that `logging/setLevel` updates
361    /// are immediately visible to all request contexts.
362    pub fn with_min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
363        self.min_log_level = Some(level);
364        self
365    }
366
367    /// Set the client requester for server-to-client requests
368    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
369        self.client_requester = Some(requester);
370        self
371    }
372
373    /// Set the extensions for this request context.
374    ///
375    /// Extensions allow router-level state and middleware data to flow to handlers.
376    pub fn with_extensions(mut self, extensions: Arc<Extensions>) -> Self {
377        self.extensions = extensions;
378        self
379    }
380
381    /// Get a reference to a value from the extensions map.
382    ///
383    /// Returns `None` if no value of the given type has been inserted.
384    ///
385    /// # Example
386    ///
387    /// ```rust,ignore
388    /// #[derive(Clone)]
389    /// struct CurrentUser { id: String }
390    ///
391    /// // In a handler:
392    /// if let Some(user) = ctx.extension::<CurrentUser>() {
393    ///     println!("User: {}", user.id);
394    /// }
395    /// ```
396    pub fn extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
397        self.extensions.get::<T>()
398    }
399
400    /// Get a mutable reference to the extensions.
401    ///
402    /// This allows middleware to insert data that handlers can access via
403    /// the `Extension<T>` extractor.
404    pub fn extensions_mut(&mut self) -> &mut Extensions {
405        Arc::make_mut(&mut self.extensions)
406    }
407
408    /// Get a reference to the extensions.
409    pub fn extensions(&self) -> &Extensions {
410        &self.extensions
411    }
412
413    /// Get the request ID
414    pub fn request_id(&self) -> &RequestId {
415        &self.request_id
416    }
417
418    /// Get the progress token (if any)
419    pub fn progress_token(&self) -> Option<&ProgressToken> {
420        self.progress_token.as_ref()
421    }
422
423    /// Check if the request has been cancelled
424    pub fn is_cancelled(&self) -> bool {
425        self.cancelled.load(Ordering::Relaxed)
426    }
427
428    /// Mark the request as cancelled
429    pub fn cancel(&self) {
430        self.cancelled.store(true, Ordering::Relaxed);
431    }
432
433    /// Get a cancellation token that can be shared
434    pub fn cancellation_token(&self) -> CancellationToken {
435        CancellationToken {
436            cancelled: self.cancelled.clone(),
437        }
438    }
439
440    /// Report progress to the client
441    ///
442    /// This is a no-op if no progress token was provided or no notification sender is configured.
443    pub async fn report_progress(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
444        let Some(token) = &self.progress_token else {
445            return;
446        };
447        let Some(tx) = &self.notification_tx else {
448            return;
449        };
450
451        let params = ProgressParams {
452            progress_token: token.clone(),
453            progress,
454            total,
455            message: message.map(|s| s.to_string()),
456            meta: None,
457        };
458
459        // Best effort - don't block if channel is full
460        let _ = tx.try_send(ServerNotification::Progress(params));
461    }
462
463    /// Report progress synchronously (non-async version)
464    ///
465    /// This is a no-op if no progress token was provided or no notification sender is configured.
466    pub fn report_progress_sync(&self, progress: f64, total: Option<f64>, message: Option<&str>) {
467        let Some(token) = &self.progress_token else {
468            return;
469        };
470        let Some(tx) = &self.notification_tx else {
471            return;
472        };
473
474        let params = ProgressParams {
475            progress_token: token.clone(),
476            progress,
477            total,
478            message: message.map(|s| s.to_string()),
479            meta: None,
480        };
481
482        let _ = tx.try_send(ServerNotification::Progress(params));
483    }
484
485    /// Send a log message notification to the client
486    ///
487    /// This is a no-op if no notification sender is configured.
488    ///
489    /// # Example
490    ///
491    /// ```rust,ignore
492    /// use tower_mcp::protocol::{LoggingMessageParams, LogLevel};
493    ///
494    /// async fn my_tool(ctx: RequestContext) {
495    ///     ctx.send_log(
496    ///         LoggingMessageParams::new(LogLevel::Info, serde_json::json!("Processing..."))
497    ///             .with_logger("my-tool")
498    ///     );
499    /// }
500    /// ```
501    pub fn send_log(&self, params: LoggingMessageParams) {
502        let Some(tx) = &self.notification_tx else {
503            return;
504        };
505
506        // Filter by minimum log level set via logging/setLevel
507        // LogLevel derives Ord with Emergency < Alert < ... < Debug,
508        // so a message passes if its severity is at least the minimum
509        // (i.e., its ordinal is <= the minimum level's ordinal).
510        if let Some(min_level) = &self.min_log_level
511            && let Ok(min) = min_level.read()
512            && params.level > *min
513        {
514            return;
515        }
516
517        let _ = tx.try_send(ServerNotification::LogMessage(params));
518    }
519
520    /// Check if sampling is available
521    ///
522    /// Returns true if a client requester is configured and the transport
523    /// supports bidirectional communication.
524    pub fn can_sample(&self) -> bool {
525        self.client_requester.is_some()
526    }
527
528    /// Request an LLM completion from the client
529    ///
530    /// This sends a `sampling/createMessage` request to the client and waits
531    /// for the response. The client is expected to forward this to an LLM
532    /// and return the result.
533    ///
534    /// Returns an error if sampling is not available (no client requester configured).
535    ///
536    /// # Example
537    ///
538    /// ```rust,ignore
539    /// use tower_mcp::{CreateMessageParams, SamplingMessage};
540    ///
541    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
542    ///     let params = CreateMessageParams::new(
543    ///         vec![SamplingMessage::user("Summarize: ...")],
544    ///         500,
545    ///     );
546    ///
547    ///     let result = ctx.sample(params).await?;
548    ///     Ok(CallToolResult::text(format!("{:?}", result.content)))
549    /// }
550    /// ```
551    pub async fn sample(&self, params: CreateMessageParams) -> Result<CreateMessageResult> {
552        let requester = self.client_requester.as_ref().ok_or_else(|| {
553            Error::Internal("Sampling not available: no client requester configured".to_string())
554        })?;
555
556        requester.sample(params).await
557    }
558
559    /// Check if elicitation is available
560    ///
561    /// Returns true if a client requester is configured and the transport
562    /// supports bidirectional communication. Note that this only checks if
563    /// the mechanism is available, not whether the client supports elicitation.
564    pub fn can_elicit(&self) -> bool {
565        self.client_requester.is_some()
566    }
567
568    /// Request user input via a form from the client
569    ///
570    /// This sends an `elicitation/create` request to the client with a form schema.
571    /// The client renders the form to the user and returns their response.
572    ///
573    /// Returns an error if elicitation is not available (no client requester configured).
574    ///
575    /// # Example
576    ///
577    /// ```rust,ignore
578    /// use tower_mcp::{ElicitFormParams, ElicitFormSchema, ElicitMode, ElicitAction};
579    ///
580    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
581    ///     let params = ElicitFormParams {
582    ///         mode: Some(ElicitMode::Form),
583    ///         message: "Please enter your details".to_string(),
584    ///         requested_schema: ElicitFormSchema::new()
585    ///             .string_field("name", Some("Your name"), true),
586    ///         meta: None,
587    ///     };
588    ///
589    ///     let result = ctx.elicit_form(params).await?;
590    ///     match result.action {
591    ///         ElicitAction::Accept => {
592    ///             // Use result.content
593    ///             Ok(CallToolResult::text("Got your input!"))
594    ///         }
595    ///         _ => Ok(CallToolResult::text("User declined"))
596    ///     }
597    /// }
598    /// ```
599    pub async fn elicit_form(&self, params: ElicitFormParams) -> Result<ElicitResult> {
600        let requester = self.client_requester.as_ref().ok_or_else(|| {
601            Error::Internal("Elicitation not available: no client requester configured".to_string())
602        })?;
603
604        requester.elicit(ElicitRequestParams::Form(params)).await
605    }
606
607    /// Request user input via URL redirect from the client
608    ///
609    /// This sends an `elicitation/create` request to the client with a URL.
610    /// The client directs the user to the URL for out-of-band input collection.
611    /// The server receives the result via a callback notification.
612    ///
613    /// Returns an error if elicitation is not available (no client requester configured).
614    ///
615    /// # Example
616    ///
617    /// ```rust,ignore
618    /// use tower_mcp::{ElicitUrlParams, ElicitMode, ElicitAction};
619    ///
620    /// async fn my_tool(ctx: RequestContext, input: MyInput) -> Result<CallToolResult> {
621    ///     let params = ElicitUrlParams {
622    ///         mode: Some(ElicitMode::Url),
623    ///         elicitation_id: "unique-id-123".to_string(),
624    ///         message: "Please authorize via the link".to_string(),
625    ///         url: "https://example.com/auth?id=unique-id-123".to_string(),
626    ///         meta: None,
627    ///     };
628    ///
629    ///     let result = ctx.elicit_url(params).await?;
630    ///     match result.action {
631    ///         ElicitAction::Accept => Ok(CallToolResult::text("Authorization complete!")),
632    ///         _ => Ok(CallToolResult::text("Authorization cancelled"))
633    ///     }
634    /// }
635    /// ```
636    pub async fn elicit_url(&self, params: ElicitUrlParams) -> Result<ElicitResult> {
637        let requester = self.client_requester.as_ref().ok_or_else(|| {
638            Error::Internal("Elicitation not available: no client requester configured".to_string())
639        })?;
640
641        requester.elicit(ElicitRequestParams::Url(params)).await
642    }
643
644    /// Request simple confirmation from the user.
645    ///
646    /// This is a convenience method for simple yes/no confirmation dialogs.
647    /// It creates an elicitation form with a single boolean "confirm" field
648    /// and returns `true` if the user accepts, `false` otherwise.
649    ///
650    /// Returns an error if elicitation is not available (no client requester configured).
651    ///
652    /// # Example
653    ///
654    /// ```rust,ignore
655    /// use tower_mcp::{RequestContext, CallToolResult};
656    ///
657    /// async fn delete_item(ctx: RequestContext) -> Result<CallToolResult> {
658    ///     let confirmed = ctx.confirm("Are you sure you want to delete this item?").await?;
659    ///     if confirmed {
660    ///         // Perform deletion
661    ///         Ok(CallToolResult::text("Item deleted"))
662    ///     } else {
663    ///         Ok(CallToolResult::text("Deletion cancelled"))
664    ///     }
665    /// }
666    /// ```
667    pub async fn confirm(&self, message: impl Into<String>) -> Result<bool> {
668        use crate::protocol::{ElicitAction, ElicitFormParams, ElicitFormSchema, ElicitMode};
669
670        let params = ElicitFormParams {
671            mode: Some(ElicitMode::Form),
672            message: message.into(),
673            requested_schema: ElicitFormSchema::new().boolean_field_with_default(
674                "confirm",
675                Some("Confirm this action"),
676                true,
677                false,
678            ),
679            meta: None,
680        };
681
682        let result = self.elicit_form(params).await?;
683        Ok(result.action == ElicitAction::Accept)
684    }
685}
686
687/// A token that can be used to check for cancellation
688#[derive(Clone, Debug)]
689pub struct CancellationToken {
690    cancelled: Arc<AtomicBool>,
691}
692
693impl CancellationToken {
694    /// Check if cancellation has been requested
695    pub fn is_cancelled(&self) -> bool {
696        self.cancelled.load(Ordering::Relaxed)
697    }
698
699    /// Request cancellation
700    pub fn cancel(&self) {
701        self.cancelled.store(true, Ordering::Relaxed);
702    }
703}
704
705/// Builder for creating request contexts
706#[derive(Default)]
707pub struct RequestContextBuilder {
708    request_id: Option<RequestId>,
709    progress_token: Option<ProgressToken>,
710    notification_tx: Option<NotificationSender>,
711    client_requester: Option<ClientRequesterHandle>,
712    min_log_level: Option<Arc<RwLock<LogLevel>>>,
713}
714
715impl RequestContextBuilder {
716    /// Create a new builder
717    pub fn new() -> Self {
718        Self::default()
719    }
720
721    /// Set the request ID
722    pub fn request_id(mut self, id: RequestId) -> Self {
723        self.request_id = Some(id);
724        self
725    }
726
727    /// Set the progress token
728    pub fn progress_token(mut self, token: ProgressToken) -> Self {
729        self.progress_token = Some(token);
730        self
731    }
732
733    /// Set the notification sender
734    pub fn notification_sender(mut self, tx: NotificationSender) -> Self {
735        self.notification_tx = Some(tx);
736        self
737    }
738
739    /// Set the client requester for server-to-client requests
740    pub fn client_requester(mut self, requester: ClientRequesterHandle) -> Self {
741        self.client_requester = Some(requester);
742        self
743    }
744
745    /// Set the minimum log level for filtering
746    pub fn min_log_level(mut self, level: Arc<RwLock<LogLevel>>) -> Self {
747        self.min_log_level = Some(level);
748        self
749    }
750
751    /// Build the request context
752    ///
753    /// Panics if request_id is not set.
754    pub fn build(self) -> RequestContext {
755        let mut ctx = RequestContext::new(self.request_id.expect("request_id is required"));
756        if let Some(token) = self.progress_token {
757            ctx = ctx.with_progress_token(token);
758        }
759        if let Some(tx) = self.notification_tx {
760            ctx = ctx.with_notification_sender(tx);
761        }
762        if let Some(requester) = self.client_requester {
763            ctx = ctx.with_client_requester(requester);
764        }
765        if let Some(level) = self.min_log_level {
766            ctx = ctx.with_min_log_level(level);
767        }
768        ctx
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775
776    #[test]
777    fn test_cancellation() {
778        let ctx = RequestContext::new(RequestId::Number(1));
779        assert!(!ctx.is_cancelled());
780
781        let token = ctx.cancellation_token();
782        assert!(!token.is_cancelled());
783
784        ctx.cancel();
785        assert!(ctx.is_cancelled());
786        assert!(token.is_cancelled());
787    }
788
789    #[tokio::test]
790    async fn test_progress_reporting() {
791        let (tx, mut rx) = notification_channel(10);
792
793        let ctx = RequestContext::new(RequestId::Number(1))
794            .with_progress_token(ProgressToken::Number(42))
795            .with_notification_sender(tx);
796
797        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
798            .await;
799
800        let notification = rx.recv().await.unwrap();
801        match notification {
802            ServerNotification::Progress(params) => {
803                assert_eq!(params.progress, 50.0);
804                assert_eq!(params.total, Some(100.0));
805                assert_eq!(params.message.as_deref(), Some("Halfway"));
806            }
807            _ => panic!("Expected Progress notification"),
808        }
809    }
810
811    #[tokio::test]
812    async fn test_progress_no_token() {
813        let (tx, mut rx) = notification_channel(10);
814
815        // No progress token - should be a no-op
816        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
817
818        ctx.report_progress(50.0, Some(100.0), None).await;
819
820        // Channel should be empty
821        assert!(rx.try_recv().is_err());
822    }
823
824    #[test]
825    fn test_builder() {
826        let (tx, _rx) = notification_channel(10);
827
828        let ctx = RequestContextBuilder::new()
829            .request_id(RequestId::String("req-1".to_string()))
830            .progress_token(ProgressToken::String("prog-1".to_string()))
831            .notification_sender(tx)
832            .build();
833
834        assert_eq!(ctx.request_id(), &RequestId::String("req-1".to_string()));
835        assert!(ctx.progress_token().is_some());
836    }
837
838    #[test]
839    fn test_can_sample_without_requester() {
840        let ctx = RequestContext::new(RequestId::Number(1));
841        assert!(!ctx.can_sample());
842    }
843
844    #[test]
845    fn test_can_sample_with_requester() {
846        let (request_tx, _rx) = outgoing_request_channel(10);
847        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
848
849        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
850        assert!(ctx.can_sample());
851    }
852
853    #[tokio::test]
854    async fn test_sample_without_requester_fails() {
855        use crate::protocol::{CreateMessageParams, SamplingMessage};
856
857        let ctx = RequestContext::new(RequestId::Number(1));
858        let params = CreateMessageParams::new(vec![SamplingMessage::user("test")], 100);
859
860        let result = ctx.sample(params).await;
861        assert!(result.is_err());
862        assert!(
863            result
864                .unwrap_err()
865                .to_string()
866                .contains("Sampling not available")
867        );
868    }
869
870    #[test]
871    fn test_builder_with_client_requester() {
872        let (request_tx, _rx) = outgoing_request_channel(10);
873        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
874
875        let ctx = RequestContextBuilder::new()
876            .request_id(RequestId::Number(1))
877            .client_requester(requester)
878            .build();
879
880        assert!(ctx.can_sample());
881    }
882
883    #[test]
884    fn test_can_elicit_without_requester() {
885        let ctx = RequestContext::new(RequestId::Number(1));
886        assert!(!ctx.can_elicit());
887    }
888
889    #[test]
890    fn test_can_elicit_with_requester() {
891        let (request_tx, _rx) = outgoing_request_channel(10);
892        let requester: ClientRequesterHandle = Arc::new(ChannelClientRequester::new(request_tx));
893
894        let ctx = RequestContext::new(RequestId::Number(1)).with_client_requester(requester);
895        assert!(ctx.can_elicit());
896    }
897
898    #[tokio::test]
899    async fn test_elicit_form_without_requester_fails() {
900        use crate::protocol::{ElicitFormSchema, ElicitMode};
901
902        let ctx = RequestContext::new(RequestId::Number(1));
903        let params = ElicitFormParams {
904            mode: Some(ElicitMode::Form),
905            message: "Enter details".to_string(),
906            requested_schema: ElicitFormSchema::new().string_field("name", None, true),
907            meta: None,
908        };
909
910        let result = ctx.elicit_form(params).await;
911        assert!(result.is_err());
912        assert!(
913            result
914                .unwrap_err()
915                .to_string()
916                .contains("Elicitation not available")
917        );
918    }
919
920    #[tokio::test]
921    async fn test_elicit_url_without_requester_fails() {
922        use crate::protocol::ElicitMode;
923
924        let ctx = RequestContext::new(RequestId::Number(1));
925        let params = ElicitUrlParams {
926            mode: Some(ElicitMode::Url),
927            elicitation_id: "test-123".to_string(),
928            message: "Please authorize".to_string(),
929            url: "https://example.com/auth".to_string(),
930            meta: None,
931        };
932
933        let result = ctx.elicit_url(params).await;
934        assert!(result.is_err());
935        assert!(
936            result
937                .unwrap_err()
938                .to_string()
939                .contains("Elicitation not available")
940        );
941    }
942
943    #[tokio::test]
944    async fn test_confirm_without_requester_fails() {
945        let ctx = RequestContext::new(RequestId::Number(1));
946
947        let result = ctx.confirm("Are you sure?").await;
948        assert!(result.is_err());
949        assert!(
950            result
951                .unwrap_err()
952                .to_string()
953                .contains("Elicitation not available")
954        );
955    }
956
957    #[tokio::test]
958    async fn test_send_log_filtered_by_level() {
959        let (tx, mut rx) = notification_channel(10);
960        let min_level = Arc::new(RwLock::new(LogLevel::Warning));
961
962        let ctx = RequestContext::new(RequestId::Number(1))
963            .with_notification_sender(tx)
964            .with_min_log_level(min_level.clone());
965
966        // Error is more severe than Warning — should pass through
967        ctx.send_log(LoggingMessageParams::new(
968            LogLevel::Error,
969            serde_json::Value::Null,
970        ));
971        let msg = rx.try_recv();
972        assert!(msg.is_ok(), "Error should pass through Warning filter");
973
974        // Warning is equal to min level — should pass through
975        ctx.send_log(LoggingMessageParams::new(
976            LogLevel::Warning,
977            serde_json::Value::Null,
978        ));
979        let msg = rx.try_recv();
980        assert!(msg.is_ok(), "Warning should pass through Warning filter");
981
982        // Info is less severe than Warning — should be filtered
983        ctx.send_log(LoggingMessageParams::new(
984            LogLevel::Info,
985            serde_json::Value::Null,
986        ));
987        let msg = rx.try_recv();
988        assert!(msg.is_err(), "Info should be filtered by Warning filter");
989
990        // Debug is less severe than Warning — should be filtered
991        ctx.send_log(LoggingMessageParams::new(
992            LogLevel::Debug,
993            serde_json::Value::Null,
994        ));
995        let msg = rx.try_recv();
996        assert!(msg.is_err(), "Debug should be filtered by Warning filter");
997    }
998
999    #[tokio::test]
1000    async fn test_send_log_level_updates_dynamically() {
1001        let (tx, mut rx) = notification_channel(10);
1002        let min_level = Arc::new(RwLock::new(LogLevel::Error));
1003
1004        let ctx = RequestContext::new(RequestId::Number(1))
1005            .with_notification_sender(tx)
1006            .with_min_log_level(min_level.clone());
1007
1008        // Info should be filtered at Error level
1009        ctx.send_log(LoggingMessageParams::new(
1010            LogLevel::Info,
1011            serde_json::Value::Null,
1012        ));
1013        assert!(
1014            rx.try_recv().is_err(),
1015            "Info should be filtered at Error level"
1016        );
1017
1018        // Dynamically update to Debug (most permissive)
1019        *min_level.write().unwrap() = LogLevel::Debug;
1020
1021        // Now Info should pass through
1022        ctx.send_log(LoggingMessageParams::new(
1023            LogLevel::Info,
1024            serde_json::Value::Null,
1025        ));
1026        assert!(
1027            rx.try_recv().is_ok(),
1028            "Info should pass through after level changed to Debug"
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn test_send_log_no_min_level_sends_all() {
1034        let (tx, mut rx) = notification_channel(10);
1035
1036        // No min_log_level set — all messages should pass through
1037        let ctx = RequestContext::new(RequestId::Number(1)).with_notification_sender(tx);
1038
1039        ctx.send_log(LoggingMessageParams::new(
1040            LogLevel::Debug,
1041            serde_json::Value::Null,
1042        ));
1043        assert!(
1044            rx.try_recv().is_ok(),
1045            "Debug should pass when no min level is set"
1046        );
1047    }
1048}