turbomcp_server/
sampling.rs1use crate::{ServerError, ServerResult};
7use turbomcp_core::RequestContext;
8use turbomcp_protocol::types::{
9 CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsResult,
10};
11
12#[async_trait::async_trait]
16pub trait SamplingExt {
17 async fn create_message(
19 &self,
20 request: CreateMessageRequest,
21 ) -> ServerResult<CreateMessageResult>;
22
23 async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult>;
25
26 async fn list_roots(&self) -> ServerResult<ListRootsResult>;
28}
29
30#[async_trait::async_trait]
31impl SamplingExt for RequestContext {
32 async fn create_message(
33 &self,
34 request: CreateMessageRequest,
35 ) -> ServerResult<CreateMessageResult> {
36 let capabilities = self
37 .server_capabilities()
38 .ok_or_else(|| ServerError::Handler {
39 message: "No server capabilities available for sampling requests".to_string(),
40 context: Some("sampling".to_string()),
41 })?;
42
43 let request_json = serde_json::to_value(request).map_err(|e| ServerError::Handler {
44 message: format!("Failed to serialize request: {}", e),
45 context: Some("sampling".to_string()),
46 })?;
47
48 let result_json = capabilities
49 .create_message(request_json)
50 .await
51 .map_err(|e| ServerError::Handler {
52 message: format!("Sampling request failed: {}", e),
53 context: Some("sampling".to_string()),
54 })?;
55
56 serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
57 message: format!("Failed to deserialize response: {}", e),
58 context: Some("sampling".to_string()),
59 })
60 }
61
62 async fn elicit(&self, request: ElicitRequest) -> ServerResult<ElicitResult> {
63 let capabilities = self
64 .server_capabilities()
65 .ok_or_else(|| ServerError::Handler {
66 message: "No server capabilities available for elicitation requests".to_string(),
67 context: Some("elicitation".to_string()),
68 })?;
69
70 let request_json = serde_json::to_value(request).map_err(|e| ServerError::Handler {
71 message: format!("Failed to serialize request: {}", e),
72 context: Some("elicitation".to_string()),
73 })?;
74
75 let result_json =
76 capabilities
77 .elicit(request_json)
78 .await
79 .map_err(|e| ServerError::Handler {
80 message: format!("Elicitation request failed: {}", e),
81 context: Some("elicitation".to_string()),
82 })?;
83
84 serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
85 message: format!("Failed to deserialize response: {}", e),
86 context: Some("elicitation".to_string()),
87 })
88 }
89
90 async fn list_roots(&self) -> ServerResult<ListRootsResult> {
91 let capabilities = self
92 .server_capabilities()
93 .ok_or_else(|| ServerError::Handler {
94 message: "No server capabilities available for roots listing".to_string(),
95 context: Some("roots".to_string()),
96 })?;
97
98 let result_json = capabilities
99 .list_roots()
100 .await
101 .map_err(|e| ServerError::Handler {
102 message: format!("Roots listing failed: {}", e),
103 context: Some("roots".to_string()),
104 })?;
105
106 serde_json::from_value(result_json).map_err(|e| ServerError::Handler {
107 message: format!("Failed to deserialize response: {}", e),
108 context: Some("roots".to_string()),
109 })
110 }
111}