turbomcp_protocol/types/
sampling.rs

1//! LLM sampling types
2//!
3//! This module contains types for server-initiated LLM sampling:
4//! - MCP 2025-11-25: Basic text-based sampling
5//! - MCP 2025-11-25 draft (SEP-1577): + Tool calling support
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::{content::Content, core::Role};
11
12#[cfg(feature = "mcp-sampling-tools")]
13use super::tools::Tool;
14
15/// Include context options for sampling
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub enum IncludeContext {
19    /// No context
20    None,
21    /// This server only
22    ThisServer,
23    /// All servers
24    AllServers,
25}
26
27/// Sampling message structure
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SamplingMessage {
30    /// Message role
31    pub role: Role,
32    /// Message content
33    pub content: Content,
34    /// Optional message metadata
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub metadata: Option<HashMap<String, serde_json::Value>>,
37}
38
39/// Create message request (for LLM sampling)
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CreateMessageRequest {
42    /// Messages to include in the sampling request
43    pub messages: Vec<SamplingMessage>,
44    /// Model preferences (optional)
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub model_preferences: Option<ModelPreferences>,
47    /// System prompt (optional)
48    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
49    pub system_prompt: Option<String>,
50    /// Include context from other servers
51    #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
52    pub include_context: Option<IncludeContext>,
53    /// Temperature for sampling (0.0 to 2.0)
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub temperature: Option<f64>,
56    /// Maximum number of tokens to generate (required by MCP spec)
57    #[serde(rename = "maxTokens")]
58    pub max_tokens: u32,
59    /// Stop sequences
60    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
61    pub stop_sequences: Option<Vec<String>>,
62    /// Tools that the model may use during generation (MCP 2025-11-25 draft, SEP-1577)
63    /// The client MUST return an error if this field is provided but
64    /// ClientCapabilities.sampling.tools is not declared
65    #[cfg(feature = "mcp-sampling-tools")]
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub tools: Option<Vec<Tool>>,
68    /// Controls how the model uses tools (MCP 2025-11-25 draft, SEP-1577)
69    /// The client MUST return an error if this field is provided but
70    /// ClientCapabilities.sampling.tools is not declared
71    /// Default is `{ mode: "auto" }`
72    #[cfg(feature = "mcp-sampling-tools")]
73    #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
74    pub tool_choice: Option<ToolChoice>,
75    /// Task metadata for task-augmented sampling (MCP 2025-11-25 draft, SEP-1686)
76    ///
77    /// When present, indicates the client should execute this sampling request as a long-running
78    /// task and return a CreateTaskResult instead of the immediate CreateMessageResult.
79    /// The actual result can be retrieved later via tasks/result.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub task: Option<crate::types::tasks::TaskMetadata>,
82    /// Optional metadata per MCP 2025-11-25 specification
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub _meta: Option<serde_json::Value>,
85}
86
87/// Model hint for selection (MCP 2025-11-25 compliant)
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub struct ModelHint {
90    /// Model name hint (substring matching)
91    /// Examples: "claude-3-5-sonnet", "sonnet", "claude"
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub name: Option<String>,
94}
95
96impl ModelHint {
97    /// Create a new model hint with a name
98    pub fn new(name: impl Into<String>) -> Self {
99        Self {
100            name: Some(name.into()),
101        }
102    }
103}
104
105/// Model preferences for sampling (MCP 2025-11-25 compliant)
106///
107/// The spec changed from tier-based to priority-based system.
108/// Priorities are 0.0-1.0 where 0 = not important, 1 = most important.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ModelPreferences {
111    /// Optional hints for model selection (evaluated in order)
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub hints: Option<Vec<ModelHint>>,
114
115    /// Cost priority (0.0 = not important, 1.0 = most important)
116    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
117    pub cost_priority: Option<f64>,
118
119    /// Speed priority (0.0 = not important, 1.0 = most important)
120    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
121    pub speed_priority: Option<f64>,
122
123    /// Intelligence priority (0.0 = not important, 1.0 = most important)
124    #[serde(
125        rename = "intelligencePriority",
126        skip_serializing_if = "Option::is_none"
127    )]
128    pub intelligence_priority: Option<f64>,
129}
130
131/// Create message result
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct CreateMessageResult {
134    /// The role of the message (required by MCP specification)
135    pub role: super::core::Role,
136    /// The generated message content
137    pub content: Content,
138    /// Model used for generation (required by MCP specification)
139    pub model: String,
140    /// Stop reason (if applicable)
141    ///
142    /// Uses the StopReason enum with camelCase serialization for MCP 2025-11-25 compliance.
143    #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
144    pub stop_reason: Option<StopReason>,
145    /// Optional metadata per MCP 2025-11-25 specification
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub _meta: Option<serde_json::Value>,
148}
149
150/// Stop reason for generation
151///
152/// Per MCP 2025-11-25 spec, these values use camelCase serialization for interoperability.
153#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
154#[serde(rename_all = "camelCase")]
155pub enum StopReason {
156    /// Generation completed naturally
157    EndTurn,
158    /// Hit maximum token limit
159    MaxTokens,
160    /// Hit a stop sequence
161    StopSequence,
162    /// Content filtering triggered
163    ContentFilter,
164    /// Tool use required
165    ToolUse,
166}
167
168/// Usage statistics for sampling
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct UsageStats {
171    /// Input tokens consumed
172    #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
173    pub input_tokens: Option<u32>,
174    /// Output tokens generated
175    #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
176    pub output_tokens: Option<u32>,
177    /// Total tokens used
178    #[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
179    pub total_tokens: Option<u32>,
180}
181
182/// Tool choice mode (MCP 2025-11-25 draft, SEP-1577)
183#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
184#[serde(rename_all = "lowercase")]
185#[cfg(feature = "mcp-sampling-tools")]
186pub enum ToolChoiceMode {
187    /// Model decides whether to use tools (default)
188    #[default]
189    Auto,
190    /// Model MUST use at least one tool before completing
191    Required,
192    /// Model MUST NOT use any tools
193    None,
194}
195
196/// Controls tool selection behavior for sampling requests (MCP 2025-11-25 draft, SEP-1577)
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
198#[cfg(feature = "mcp-sampling-tools")]
199pub struct ToolChoice {
200    /// Controls the tool use ability of the model
201    /// - "auto": Model decides whether to use tools (default)
202    /// - "required": Model MUST use at least one tool before completing
203    /// - "none": Model MUST NOT use any tools
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub mode: Option<ToolChoiceMode>,
206}
207
208#[cfg(feature = "mcp-sampling-tools")]
209impl ToolChoice {
210    /// Create a new ToolChoice with auto mode
211    pub fn auto() -> Self {
212        Self {
213            mode: Some(ToolChoiceMode::Auto),
214        }
215    }
216
217    /// Create a new ToolChoice requiring tool use
218    pub fn required() -> Self {
219        Self {
220            mode: Some(ToolChoiceMode::Required),
221        }
222    }
223
224    /// Create a new ToolChoice forbidding tool use
225    pub fn none() -> Self {
226        Self {
227            mode: Some(ToolChoiceMode::None),
228        }
229    }
230}
231
232#[cfg(feature = "mcp-sampling-tools")]
233impl Default for ToolChoice {
234    fn default() -> Self {
235        Self::auto()
236    }
237}
238
239#[cfg(test)]
240#[cfg(feature = "mcp-sampling-tools")]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_tool_choice_mode_serialization() {
246        assert_eq!(
247            serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
248            "\"auto\""
249        );
250        assert_eq!(
251            serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
252            "\"required\""
253        );
254        assert_eq!(
255            serde_json::to_string(&ToolChoiceMode::None).unwrap(),
256            "\"none\""
257        );
258    }
259
260    #[test]
261    fn test_tool_choice_constructors() {
262        let auto = ToolChoice::auto();
263        assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
264
265        let required = ToolChoice::required();
266        assert_eq!(required.mode, Some(ToolChoiceMode::Required));
267
268        let none = ToolChoice::none();
269        assert_eq!(none.mode, Some(ToolChoiceMode::None));
270    }
271
272    #[test]
273    fn test_tool_choice_default() {
274        let default = ToolChoice::default();
275        assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
276    }
277
278    #[test]
279    fn test_tool_choice_serialization() {
280        let choice = ToolChoice::required();
281        let json = serde_json::to_string(&choice).unwrap();
282        assert!(json.contains("\"mode\":\"required\""));
283    }
284}