1use serde_json::Value;
8use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11
12use turul_mcp_protocol::prompts::{
14 ContentBlock, GetPromptResult, HasPromptAnnotations, HasPromptArguments, HasPromptDescription,
15 HasPromptMeta, HasPromptMetadata, PromptArgument, PromptMessage,
16};
17
18pub type DynamicPromptFn = Box<
20 dyn Fn(
21 HashMap<String, String>,
22 ) -> Pin<Box<dyn Future<Output = Result<GetPromptResult, String>> + Send>>
23 + Send
24 + Sync,
25>;
26
27pub struct PromptBuilder {
29 name: String,
30 title: Option<String>,
31 description: Option<String>,
32 arguments: Vec<PromptArgument>,
33 messages: Vec<PromptMessage>,
34 meta: Option<HashMap<String, Value>>,
35 get_fn: Option<DynamicPromptFn>,
36}
37
38impl PromptBuilder {
39 pub fn new(name: impl Into<String>) -> Self {
41 Self {
42 name: name.into(),
43 title: None,
44 description: None,
45 arguments: Vec::new(),
46 messages: Vec::new(),
47 meta: None,
48 get_fn: None,
49 }
50 }
51
52 pub fn title(mut self, title: impl Into<String>) -> Self {
54 self.title = Some(title.into());
55 self
56 }
57
58 pub fn description(mut self, description: impl Into<String>) -> Self {
60 self.description = Some(description.into());
61 self
62 }
63
64 pub fn argument(mut self, argument: PromptArgument) -> Self {
66 self.arguments.push(argument);
67 self
68 }
69
70 pub fn string_argument(
72 mut self,
73 name: impl Into<String>,
74 description: impl Into<String>,
75 ) -> Self {
76 let arg = PromptArgument::new(name)
77 .with_description(description)
78 .required();
79 self.arguments.push(arg);
80 self
81 }
82
83 pub fn optional_string_argument(
85 mut self,
86 name: impl Into<String>,
87 description: impl Into<String>,
88 ) -> Self {
89 let arg = PromptArgument::new(name)
90 .with_description(description)
91 .optional();
92 self.arguments.push(arg);
93 self
94 }
95
96 pub fn message(mut self, message: PromptMessage) -> Self {
98 self.messages.push(message);
99 self
100 }
101
102 pub fn system_message(mut self, text: impl Into<String>) -> Self {
104 self.messages
106 .push(PromptMessage::user_text(format!("System: {}", text.into())));
107 self
108 }
109
110 pub fn user_message(mut self, text: impl Into<String>) -> Self {
112 self.messages.push(PromptMessage::user_text(text));
113 self
114 }
115
116 pub fn assistant_message(mut self, text: impl Into<String>) -> Self {
118 self.messages.push(PromptMessage::assistant_text(text));
119 self
120 }
121
122 pub fn user_image(mut self, data: impl Into<String>, mime_type: impl Into<String>) -> Self {
124 self.messages
125 .push(PromptMessage::user_image(data, mime_type));
126 self
127 }
128
129 pub fn template_user_message(mut self, template: impl Into<String>) -> Self {
131 self.messages.push(PromptMessage::user_text(template));
132 self
133 }
134
135 pub fn template_assistant_message(mut self, template: impl Into<String>) -> Self {
137 self.messages.push(PromptMessage::assistant_text(template));
138 self
139 }
140
141 pub fn meta(mut self, meta: HashMap<String, Value>) -> Self {
143 self.meta = Some(meta);
144 self
145 }
146
147 pub fn get<F, Fut>(mut self, f: F) -> Self
149 where
150 F: Fn(HashMap<String, String>) -> Fut + Send + Sync + 'static,
151 Fut: Future<Output = Result<GetPromptResult, String>> + Send + 'static,
152 {
153 self.get_fn = Some(Box::new(move |args| Box::pin(f(args))));
154 self
155 }
156
157 pub fn build(self) -> Result<DynamicPrompt, String> {
159 let get_fn = if let Some(f) = self.get_fn {
161 f
162 } else {
163 let messages = self.messages.clone();
165 let description = self.description.clone();
166 Box::new(move |args| {
167 let messages = messages.clone();
168 let description = description.clone();
169 Box::pin(async move {
170 let processed_messages = process_template_messages(messages, &args)?;
171 let mut result = GetPromptResult::new(processed_messages);
172 if let Some(desc) = description {
173 result = result.with_description(desc);
174 }
175 Ok(result)
176 })
177 as Pin<Box<dyn Future<Output = Result<GetPromptResult, String>> + Send>>
178 })
179 };
180
181 Ok(DynamicPrompt {
182 name: self.name,
183 title: self.title,
184 description: self.description,
185 arguments: self.arguments,
186 messages: self.messages,
187 meta: self.meta,
188 get_fn,
189 })
190 }
191}
192
193pub struct DynamicPrompt {
195 name: String,
196 title: Option<String>,
197 description: Option<String>,
198 arguments: Vec<PromptArgument>,
199 #[allow(dead_code)]
200 messages: Vec<PromptMessage>,
201 meta: Option<HashMap<String, Value>>,
202 get_fn: DynamicPromptFn,
203}
204
205impl DynamicPrompt {
206 pub async fn get(&self, args: HashMap<String, String>) -> Result<GetPromptResult, String> {
208 (self.get_fn)(args).await
209 }
210}
211
212impl HasPromptMetadata for DynamicPrompt {
214 fn name(&self) -> &str {
215 &self.name
216 }
217
218 fn title(&self) -> Option<&str> {
219 self.title.as_deref()
220 }
221}
222
223impl HasPromptDescription for DynamicPrompt {
224 fn description(&self) -> Option<&str> {
225 self.description.as_deref()
226 }
227}
228
229impl HasPromptArguments for DynamicPrompt {
230 fn arguments(&self) -> Option<&Vec<PromptArgument>> {
231 if self.arguments.is_empty() {
232 None
233 } else {
234 Some(&self.arguments)
235 }
236 }
237}
238
239impl HasPromptAnnotations for DynamicPrompt {
240 fn annotations(&self) -> Option<&turul_mcp_protocol::prompts::PromptAnnotations> {
241 None }
243}
244
245impl HasPromptMeta for DynamicPrompt {
246 fn prompt_meta(&self) -> Option<&HashMap<String, Value>> {
247 self.meta.as_ref()
248 }
249}
250
251fn process_template_messages(
255 messages: Vec<PromptMessage>,
256 args: &HashMap<String, String>,
257) -> Result<Vec<PromptMessage>, String> {
258 let mut processed = Vec::new();
259
260 for message in messages {
261 let processed_message = match message.content {
262 ContentBlock::Text { text, .. } => {
263 let processed_text = process_template_string(&text, args);
264 PromptMessage {
265 role: message.role,
266 content: ContentBlock::text(processed_text),
267 }
268 }
269 other_content => PromptMessage {
271 role: message.role,
272 content: other_content,
273 },
274 };
275 processed.push(processed_message);
276 }
277
278 Ok(processed)
279}
280
281fn process_template_string(template: &str, args: &HashMap<String, String>) -> String {
283 let mut result = template.to_string();
284
285 for (key, value) in args {
286 let placeholder = format!("{{{}}}", key);
287 result = result.replace(&placeholder, value);
288 }
289
290 result
291}
292
293#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_prompt_builder_basic() {
302 let prompt = PromptBuilder::new("greeting_prompt")
303 .title("Greeting Generator")
304 .description("Generate personalized greetings")
305 .string_argument("name", "The person's name")
306 .user_message("Hello {name}! How are you today?")
307 .build()
308 .expect("Failed to build prompt");
309
310 assert_eq!(prompt.name(), "greeting_prompt");
311 assert_eq!(prompt.title(), Some("Greeting Generator"));
312 assert_eq!(
313 prompt.description(),
314 Some("Generate personalized greetings")
315 );
316 assert_eq!(prompt.arguments().unwrap().len(), 1);
317 }
318
319 #[tokio::test]
320 async fn test_prompt_builder_template_processing() {
321 let prompt = PromptBuilder::new("conversation_starter")
322 .description("Start a conversation with someone")
323 .string_argument("name", "Person's name")
324 .optional_string_argument("topic", "Optional conversation topic")
325 .user_message("Hi {name}! Nice to meet you.")
326 .template_assistant_message("Hello! What would you like to talk about?")
327 .user_message("Let's discuss {topic}")
328 .build()
329 .expect("Failed to build prompt");
330
331 let mut args = HashMap::new();
332 args.insert("name".to_string(), "Alice".to_string());
333 args.insert("topic".to_string(), "Rust programming".to_string());
334
335 let result = prompt.get(args).await.expect("Failed to get prompt");
336
337 assert_eq!(result.messages.len(), 3);
338
339 if let ContentBlock::Text { text, .. } = &result.messages[0].content {
341 assert_eq!(text, "Hi Alice! Nice to meet you.");
342 } else {
343 panic!("Expected text content");
344 }
345
346 if let ContentBlock::Text { text, .. } = &result.messages[2].content {
347 assert_eq!(text, "Let's discuss Rust programming");
348 } else {
349 panic!("Expected text content");
350 }
351 }
352
353 #[tokio::test]
354 async fn test_prompt_builder_custom_get_function() {
355 let prompt = PromptBuilder::new("dynamic_prompt")
356 .description("Dynamic prompt with custom logic")
357 .string_argument("mood", "Current mood")
358 .get(|args| async move {
359 let default_mood = "neutral".to_string();
360 let mood = args.get("mood").unwrap_or(&default_mood);
361 let message_text = match mood.as_str() {
362 "happy" => "That's wonderful! Tell me more about what's making you happy.",
363 "sad" => "I'm sorry to hear that. Would you like to talk about it?",
364 _ => "How are you feeling today?",
365 };
366
367 let messages = vec![
368 PromptMessage::user_text(format!("I'm feeling {}", mood)),
369 PromptMessage::assistant_text(message_text),
370 ];
371
372 Ok(GetPromptResult::new(messages).with_description("Mood-based conversation"))
373 })
374 .build()
375 .expect("Failed to build prompt");
376
377 let mut args = HashMap::new();
378 args.insert("mood".to_string(), "happy".to_string());
379
380 let result = prompt.get(args).await.expect("Failed to get prompt");
381
382 assert_eq!(result.messages.len(), 2);
383 assert_eq!(
384 result.description,
385 Some("Mood-based conversation".to_string())
386 );
387
388 if let ContentBlock::Text { text, .. } = &result.messages[1].content {
389 assert!(text.contains("wonderful"));
390 } else {
391 panic!("Expected text content");
392 }
393 }
394
395 #[test]
396 fn test_prompt_builder_arguments() {
397 let prompt = PromptBuilder::new("complex_prompt")
398 .string_argument("required_arg", "This is required")
399 .optional_string_argument("optional_arg", "This is optional")
400 .argument(
401 PromptArgument::new("custom_arg")
402 .with_title("Custom Argument")
403 .with_description("A custom argument")
404 .required(),
405 )
406 .build()
407 .expect("Failed to build prompt");
408
409 let args = prompt.arguments().unwrap();
410 assert_eq!(args.len(), 3);
411 assert_eq!(args[0].name, "required_arg");
412 assert_eq!(args[0].required, Some(true));
413 assert_eq!(args[1].name, "optional_arg");
414 assert_eq!(args[1].required, Some(false));
415 assert_eq!(args[2].name, "custom_arg");
416 assert_eq!(args[2].title, Some("Custom Argument".to_string()));
417 }
418
419 #[test]
420 fn test_template_string_processing() {
421 let template = "Hello {name}, welcome to {place}!";
422 let mut args = HashMap::new();
423 args.insert("name".to_string(), "Alice".to_string());
424 args.insert("place".to_string(), "Wonderland".to_string());
425
426 let result = process_template_string(template, &args);
427 assert_eq!(result, "Hello Alice, welcome to Wonderland!");
428 }
429}