1use async_trait::async_trait;
7use ultrafast_mcp_core::{
8 error::{MCPError, MCPResult},
9 types::{
10 ServerInfo,
11 completion::{CompleteRequest, CompleteResponse},
12 elicitation::{ElicitationRequest, ElicitationResponse},
13 prompts::{GetPromptRequest, GetPromptResponse, ListPromptsRequest, ListPromptsResponse},
14 resources::{
15 ListResourceTemplatesRequest, ListResourceTemplatesResponse, ListResourcesRequest,
16 ListResourcesResponse, ReadResourceRequest, ReadResourceResponse,
17 },
18 sampling::{
19 ApprovalStatus, CostInfo, CreateMessageRequest, CreateMessageResponse, HumanFeedback,
20 IncludeContext, ResourceContextInfo, SamplingContent, SamplingContext, SamplingRequest,
21 SamplingResponse, SamplingRole, ServerContextInfo, StopReason, ToolContextInfo,
22 },
23 tools::{ListToolsRequest, ListToolsResponse, ToolCall, ToolResult},
24 },
25};
26
27#[async_trait]
29pub trait ToolHandler: Send + Sync {
30 async fn handle_tool_call(&self, call: ToolCall) -> MCPResult<ToolResult>;
32
33 async fn list_tools(&self, request: ListToolsRequest) -> MCPResult<ListToolsResponse>;
35}
36
37#[async_trait]
39pub trait ResourceHandler: Send + Sync {
40 async fn read_resource(&self, request: ReadResourceRequest) -> MCPResult<ReadResourceResponse>;
42
43 async fn list_resources(
45 &self,
46 request: ListResourcesRequest,
47 ) -> MCPResult<ListResourcesResponse>;
48
49 async fn list_resource_templates(
51 &self,
52 request: ListResourceTemplatesRequest,
53 ) -> MCPResult<ListResourceTemplatesResponse>;
54
55 async fn validate_resource_access(
59 &self,
60 uri: &str,
61 operation: ultrafast_mcp_core::types::roots::RootOperation,
62 roots: &[ultrafast_mcp_core::types::roots::Root],
63 ) -> MCPResult<()>;
64}
65
66#[async_trait]
68pub trait PromptHandler: Send + Sync {
69 async fn get_prompt(&self, request: GetPromptRequest) -> MCPResult<GetPromptResponse>;
71
72 async fn list_prompts(&self, request: ListPromptsRequest) -> MCPResult<ListPromptsResponse>;
74}
75
76#[async_trait]
78pub trait SamplingHandler: Send + Sync {
79 async fn create_message(
81 &self,
82 request: CreateMessageRequest,
83 ) -> MCPResult<CreateMessageResponse>;
84}
85
86#[async_trait]
88pub trait CompletionHandler: Send + Sync {
89 async fn complete(&self, request: CompleteRequest) -> MCPResult<CompleteResponse>;
91}
92
93#[async_trait]
95pub trait RootsHandler: Send + Sync {
96 async fn list_roots(&self) -> MCPResult<Vec<ultrafast_mcp_core::types::roots::Root>>;
98 async fn set_roots(&self, roots: Vec<ultrafast_mcp_core::types::roots::Root>) -> MCPResult<()> {
100 let _ = roots;
101 Err(MCPError::method_not_found(
102 "Dynamic roots update not implemented".to_string(),
103 ))
104 }
105}
106
107#[async_trait]
109pub trait ElicitationHandler: Send + Sync {
110 async fn handle_elicitation(
112 &self,
113 request: ElicitationRequest,
114 ) -> MCPResult<ElicitationResponse>;
115}
116
117#[async_trait]
119pub trait ResourceSubscriptionHandler: Send + Sync {
120 async fn subscribe(&self, uri: String) -> MCPResult<()>;
122
123 async fn unsubscribe(&self, uri: String) -> MCPResult<()>;
125
126 async fn notify_change(&self, uri: String, content: serde_json::Value) -> MCPResult<()>;
128}
129
130#[async_trait]
132pub trait AdvancedSamplingHandler: Send + Sync {
133 async fn collect_context(
135 &self,
136 include_context: &IncludeContext,
137 request: &SamplingRequest,
138 ) -> MCPResult<Option<SamplingContext>>;
139
140 async fn handle_human_approval(
142 &self,
143 request: &SamplingRequest,
144 response: &SamplingResponse,
145 ) -> MCPResult<ApprovalStatus>;
146
147 async fn process_human_feedback(
149 &self,
150 request: &SamplingRequest,
151 feedback: &HumanFeedback,
152 ) -> MCPResult<SamplingResponse>;
153
154 async fn estimate_cost(&self, request: &SamplingRequest) -> MCPResult<CostInfo>;
156
157 async fn validate_sampling_request(&self, request: &SamplingRequest) -> MCPResult<Vec<String>>;
159}
160
161pub struct DefaultAdvancedSamplingHandler {
163 server_info: ServerInfo,
164 tools: Vec<ToolContextInfo>,
165 resources: Vec<ResourceContextInfo>,
166}
167
168impl DefaultAdvancedSamplingHandler {
169 pub fn new(server_info: ServerInfo) -> Self {
170 Self {
171 server_info,
172 tools: Vec::new(),
173 resources: Vec::new(),
174 }
175 }
176
177 pub fn with_tools(mut self, tools: Vec<ToolContextInfo>) -> Self {
178 self.tools = tools;
179 self
180 }
181
182 pub fn with_resources(mut self, resources: Vec<ResourceContextInfo>) -> Self {
183 self.resources = resources;
184 self
185 }
186}
187
188#[async_trait]
189impl AdvancedSamplingHandler for DefaultAdvancedSamplingHandler {
190 async fn collect_context(
191 &self,
192 include_context: &IncludeContext,
193 _request: &SamplingRequest,
194 ) -> MCPResult<Option<SamplingContext>> {
195 match include_context {
196 IncludeContext::None => Ok(None),
197 IncludeContext::ThisServer => {
198 let server_info = ServerContextInfo {
199 name: self.server_info.name.clone(),
200 version: self.server_info.version.clone(),
201 description: self.server_info.description.clone(),
202 capabilities: vec![
203 "tools".to_string(),
204 "resources".to_string(),
205 "prompts".to_string(),
206 ],
207 };
208
209 Ok(Some(SamplingContext {
210 server_info: Some(server_info),
211 available_tools: Some(self.tools.clone()),
212 available_resources: Some(self.resources.clone()),
213 conversation_history: None,
214 user_preferences: None,
215 }))
216 }
217 IncludeContext::AllServers => {
218 let server_info = ServerContextInfo {
220 name: self.server_info.name.clone(),
221 version: self.server_info.version.clone(),
222 description: self.server_info.description.clone(),
223 capabilities: vec![
224 "tools".to_string(),
225 "resources".to_string(),
226 "prompts".to_string(),
227 ],
228 };
229
230 Ok(Some(SamplingContext {
231 server_info: Some(server_info),
232 available_tools: Some(self.tools.clone()),
233 available_resources: Some(self.resources.clone()),
234 conversation_history: None,
235 user_preferences: None,
236 }))
237 }
238 }
239 }
240
241 async fn handle_human_approval(
242 &self,
243 request: &SamplingRequest,
244 _response: &SamplingResponse,
245 ) -> MCPResult<ApprovalStatus> {
246 if let Some(hitl) = &request.human_in_the_loop {
248 if hitl.require_prompt_approval.unwrap_or(false) {
249 return Ok(ApprovalStatus::Pending);
251 }
252 if hitl.require_completion_approval.unwrap_or(false) {
253 return Ok(ApprovalStatus::Pending);
255 }
256 }
257
258 Ok(ApprovalStatus::Approved)
259 }
260
261 async fn process_human_feedback(
262 &self,
263 _request: &SamplingRequest,
264 feedback: &HumanFeedback,
265 ) -> MCPResult<SamplingResponse> {
266 Ok(SamplingResponse {
269 role: SamplingRole::Assistant,
270 content: SamplingContent::Text {
271 text: format!(
272 "Response modified based on feedback: {}",
273 feedback.reason.as_deref().unwrap_or("No reason provided")
274 ),
275 },
276 model: Some("human-modified".to_string()),
277 stop_reason: Some(StopReason::EndTurn),
278 approval_status: Some(ApprovalStatus::Modified),
279 request_id: None,
280 processing_time_ms: None,
281 cost_info: None,
282 included_context: None,
283 human_feedback: Some(feedback.clone()),
284 warnings: None,
285 })
286 }
287
288 async fn estimate_cost(&self, request: &SamplingRequest) -> MCPResult<CostInfo> {
289 let input_tokens = request
290 .estimate_input_tokens()
291 .map_err(MCPError::invalid_request)?;
292 let output_tokens = request.max_tokens.unwrap_or(1000);
293
294 let input_cost_cents = (input_tokens as f64 / 1000.0) * 0.2; let output_cost_cents = (output_tokens as f64 / 1000.0) * 1.2; let total_cost_cents = input_cost_cents + output_cost_cents;
298
299 Ok(CostInfo {
300 total_cost_cents,
301 input_cost_cents,
302 output_cost_cents,
303 input_tokens,
304 output_tokens,
305 model: "gpt-4".to_string(),
306 })
307 }
308
309 async fn validate_sampling_request(&self, request: &SamplingRequest) -> MCPResult<Vec<String>> {
310 let mut warnings = Vec::new();
311
312 if request.messages.is_empty() {
314 warnings.push("No messages provided for sampling".to_string());
315 }
316
317 if let Some(temp) = request.temperature {
318 if temp > 1.0 {
319 warnings.push(
320 "Temperature is very high, may produce unpredictable results".to_string(),
321 );
322 }
323 }
324
325 if let Some(max_tokens) = request.max_tokens {
326 if max_tokens > 10000 {
327 warnings.push("Very high max_tokens may be expensive".to_string());
328 }
329 }
330
331 if request.requires_human_approval() {
332 warnings.push("Human approval required - response may be delayed".to_string());
333 }
334
335 if request.requires_image_modality() {
336 warnings.push("Image modality detected - ensure model supports vision".to_string());
337 }
338
339 Ok(warnings)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use serde_json::json;
347
348 struct MockToolHandler;
350
351 #[async_trait]
352 impl ToolHandler for MockToolHandler {
353 async fn handle_tool_call(&self, _call: ToolCall) -> MCPResult<ToolResult> {
354 Ok(ToolResult {
355 content: vec![ultrafast_mcp_core::types::tools::ToolContent::text(
356 "mock result".to_string(),
357 )],
358 is_error: None,
359 })
360 }
361
362 async fn list_tools(&self, _request: ListToolsRequest) -> MCPResult<ListToolsResponse> {
363 Ok(ListToolsResponse {
364 tools: vec![],
365 next_cursor: None,
366 })
367 }
368 }
369
370 struct MockResourceHandler;
371
372 #[async_trait]
373 impl ResourceHandler for MockResourceHandler {
374 async fn read_resource(
375 &self,
376 _request: ReadResourceRequest,
377 ) -> MCPResult<ReadResourceResponse> {
378 Ok(ReadResourceResponse {
379 contents: vec![ultrafast_mcp_core::types::resources::ResourceContent::text(
380 "mock://resource".to_string(),
381 "mock resource".to_string(),
382 )],
383 })
384 }
385
386 async fn list_resources(
387 &self,
388 _request: ListResourcesRequest,
389 ) -> MCPResult<ListResourcesResponse> {
390 Ok(ListResourcesResponse {
391 resources: vec![],
392 next_cursor: None,
393 })
394 }
395
396 async fn list_resource_templates(
397 &self,
398 _request: ListResourceTemplatesRequest,
399 ) -> MCPResult<ListResourceTemplatesResponse> {
400 Ok(ListResourceTemplatesResponse {
401 resource_templates: vec![],
402 next_cursor: None,
403 })
404 }
405
406 async fn validate_resource_access(
407 &self,
408 uri: &str,
409 operation: ultrafast_mcp_core::types::roots::RootOperation,
410 roots: &[ultrafast_mcp_core::types::roots::Root],
411 ) -> MCPResult<()> {
412 if roots.is_empty() {
413 return Ok(());
414 }
415 for root in roots {
416 if uri.starts_with(&root.uri) {
417 if root.uri.starts_with("file://") && uri.starts_with("file://") {
418 let validator =
419 ultrafast_mcp_core::types::roots::RootSecurityValidator::default();
420 return validator
421 .validate_access(root, uri, operation)
422 .map_err(|e| {
423 MCPError::Resource(
424 ultrafast_mcp_core::error::ResourceError::AccessDenied(
425 format!("Root validation failed: {e}"),
426 ),
427 )
428 });
429 } else {
430 return Ok(());
431 }
432 }
433 }
434 Ok(())
435 }
436 }
437
438 #[tokio::test]
439 async fn test_tool_handler() {
440 let handler = MockToolHandler;
441 let call = ToolCall {
442 name: "test".to_string(),
443 arguments: Some(json!({"test": "data"})),
444 };
445
446 let result = handler.handle_tool_call(call).await.unwrap();
447 assert_eq!(result.content.len(), 1);
448 }
449
450 #[tokio::test]
451 async fn test_resource_handler() {
452 let handler = MockResourceHandler;
453 let request = ReadResourceRequest {
454 uri: "test://resource".to_string(),
455 };
456
457 let result = handler.read_resource(request).await.unwrap();
458 assert_eq!(result.contents.len(), 1);
459 }
460
461 #[tokio::test]
462 async fn test_root_validation_informational() {
463 let handler = MockResourceHandler;
464
465 let result = handler
467 .validate_resource_access(
468 "test://static/resource/1",
469 ultrafast_mcp_core::types::roots::RootOperation::Read,
470 &[],
471 )
472 .await;
473 assert!(
474 result.is_ok(),
475 "Should allow access when no roots are configured"
476 );
477
478 let roots = vec![ultrafast_mcp_core::types::roots::Root {
480 uri: "file:///tmp".to_string(),
481 name: Some("Test Root".to_string()),
482 security: None,
483 }];
484
485 let result = handler
486 .validate_resource_access(
487 "test://static/resource/1",
488 ultrafast_mcp_core::types::roots::RootOperation::Read,
489 &roots,
490 )
491 .await;
492 assert!(
493 result.is_ok(),
494 "Should allow access when no matching root is found (informational nature)"
495 );
496
497 let roots = vec![ultrafast_mcp_core::types::roots::Root {
499 uri: "file:///tmp/static/".to_string(),
500 name: Some("Test Root".to_string()),
501 security: Some(ultrafast_mcp_core::types::roots::RootSecurityConfig {
502 allow_read: true,
503 ..Default::default()
504 }),
505 }];
506
507 let result = handler
508 .validate_resource_access(
509 "file:///tmp/static/resource/1",
510 ultrafast_mcp_core::types::roots::RootOperation::Read,
511 &roots,
512 )
513 .await;
514 assert!(
515 result.is_ok(),
516 "Should allow access when matching root allows it"
517 );
518 }
519}