1use crate::Result;
8use std::collections::HashMap;
9use std::sync::RwLock;
10use thulp_core::{GetPromptResult, Prompt, PromptListResult, PromptMessage};
11
12type PromptRenderer = Box<dyn Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync>;
14
15pub struct PromptsClient {
17 cache: RwLock<HashMap<String, Prompt>>,
19 renderers: RwLock<HashMap<String, PromptRenderer>>,
21}
22
23impl PromptsClient {
24 pub fn new() -> Self {
26 Self {
27 cache: RwLock::new(HashMap::new()),
28 renderers: RwLock::new(HashMap::new()),
29 }
30 }
31
32 pub async fn list(&self) -> Result<PromptListResult> {
36 let cache = self.cache.read().unwrap();
37 Ok(PromptListResult {
38 prompts: cache.values().cloned().collect(),
39 next_cursor: None,
40 })
41 }
42
43 pub async fn get(&self, name: &str, arguments: HashMap<String, String>) -> Result<GetPromptResult> {
47 let renderers = self.renderers.read().unwrap();
48
49 if let Some(renderer) = renderers.get(name) {
50 Ok(renderer(&arguments))
51 } else {
52 let prompt = self.cache.read().unwrap().get(name).cloned();
54 let description = prompt.and_then(|p| p.description);
55
56 Ok(GetPromptResult {
57 description,
58 messages: vec![PromptMessage::user_text(format!(
59 "Prompt '{}' with args: {:?}",
60 name, arguments
61 ))],
62 })
63 }
64 }
65
66 pub fn register(&self, prompt: Prompt) {
68 let mut cache = self.cache.write().unwrap();
69 cache.insert(prompt.name.clone(), prompt);
70 }
71
72 pub fn register_with_renderer<F>(&self, prompt: Prompt, renderer: F)
74 where
75 F: Fn(&HashMap<String, String>) -> GetPromptResult + Send + Sync + 'static,
76 {
77 let name = prompt.name.clone();
78 self.register(prompt);
79 let mut renderers = self.renderers.write().unwrap();
80 renderers.insert(name, Box::new(renderer));
81 }
82
83 pub fn get_definition(&self, name: &str) -> Option<Prompt> {
85 self.cache.read().unwrap().get(name).cloned()
86 }
87
88 pub fn clear(&self) {
90 self.cache.write().unwrap().clear();
91 self.renderers.write().unwrap().clear();
92 }
93}
94
95impl Default for PromptsClient {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use thulp_core::PromptArgument;
105
106 #[tokio::test]
107 async fn test_prompts_client_creation() {
108 let client = PromptsClient::new();
109 let result = client.list().await.unwrap();
110 assert!(result.prompts.is_empty());
111 }
112
113 #[tokio::test]
114 async fn test_register_prompt() {
115 let client = PromptsClient::new();
116 let prompt = Prompt::new("test_prompt");
117 client.register(prompt);
118
119 let result = client.list().await.unwrap();
120 assert_eq!(result.prompts.len(), 1);
121 }
122
123 #[tokio::test]
124 async fn test_get_prompt() {
125 let client = PromptsClient::new();
126 let prompt = Prompt::builder("greeting")
127 .description("A greeting prompt")
128 .argument(PromptArgument::required("name", "Person to greet"))
129 .build();
130
131 client.register(prompt);
132
133 let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
134 let result = client.get("greeting", args).await.unwrap();
135
136 assert_eq!(result.description, Some("A greeting prompt".to_string()));
137 }
138
139 #[tokio::test]
140 async fn test_get_with_renderer() {
141 let client = PromptsClient::new();
142 let prompt = Prompt::builder("greeting")
143 .description("A greeting prompt")
144 .argument(PromptArgument::required("name", "Person to greet"))
145 .build();
146
147 client.register_with_renderer(prompt, |args| {
148 let name = args.get("name").map(|s| s.as_str()).unwrap_or("World");
149 GetPromptResult::new(vec![
150 PromptMessage::user_text(format!("Hello, {}!", name)),
151 ])
152 });
153
154 let args = HashMap::from([("name".to_string(), "Alice".to_string())]);
155 let result = client.get("greeting", args).await.unwrap();
156
157 assert_eq!(result.messages.len(), 1);
158 }
159
160 #[tokio::test]
161 async fn test_get_definition() {
162 let client = PromptsClient::new();
163 let prompt = Prompt::builder("test")
164 .title("Test Prompt")
165 .build();
166
167 client.register(prompt);
168
169 let def = client.get_definition("test");
170 assert!(def.is_some());
171 assert_eq!(def.unwrap().title, Some("Test Prompt".to_string()));
172 }
173
174 #[tokio::test]
175 async fn test_clear() {
176 let client = PromptsClient::new();
177 client.register(Prompt::new("test"));
178
179 client.clear();
180
181 assert!(client.list().await.unwrap().prompts.is_empty());
182 }
183}