1use crate::Result;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use thulp_core::{GetPromptResult, Prompt, PromptListResult, PromptMessage};
11
12type PromptRenderer = Box<dyn Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync>;
14
15pub struct PromptsClient {
17 cache: RwLock<HashMap<String, Prompt>>,
19 renderers: RwLock<HashMap<String, PromptRenderer>>,
21}
22
23impl PromptsClient {
24 pub fn new() -> Self {
26 Self {
27 cache: RwLock::new(HashMap::new()),
28 renderers: RwLock::new(HashMap::new()),
29 }
30 }
31
32 pub async fn list(&self) -> Result<PromptListResult> {
36 let cache = self.cache.read().unwrap();
37 Ok(PromptListResult {
38 prompts: cache.values().cloned().collect(),
39 next_cursor: None,
40 })
41 }
42
43 pub async fn get(
47 &self,
48 name: &str,
49 arguments: HashMap<String, String>,
50 ) -> Result<GetPromptResult> {
51 let renderers = self.renderers.read().unwrap();
52
53 if let Some(renderer) = renderers.get(name) {
54 Ok(renderer(&arguments))
55 } else {
56 let prompt = self.cache.read().unwrap().get(name).cloned();
58 let description = prompt.and_then(|p| p.description);
59
60 Ok(GetPromptResult {
61 description,
62 messages: vec![PromptMessage::user_text(format!(
63 "Prompt '{}' with args: {:?}",
64 name, arguments
65 ))],
66 })
67 }
68 }
69
70 pub fn register(&self, prompt: Prompt) {
72 let mut cache = self.cache.write().unwrap();
73 cache.insert(prompt.name.clone(), prompt);
74 }
75
76 pub fn register_with_renderer<F>(&self, prompt: Prompt, renderer: F)
78 where
79 F: Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync + 'static,
80 {
81 let name = prompt.name.clone();
82 self.register(prompt);
83 let mut renderers = self.renderers.write().unwrap();
84 renderers.insert(name, Box::new(renderer));
85 }
86
87 pub fn get_definition(&self, name: &str) -> Option<Prompt> {
89 self.cache.read().unwrap().get(name).cloned()
90 }
91
92 pub fn clear(&self) {
94 self.cache.write().unwrap().clear();
95 self.renderers.write().unwrap().clear();
96 }
97}
98
99impl Default for PromptsClient {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use thulp_core::PromptArgument;
109
110 #[tokio::test]
111 async fn test_prompts_client_creation() {
112 let client = PromptsClient::new();
113 let result = client.list().await.unwrap();
114 assert!(result.prompts.is_empty());
115 }
116
117 #[tokio::test]
118 async fn test_register_prompt() {
119 let client = PromptsClient::new();
120 let prompt = Prompt::new("test_prompt");
121 client.register(prompt);
122
123 let result = client.list().await.unwrap();
124 assert_eq!(result.prompts.len(), 1);
125 }
126
127 #[tokio::test]
128 async fn test_get_prompt() {
129 let client = PromptsClient::new();
130 let prompt = Prompt::builder("greeting")
131 .description("A greeting prompt")
132 .argument(PromptArgument::required("name", "Person to greet"))
133 .build();
134
135 client.register(prompt);
136
137 let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
138 let result = client.get("greeting", args).await.unwrap();
139
140 assert_eq!(result.description, Some("A greeting prompt".to_string()));
141 }
142
143 #[tokio::test]
144 async fn test_get_with_renderer() {
145 let client = PromptsClient::new();
146 let prompt = Prompt::builder("greeting")
147 .description("A greeting prompt")
148 .argument(PromptArgument::required("name", "Person to greet"))
149 .build();
150
151 client.register_with_renderer(prompt, |args| {
152 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
153 GetPromptResult::new(vec![PromptMessage::user_text(format!("Hello, {}!", name))])
154 });
155
156 let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
157 let result = client.get("greeting", args).await.unwrap();
158
159 assert_eq!(result.messages.len(), 1);
160 }
161
162 #[tokio::test]
163 async fn test_get_definition() {
164 let client = PromptsClient::new();
165 let prompt = Prompt::builder("test").title("Test Prompt").build();
166
167 client.register(prompt);
168
169 let def = client.get_definition("test");
170 assert!(def.is_some());
171 assert_eq!(def.unwrap().title, Some("Test Prompt".to_string()));
172 }
173
174 #[tokio::test]
175 async fn test_clear() {
176 let client = PromptsClient::new();
177 client.register(Prompt::new("test"));
178
179 client.clear();
180
181 assert!(client.list().await.unwrap().prompts.is_empty());
182 }
183}