1use serde_json::Value;
7use std::collections::HashMap;
8
9use turul_mcp_protocol::prompts::ContentBlock;
11use turul_mcp_protocol::sampling::{
12 CreateMessageParams, CreateMessageRequest, ModelHint, ModelPreferences, Role, SamplingMessage,
13};
14
15pub struct MessageBuilder {
17 messages: Vec<SamplingMessage>,
18 model_preferences: Option<ModelPreferences>,
19 system_prompt: Option<String>,
20 include_context: Option<String>,
21 temperature: Option<f64>,
22 max_tokens: u32,
23 stop_sequences: Option<Vec<String>>,
24 metadata: Option<Value>,
25 meta: Option<HashMap<String, Value>>,
26}
27
28impl MessageBuilder {
29 pub fn new() -> Self {
31 Self {
32 messages: Vec::new(),
33 model_preferences: None,
34 system_prompt: None,
35 include_context: None,
36 temperature: None,
37 max_tokens: 1000, stop_sequences: None,
39 metadata: None,
40 meta: None,
41 }
42 }
43
44 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
46 self.max_tokens = max_tokens;
47 self
48 }
49
50 pub fn message(mut self, message: SamplingMessage) -> Self {
52 self.messages.push(message);
53 self
54 }
55
56 pub fn system(self, content: impl Into<String>) -> Self {
58 self.system_prompt(content)
59 }
60
61 pub fn user_text(mut self, content: impl Into<String>) -> Self {
63 self.messages.push(SamplingMessage {
64 role: Role::User,
65 content: ContentBlock::text(content),
66 });
67 self
68 }
69
70 pub fn user_image(mut self, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
72 self.messages.push(SamplingMessage {
73 role: Role::User,
74 content: ContentBlock::image(data, mime_type),
75 });
76 self
77 }
78
79 pub fn assistant_text(mut self, content: impl Into<String>) -> Self {
81 self.messages.push(SamplingMessage {
82 role: Role::Assistant,
83 content: ContentBlock::text(content),
84 });
85 self
86 }
87
88 pub fn model_preferences(mut self, preferences: ModelPreferences) -> Self {
90 self.model_preferences = Some(preferences);
91 self
92 }
93
94 pub fn with_model_preferences<F>(mut self, f: F) -> Self
96 where
97 F: FnOnce(ModelPreferencesBuilder) -> ModelPreferencesBuilder,
98 {
99 let builder = ModelPreferencesBuilder::new();
100 let preferences = f(builder).build();
101 self.model_preferences = Some(preferences);
102 self
103 }
104
105 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
107 self.system_prompt = Some(prompt.into());
108 self
109 }
110
111 pub fn include_context(mut self, context: impl Into<String>) -> Self {
113 self.include_context = Some(context.into());
114 self
115 }
116
117 pub fn temperature(mut self, temperature: f64) -> Self {
119 self.temperature = Some(temperature.clamp(0.0, 2.0));
120 self
121 }
122
123 pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
125 self.stop_sequences = Some(sequences);
126 self
127 }
128
129 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
131 if let Some(ref mut sequences) = self.stop_sequences {
132 sequences.push(sequence.into());
133 } else {
134 self.stop_sequences = Some(vec![sequence.into()]);
135 }
136 self
137 }
138
139 pub fn metadata(mut self, metadata: Value) -> Self {
141 self.metadata = Some(metadata);
142 self
143 }
144
145 pub fn meta(mut self, meta: HashMap<String, Value>) -> Self {
147 self.meta = Some(meta);
148 self
149 }
150
151 pub fn build_params(self) -> CreateMessageParams {
153 let mut params = CreateMessageParams::new(self.messages, self.max_tokens);
154
155 if let Some(preferences) = self.model_preferences {
156 params = params.with_model_preferences(preferences);
157 }
158 if let Some(prompt) = self.system_prompt {
159 params = params.with_system_prompt(prompt);
160 }
161 if let Some(temp) = self.temperature {
162 params = params.with_temperature(temp);
163 }
164 if let Some(sequences) = self.stop_sequences {
165 params = params.with_stop_sequences(sequences);
166 }
167 if let Some(meta) = self.meta {
168 params = params.with_meta(meta);
169 }
170
171 params.include_context = self.include_context;
173 params.metadata = self.metadata;
174
175 params
176 }
177
178 pub fn build_request(self) -> CreateMessageRequest {
180 CreateMessageRequest {
181 method: "sampling/createMessage".to_string(),
182 params: self.build_params(),
183 }
184 }
185}
186
187impl Default for MessageBuilder {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193pub struct ModelPreferencesBuilder {
195 hints: Vec<ModelHint>,
196 cost_priority: Option<f64>,
197 speed_priority: Option<f64>,
198 intelligence_priority: Option<f64>,
199}
200
201impl ModelPreferencesBuilder {
202 pub fn new() -> Self {
203 Self {
204 hints: Vec::new(),
205 cost_priority: None,
206 speed_priority: None,
207 intelligence_priority: None,
208 }
209 }
210
211 pub fn hint(mut self, hint: ModelHint) -> Self {
213 self.hints.push(hint);
214 self
215 }
216
217 pub fn prefer_claude_sonnet(self) -> Self {
219 self.hint(ModelHint::new("claude-3-5-sonnet-20241022"))
220 }
221
222 pub fn prefer_claude_haiku(self) -> Self {
224 self.hint(ModelHint::new("claude-3-5-haiku-20241022"))
225 }
226
227 pub fn prefer_gpt4o(self) -> Self {
229 self.hint(ModelHint::new("gpt-4o"))
230 }
231
232 pub fn prefer_gpt4o_mini(self) -> Self {
234 self.hint(ModelHint::new("gpt-4o-mini"))
235 }
236
237 pub fn prefer_fast(self) -> Self {
239 self.hint(ModelHint::new("claude-3-5-haiku-20241022"))
240 .hint(ModelHint::new("gpt-4o-mini"))
241 }
242
243 pub fn prefer_quality(self) -> Self {
245 self.hint(ModelHint::new("claude-3-5-sonnet-20241022"))
246 .hint(ModelHint::new("gpt-4o"))
247 }
248
249 pub fn cost_priority(mut self, priority: f64) -> Self {
251 self.cost_priority = Some(priority.clamp(0.0, 1.0));
252 self
253 }
254
255 pub fn speed_priority(mut self, priority: f64) -> Self {
257 self.speed_priority = Some(priority.clamp(0.0, 1.0));
258 self
259 }
260
261 pub fn intelligence_priority(mut self, priority: f64) -> Self {
263 self.intelligence_priority = Some(priority.clamp(0.0, 1.0));
264 self
265 }
266
267 pub fn build(self) -> ModelPreferences {
269 ModelPreferences {
270 hints: if self.hints.is_empty() {
271 None
272 } else {
273 Some(self.hints)
274 },
275 cost_priority: self.cost_priority,
276 speed_priority: self.speed_priority,
277 intelligence_priority: self.intelligence_priority,
278 }
279 }
280}
281
282impl Default for ModelPreferencesBuilder {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288pub trait SamplingMessageExt {
290 fn user_text(content: impl Into<String>) -> SamplingMessage;
292 fn user_image(data: impl Into<String>, mime_type: impl Into<String>) -> SamplingMessage;
294 fn assistant_text(content: impl Into<String>) -> SamplingMessage;
296}
297
298impl SamplingMessageExt for SamplingMessage {
299 fn user_text(content: impl Into<String>) -> Self {
300 Self {
301 role: Role::User,
302 content: ContentBlock::text(content),
303 }
304 }
305
306 fn user_image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
307 Self {
308 role: Role::User,
309 content: ContentBlock::image(data, mime_type),
310 }
311 }
312
313 fn assistant_text(content: impl Into<String>) -> Self {
314 Self {
315 role: Role::Assistant,
316 content: ContentBlock::text(content),
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use serde_json::json;
325
326 #[test]
327 fn test_message_builder_basic() {
328 let params = MessageBuilder::new()
329 .max_tokens(2000)
330 .system("You are a helpful assistant.")
331 .user_text("Hello, how are you?")
332 .assistant_text("I'm doing well, thank you!")
333 .temperature(0.7)
334 .build_params();
335
336 assert_eq!(params.messages.len(), 2);
338 assert_eq!(params.max_tokens, 2000);
339 assert_eq!(params.temperature, Some(0.7));
340 assert_eq!(
341 params.system_prompt,
342 Some("You are a helpful assistant.".to_string())
343 );
344
345 assert_eq!(params.messages[0].role, Role::User);
347 if let ContentBlock::Text { text, .. } = ¶ms.messages[0].content {
348 assert_eq!(text, "Hello, how are you?");
349 } else {
350 panic!("Expected text content");
351 }
352 }
353
354 #[test]
355 fn test_message_builder_model_preferences() {
356 let params = MessageBuilder::new()
357 .user_text("Test message")
358 .with_model_preferences(|prefs| {
359 prefs
360 .prefer_claude_sonnet()
361 .cost_priority(0.8)
362 .speed_priority(0.6)
363 .intelligence_priority(0.9)
364 })
365 .build_params();
366
367 let preferences = params
368 .model_preferences
369 .expect("Expected model preferences");
370 assert_eq!(preferences.hints.as_ref().unwrap().len(), 1);
371 assert_eq!(
372 preferences.hints.as_ref().unwrap()[0],
373 ModelHint::new("claude-3-5-sonnet-20241022")
374 );
375 assert_eq!(preferences.cost_priority, Some(0.8));
376 assert_eq!(preferences.speed_priority, Some(0.6));
377 assert_eq!(preferences.intelligence_priority, Some(0.9));
378 }
379
380 #[test]
381 fn test_message_builder_stop_sequences() {
382 let params = MessageBuilder::new()
383 .user_text("Generate some code")
384 .stop_sequence("```")
385 .stop_sequence("\n\n")
386 .build_params();
387
388 let sequences = params.stop_sequences.expect("Expected stop sequences");
389 assert_eq!(sequences.len(), 2);
390 assert_eq!(sequences[0], "```");
391 assert_eq!(sequences[1], "\n\n");
392 }
393
394 #[test]
395 fn test_message_builder_complete_request() {
396 let request = MessageBuilder::new()
397 .system_prompt("You are a coding assistant")
398 .user_text("Write a function to calculate fibonacci numbers")
399 .temperature(0.3)
400 .max_tokens(500)
401 .metadata(json!({"request_id": "12345"}))
402 .build_request();
403
404 assert_eq!(request.method, "sampling/createMessage");
405 assert_eq!(request.params.max_tokens, 500);
406 assert_eq!(request.params.temperature, Some(0.3));
407 assert_eq!(
408 request.params.system_prompt,
409 Some("You are a coding assistant".to_string())
410 );
411 assert!(request.params.metadata.is_some());
412 }
413
414 #[test]
415 fn test_model_preferences_builder() {
416 let preferences = ModelPreferencesBuilder::new()
417 .prefer_fast()
418 .cost_priority(0.9)
419 .speed_priority(0.8)
420 .build();
421
422 let hints = preferences.hints.expect("Expected hints");
423 assert_eq!(hints.len(), 2);
424 assert!(hints.contains(&ModelHint::new("claude-3-5-haiku-20241022")));
425 assert!(hints.contains(&ModelHint::new("gpt-4o-mini")));
426 assert_eq!(preferences.cost_priority, Some(0.9));
427 assert_eq!(preferences.speed_priority, Some(0.8));
428 }
429
430 #[test]
431 fn test_sampling_message_convenience_methods() {
432 let user_msg = SamplingMessage::user_text("User input");
433 assert_eq!(user_msg.role, Role::User);
434
435 let assistant_msg = SamplingMessage::assistant_text("Assistant response");
436 assert_eq!(assistant_msg.role, Role::Assistant);
437
438 let image_msg = SamplingMessage::user_image("base64data", "image/png");
439 assert_eq!(image_msg.role, Role::User);
440 if let ContentBlock::Image {
441 data, mime_type, ..
442 } = &image_msg.content
443 {
444 assert_eq!(data, "base64data");
445 assert_eq!(mime_type, "image/png");
446 } else {
447 panic!("Expected image content");
448 }
449 }
450
451 #[test]
452 fn test_temperature_clamping() {
453 let params = MessageBuilder::new()
454 .user_text("Test")
455 .temperature(5.0) .build_params();
457
458 assert_eq!(params.temperature, Some(2.0));
459
460 let params2 = MessageBuilder::new()
461 .user_text("Test")
462 .temperature(-1.0) .build_params();
464
465 assert_eq!(params2.temperature, Some(0.0));
466 }
467}