turul_mcp_server/
sampling.rs1use async_trait::async_trait;
6use turul_mcp_builders::prelude::*;
7use turul_mcp_protocol::{
8 McpResult,
9 sampling::{CreateMessageRequest, CreateMessageResult},
10};
11
12#[async_trait]
18pub trait McpSampling: SamplingDefinition + Send + Sync {
19 async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult>;
24
25 fn can_handle(&self, _request: &CreateMessageRequest) -> bool {
30 true
31 }
32
33 fn priority(&self) -> u32 {
38 0
39 }
40
41 async fn validate_request(&self, request: &CreateMessageRequest) -> McpResult<()> {
45 if request.params.max_tokens == 0 {
47 return Err(turul_mcp_protocol::McpError::validation(
48 "max_tokens must be greater than 0",
49 ));
50 }
51 if request.params.max_tokens > 1000000 {
52 return Err(turul_mcp_protocol::McpError::validation(
53 "max_tokens exceeds maximum allowed value",
54 ));
55 }
56 Ok(())
57 }
58}
59
60pub fn sampling_to_params(
65 sampling: &dyn McpSampling,
66) -> turul_mcp_protocol::sampling::CreateMessageParams {
67 sampling.to_create_params()
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use turul_mcp_protocol::sampling::SamplingMessage;
74 struct TestSampling {
77 messages: Vec<SamplingMessage>,
78 max_tokens: u32,
79 temperature: Option<f64>,
80 }
81
82 impl HasSamplingConfig for TestSampling {
84 fn max_tokens(&self) -> u32 {
85 self.max_tokens
86 }
87
88 fn temperature(&self) -> Option<f64> {
89 self.temperature
90 }
91 }
92
93 impl HasSamplingContext for TestSampling {
94 fn messages(&self) -> &[SamplingMessage] {
95 &self.messages
96 }
97 }
98
99 impl HasModelPreferences for TestSampling {}
100
101 #[async_trait]
104 impl McpSampling for TestSampling {
105 async fn sample(&self, _request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
106 let response_message = SamplingMessage {
108 role: turul_mcp_protocol::sampling::Role::Assistant,
109 content: turul_mcp_protocol::prompts::ContentBlock::Text {
110 text: "Generated response".to_string(),
111 annotations: None,
112 meta: None,
113 },
114 };
115
116 Ok(CreateMessageResult::new(response_message, "test-model"))
117 }
118 }
119
120 #[test]
121 fn test_sampling_trait() {
122 let sampling = TestSampling {
123 messages: vec![],
124 max_tokens: 100,
125 temperature: Some(0.7),
126 };
127
128 assert_eq!(sampling.max_tokens(), 100);
129 assert_eq!(sampling.temperature(), Some(0.7));
130 }
131
132 #[tokio::test]
133 async fn test_sampling_validation() {
134 let sampling = TestSampling {
135 messages: vec![],
136 max_tokens: 0,
137 temperature: None,
138 };
139
140 let params = sampling.to_create_params();
141 let request = CreateMessageRequest {
142 method: "sampling/createMessage".to_string(),
143 params,
144 };
145
146 let result = sampling.validate_request(&request).await;
147 assert!(result.is_err());
148 }
149}