Skip to main content

thulp_mcp/
prompts.rs

1//! MCP Prompts support.
2//!
3//! This module provides support for MCP prompts protocol methods:
4//! - `prompts/list` - List available prompts
5//! - `prompts/get` - Get a rendered prompt with arguments
6
7use crate::Result;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use thulp_core::{GetPromptResult, Prompt, PromptListResult, PromptMessage};
11
12/// Type alias for prompt renderer function.
13type PromptRenderer = Box<dyn Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync>;
14
15/// MCP Prompts client for managing and rendering prompts.
16pub struct PromptsClient {
17    /// Cached prompts
18    cache: RwLock<HashMap<String, Prompt>>,
19    /// Prompt renderers (name -> renderer function)
20    renderers: RwLock<HashMap<String, PromptRenderer>>,
21}
22
23impl PromptsClient {
24    /// Create a new prompts client.
25    pub fn new() -> Self {
26        Self {
27            cache: RwLock::new(HashMap::new()),
28            renderers: RwLock::new(HashMap::new()),
29        }
30    }
31
32    /// List all available prompts.
33    ///
34    /// In a full implementation, this would call `prompts/list` on the MCP server.
35    pub async fn list(&self) -> Result<PromptListResult> {
36        let cache = self.cache.read().unwrap();
37        Ok(PromptListResult {
38            prompts: cache.values().cloned().collect(),
39            next_cursor: None,
40        })
41    }
42
43    /// Get a rendered prompt with arguments.
44    ///
45    /// In a full implementation, this would call `prompts/get` on the MCP server.
46    pub async fn get(
47        &self,
48        name: &str,
49        arguments: HashMap<String, String>,
50    ) -> Result<GetPromptResult> {
51        let renderers = self.renderers.read().unwrap();
52
53        if let Some(renderer) = renderers.get(name) {
54            Ok(renderer(&arguments))
55        } else {
56            // Default: return a simple message
57            let prompt = self.cache.read().unwrap().get(name).cloned();
58            let description = prompt.and_then(|p| p.description);
59
60            Ok(GetPromptResult {
61                description,
62                messages: vec![PromptMessage::user_text(format!(
63                    "Prompt '{}' with args: {:?}",
64                    name, arguments
65                ))],
66            })
67        }
68    }
69
70    /// Register a prompt definition.
71    pub fn register(&self, prompt: Prompt) {
72        let mut cache = self.cache.write().unwrap();
73        cache.insert(prompt.name.clone(), prompt);
74    }
75
76    /// Register a prompt with a custom renderer.
77    pub fn register_with_renderer<F>(&self, prompt: Prompt, renderer: F)
78    where
79        F: Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync + 'static,
80    {
81        let name = prompt.name.clone();
82        self.register(prompt);
83        let mut renderers = self.renderers.write().unwrap();
84        renderers.insert(name, Box::new(renderer));
85    }
86
87    /// Get a prompt definition by name.
88    pub fn get_definition(&self, name: &str) -> Option<Prompt> {
89        self.cache.read().unwrap().get(name).cloned()
90    }
91
92    /// Clear all cached prompts.
93    pub fn clear(&self) {
94        self.cache.write().unwrap().clear();
95        self.renderers.write().unwrap().clear();
96    }
97}
98
99impl Default for PromptsClient {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use thulp_core::PromptArgument;
109
110    #[tokio::test]
111    async fn test_prompts_client_creation() {
112        let client = PromptsClient::new();
113        let result = client.list().await.unwrap();
114        assert!(result.prompts.is_empty());
115    }
116
117    #[tokio::test]
118    async fn test_register_prompt() {
119        let client = PromptsClient::new();
120        let prompt = Prompt::new("test_prompt");
121        client.register(prompt);
122
123        let result = client.list().await.unwrap();
124        assert_eq!(result.prompts.len(), 1);
125    }
126
127    #[tokio::test]
128    async fn test_get_prompt() {
129        let client = PromptsClient::new();
130        let prompt = Prompt::builder("greeting")
131            .description("A greeting prompt")
132            .argument(PromptArgument::required("name", "Person to greet"))
133            .build();
134
135        client.register(prompt);
136
137        let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
138        let result = client.get("greeting", args).await.unwrap();
139
140        assert_eq!(result.description, Some("A greeting prompt".to_string()));
141    }
142
143    #[tokio::test]
144    async fn test_get_with_renderer() {
145        let client = PromptsClient::new();
146        let prompt = Prompt::builder("greeting")
147            .description("A greeting prompt")
148            .argument(PromptArgument::required("name", "Person to greet"))
149            .build();
150
151        client.register_with_renderer(prompt, |args| {
152            let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
153            GetPromptResult::new(vec![PromptMessage::user_text(format!("Hello, {}!", name))])
154        });
155
156        let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
157        let result = client.get("greeting", args).await.unwrap();
158
159        assert_eq!(result.messages.len(), 1);
160    }
161
162    #[tokio::test]
163    async fn test_get_definition() {
164        let client = PromptsClient::new();
165        let prompt = Prompt::builder("test").title("Test Prompt").build();
166
167        client.register(prompt);
168
169        let def = client.get_definition("test");
170        assert!(def.is_some());
171        assert_eq!(def.unwrap().title, Some("Test Prompt".to_string()));
172    }
173
174    #[tokio::test]
175    async fn test_clear() {
176        let client = PromptsClient::new();
177        client.register(Prompt::new("test"));
178
179        client.clear();
180
181        assert!(client.list().await.unwrap().prompts.is_empty());
182    }
183}