thulp_mcp/
prompts.rs

1//! MCP Prompts support.
2//!
3//! This module provides support for MCP prompts protocol methods:
4//! - `prompts/list` - List available prompts
5//! - `prompts/get` - Get a rendered prompt with arguments
6
7use crate::Result;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use thulp_core::{GetPromptResult, Prompt, PromptListResult, PromptMessage};
11
12/// Type alias for prompt renderer function.
13type PromptRenderer = Box<dyn Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync>;
14
15/// MCP Prompts client for managing and rendering prompts.
16pub struct PromptsClient {
17    /// Cached prompts
18    cache: RwLock<HashMap<String, Prompt>>,
19    /// Prompt renderers (name -> renderer function)
20    renderers: RwLock<HashMap<String, PromptRenderer>>,
21}
22
23impl PromptsClient {
24    /// Create a new prompts client.
25    pub fn new() -> Self {
26        Self {
27            cache: RwLock::new(HashMap::new()),
28            renderers: RwLock::new(HashMap::new()),
29        }
30    }
31
32    /// List all available prompts.
33    ///
34    /// In a full implementation, this would call `prompts/list` on the MCP server.
35    pub async fn list(&self) -> Result<PromptListResult> {
36        let cache = self.cache.read().unwrap();
37        Ok(PromptListResult {
38            prompts: cache.values().cloned().collect(),
39            next_cursor: None,
40        })
41    }
42
43    /// Get a rendered prompt with arguments.
44    ///
45    /// In a full implementation, this would call `prompts/get` on the MCP server.
46    pub async fn get(&self, name: &str, arguments: HashMap<String, String>) -> Result<GetPromptResult> {
47        let renderers = self.renderers.read().unwrap();
48        
49        if let Some(renderer) = renderers.get(name) {
50            Ok(renderer(&arguments))
51        } else {
52            // Default: return a simple message
53            let prompt = self.cache.read().unwrap().get(name).cloned();
54            let description = prompt.and_then(|p| p.description);
55            
56            Ok(GetPromptResult {
57                description,
58                messages: vec![PromptMessage::user_text(format!(
59                    "Prompt '{}' with args: {:?}",
60                    name, arguments
61                ))],
62            })
63        }
64    }
65
66    /// Register a prompt definition.
67    pub fn register(&self, prompt: Prompt) {
68        let mut cache = self.cache.write().unwrap();
69        cache.insert(prompt.name.clone(), prompt);
70    }
71
72    /// Register a prompt with a custom renderer.
73    pub fn register_with_renderer<F>(&self, prompt: Prompt, renderer: F)
74    where
75        F: Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync + 'static,
76    {
77        let name = prompt.name.clone();
78        self.register(prompt);
79        let mut renderers = self.renderers.write().unwrap();
80        renderers.insert(name, Box::new(renderer));
81    }
82
83    /// Get a prompt definition by name.
84    pub fn get_definition(&self, name: &str) -> Option<Prompt> {
85        self.cache.read().unwrap().get(name).cloned()
86    }
87
88    /// Clear all cached prompts.
89    pub fn clear(&self) {
90        self.cache.write().unwrap().clear();
91        self.renderers.write().unwrap().clear();
92    }
93}
94
95impl Default for PromptsClient {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use thulp_core::PromptArgument;
105
106    #[tokio::test]
107    async fn test_prompts_client_creation() {
108        let client = PromptsClient::new();
109        let result = client.list().await.unwrap();
110        assert!(result.prompts.is_empty());
111    }
112
113    #[tokio::test]
114    async fn test_register_prompt() {
115        let client = PromptsClient::new();
116        let prompt = Prompt::new("test_prompt");
117        client.register(prompt);
118
119        let result = client.list().await.unwrap();
120        assert_eq!(result.prompts.len(), 1);
121    }
122
123    #[tokio::test]
124    async fn test_get_prompt() {
125        let client = PromptsClient::new();
126        let prompt = Prompt::builder("greeting")
127            .description("A greeting prompt")
128            .argument(PromptArgument::required("name", "Person to greet"))
129            .build();
130        
131        client.register(prompt);
132
133        let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
134        let result = client.get("greeting", args).await.unwrap();
135        
136        assert_eq!(result.description, Some("A greeting prompt".to_string()));
137    }
138
139    #[tokio::test]
140    async fn test_get_with_renderer() {
141        let client = PromptsClient::new();
142        let prompt = Prompt::builder("greeting")
143            .description("A greeting prompt")
144            .argument(PromptArgument::required("name", "Person to greet"))
145            .build();
146
147        client.register_with_renderer(prompt, |args| {
148            let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
149            GetPromptResult::new(vec![
150                PromptMessage::user_text(format!("Hello, {}!", name)),
151            ])
152        });
153
154        let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
155        let result = client.get("greeting", args).await.unwrap();
156        
157        assert_eq!(result.messages.len(), 1);
158    }
159
160    #[tokio::test]
161    async fn test_get_definition() {
162        let client = PromptsClient::new();
163        let prompt = Prompt::builder("test")
164            .title("Test Prompt")
165            .build();
166        
167        client.register(prompt);
168
169        let def = client.get_definition("test");
170        assert!(def.is_some());
171        assert_eq!(def.unwrap().title, Some("Test Prompt".to_string()));
172    }
173
174    #[tokio::test]
175    async fn test_clear() {
176        let client = PromptsClient::new();
177        client.register(Prompt::new("test"));
178        
179        client.clear();
180        
181        assert!(client.list().await.unwrap().prompts.is_empty());
182    }
183}