Skip to main content

turbomcp_protocol/types/
sampling.rs

1//! LLM sampling types
2//!
3//! This module contains types for server-initiated LLM sampling:
4//! - MCP 2025-11-25: Basic text-based sampling
5//! - MCP 2025-11-25 draft (SEP-1577): + Tool calling support
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::{content::Content, core::Role};
11
12use super::tools::Tool;
13
14/// Include context options for sampling
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "camelCase")]
17pub enum IncludeContext {
18    /// No context
19    None,
20    /// This server only
21    ThisServer,
22    /// All servers
23    AllServers,
24}
25
26/// Sampling message structure
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SamplingMessage {
29    /// Message role
30    pub role: Role,
31    /// Message content
32    pub content: Content,
33    /// Optional message metadata
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub metadata: Option<HashMap<String, serde_json::Value>>,
36}
37
38/// Create message request (for LLM sampling)
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct CreateMessageRequest {
41    /// Messages to include in the sampling request
42    pub messages: Vec<SamplingMessage>,
43    /// Model preferences (optional)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub model_preferences: Option<ModelPreferences>,
46    /// System prompt (optional)
47    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
48    pub system_prompt: Option<String>,
49    /// Include context from other servers
50    #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
51    pub include_context: Option<IncludeContext>,
52    /// Temperature for sampling (0.0 to 2.0)
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub temperature: Option<f64>,
55    /// Maximum number of tokens to generate (required by MCP spec)
56    #[serde(rename = "maxTokens")]
57    pub max_tokens: u32,
58    /// Stop sequences
59    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
60    pub stop_sequences: Option<Vec<String>>,
61    /// Tools that the model may use during generation (MCP 2025-11-25 draft, SEP-1577)
62    /// The client MUST return an error if this field is provided but
63    /// ClientCapabilities.sampling.tools is not declared
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub tools: Option<Vec<Tool>>,
66    /// Controls how the model uses tools (MCP 2025-11-25 draft, SEP-1577)
67    /// The client MUST return an error if this field is provided but
68    /// ClientCapabilities.sampling.tools is not declared
69    /// Default is `{ mode: "auto" }`
70    #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
71    pub tool_choice: Option<ToolChoice>,
72    /// Task metadata for task-augmented sampling (MCP 2025-11-25 draft, SEP-1686)
73    ///
74    /// When present, indicates the client should execute this sampling request as a long-running
75    /// task and return a CreateTaskResult instead of the immediate CreateMessageResult.
76    /// The actual result can be retrieved later via tasks/result.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub task: Option<crate::types::tasks::TaskMetadata>,
79    /// Optional metadata per MCP 2025-11-25 specification
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub _meta: Option<serde_json::Value>,
82}
83
84/// Model hint for selection (MCP 2025-11-25 compliant)
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ModelHint {
87    /// Model name hint (substring matching)
88    /// Examples: "claude-3-5-sonnet", "sonnet", "claude"
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub name: Option<String>,
91}
92
93impl ModelHint {
94    /// Create a new model hint with a name
95    pub fn new(name: impl Into<String>) -> Self {
96        Self {
97            name: Some(name.into()),
98        }
99    }
100}
101
102/// Model preferences for sampling (MCP 2025-11-25 compliant)
103///
104/// The spec changed from tier-based to priority-based system.
105/// Priorities are 0.0-1.0 where 0 = not important, 1 = most important.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ModelPreferences {
108    /// Optional hints for model selection (evaluated in order)
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub hints: Option<Vec<ModelHint>>,
111
112    /// Cost priority (0.0 = not important, 1.0 = most important)
113    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
114    pub cost_priority: Option<f64>,
115
116    /// Speed priority (0.0 = not important, 1.0 = most important)
117    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
118    pub speed_priority: Option<f64>,
119
120    /// Intelligence priority (0.0 = not important, 1.0 = most important)
121    #[serde(
122        rename = "intelligencePriority",
123        skip_serializing_if = "Option::is_none"
124    )]
125    pub intelligence_priority: Option<f64>,
126}
127
128/// Create message result
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct CreateMessageResult {
131    /// The role of the message (required by MCP specification)
132    pub role: super::core::Role,
133    /// The generated message content
134    pub content: Content,
135    /// Model used for generation (required by MCP specification)
136    pub model: String,
137    /// Stop reason (if applicable)
138    ///
139    /// Uses the StopReason enum with camelCase serialization for MCP 2025-11-25 compliance.
140    #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
141    pub stop_reason: Option<StopReason>,
142    /// Optional metadata per MCP 2025-11-25 specification
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub _meta: Option<serde_json::Value>,
145}
146
147/// Stop reason for generation
148///
149/// Per MCP 2025-11-25 spec, these values use camelCase serialization for interoperability.
150#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
151#[serde(rename_all = "camelCase")]
152pub enum StopReason {
153    /// Generation completed naturally
154    EndTurn,
155    /// Hit maximum token limit
156    MaxTokens,
157    /// Hit a stop sequence
158    StopSequence,
159    /// Content filtering triggered
160    ContentFilter,
161    /// Tool use required
162    ToolUse,
163}
164
165/// Usage statistics for sampling
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct UsageStats {
168    /// Input tokens consumed
169    #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
170    pub input_tokens: Option<u32>,
171    /// Output tokens generated
172    #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
173    pub output_tokens: Option<u32>,
174    /// Total tokens used
175    #[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
176    pub total_tokens: Option<u32>,
177}
178
179/// Tool choice mode (MCP 2025-11-25 draft, SEP-1577)
180#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
181#[serde(rename_all = "lowercase")]
182pub enum ToolChoiceMode {
183    /// Model decides whether to use tools (default)
184    #[default]
185    Auto,
186    /// Model MUST use at least one tool before completing
187    Required,
188    /// Model MUST NOT use any tools
189    None,
190}
191
192/// Controls tool selection behavior for sampling requests (MCP 2025-11-25 draft, SEP-1577)
193#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
194pub struct ToolChoice {
195    /// Controls the tool use ability of the model
196    /// - "auto": Model decides whether to use tools (default)
197    /// - "required": Model MUST use at least one tool before completing
198    /// - "none": Model MUST NOT use any tools
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub mode: Option<ToolChoiceMode>,
201}
202
203impl ToolChoice {
204    /// Create a new ToolChoice with auto mode
205    pub fn auto() -> Self {
206        Self {
207            mode: Some(ToolChoiceMode::Auto),
208        }
209    }
210
211    /// Create a new ToolChoice requiring tool use
212    pub fn required() -> Self {
213        Self {
214            mode: Some(ToolChoiceMode::Required),
215        }
216    }
217
218    /// Create a new ToolChoice forbidding tool use
219    pub fn none() -> Self {
220        Self {
221            mode: Some(ToolChoiceMode::None),
222        }
223    }
224}
225
226impl Default for ToolChoice {
227    fn default() -> Self {
228        Self::auto()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_tool_choice_mode_serialization() {
238        assert_eq!(
239            serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
240            "\"auto\""
241        );
242        assert_eq!(
243            serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
244            "\"required\""
245        );
246        assert_eq!(
247            serde_json::to_string(&ToolChoiceMode::None).unwrap(),
248            "\"none\""
249        );
250    }
251
252    #[test]
253    fn test_tool_choice_constructors() {
254        let auto = ToolChoice::auto();
255        assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
256
257        let required = ToolChoice::required();
258        assert_eq!(required.mode, Some(ToolChoiceMode::Required));
259
260        let none = ToolChoice::none();
261        assert_eq!(none.mode, Some(ToolChoiceMode::None));
262    }
263
264    #[test]
265    fn test_tool_choice_default() {
266        let default = ToolChoice::default();
267        assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
268    }
269
270    #[test]
271    fn test_tool_choice_serialization() {
272        let choice = ToolChoice::required();
273        let json = serde_json::to_string(&choice).unwrap();
274        assert!(json.contains("\"mode\":\"required\""));
275    }
276}