ultrafast_mcp_server/
context.rs

1//! Context module for UltraFastServer
2//!
3//! This module provides the Context type that allows tools and handlers to interact
4//! with the server for progress tracking, logging, and other operations.
5
6use serde_json::Value;
7use std::borrow::Cow;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::{debug, error, info, warn};
11
12use ultrafast_mcp_core::{
13    error::MCPResult,
14    protocol::jsonrpc::{JsonRpcMessage, JsonRpcRequest},
15    types::notifications::{LogLevel, LoggingMessageNotification, ProgressNotification},
16};
17
18/// Simple cancellation manager for tracking cancelled requests
19#[derive(Debug, Clone)]
20pub struct CancellationManager {
21    cancelled_requests: Arc<tokio::sync::RwLock<std::collections::HashSet<String>>>,
22}
23
24impl CancellationManager {
25    pub fn new() -> Self {
26        Self {
27            cancelled_requests: Arc::new(
28                tokio::sync::RwLock::new(std::collections::HashSet::new()),
29            ),
30        }
31    }
32
33    pub async fn cancel_request(&self, request_id: &str) {
34        let mut requests = self.cancelled_requests.write().await;
35        requests.insert(request_id.to_string());
36    }
37
38    pub async fn is_cancelled(&self, request_id: &str) -> bool {
39        let requests = self.cancelled_requests.read().await;
40        requests.contains(request_id)
41    }
42
43    pub async fn clear_cancelled(&self, request_id: &str) {
44        let mut requests = self.cancelled_requests.write().await;
45        requests.remove(request_id);
46    }
47}
48
49impl Default for CancellationManager {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55/// Logger configuration for the context
56#[derive(Debug, Clone)]
57pub struct LoggerConfig {
58    /// Minimum log level to process
59    pub min_level: LogLevel,
60    /// Whether to send notifications to client
61    pub send_notifications: bool,
62    /// Whether to include structured output
63    pub structured_output: bool,
64    /// Maximum log message length
65    pub max_message_length: usize,
66    /// Whether to include timestamps
67    pub include_timestamps: bool,
68    /// Whether to include logger name
69    pub include_logger_name: bool,
70    /// Custom logger name
71    pub logger_name: Option<String>,
72}
73
74impl Default for LoggerConfig {
75    fn default() -> Self {
76        Self {
77            min_level: LogLevel::Info,
78            send_notifications: true,
79            structured_output: true,
80            max_message_length: 4096,
81            include_timestamps: true,
82            include_logger_name: true,
83            logger_name: None,
84        }
85    }
86}
87
88/// Notification sender for sending messages to the client
89type NotificationSender = Arc<
90    dyn Fn(
91            JsonRpcMessage,
92        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = MCPResult<()>> + Send>>
93        + Send
94        + Sync,
95>;
96
97/// Context for tool and handler execution
98///
99/// Provides access to server functionality like progress tracking, logging,
100/// and request metadata.
101#[derive(Clone)]
102pub struct Context {
103    session_id: Option<String>,
104    request_id: Option<String>,
105    metadata: HashMap<String, serde_json::Value>,
106    logger_config: LoggerConfig,
107    notification_sender: Option<NotificationSender>,
108    cancellation_manager: Option<Arc<CancellationManager>>,
109}
110
111impl std::fmt::Debug for Context {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("Context")
114            .field("session_id", &self.session_id)
115            .field("request_id", &self.request_id)
116            .field("metadata", &self.metadata)
117            .field("logger_config", &self.logger_config)
118            .field("notification_sender", &self.notification_sender.is_some())
119            .finish()
120    }
121}
122
123impl Context {
124    /// Create a new empty context
125    pub fn new() -> Self {
126        Self {
127            session_id: None,
128            request_id: None,
129            metadata: HashMap::new(),
130            logger_config: LoggerConfig::default(),
131            notification_sender: None,
132            cancellation_manager: None,
133        }
134    }
135
136    /// Create a context with session ID
137    pub fn with_session_id(mut self, session_id: String) -> Self {
138        self.session_id = Some(session_id);
139        self
140    }
141
142    /// Create a context with request ID
143    pub fn with_request_id(mut self, request_id: String) -> Self {
144        self.request_id = Some(request_id);
145        self
146    }
147
148    /// Add metadata to the context
149    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
150        self.metadata.insert(key, value);
151        self
152    }
153
154    /// Configure the logger for this context
155    pub fn with_logger_config(mut self, config: LoggerConfig) -> Self {
156        self.logger_config = config;
157        self
158    }
159
160    /// Set the notification sender
161    pub fn with_notification_sender(mut self, sender: NotificationSender) -> Self {
162        self.notification_sender = Some(sender);
163        self
164    }
165
166    /// Set the cancellation manager
167    pub fn with_cancellation_manager(mut self, manager: Arc<CancellationManager>) -> Self {
168        self.cancellation_manager = Some(manager);
169        self
170    }
171
172    /// Get the session ID
173    pub fn session_id(&self) -> Option<&str> {
174        self.session_id.as_deref()
175    }
176
177    /// Get the request ID
178    pub fn request_id(&self) -> Option<&str> {
179        self.request_id.as_deref()
180    }
181
182    /// Get metadata value
183    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
184        self.metadata.get(key)
185    }
186
187    /// Set the minimum log level
188    pub fn set_log_level(&mut self, level: LogLevel) {
189        self.logger_config.min_level = level;
190    }
191
192    /// Get the current minimum log level
193    pub fn get_log_level(&self) -> &LogLevel {
194        &self.logger_config.min_level
195    }
196
197    /// Check if a log level should be processed
198    fn should_log(&self, level: &LogLevel) -> bool {
199        let level_priority = log_level_priority(level);
200        let min_priority = log_level_priority(&self.logger_config.min_level);
201        level_priority >= min_priority
202    }
203
204    /// Send a progress update
205    ///
206    /// # Arguments
207    /// * `message` - Progress message
208    /// * `progress` - Current progress value (0.0 to 1.0)
209    /// * `total` - Optional total value
210    pub async fn progress(
211        &self,
212        message: &str,
213        progress: f64,
214        total: Option<f64>,
215    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
216        // Log progress for debugging
217        if let Some(total) = total {
218            info!(
219                "Progress: {} - {:.2}/{:.2} ({:.1}%)",
220                message,
221                progress,
222                total,
223                (progress / total) * 100.0
224            );
225        } else {
226            info!("Progress: {} - {:.2}", message, progress);
227        }
228
229        // Send progress notification if sender is available
230        if let Some(sender) = &self.notification_sender {
231            let progress_token = self
232                .request_id()
233                .map(|id| serde_json::Value::String(id.to_string()))
234                .unwrap_or(serde_json::Value::Null);
235
236            let mut notification = ProgressNotification::new(progress_token, progress)
237                .with_message(message.to_string());
238
239            if let Some(total) = total {
240                notification = notification.with_total(total);
241            }
242
243            let notification_request = JsonRpcRequest {
244                jsonrpc: Cow::Borrowed("2.0"),
245                id: None, // Notifications don't have IDs
246                method: "notifications/progress".to_string(),
247                params: Some(serde_json::to_value(notification)?),
248                meta: std::collections::HashMap::new(),
249            };
250
251            sender(JsonRpcMessage::Notification(notification_request)).await?;
252        }
253
254        Ok(())
255    }
256
257    /// Log a debug message
258    pub async fn log_debug(
259        &self,
260        message: &str,
261    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
262        self.log_with_level(LogLevel::Debug, message, None).await
263    }
264
265    /// Log a debug message with structured data
266    pub async fn log_debug_structured(
267        &self,
268        message: &str,
269        data: Value,
270    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
271        self.log_with_level(LogLevel::Debug, message, Some(data))
272            .await
273    }
274
275    /// Log an info message
276    pub async fn log_info(
277        &self,
278        message: &str,
279    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
280        self.log_with_level(LogLevel::Info, message, None).await
281    }
282
283    /// Log an info message with structured data
284    pub async fn log_info_structured(
285        &self,
286        message: &str,
287        data: Value,
288    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
289        self.log_with_level(LogLevel::Info, message, Some(data))
290            .await
291    }
292
293    /// Log a notice message
294    pub async fn log_notice(
295        &self,
296        message: &str,
297    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
298        self.log_with_level(LogLevel::Notice, message, None).await
299    }
300
301    /// Log a notice message with structured data
302    pub async fn log_notice_structured(
303        &self,
304        message: &str,
305        data: Value,
306    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
307        self.log_with_level(LogLevel::Notice, message, Some(data))
308            .await
309    }
310
311    /// Log a warning message
312    pub async fn log_warn(
313        &self,
314        message: &str,
315    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
316        self.log_with_level(LogLevel::Warning, message, None).await
317    }
318
319    /// Log a warning message with structured data
320    pub async fn log_warn_structured(
321        &self,
322        message: &str,
323        data: Value,
324    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
325        self.log_with_level(LogLevel::Warning, message, Some(data))
326            .await
327    }
328
329    /// Log an error message
330    pub async fn log_error(
331        &self,
332        message: &str,
333    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
334        self.log_with_level(LogLevel::Error, message, None).await
335    }
336
337    /// Log an error message with structured data
338    pub async fn log_error_structured(
339        &self,
340        message: &str,
341        data: Value,
342    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
343        self.log_with_level(LogLevel::Error, message, Some(data))
344            .await
345    }
346
347    /// Log a critical message
348    pub async fn log_critical(
349        &self,
350        message: &str,
351    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352        self.log_with_level(LogLevel::Critical, message, None).await
353    }
354
355    /// Log a critical message with structured data
356    pub async fn log_critical_structured(
357        &self,
358        message: &str,
359        data: Value,
360    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
361        self.log_with_level(LogLevel::Critical, message, Some(data))
362            .await
363    }
364
365    /// Log an alert message
366    pub async fn log_alert(
367        &self,
368        message: &str,
369    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
370        self.log_with_level(LogLevel::Alert, message, None).await
371    }
372
373    /// Log an alert message with structured data
374    pub async fn log_alert_structured(
375        &self,
376        message: &str,
377        data: Value,
378    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
379        self.log_with_level(LogLevel::Alert, message, Some(data))
380            .await
381    }
382
383    /// Log an emergency message
384    pub async fn log_emergency(
385        &self,
386        message: &str,
387    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
388        self.log_with_level(LogLevel::Emergency, message, None)
389            .await
390    }
391
392    /// Log an emergency message with structured data
393    pub async fn log_emergency_structured(
394        &self,
395        message: &str,
396        data: Value,
397    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
398        self.log_with_level(LogLevel::Emergency, message, Some(data))
399            .await
400    }
401
402    /// Internal method to log with a specific level
403    async fn log_with_level(
404        &self,
405        level: LogLevel,
406        message: &str,
407        structured_data: Option<Value>,
408    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
409        // Check if this level should be logged
410        if !self.should_log(&level) {
411            return Ok(());
412        }
413
414        // Truncate message if too long
415        let truncated_message = if message.len() > self.logger_config.max_message_length {
416            let mut truncated = message[..self.logger_config.max_message_length - 3].to_string();
417            truncated.push_str("...");
418            truncated
419        } else {
420            message.to_string()
421        };
422
423        // Create structured log data
424        let log_data = if self.logger_config.structured_output {
425            let mut data_obj = serde_json::Map::new();
426
427            // Add basic message
428            data_obj.insert(
429                "message".to_string(),
430                Value::String(truncated_message.clone()),
431            );
432
433            // Add request context
434            if let Some(request_id) = &self.request_id {
435                data_obj.insert("request_id".to_string(), Value::String(request_id.clone()));
436            }
437
438            if let Some(session_id) = &self.session_id {
439                data_obj.insert("session_id".to_string(), Value::String(session_id.clone()));
440            }
441
442            // Add timestamp if configured
443            if self.logger_config.include_timestamps {
444                let timestamp = chrono::Utc::now().to_rfc3339();
445                data_obj.insert("timestamp".to_string(), Value::String(timestamp));
446            }
447
448            // Add logger name if configured
449            if self.logger_config.include_logger_name {
450                let logger_name = self
451                    .logger_config
452                    .logger_name
453                    .as_deref()
454                    .unwrap_or("ultrafast-mcp-server");
455                data_obj.insert("logger".to_string(), Value::String(logger_name.to_string()));
456            }
457
458            // Add level
459            data_obj.insert(
460                "level".to_string(),
461                Value::String(format!("{level:?}").to_lowercase()),
462            );
463
464            // Add any structured data
465            if let Some(data) = structured_data {
466                data_obj.insert("data".to_string(), data);
467            }
468
469            // Add metadata
470            if !self.metadata.is_empty() {
471                data_obj.insert(
472                    "metadata".to_string(),
473                    Value::Object(
474                        self.metadata
475                            .iter()
476                            .map(|(k, v)| (k.clone(), v.clone()))
477                            .collect(),
478                    ),
479                );
480            }
481
482            Value::Object(data_obj)
483        } else {
484            // Simple string message
485            Value::String(truncated_message.clone())
486        };
487
488        // Log to tracing system based on level
489        let request_context = self.request_id.as_deref().unwrap_or("unknown");
490        match level {
491            LogLevel::Debug => debug!("[{}] {}", request_context, truncated_message),
492            LogLevel::Info => info!("[{}] {}", request_context, truncated_message),
493            LogLevel::Notice => info!("[{}] NOTICE: {}", request_context, truncated_message),
494            LogLevel::Warning => warn!("[{}] {}", request_context, truncated_message),
495            LogLevel::Error => error!("[{}] {}", request_context, truncated_message),
496            LogLevel::Critical => error!("[{}] CRITICAL: {}", request_context, truncated_message),
497            LogLevel::Alert => error!("[{}] ALERT: {}", request_context, truncated_message),
498            LogLevel::Emergency => error!("[{}] EMERGENCY: {}", request_context, truncated_message),
499        }
500
501        // Send logging notification to client if configured and sender is available
502        if self.logger_config.send_notifications {
503            if let Some(sender) = &self.notification_sender {
504                let logger_name = self
505                    .logger_config
506                    .logger_name
507                    .as_deref()
508                    .unwrap_or("ultrafast-mcp-server");
509
510                let notification = LoggingMessageNotification::new(level, log_data)
511                    .with_logger(logger_name.to_string());
512
513                let notification_request = JsonRpcRequest {
514                    jsonrpc: Cow::Borrowed("2.0"),
515                    id: None, // Notifications don't have IDs
516                    method: "notifications/message".to_string(),
517                    params: Some(serde_json::to_value(notification)?),
518                    meta: std::collections::HashMap::new(),
519                };
520
521                // Send notification but don't fail if it doesn't work
522                if let Err(e) = sender(JsonRpcMessage::Notification(notification_request)).await {
523                    // Log the error but don't propagate it
524                    error!("Failed to send logging notification: {}", e);
525                }
526            }
527        }
528
529        Ok(())
530    }
531
532    /// Check if the current request has been cancelled
533    pub async fn is_cancelled(&self) -> bool {
534        if let Some(cancellation_manager) = &self.cancellation_manager {
535            if let Some(request_id) = &self.request_id {
536                cancellation_manager.is_cancelled(request_id).await
537            } else {
538                false
539            }
540        } else {
541            false
542        }
543    }
544}
545
546impl Default for Context {
547    fn default() -> Self {
548        Self::new()
549    }
550}
551
552/// Get numeric priority for log level (higher = more urgent)
553fn log_level_priority(level: &LogLevel) -> u8 {
554    match level {
555        LogLevel::Debug => 0,
556        LogLevel::Info => 1,
557        LogLevel::Notice => 2,
558        LogLevel::Warning => 3,
559        LogLevel::Error => 4,
560        LogLevel::Critical => 5,
561        LogLevel::Alert => 6,
562        LogLevel::Emergency => 7,
563    }
564}
565
566/// Structured logger builder for easy configuration
567pub struct ContextLogger {
568    config: LoggerConfig,
569}
570
571impl ContextLogger {
572    pub fn new() -> Self {
573        Self {
574            config: LoggerConfig::default(),
575        }
576    }
577
578    pub fn with_min_level(mut self, level: LogLevel) -> Self {
579        self.config.min_level = level;
580        self
581    }
582
583    pub fn with_notifications(mut self, send_notifications: bool) -> Self {
584        self.config.send_notifications = send_notifications;
585        self
586    }
587
588    pub fn with_structured_output(mut self, structured: bool) -> Self {
589        self.config.structured_output = structured;
590        self
591    }
592
593    pub fn with_max_message_length(mut self, length: usize) -> Self {
594        self.config.max_message_length = length;
595        self
596    }
597
598    pub fn with_timestamps(mut self, include: bool) -> Self {
599        self.config.include_timestamps = include;
600        self
601    }
602
603    pub fn with_logger_name(mut self, name: String) -> Self {
604        self.config.logger_name = Some(name);
605        self
606    }
607
608    pub fn build(self) -> LoggerConfig {
609        self.config
610    }
611}
612
613impl Default for ContextLogger {
614    fn default() -> Self {
615        Self::new()
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[tokio::test]
624    async fn test_context_creation() {
625        let ctx = Context::new()
626            .with_session_id("session-123".to_string())
627            .with_request_id("request-456".to_string())
628            .with_metadata("key".to_string(), serde_json::json!("value"));
629
630        assert_eq!(ctx.session_id(), Some("session-123"));
631        assert_eq!(ctx.request_id(), Some("request-456"));
632        assert_eq!(ctx.get_metadata("key"), Some(&serde_json::json!("value")));
633    }
634
635    #[tokio::test]
636    async fn test_context_logging() {
637        let ctx = Context::new().with_request_id("test-request".to_string());
638
639        // These should not panic
640        ctx.log_info("Test info message").await.unwrap();
641        ctx.log_warn("Test warning message").await.unwrap();
642        ctx.log_error("Test error message").await.unwrap();
643    }
644
645    #[tokio::test]
646    async fn test_context_progress() {
647        let ctx = Context::new();
648
649        // Test progress tracking
650        ctx.progress("Starting operation", 0.0, Some(1.0))
651            .await
652            .unwrap();
653        ctx.progress("Halfway done", 0.5, Some(1.0)).await.unwrap();
654        ctx.progress("Completed", 1.0, Some(1.0)).await.unwrap();
655
656        // Test without total
657        ctx.progress("Indeterminate progress", 0.3, None)
658            .await
659            .unwrap();
660    }
661}