turbomcp_server/
sampling.rs

1//! Server-initiated sampling support for TurboMCP
2//!
3//! This module provides helper functions for tools to make sampling requests
4//! to clients, enabling server-initiated LLM interactions.
5
6use crate::{ServerError, ServerResult};
7use turbomcp_core::RequestContext;
8use turbomcp_protocol::types::{
9    CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsResult,
10};
11
12/// Extension trait for RequestContext to provide sampling capabilities
13///
14/// Note: We use `async-trait` to address the async fn in trait warning
15#[async_trait::async_trait]
16pub trait SamplingExt {
17    /// Send a sampling/createMessage request to the client
18    async fn create_message(
19        &self,
20        request: CreateMessageRequest,
21    ) -> ServerResult<CreateMessageResult>;
22
23    /// Send an elicitation request to the client for user input
24    async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult>;
25
26    /// List client's root capabilities
27    async fn list_roots(&self) -> ServerResult<ListRootsResult>;
28}
29
30#[async_trait::async_trait]
31impl SamplingExt for RequestContext {
32    async fn create_message(
33        &self,
34        request: CreateMessageRequest,
35    ) -> ServerResult<CreateMessageResult> {
36        let capabilities = self
37            .server_capabilities()
38            .ok_or_else(|| ServerError::Handler {
39                message: "No server capabilities available for sampling requests".to_string(),
40                context: Some("sampling".to_string()),
41            })?;
42
43        let request_json = serde_json::to_value(request).map_err(|e| ServerError::Handler {
44            message: format!("Failed to serialize request: {}", e),
45            context: Some("sampling".to_string()),
46        })?;
47
48        let result_json = capabilities
49            .create_message(request_json)
50            .await
51            .map_err(|e| ServerError::Handler {
52                message: format!("Sampling request failed: {}", e),
53                context: Some("sampling".to_string()),
54            })?;
55
56        serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
57            message: format!("Failed to deserialize response: {}", e),
58            context: Some("sampling".to_string()),
59        })
60    }
61
62    async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult> {
63        let capabilities = self
64            .server_capabilities()
65            .ok_or_else(|| ServerError::Handler {
66                message: "No server capabilities available for elicitation requests".to_string(),
67                context: Some("elicitation".to_string()),
68            })?;
69
70        let request_json = serde_json::to_value(request).map_err(|e| ServerError::Handler {
71            message: format!("Failed to serialize request: {}", e),
72            context: Some("elicitation".to_string()),
73        })?;
74
75        let result_json =
76            capabilities
77                .elicit(request_json)
78                .await
79                .map_err(|e| ServerError::Handler {
80                    message: format!("Elicitation request failed: {}", e),
81                    context: Some("elicitation".to_string()),
82                })?;
83
84        serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
85            message: format!("Failed to deserialize response: {}", e),
86            context: Some("elicitation".to_string()),
87        })
88    }
89
90    async fn list_roots(&self) -> ServerResult<ListRootsResult> {
91        let capabilities = self
92            .server_capabilities()
93            .ok_or_else(|| ServerError::Handler {
94                message: "No server capabilities available for roots listing".to_string(),
95                context: Some("roots".to_string()),
96            })?;
97
98        let result_json = capabilities
99            .list_roots()
100            .await
101            .map_err(|e| ServerError::Handler {
102                message: format!("Roots listing failed: {}", e),
103                context: Some("roots".to_string()),
104            })?;
105
106        serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
107            message: format!("Failed to deserialize response: {}", e),
108            context: Some("roots".to_string()),
109        })
110    }
111}