turbomcp_protocol/types/
sampling.rs

1//! LLM sampling types
2//!
3//! This module contains types for server-initiated LLM sampling:
4//! - MCP 2025-11-25: Basic text-based sampling
5//! - MCP 2025-11-25 draft (SEP-1577): + Tool calling support
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::{content::Content, core::Role};
11
12#[cfg(feature = "mcp-sampling-tools")]
13use super::tools::Tool;
14
15/// Include context options for sampling
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub enum IncludeContext {
19    /// No context
20    None,
21    /// This server only
22    ThisServer,
23    /// All servers
24    AllServers,
25}
26
27/// Sampling message structure
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SamplingMessage {
30    /// Message role
31    pub role: Role,
32    /// Message content
33    pub content: Content,
34    /// Optional message metadata
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub metadata: Option<HashMap<String, serde_json::Value>>,
37}
38
39/// Create message request (for LLM sampling)
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CreateMessageRequest {
42    /// Messages to include in the sampling request
43    pub messages: Vec<SamplingMessage>,
44    /// Model preferences (optional)
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub model_preferences: Option<ModelPreferences>,
47    /// System prompt (optional)
48    #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
49    pub system_prompt: Option<String>,
50    /// Include context from other servers
51    #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
52    pub include_context: Option<IncludeContext>,
53    /// Temperature for sampling (0.0 to 2.0)
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub temperature: Option<f64>,
56    /// Maximum number of tokens to generate (required by MCP spec)
57    #[serde(rename = "maxTokens")]
58    pub max_tokens: u32,
59    /// Stop sequences
60    #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
61    pub stop_sequences: Option<Vec<String>>,
62    /// Tools that the model may use during generation (MCP 2025-11-25 draft, SEP-1577)
63    /// The client MUST return an error if this field is provided but
64    /// ClientCapabilities.sampling.tools is not declared
65    #[cfg(feature = "mcp-sampling-tools")]
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub tools: Option<Vec<Tool>>,
68    /// Controls how the model uses tools (MCP 2025-11-25 draft, SEP-1577)
69    /// The client MUST return an error if this field is provided but
70    /// ClientCapabilities.sampling.tools is not declared
71    /// Default is `{ mode: "auto" }`
72    #[cfg(feature = "mcp-sampling-tools")]
73    #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
74    pub tool_choice: Option<ToolChoice>,
75    /// Task metadata for task-augmented sampling (MCP 2025-11-25 draft, SEP-1686)
76    ///
77    /// When present, indicates the client should execute this sampling request as a long-running
78    /// task and return a CreateTaskResult instead of the immediate CreateMessageResult.
79    /// The actual result can be retrieved later via tasks/result.
80    #[cfg(feature = "mcp-tasks")]
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub task: Option<crate::types::tasks::TaskMetadata>,
83    /// Optional metadata per MCP 2025-11-25 specification
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub _meta: Option<serde_json::Value>,
86}
87
88/// Model hint for selection (MCP 2025-11-25 compliant)
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90pub struct ModelHint {
91    /// Model name hint (substring matching)
92    /// Examples: "claude-3-5-sonnet", "sonnet", "claude"
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub name: Option<String>,
95}
96
97impl ModelHint {
98    /// Create a new model hint with a name
99    pub fn new(name: impl Into<String>) -> Self {
100        Self {
101            name: Some(name.into()),
102        }
103    }
104}
105
106/// Model preferences for sampling (MCP 2025-11-25 compliant)
107///
108/// The spec changed from tier-based to priority-based system.
109/// Priorities are 0.0-1.0 where 0 = not important, 1 = most important.
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ModelPreferences {
112    /// Optional hints for model selection (evaluated in order)
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub hints: Option<Vec<ModelHint>>,
115
116    /// Cost priority (0.0 = not important, 1.0 = most important)
117    #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
118    pub cost_priority: Option<f64>,
119
120    /// Speed priority (0.0 = not important, 1.0 = most important)
121    #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
122    pub speed_priority: Option<f64>,
123
124    /// Intelligence priority (0.0 = not important, 1.0 = most important)
125    #[serde(
126        rename = "intelligencePriority",
127        skip_serializing_if = "Option::is_none"
128    )]
129    pub intelligence_priority: Option<f64>,
130}
131
132/// Create message result
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct CreateMessageResult {
135    /// The role of the message (required by MCP specification)
136    pub role: super::core::Role,
137    /// The generated message content
138    pub content: Content,
139    /// Model used for generation (required by MCP specification)
140    pub model: String,
141    /// Stop reason (if applicable)
142    ///
143    /// Uses the StopReason enum with camelCase serialization for MCP 2025-11-25 compliance.
144    #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
145    pub stop_reason: Option<StopReason>,
146    /// Optional metadata per MCP 2025-11-25 specification
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub _meta: Option<serde_json::Value>,
149}
150
151/// Stop reason for generation
152///
153/// Per MCP 2025-11-25 spec, these values use camelCase serialization for interoperability.
154#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
155#[serde(rename_all = "camelCase")]
156pub enum StopReason {
157    /// Generation completed naturally
158    EndTurn,
159    /// Hit maximum token limit
160    MaxTokens,
161    /// Hit a stop sequence
162    StopSequence,
163    /// Content filtering triggered
164    ContentFilter,
165    /// Tool use required
166    ToolUse,
167}
168
169/// Usage statistics for sampling
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct UsageStats {
172    /// Input tokens consumed
173    #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
174    pub input_tokens: Option<u32>,
175    /// Output tokens generated
176    #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
177    pub output_tokens: Option<u32>,
178    /// Total tokens used
179    #[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
180    pub total_tokens: Option<u32>,
181}
182
183/// Tool choice mode (MCP 2025-11-25 draft, SEP-1577)
184#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
185#[serde(rename_all = "lowercase")]
186#[cfg(feature = "mcp-sampling-tools")]
187pub enum ToolChoiceMode {
188    /// Model decides whether to use tools (default)
189    #[default]
190    Auto,
191    /// Model MUST use at least one tool before completing
192    Required,
193    /// Model MUST NOT use any tools
194    None,
195}
196
197/// Controls tool selection behavior for sampling requests (MCP 2025-11-25 draft, SEP-1577)
198#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
199#[cfg(feature = "mcp-sampling-tools")]
200pub struct ToolChoice {
201    /// Controls the tool use ability of the model
202    /// - "auto": Model decides whether to use tools (default)
203    /// - "required": Model MUST use at least one tool before completing
204    /// - "none": Model MUST NOT use any tools
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub mode: Option<ToolChoiceMode>,
207}
208
209#[cfg(feature = "mcp-sampling-tools")]
210impl ToolChoice {
211    /// Create a new ToolChoice with auto mode
212    pub fn auto() -> Self {
213        Self {
214            mode: Some(ToolChoiceMode::Auto),
215        }
216    }
217
218    /// Create a new ToolChoice requiring tool use
219    pub fn required() -> Self {
220        Self {
221            mode: Some(ToolChoiceMode::Required),
222        }
223    }
224
225    /// Create a new ToolChoice forbidding tool use
226    pub fn none() -> Self {
227        Self {
228            mode: Some(ToolChoiceMode::None),
229        }
230    }
231}
232
233#[cfg(feature = "mcp-sampling-tools")]
234impl Default for ToolChoice {
235    fn default() -> Self {
236        Self::auto()
237    }
238}
239
240#[cfg(test)]
241#[cfg(feature = "mcp-sampling-tools")]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_tool_choice_mode_serialization() {
247        assert_eq!(
248            serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
249            "\"auto\""
250        );
251        assert_eq!(
252            serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
253            "\"required\""
254        );
255        assert_eq!(
256            serde_json::to_string(&ToolChoiceMode::None).unwrap(),
257            "\"none\""
258        );
259    }
260
261    #[test]
262    fn test_tool_choice_constructors() {
263        let auto = ToolChoice::auto();
264        assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
265
266        let required = ToolChoice::required();
267        assert_eq!(required.mode, Some(ToolChoiceMode::Required));
268
269        let none = ToolChoice::none();
270        assert_eq!(none.mode, Some(ToolChoiceMode::None));
271    }
272
273    #[test]
274    fn test_tool_choice_default() {
275        let default = ToolChoice::default();
276        assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
277    }
278
279    #[test]
280    fn test_tool_choice_serialization() {
281        let choice = ToolChoice::required();
282        let json = serde_json::to_string(&choice).unwrap();
283        assert!(json.contains("\"mode\":\"required\""));
284    }
285}