turul_mcp_server/
sampling.rs1use async_trait::async_trait;
6use turul_mcp_builders::prelude::*;
7use turul_mcp_protocol::{
8    McpResult,
9    sampling::{CreateMessageRequest, CreateMessageResult},
10};
11
12#[async_trait]
18pub trait McpSampling: SamplingDefinition + Send + Sync {
19    async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult>;
24
25    fn can_handle(&self, _request: &CreateMessageRequest) -> bool {
30        true
31    }
32
33    fn priority(&self) -> u32 {
38        0
39    }
40
41    async fn validate_request(&self, request: &CreateMessageRequest) -> McpResult<()> {
45        if request.params.max_tokens == 0 {
47            return Err(turul_mcp_protocol::McpError::validation(
48                "max_tokens must be greater than 0",
49            ));
50        }
51        if request.params.max_tokens > 1000000 {
52            return Err(turul_mcp_protocol::McpError::validation(
53                "max_tokens exceeds maximum allowed value",
54            ));
55        }
56        Ok(())
57    }
58}
59
60pub fn sampling_to_params(
65    sampling: &dyn McpSampling,
66) -> turul_mcp_protocol::sampling::CreateMessageParams {
67    sampling.to_create_params()
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use turul_mcp_protocol::sampling::SamplingMessage;
74      struct TestSampling {
77        messages: Vec<SamplingMessage>,
78        max_tokens: u32,
79        temperature: Option<f64>,
80    }
81
82    impl HasSamplingConfig for TestSampling {
84        fn max_tokens(&self) -> u32 {
85            self.max_tokens
86        }
87
88        fn temperature(&self) -> Option<f64> {
89            self.temperature
90        }
91    }
92
93    impl HasSamplingContext for TestSampling {
94        fn messages(&self) -> &[SamplingMessage] {
95            &self.messages
96        }
97    }
98
99    impl HasModelPreferences for TestSampling {}
100
101    #[async_trait]
104    impl McpSampling for TestSampling {
105        async fn sample(&self, _request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
106            let response_message = SamplingMessage {
108                role: turul_mcp_protocol::sampling::Role::Assistant,
109                content: turul_mcp_protocol::prompts::ContentBlock::Text {
110                    text: "Generated response".to_string(),
111                    annotations: None,
112                    meta: None,
113                },
114            };
115
116            Ok(CreateMessageResult::new(response_message, "test-model"))
117        }
118    }
119
120    #[test]
121    fn test_sampling_trait() {
122        let sampling = TestSampling {
123            messages: vec![],
124            max_tokens: 100,
125            temperature: Some(0.7),
126        };
127
128        assert_eq!(sampling.max_tokens(), 100);
129        assert_eq!(sampling.temperature(), Some(0.7));
130    }
131
132    #[tokio::test]
133    async fn test_sampling_validation() {
134        let sampling = TestSampling {
135            messages: vec![],
136            max_tokens: 0,
137            temperature: None,
138        };
139
140        let params = sampling.to_create_params();
141        let request = CreateMessageRequest {
142            method: "sampling/createMessage".to_string(),
143            params,
144        };
145
146        let result = sampling.validate_request(&request).await;
147        assert!(result.is_err());
148    }
149}