Skip to main content

turul_mcp_protocol_2025_11_25/
sampling.rs

1//! MCP Sampling Protocol Types
2//!
3//! This module defines types for sampling requests in MCP.
4
5use crate::content::ContentBlock;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9/// Sampling request parameters
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct SamplingRequest {
13    /// The sampling method to use
14    pub method: String,
15    /// Parameters for the sampling method
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub params: Option<Value>,
18}
19
20/// Sampling response
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct SamplingResult {
24    /// The sampled result
25    pub result: Value,
26}
27
28impl SamplingResult {
29    pub fn new(result: Value) -> Self {
30        Self { result }
31    }
32}
33
34/// Role enum for messages (per MCP 2025-11-25 spec — only "user" | "assistant")
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36#[serde(rename_all = "lowercase")]
37pub enum Role {
38    User,
39    Assistant,
40}
41
42/// Model hint — an open-ended struct per MCP 2025-11-25 spec.
43///
44/// The `name` field can be any model identifier string. Clients use hints to
45/// express model preferences without restricting to a hardcoded set.
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
47#[serde(rename_all = "camelCase")]
48pub struct ModelHint {
49    /// Optional model name hint (e.g., "claude-3-5-sonnet-20241022", "gpt-4o")
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub name: Option<String>,
52}
53
54impl ModelHint {
55    pub fn new(name: impl Into<String>) -> Self {
56        Self {
57            name: Some(name.into()),
58        }
59    }
60}
61
62/// Model preferences (per MCP spec)
63#[derive(Debug, Clone, Serialize, Deserialize)]
64#[serde(rename_all = "camelCase")]
65pub struct ModelPreferences {
66    /// Optional hints about which models to use
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub hints: Option<Vec<ModelHint>>,
69    /// Optional cost priority (0.0-1.0)
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub cost_priority: Option<f64>,
72    /// Optional speed priority (0.0-1.0)
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub speed_priority: Option<f64>,
75    /// Optional intelligence priority (0.0-1.0)
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub intelligence_priority: Option<f64>,
78}
79
80/// Tool choice mode for sampling requests (per MCP 2025-11-25)
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "lowercase")]
83pub enum ToolChoiceMode {
84    /// Model decides whether to use tools
85    Auto,
86    /// Model must not use any tools
87    None,
88    /// Model must use at least one tool (MCP 2025-11-25: "required")
89    #[serde(alias = "any")]
90    Required,
91}
92
93/// Tool choice configuration for sampling requests (per MCP 2025-11-25)
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(rename_all = "camelCase")]
96pub struct ToolChoice {
97    /// The mode for tool selection
98    pub mode: ToolChoiceMode,
99    /// Optional specific tool name to use (only meaningful with mode "required")
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub name: Option<String>,
102}
103
104impl ToolChoice {
105    pub fn auto() -> Self {
106        Self {
107            mode: ToolChoiceMode::Auto,
108            name: None,
109        }
110    }
111
112    pub fn none() -> Self {
113        Self {
114            mode: ToolChoiceMode::None,
115            name: None,
116        }
117    }
118
119    /// Create tool choice requiring at least one tool (MCP 2025-11-25: "required")
120    pub fn required() -> Self {
121        Self {
122            mode: ToolChoiceMode::Required,
123            name: None,
124        }
125    }
126
127    /// Compatibility alias for `required()` — MCP 2025-11-25 wire value is "required"
128    pub fn any() -> Self {
129        Self::required()
130    }
131
132    pub fn specific(name: impl Into<String>) -> Self {
133        Self {
134            mode: ToolChoiceMode::Required,
135            name: Some(name.into()),
136        }
137    }
138}
139
140/// Sampling message (per MCP spec)
141#[derive(Debug, Clone, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct SamplingMessage {
144    /// Role of the message
145    pub role: Role,
146    /// Content of the message
147    pub content: ContentBlock,
148}
149
150/// Parameters for sampling/createMessage request (per MCP spec)
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct CreateMessageParams {
154    /// Messages for context
155    pub messages: Vec<SamplingMessage>,
156    /// Optional model preferences
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub model_preferences: Option<ModelPreferences>,
159    /// Optional system prompt
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub system_prompt: Option<String>,
162    /// Optional include context
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub include_context: Option<String>,
165    /// Optional temperature
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub temperature: Option<f64>,
168    /// Maximum tokens (required field)
169    pub max_tokens: u32,
170    /// Optional stop sequences
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub stop_sequences: Option<Vec<String>>,
173    /// Optional metadata
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub metadata: Option<Value>,
176    /// Optional tools the LLM can use during sampling (MCP 2025-11-25)
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub tools: Option<Vec<crate::tools::Tool>>,
179    /// Optional tool choice configuration (MCP 2025-11-25)
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub tool_choice: Option<ToolChoice>,
182    /// Task metadata for task-augmented requests (MCP 2025-11-25)
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub task: Option<crate::tasks::TaskMetadata>,
185    /// Meta information (optional _meta field inside params)
186    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
187    pub meta: Option<std::collections::HashMap<String, Value>>,
188}
189
190/// Complete sampling/createMessage request (matches TypeScript CreateMessageRequest interface)
191#[derive(Debug, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct CreateMessageRequest {
194    /// Method name (always "sampling/createMessage")
195    pub method: String,
196    /// Request parameters
197    pub params: CreateMessageParams,
198}
199
200/// Result for sampling/createMessage (per MCP 2025-11-25 spec)
201///
202/// Flattened structure: { role, content, model, stopReason?, _meta? }
203#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(rename_all = "camelCase")]
205pub struct CreateMessageResult {
206    /// Role of the generated message
207    pub role: Role,
208    /// Content of the generated message
209    pub content: ContentBlock,
210    /// Model used for generation
211    pub model: String,
212    /// Stop reason
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub stop_reason: Option<String>,
215    /// Meta information (follows MCP Result interface)
216    #[serde(
217        default,
218        skip_serializing_if = "Option::is_none",
219        alias = "_meta",
220        rename = "_meta"
221    )]
222    pub meta: Option<std::collections::HashMap<String, Value>>,
223}
224
225impl CreateMessageParams {
226    pub fn new(messages: Vec<SamplingMessage>, max_tokens: u32) -> Self {
227        Self {
228            messages,
229            model_preferences: None,
230            system_prompt: None,
231            include_context: None,
232            temperature: None,
233            max_tokens,
234            stop_sequences: None,
235            metadata: None,
236            tools: None,
237            tool_choice: None,
238            task: None,
239            meta: None,
240        }
241    }
242
243    pub fn with_task(mut self, task: crate::tasks::TaskMetadata) -> Self {
244        self.task = Some(task);
245        self
246    }
247
248    pub fn with_tools(mut self, tools: Vec<crate::tools::Tool>) -> Self {
249        self.tools = Some(tools);
250        self
251    }
252
253    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
254        self.tool_choice = Some(tool_choice);
255        self
256    }
257
258    pub fn with_model_preferences(mut self, preferences: ModelPreferences) -> Self {
259        self.model_preferences = Some(preferences);
260        self
261    }
262
263    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
264        self.system_prompt = Some(prompt.into());
265        self
266    }
267
268    pub fn with_temperature(mut self, temperature: f64) -> Self {
269        self.temperature = Some(temperature);
270        self
271    }
272
273    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
274        self.stop_sequences = Some(sequences);
275        self
276    }
277
278    pub fn with_meta(mut self, meta: std::collections::HashMap<String, Value>) -> Self {
279        self.meta = Some(meta);
280        self
281    }
282}
283
284impl CreateMessageRequest {
285    pub fn new(messages: Vec<SamplingMessage>, max_tokens: u32) -> Self {
286        Self {
287            method: "sampling/createMessage".to_string(),
288            params: CreateMessageParams::new(messages, max_tokens),
289        }
290    }
291
292    pub fn with_model_preferences(mut self, preferences: ModelPreferences) -> Self {
293        self.params = self.params.with_model_preferences(preferences);
294        self
295    }
296
297    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
298        self.params = self.params.with_system_prompt(prompt);
299        self
300    }
301
302    pub fn with_temperature(mut self, temperature: f64) -> Self {
303        self.params = self.params.with_temperature(temperature);
304        self
305    }
306
307    pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
308        self.params = self.params.with_stop_sequences(sequences);
309        self
310    }
311
312    pub fn with_tools(mut self, tools: Vec<crate::tools::Tool>) -> Self {
313        self.params = self.params.with_tools(tools);
314        self
315    }
316
317    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
318        self.params = self.params.with_tool_choice(tool_choice);
319        self
320    }
321
322    pub fn with_meta(mut self, meta: std::collections::HashMap<String, Value>) -> Self {
323        self.params = self.params.with_meta(meta);
324        self
325    }
326}
327
328impl CreateMessageResult {
329    pub fn new(role: Role, content: ContentBlock, model: impl Into<String>) -> Self {
330        Self {
331            role,
332            content,
333            model: model.into(),
334            stop_reason: None,
335            meta: None,
336        }
337    }
338
339    pub fn with_stop_reason(mut self, reason: impl Into<String>) -> Self {
340        self.stop_reason = Some(reason.into());
341        self
342    }
343
344    pub fn with_meta(mut self, meta: std::collections::HashMap<String, Value>) -> Self {
345        self.meta = Some(meta);
346        self
347    }
348}
349
350// Trait implementations for sampling
351
352use crate::traits::*;
353use std::collections::HashMap;
354
355// Trait implementations for CreateMessageParams
356impl Params for CreateMessageParams {}
357
358impl HasCreateMessageParams for CreateMessageParams {
359    fn messages(&self) -> &Vec<SamplingMessage> {
360        &self.messages
361    }
362
363    fn model_preferences(&self) -> Option<&ModelPreferences> {
364        self.model_preferences.as_ref()
365    }
366
367    fn system_prompt(&self) -> Option<&String> {
368        self.system_prompt.as_ref()
369    }
370
371    fn include_context(&self) -> Option<&String> {
372        self.include_context.as_ref()
373    }
374
375    fn temperature(&self) -> Option<&f64> {
376        self.temperature.as_ref()
377    }
378
379    fn max_tokens(&self) -> u32 {
380        self.max_tokens
381    }
382
383    fn stop_sequences(&self) -> Option<&Vec<String>> {
384        self.stop_sequences.as_ref()
385    }
386
387    fn metadata(&self) -> Option<&Value> {
388        self.metadata.as_ref()
389    }
390}
391
392impl HasMetaParam for CreateMessageParams {
393    fn meta(&self) -> Option<&std::collections::HashMap<String, Value>> {
394        self.meta.as_ref()
395    }
396}
397
398// Trait implementations for CreateMessageRequest
399impl HasMethod for CreateMessageRequest {
400    fn method(&self) -> &str {
401        &self.method
402    }
403}
404
405impl HasParams for CreateMessageRequest {
406    fn params(&self) -> Option<&dyn Params> {
407        Some(&self.params)
408    }
409}
410
411// Trait implementations for CreateMessageResult
412impl HasData for CreateMessageResult {
413    fn data(&self) -> HashMap<String, Value> {
414        let mut data = HashMap::new();
415        data.insert(
416            "role".to_string(),
417            serde_json::to_value(&self.role).unwrap_or(Value::String("user".to_string())),
418        );
419        data.insert(
420            "content".to_string(),
421            serde_json::to_value(&self.content).unwrap_or(Value::Null),
422        );
423        data.insert("model".to_string(), Value::String(self.model.clone()));
424        if let Some(ref stop_reason) = self.stop_reason {
425            data.insert("stopReason".to_string(), Value::String(stop_reason.clone()));
426        }
427        data
428    }
429}
430
431impl HasMeta for CreateMessageResult {
432    fn meta(&self) -> Option<HashMap<String, Value>> {
433        self.meta.clone()
434    }
435}
436
437impl RpcResult for CreateMessageResult {}
438
439impl crate::traits::CreateMessageResult for CreateMessageResult {
440    fn role(&self) -> &Role {
441        &self.role
442    }
443
444    fn content(&self) -> &ContentBlock {
445        &self.content
446    }
447
448    fn model(&self) -> &String {
449        &self.model
450    }
451
452    fn stop_reason(&self) -> Option<&String> {
453        self.stop_reason.as_ref()
454    }
455}
456
457// ===========================================
458// === Fine-Grained Sampling Traits ===
459// ===========================================
460
461// ================== CONVENIENCE CONSTRUCTORS ==================
462
463impl ModelPreferences {
464    pub fn new() -> Self {
465        Self {
466            hints: None,
467            cost_priority: None,
468            speed_priority: None,
469            intelligence_priority: None,
470        }
471    }
472
473    pub fn with_hints(mut self, hints: Vec<ModelHint>) -> Self {
474        self.hints = Some(hints);
475        self
476    }
477
478    pub fn with_cost_priority(mut self, priority: f64) -> Self {
479        self.cost_priority = Some(priority);
480        self
481    }
482
483    pub fn with_speed_priority(mut self, priority: f64) -> Self {
484        self.speed_priority = Some(priority);
485        self
486    }
487
488    pub fn with_intelligence_priority(mut self, priority: f64) -> Self {
489        self.intelligence_priority = Some(priority);
490        self
491    }
492}
493
494impl Default for ModelPreferences {
495    fn default() -> Self {
496        Self::new()
497    }
498}
499
500impl SamplingMessage {
501    pub fn new(role: Role, content: ContentBlock) -> Self {
502        Self { role, content }
503    }
504
505    pub fn user_text(text: impl Into<String>) -> Self {
506        Self::new(Role::User, ContentBlock::text(text))
507    }
508
509    pub fn assistant_text(text: impl Into<String>) -> Self {
510        Self::new(Role::Assistant, ContentBlock::text(text))
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_tool_choice_mode_serializes_as_required() {
520        let tc = ToolChoice::required();
521        let json = serde_json::to_value(&tc).unwrap();
522        assert_eq!(json["mode"], "required");
523    }
524
525    #[test]
526    fn test_tool_choice_mode_deserializes_legacy_any() {
527        let json = serde_json::json!({"mode": "any"});
528        let tc: ToolChoice = serde_json::from_value(json).unwrap();
529        assert_eq!(tc.mode, ToolChoiceMode::Required);
530    }
531
532    #[test]
533    fn test_tool_choice_mode_deserializes_required() {
534        let json = serde_json::json!({"mode": "required"});
535        let tc: ToolChoice = serde_json::from_value(json).unwrap();
536        assert_eq!(tc.mode, ToolChoiceMode::Required);
537    }
538
539    #[test]
540    fn test_tool_choice_any_alias_returns_required() {
541        let tc = ToolChoice::any();
542        assert_eq!(tc.mode, ToolChoiceMode::Required);
543    }
544
545    #[test]
546    fn test_tool_choice_specific_uses_required_mode() {
547        let tc = ToolChoice::specific("my_tool");
548        assert_eq!(tc.mode, ToolChoiceMode::Required);
549        assert_eq!(tc.name, Some("my_tool".to_string()));
550    }
551}