ultrafast_mcp_server/
handlers.rs

1//! Handler traits for UltraFastServer
2//!
3//! This module defines the trait interfaces that server implementations must implement
4//! to handle different types of MCP requests.
5
6use async_trait::async_trait;
7use ultrafast_mcp_core::{
8    error::{MCPError, MCPResult},
9    types::{
10        ServerInfo,
11        completion::{CompleteRequest, CompleteResponse},
12        elicitation::{ElicitationRequest, ElicitationResponse},
13        prompts::{GetPromptRequest, GetPromptResponse, ListPromptsRequest, ListPromptsResponse},
14        resources::{
15            ListResourceTemplatesRequest, ListResourceTemplatesResponse, ListResourcesRequest,
16            ListResourcesResponse, ReadResourceRequest, ReadResourceResponse,
17        },
18        sampling::{
19            ApprovalStatus, CostInfo, CreateMessageRequest, CreateMessageResponse, HumanFeedback,
20            IncludeContext, ResourceContextInfo, SamplingContent, SamplingContext, SamplingRequest,
21            SamplingResponse, SamplingRole, ServerContextInfo, StopReason, ToolContextInfo,
22        },
23        tools::{ListToolsRequest, ListToolsResponse, ToolCall, ToolResult},
24    },
25};
26
27/// Tool handler trait for processing tool calls
28#[async_trait]
29pub trait ToolHandler: Send + Sync {
30    /// Handle a tool call request
31    async fn handle_tool_call(&self, call: ToolCall) -> MCPResult<ToolResult>;
32
33    /// List available tools
34    async fn list_tools(&self, request: ListToolsRequest) -> MCPResult<ListToolsResponse>;
35}
36
37/// Resource handler trait for managing resources
38#[async_trait]
39pub trait ResourceHandler: Send + Sync {
40    /// Read a resource
41    async fn read_resource(&self, request: ReadResourceRequest) -> MCPResult<ReadResourceResponse>;
42
43    /// List available resources
44    async fn list_resources(
45        &self,
46        request: ListResourcesRequest,
47    ) -> MCPResult<ListResourcesResponse>;
48
49    /// List resource templates
50    async fn list_resource_templates(
51        &self,
52        request: ListResourceTemplatesRequest,
53    ) -> MCPResult<ListResourceTemplatesResponse>;
54
55    /// Validate resource access against roots (optional implementation)
56    /// According to MCP specification, roots are informational and not strictly enforcing.
57    /// This method provides advisory validation but does not block access if no root matches.
58    async fn validate_resource_access(
59        &self,
60        uri: &str,
61        operation: ultrafast_mcp_core::types::roots::RootOperation,
62        roots: &[ultrafast_mcp_core::types::roots::Root],
63    ) -> MCPResult<()>;
64}
65
66/// Prompt handler trait for managing prompts
67#[async_trait]
68pub trait PromptHandler: Send + Sync {
69    /// Get a specific prompt
70    async fn get_prompt(&self, request: GetPromptRequest) -> MCPResult<GetPromptResponse>;
71
72    /// List available prompts
73    async fn list_prompts(&self, request: ListPromptsRequest) -> MCPResult<ListPromptsResponse>;
74}
75
76/// Sampling handler trait for LLM completions
77#[async_trait]
78pub trait SamplingHandler: Send + Sync {
79    /// Create a message using sampling
80    async fn create_message(
81        &self,
82        request: CreateMessageRequest,
83    ) -> MCPResult<CreateMessageResponse>;
84}
85
86/// Completion handler trait for autocompletion
87#[async_trait]
88pub trait CompletionHandler: Send + Sync {
89    /// Complete a request
90    async fn complete(&self, request: CompleteRequest) -> MCPResult<CompleteResponse>;
91}
92
93/// Roots handler trait for filesystem boundary management
94#[async_trait]
95pub trait RootsHandler: Send + Sync {
96    /// List available roots
97    async fn list_roots(&self) -> MCPResult<Vec<ultrafast_mcp_core::types::roots::Root>>;
98    /// Set/update the list of roots
99    async fn set_roots(&self, roots: Vec<ultrafast_mcp_core::types::roots::Root>) -> MCPResult<()> {
100        let _ = roots;
101        Err(MCPError::method_not_found(
102            "Dynamic roots update not implemented".to_string(),
103        ))
104    }
105}
106
107/// Elicitation handler trait for user input collection
108#[async_trait]
109pub trait ElicitationHandler: Send + Sync {
110    /// Handle an elicitation request
111    async fn handle_elicitation(
112        &self,
113        request: ElicitationRequest,
114    ) -> MCPResult<ElicitationResponse>;
115}
116
117/// Resource subscription handler trait
118#[async_trait]
119pub trait ResourceSubscriptionHandler: Send + Sync {
120    /// Subscribe to a resource
121    async fn subscribe(&self, uri: String) -> MCPResult<()>;
122
123    /// Unsubscribe from a resource
124    async fn unsubscribe(&self, uri: String) -> MCPResult<()>;
125
126    /// Notify about a resource change
127    async fn notify_change(&self, uri: String, content: serde_json::Value) -> MCPResult<()>;
128}
129
130/// Handler for advanced sampling features including context collection and human-in-the-loop
131#[async_trait]
132pub trait AdvancedSamplingHandler: Send + Sync {
133    /// Collect context information for sampling requests
134    async fn collect_context(
135        &self,
136        include_context: &IncludeContext,
137        request: &SamplingRequest,
138    ) -> MCPResult<Option<SamplingContext>>;
139
140    /// Handle human-in-the-loop approval workflow
141    async fn handle_human_approval(
142        &self,
143        request: &SamplingRequest,
144        response: &SamplingResponse,
145    ) -> MCPResult<ApprovalStatus>;
146
147    /// Process human feedback and modifications
148    async fn process_human_feedback(
149        &self,
150        request: &SamplingRequest,
151        feedback: &HumanFeedback,
152    ) -> MCPResult<SamplingResponse>;
153
154    /// Estimate cost for sampling request
155    async fn estimate_cost(&self, request: &SamplingRequest) -> MCPResult<CostInfo>;
156
157    /// Validate sampling request with advanced checks
158    async fn validate_sampling_request(&self, request: &SamplingRequest) -> MCPResult<Vec<String>>;
159}
160
161/// Default implementation of advanced sampling features
162pub struct DefaultAdvancedSamplingHandler {
163    server_info: ServerInfo,
164    tools: Vec<ToolContextInfo>,
165    resources: Vec<ResourceContextInfo>,
166}
167
168impl DefaultAdvancedSamplingHandler {
169    pub fn new(server_info: ServerInfo) -> Self {
170        Self {
171            server_info,
172            tools: Vec::new(),
173            resources: Vec::new(),
174        }
175    }
176
177    pub fn with_tools(mut self, tools: Vec<ToolContextInfo>) -> Self {
178        self.tools = tools;
179        self
180    }
181
182    pub fn with_resources(mut self, resources: Vec<ResourceContextInfo>) -> Self {
183        self.resources = resources;
184        self
185    }
186}
187
188#[async_trait]
189impl AdvancedSamplingHandler for DefaultAdvancedSamplingHandler {
190    async fn collect_context(
191        &self,
192        include_context: &IncludeContext,
193        _request: &SamplingRequest,
194    ) -> MCPResult<Option<SamplingContext>> {
195        match include_context {
196            IncludeContext::None => Ok(None),
197            IncludeContext::ThisServer => {
198                let server_info = ServerContextInfo {
199                    name: self.server_info.name.clone(),
200                    version: self.server_info.version.clone(),
201                    description: self.server_info.description.clone(),
202                    capabilities: vec![
203                        "tools".to_string(),
204                        "resources".to_string(),
205                        "prompts".to_string(),
206                    ],
207                };
208
209                Ok(Some(SamplingContext {
210                    server_info: Some(server_info),
211                    available_tools: Some(self.tools.clone()),
212                    available_resources: Some(self.resources.clone()),
213                    conversation_history: None,
214                    user_preferences: None,
215                }))
216            }
217            IncludeContext::AllServers => {
218                // In a real implementation, this would collect context from all connected servers
219                let server_info = ServerContextInfo {
220                    name: self.server_info.name.clone(),
221                    version: self.server_info.version.clone(),
222                    description: self.server_info.description.clone(),
223                    capabilities: vec![
224                        "tools".to_string(),
225                        "resources".to_string(),
226                        "prompts".to_string(),
227                    ],
228                };
229
230                Ok(Some(SamplingContext {
231                    server_info: Some(server_info),
232                    available_tools: Some(self.tools.clone()),
233                    available_resources: Some(self.resources.clone()),
234                    conversation_history: None,
235                    user_preferences: None,
236                }))
237            }
238        }
239    }
240
241    async fn handle_human_approval(
242        &self,
243        request: &SamplingRequest,
244        _response: &SamplingResponse,
245    ) -> MCPResult<ApprovalStatus> {
246        // Check if human approval is required
247        if let Some(hitl) = &request.human_in_the_loop {
248            if hitl.require_prompt_approval.unwrap_or(false) {
249                // In a real implementation, this would trigger a UI prompt for approval
250                return Ok(ApprovalStatus::Pending);
251            }
252            if hitl.require_completion_approval.unwrap_or(false) {
253                // In a real implementation, this would trigger a UI prompt for approval
254                return Ok(ApprovalStatus::Pending);
255            }
256        }
257
258        Ok(ApprovalStatus::Approved)
259    }
260
261    async fn process_human_feedback(
262        &self,
263        _request: &SamplingRequest,
264        feedback: &HumanFeedback,
265    ) -> MCPResult<SamplingResponse> {
266        // In a real implementation, this would process the human feedback
267        // and potentially modify the response based on the feedback
268        Ok(SamplingResponse {
269            role: SamplingRole::Assistant,
270            content: SamplingContent::Text {
271                text: format!(
272                    "Response modified based on feedback: {}",
273                    feedback.reason.as_deref().unwrap_or("No reason provided")
274                ),
275            },
276            model: Some("human-modified".to_string()),
277            stop_reason: Some(StopReason::EndTurn),
278            approval_status: Some(ApprovalStatus::Modified),
279            request_id: None,
280            processing_time_ms: None,
281            cost_info: None,
282            included_context: None,
283            human_feedback: Some(feedback.clone()),
284            warnings: None,
285        })
286    }
287
288    async fn estimate_cost(&self, request: &SamplingRequest) -> MCPResult<CostInfo> {
289        let input_tokens = request
290            .estimate_input_tokens()
291            .map_err(MCPError::invalid_request)?;
292        let output_tokens = request.max_tokens.unwrap_or(1000);
293
294        // Simple cost estimation: $0.002 per 1K input tokens, $0.012 per 1K output tokens
295        let input_cost_cents = (input_tokens as f64 / 1000.0) * 0.2; // $0.002 * 100 = 0.2 cents
296        let output_cost_cents = (output_tokens as f64 / 1000.0) * 1.2; // $0.012 * 100 = 1.2 cents
297        let total_cost_cents = input_cost_cents + output_cost_cents;
298
299        Ok(CostInfo {
300            total_cost_cents,
301            input_cost_cents,
302            output_cost_cents,
303            input_tokens,
304            output_tokens,
305            model: "gpt-4".to_string(),
306        })
307    }
308
309    async fn validate_sampling_request(&self, request: &SamplingRequest) -> MCPResult<Vec<String>> {
310        let mut warnings = Vec::new();
311
312        // Check for potential issues
313        if request.messages.is_empty() {
314            warnings.push("No messages provided for sampling".to_string());
315        }
316
317        if let Some(temp) = request.temperature {
318            if temp > 1.0 {
319                warnings.push(
320                    "Temperature is very high, may produce unpredictable results".to_string(),
321                );
322            }
323        }
324
325        if let Some(max_tokens) = request.max_tokens {
326            if max_tokens > 10000 {
327                warnings.push("Very high max_tokens may be expensive".to_string());
328            }
329        }
330
331        if request.requires_human_approval() {
332            warnings.push("Human approval required - response may be delayed".to_string());
333        }
334
335        if request.requires_image_modality() {
336            warnings.push("Image modality detected - ensure model supports vision".to_string());
337        }
338
339        Ok(warnings)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use serde_json::json;
347
348    // Mock implementations for testing
349    struct MockToolHandler;
350
351    #[async_trait]
352    impl ToolHandler for MockToolHandler {
353        async fn handle_tool_call(&self, _call: ToolCall) -> MCPResult<ToolResult> {
354            Ok(ToolResult {
355                content: vec![ultrafast_mcp_core::types::tools::ToolContent::text(
356                    "mock result".to_string(),
357                )],
358                is_error: None,
359            })
360        }
361
362        async fn list_tools(&self, _request: ListToolsRequest) -> MCPResult<ListToolsResponse> {
363            Ok(ListToolsResponse {
364                tools: vec![],
365                next_cursor: None,
366            })
367        }
368    }
369
370    struct MockResourceHandler;
371
372    #[async_trait]
373    impl ResourceHandler for MockResourceHandler {
374        async fn read_resource(
375            &self,
376            _request: ReadResourceRequest,
377        ) -> MCPResult<ReadResourceResponse> {
378            Ok(ReadResourceResponse {
379                contents: vec![ultrafast_mcp_core::types::resources::ResourceContent::text(
380                    "mock://resource".to_string(),
381                    "mock resource".to_string(),
382                )],
383            })
384        }
385
386        async fn list_resources(
387            &self,
388            _request: ListResourcesRequest,
389        ) -> MCPResult<ListResourcesResponse> {
390            Ok(ListResourcesResponse {
391                resources: vec![],
392                next_cursor: None,
393            })
394        }
395
396        async fn list_resource_templates(
397            &self,
398            _request: ListResourceTemplatesRequest,
399        ) -> MCPResult<ListResourceTemplatesResponse> {
400            Ok(ListResourceTemplatesResponse {
401                resource_templates: vec![],
402                next_cursor: None,
403            })
404        }
405
406        async fn validate_resource_access(
407            &self,
408            uri: &str,
409            operation: ultrafast_mcp_core::types::roots::RootOperation,
410            roots: &[ultrafast_mcp_core::types::roots::Root],
411        ) -> MCPResult<()> {
412            if roots.is_empty() {
413                return Ok(());
414            }
415            for root in roots {
416                if uri.starts_with(&root.uri) {
417                    if root.uri.starts_with("file://") && uri.starts_with("file://") {
418                        let validator =
419                            ultrafast_mcp_core::types::roots::RootSecurityValidator::default();
420                        return validator
421                            .validate_access(root, uri, operation)
422                            .map_err(|e| {
423                                MCPError::Resource(
424                                    ultrafast_mcp_core::error::ResourceError::AccessDenied(
425                                        format!("Root validation failed: {e}"),
426                                    ),
427                                )
428                            });
429                    } else {
430                        return Ok(());
431                    }
432                }
433            }
434            Ok(())
435        }
436    }
437
438    #[tokio::test]
439    async fn test_tool_handler() {
440        let handler = MockToolHandler;
441        let call = ToolCall {
442            name: "test".to_string(),
443            arguments: Some(json!({"test": "data"})),
444        };
445
446        let result = handler.handle_tool_call(call).await.unwrap();
447        assert_eq!(result.content.len(), 1);
448    }
449
450    #[tokio::test]
451    async fn test_resource_handler() {
452        let handler = MockResourceHandler;
453        let request = ReadResourceRequest {
454            uri: "test://resource".to_string(),
455        };
456
457        let result = handler.read_resource(request).await.unwrap();
458        assert_eq!(result.contents.len(), 1);
459    }
460
461    #[tokio::test]
462    async fn test_root_validation_informational() {
463        let handler = MockResourceHandler;
464
465        // Test with no roots configured - should allow access
466        let result = handler
467            .validate_resource_access(
468                "test://static/resource/1",
469                ultrafast_mcp_core::types::roots::RootOperation::Read,
470                &[],
471            )
472            .await;
473        assert!(
474            result.is_ok(),
475            "Should allow access when no roots are configured"
476        );
477
478        // Test with roots configured but no matching root - should allow access (informational)
479        let roots = vec![ultrafast_mcp_core::types::roots::Root {
480            uri: "file:///tmp".to_string(),
481            name: Some("Test Root".to_string()),
482            security: None,
483        }];
484
485        let result = handler
486            .validate_resource_access(
487                "test://static/resource/1",
488                ultrafast_mcp_core::types::roots::RootOperation::Read,
489                &roots,
490            )
491            .await;
492        assert!(
493            result.is_ok(),
494            "Should allow access when no matching root is found (informational nature)"
495        );
496
497        // Test with matching root - use file URI so validator logic is exercised
498        let roots = vec![ultrafast_mcp_core::types::roots::Root {
499            uri: "file:///tmp/static/".to_string(),
500            name: Some("Test Root".to_string()),
501            security: Some(ultrafast_mcp_core::types::roots::RootSecurityConfig {
502                allow_read: true,
503                ..Default::default()
504            }),
505        }];
506
507        let result = handler
508            .validate_resource_access(
509                "file:///tmp/static/resource/1",
510                ultrafast_mcp_core::types::roots::RootOperation::Read,
511                &roots,
512            )
513            .await;
514        assert!(
515            result.is_ok(),
516            "Should allow access when matching root allows it"
517        );
518    }
519}