1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::{content::Content, core::Role};
11
12#[cfg(feature = "mcp-sampling-tools")]
13use super::tools::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub enum IncludeContext {
19 None,
21 ThisServer,
23 AllServers,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SamplingMessage {
30 pub role: Role,
32 pub content: Content,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub metadata: Option<HashMap<String, serde_json::Value>>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CreateMessageRequest {
42 pub messages: Vec<SamplingMessage>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub model_preferences: Option<ModelPreferences>,
47 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
49 pub system_prompt: Option<String>,
50 #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
52 pub include_context: Option<IncludeContext>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub temperature: Option<f64>,
56 #[serde(rename = "maxTokens")]
58 pub max_tokens: u32,
59 #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
61 pub stop_sequences: Option<Vec<String>>,
62 #[cfg(feature = "mcp-sampling-tools")]
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub tools: Option<Vec<Tool>>,
68 #[cfg(feature = "mcp-sampling-tools")]
73 #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
74 pub tool_choice: Option<ToolChoice>,
75 #[cfg(feature = "mcp-tasks")]
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub task: Option<crate::types::tasks::TaskMetadata>,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub _meta: Option<serde_json::Value>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
90pub struct ModelHint {
91 #[serde(skip_serializing_if = "Option::is_none")]
94 pub name: Option<String>,
95}
96
97impl ModelHint {
98 pub fn new(name: impl Into<String>) -> Self {
100 Self {
101 name: Some(name.into()),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ModelPreferences {
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub hints: Option<Vec<ModelHint>>,
115
116 #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
118 pub cost_priority: Option<f64>,
119
120 #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
122 pub speed_priority: Option<f64>,
123
124 #[serde(
126 rename = "intelligencePriority",
127 skip_serializing_if = "Option::is_none"
128 )]
129 pub intelligence_priority: Option<f64>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct CreateMessageResult {
135 pub role: super::core::Role,
137 pub content: Content,
139 pub model: String,
141 #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
145 pub stop_reason: Option<StopReason>,
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub _meta: Option<serde_json::Value>,
149}
150
151#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
155#[serde(rename_all = "camelCase")]
156pub enum StopReason {
157 EndTurn,
159 MaxTokens,
161 StopSequence,
163 ContentFilter,
165 ToolUse,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct UsageStats {
172 #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
174 pub input_tokens: Option<u32>,
175 #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
177 pub output_tokens: Option<u32>,
178 #[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
180 pub total_tokens: Option<u32>,
181}
182
183#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
185#[serde(rename_all = "lowercase")]
186#[cfg(feature = "mcp-sampling-tools")]
187pub enum ToolChoiceMode {
188 #[default]
190 Auto,
191 Required,
193 None,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
199#[cfg(feature = "mcp-sampling-tools")]
200pub struct ToolChoice {
201 #[serde(skip_serializing_if = "Option::is_none")]
206 pub mode: Option<ToolChoiceMode>,
207}
208
209#[cfg(feature = "mcp-sampling-tools")]
210impl ToolChoice {
211 pub fn auto() -> Self {
213 Self {
214 mode: Some(ToolChoiceMode::Auto),
215 }
216 }
217
218 pub fn required() -> Self {
220 Self {
221 mode: Some(ToolChoiceMode::Required),
222 }
223 }
224
225 pub fn none() -> Self {
227 Self {
228 mode: Some(ToolChoiceMode::None),
229 }
230 }
231}
232
233#[cfg(feature = "mcp-sampling-tools")]
234impl Default for ToolChoice {
235 fn default() -> Self {
236 Self::auto()
237 }
238}
239
240#[cfg(test)]
241#[cfg(feature = "mcp-sampling-tools")]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_tool_choice_mode_serialization() {
247 assert_eq!(
248 serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
249 "\"auto\""
250 );
251 assert_eq!(
252 serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
253 "\"required\""
254 );
255 assert_eq!(
256 serde_json::to_string(&ToolChoiceMode::None).unwrap(),
257 "\"none\""
258 );
259 }
260
261 #[test]
262 fn test_tool_choice_constructors() {
263 let auto = ToolChoice::auto();
264 assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
265
266 let required = ToolChoice::required();
267 assert_eq!(required.mode, Some(ToolChoiceMode::Required));
268
269 let none = ToolChoice::none();
270 assert_eq!(none.mode, Some(ToolChoiceMode::None));
271 }
272
273 #[test]
274 fn test_tool_choice_default() {
275 let default = ToolChoice::default();
276 assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
277 }
278
279 #[test]
280 fn test_tool_choice_serialization() {
281 let choice = ToolChoice::required();
282 let json = serde_json::to_string(&choice).unwrap();
283 assert!(json.contains("\"mode\":\"required\""));
284 }
285}