1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::{content::Content, core::Role};
11
12#[cfg(feature = "mcp-sampling-tools")]
13use super::tools::Tool;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "camelCase")]
18pub enum IncludeContext {
19 None,
21 ThisServer,
23 AllServers,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct SamplingMessage {
30 pub role: Role,
32 pub content: Content,
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub metadata: Option<HashMap<String, serde_json::Value>>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct CreateMessageRequest {
42 pub messages: Vec<SamplingMessage>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub model_preferences: Option<ModelPreferences>,
47 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
49 pub system_prompt: Option<String>,
50 #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
52 pub include_context: Option<IncludeContext>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub temperature: Option<f64>,
56 #[serde(rename = "maxTokens")]
58 pub max_tokens: u32,
59 #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
61 pub stop_sequences: Option<Vec<String>>,
62 #[cfg(feature = "mcp-sampling-tools")]
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub tools: Option<Vec<Tool>>,
68 #[cfg(feature = "mcp-sampling-tools")]
73 #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
74 pub tool_choice: Option<ToolChoice>,
75 #[serde(skip_serializing_if = "Option::is_none")]
81 pub task: Option<crate::types::tasks::TaskMetadata>,
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub _meta: Option<serde_json::Value>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub struct ModelHint {
90 #[serde(skip_serializing_if = "Option::is_none")]
93 pub name: Option<String>,
94}
95
96impl ModelHint {
97 pub fn new(name: impl Into<String>) -> Self {
99 Self {
100 name: Some(name.into()),
101 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ModelPreferences {
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub hints: Option<Vec<ModelHint>>,
114
115 #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
117 pub cost_priority: Option<f64>,
118
119 #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
121 pub speed_priority: Option<f64>,
122
123 #[serde(
125 rename = "intelligencePriority",
126 skip_serializing_if = "Option::is_none"
127 )]
128 pub intelligence_priority: Option<f64>,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct CreateMessageResult {
134 pub role: super::core::Role,
136 pub content: Content,
138 pub model: String,
140 #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
144 pub stop_reason: Option<StopReason>,
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub _meta: Option<serde_json::Value>,
148}
149
150#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
154#[serde(rename_all = "camelCase")]
155pub enum StopReason {
156 EndTurn,
158 MaxTokens,
160 StopSequence,
162 ContentFilter,
164 ToolUse,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct UsageStats {
171 #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")]
173 pub input_tokens: Option<u32>,
174 #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")]
176 pub output_tokens: Option<u32>,
177 #[serde(rename = "totalTokens", skip_serializing_if = "Option::is_none")]
179 pub total_tokens: Option<u32>,
180}
181
182#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
184#[serde(rename_all = "lowercase")]
185#[cfg(feature = "mcp-sampling-tools")]
186pub enum ToolChoiceMode {
187 #[default]
189 Auto,
190 Required,
192 None,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
198#[cfg(feature = "mcp-sampling-tools")]
199pub struct ToolChoice {
200 #[serde(skip_serializing_if = "Option::is_none")]
205 pub mode: Option<ToolChoiceMode>,
206}
207
208#[cfg(feature = "mcp-sampling-tools")]
209impl ToolChoice {
210 pub fn auto() -> Self {
212 Self {
213 mode: Some(ToolChoiceMode::Auto),
214 }
215 }
216
217 pub fn required() -> Self {
219 Self {
220 mode: Some(ToolChoiceMode::Required),
221 }
222 }
223
224 pub fn none() -> Self {
226 Self {
227 mode: Some(ToolChoiceMode::None),
228 }
229 }
230}
231
232#[cfg(feature = "mcp-sampling-tools")]
233impl Default for ToolChoice {
234 fn default() -> Self {
235 Self::auto()
236 }
237}
238
239#[cfg(test)]
240#[cfg(feature = "mcp-sampling-tools")]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_tool_choice_mode_serialization() {
246 assert_eq!(
247 serde_json::to_string(&ToolChoiceMode::Auto).unwrap(),
248 "\"auto\""
249 );
250 assert_eq!(
251 serde_json::to_string(&ToolChoiceMode::Required).unwrap(),
252 "\"required\""
253 );
254 assert_eq!(
255 serde_json::to_string(&ToolChoiceMode::None).unwrap(),
256 "\"none\""
257 );
258 }
259
260 #[test]
261 fn test_tool_choice_constructors() {
262 let auto = ToolChoice::auto();
263 assert_eq!(auto.mode, Some(ToolChoiceMode::Auto));
264
265 let required = ToolChoice::required();
266 assert_eq!(required.mode, Some(ToolChoiceMode::Required));
267
268 let none = ToolChoice::none();
269 assert_eq!(none.mode, Some(ToolChoiceMode::None));
270 }
271
272 #[test]
273 fn test_tool_choice_default() {
274 let default = ToolChoice::default();
275 assert_eq!(default.mode, Some(ToolChoiceMode::Auto));
276 }
277
278 #[test]
279 fn test_tool_choice_serialization() {
280 let choice = ToolChoice::required();
281 let json = serde_json::to_string(&choice).unwrap();
282 assert!(json.contains("\"mode\":\"required\""));
283 }
284}