turul_mcp_server/
sampling.rs

1//! MCP Sampling Trait
2//!
3//! This module defines the high-level trait for implementing MCP sampling.
4
5use async_trait::async_trait;
6use turul_mcp_builders::prelude::*;
7use turul_mcp_protocol::{
8    McpResult,
9    sampling::{CreateMessageRequest, CreateMessageResult},
10};
11
12/// High-level trait for implementing MCP sampling
13///
14/// McpSampling extends SamplingDefinition with execution capabilities.
15/// All metadata is provided by the SamplingDefinition trait, ensuring
16/// consistency between concrete Sampling structs and dynamic implementations.
17#[async_trait]
18pub trait McpSampling: SamplingDefinition + Send + Sync {
19    /// Create a message using the sampling model (per MCP spec)
20    ///
21    /// This method processes the sampling/createMessage request and returns
22    /// the generated message response.
23    async fn sample(&self, request: CreateMessageRequest) -> McpResult<CreateMessageResult>;
24
25    /// Optional: Check if this sampling handler can handle the given request
26    ///
27    /// This allows for conditional sampling based on model preferences,
28    /// context size, or other factors.
29    fn can_handle(&self, _request: &CreateMessageRequest) -> bool {
30        true
31    }
32
33    /// Optional: Get sampling priority for request routing
34    ///
35    /// Higher priority handlers are tried first when multiple handlers
36    /// can handle the same request.
37    fn priority(&self) -> u32 {
38        0
39    }
40
41    /// Optional: Validate the sampling request
42    ///
43    /// This method can perform additional validation beyond basic parameter checks.
44    async fn validate_request(&self, request: &CreateMessageRequest) -> McpResult<()> {
45        // Basic validation - ensure max_tokens is reasonable
46        if request.params.max_tokens == 0 {
47            return Err(turul_mcp_protocol::McpError::validation(
48                "max_tokens must be greater than 0",
49            ));
50        }
51        if request.params.max_tokens > 1000000 {
52            return Err(turul_mcp_protocol::McpError::validation(
53                "max_tokens exceeds maximum allowed value",
54            ));
55        }
56        Ok(())
57    }
58}
59
60/// Convert an McpSampling trait object to CreateMessageParams
61///
62/// This is a convenience function for converting sampling definitions
63/// to protocol parameters.
64pub fn sampling_to_params(
65    sampling: &dyn McpSampling,
66) -> turul_mcp_protocol::sampling::CreateMessageParams {
67    sampling.to_create_params()
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use turul_mcp_protocol::sampling::SamplingMessage;
74      // HasSamplingConfig, HasSamplingContext, etc.
75
76    struct TestSampling {
77        messages: Vec<SamplingMessage>,
78        max_tokens: u32,
79        temperature: Option<f64>,
80    }
81
82    // Implement fine-grained traits (MCP spec compliant)
83    impl HasSamplingConfig for TestSampling {
84        fn max_tokens(&self) -> u32 {
85            self.max_tokens
86        }
87
88        fn temperature(&self) -> Option<f64> {
89            self.temperature
90        }
91    }
92
93    impl HasSamplingContext for TestSampling {
94        fn messages(&self) -> &[SamplingMessage] {
95            &self.messages
96        }
97    }
98
99    impl HasModelPreferences for TestSampling {}
100
101    // SamplingDefinition automatically implemented via blanket impl!
102
103    #[async_trait]
104    impl McpSampling for TestSampling {
105        async fn sample(&self, _request: CreateMessageRequest) -> McpResult<CreateMessageResult> {
106            // Simulate message generation
107            let response_message = SamplingMessage {
108                role: turul_mcp_protocol::sampling::Role::Assistant,
109                content: turul_mcp_protocol::prompts::ContentBlock::Text {
110                    text: "Generated response".to_string(),
111                    annotations: None,
112                    meta: None,
113                },
114            };
115
116            Ok(CreateMessageResult::new(response_message, "test-model"))
117        }
118    }
119
120    #[test]
121    fn test_sampling_trait() {
122        let sampling = TestSampling {
123            messages: vec![],
124            max_tokens: 100,
125            temperature: Some(0.7),
126        };
127
128        assert_eq!(sampling.max_tokens(), 100);
129        assert_eq!(sampling.temperature(), Some(0.7));
130    }
131
132    #[tokio::test]
133    async fn test_sampling_validation() {
134        let sampling = TestSampling {
135            messages: vec![],
136            max_tokens: 0,
137            temperature: None,
138        };
139
140        let params = sampling.to_create_params();
141        let request = CreateMessageRequest {
142            method: "sampling/createMessage".to_string(),
143            params,
144        };
145
146        let result = sampling.validate_request(&request).await;
147        assert!(result.is_err());
148    }
149}