1use super::model::*;
11use super::providers::{
12 CliInputProvider, CompositeInputProvider, EnvironmentInputProvider, InputProvider,
13};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::{debug, error, info};
18
19pub struct InputHandler {
21 provider: Arc<dyn InputProvider>,
23 cache: Arc<RwLock<HashMap<String, InputValue>>>,
25 enable_cache: bool,
27}
28
29impl InputHandler {
30 pub fn new() -> Self {
32 let provider: Box<dyn InputProvider> = Box::new(
35 CompositeInputProvider::new()
36 .add_provider(Box::new(EnvironmentInputProvider::new()))
37 .add_provider(Box::new(CliInputProvider::new())),
38 );
39
40 Self {
41 provider: Arc::from(provider),
42 cache: Arc::new(RwLock::new(HashMap::new())),
43 enable_cache: true,
44 }
45 }
46
47 pub fn with_provider<P>(provider: P) -> Self
49 where
50 P: InputProvider + 'static,
51 {
52 Self {
53 provider: Arc::new(provider),
54 cache: Arc::new(RwLock::new(HashMap::new())),
55 enable_cache: true,
56 }
57 }
58
59 pub fn with_cache(mut self, enable: bool) -> Self {
61 self.enable_cache = enable;
62 self
63 }
64
65 pub async fn get_input(
67 &self,
68 request: InputRequest,
69 context: InputContext,
70 ) -> InputResult<InputResponse> {
71 debug!("Getting input for: {} (context: {:?})", request.id, context);
72
73 if self.enable_cache {
75 let cache_key = self.build_cache_key(&request.id, &context);
76 if let Some(value) = self.get_cached_value(&cache_key).await {
77 debug!("Using cached value for: {}", request.id);
78 return Ok(InputResponse {
79 id: request.id,
80 value,
81 cancelled: false,
82 });
83 }
84 }
85
86 let mut response = self.provider.get_input(&request, &context).await;
88
89 if response.is_err() && !request.required {
92 if let Some(default) = &request.default {
93 info!("Using default value for: {}", request.id);
94 response = Ok(InputResponse {
95 id: request.id.clone(),
96 value: default.clone(),
97 cancelled: false,
98 });
99 }
100 }
101
102 if self.enable_cache {
104 if let Ok(ref resp) = response {
105 if !resp.cancelled {
106 let cache_key = self.build_cache_key(&request.id, &context);
107 self.cache_value(cache_key, resp.value.clone()).await;
108 }
109 }
110 }
111
112 response
113 }
114
115 pub async fn get_inputs(
117 &self,
118 requests: Vec<InputRequest>,
119 context: InputContext,
120 ) -> InputResult<Vec<InputResponse>> {
121 let mut responses = Vec::new();
122
123 for request in requests {
124 match self.get_input(request, context.clone()).await {
125 Ok(response) => responses.push(response),
126 Err(e) => {
127 error!("Failed to get input: {}", e);
128 return Err(e);
129 }
130 }
131 }
132
133 Ok(responses)
134 }
135
136 pub async fn clear_cache(&self) {
138 self.cache.write().await.clear();
139 debug!("Input cache cleared");
140 }
141
142 pub async fn clear_cache_for(&self, id: &str, context: &InputContext) {
144 let cache_key = self.build_cache_key(id, context);
145 let mut cache = self.cache.write().await;
146 cache.remove(&cache_key);
147 debug!("Cleared cache for: {}", id);
148 }
149
150 fn build_cache_key(&self, id: &str, context: &InputContext) -> String {
152 let mut key = id.to_string();
153
154 if let Some(server) = &context.server_name {
155 key = format!("{}:{}", key, server);
156 }
157
158 if let Some(tool) = &context.tool_name {
159 key = format!("{}:{}", key, tool);
160 }
161
162 if !context.metadata.is_empty() {
164 let mut metadata_pairs: Vec<_> = context.metadata.iter().collect();
165 metadata_pairs.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); for (k, v) in metadata_pairs {
168 key = format!("{}:{}={}", key, k, v);
169 }
170 }
171
172 key
173 }
174
175 async fn get_cached_value(&self, key: &str) -> Option<InputValue> {
177 let cache = self.cache.read().await;
178 cache.get(key).cloned()
179 }
180
181 async fn cache_value(&self, key: String, value: InputValue) {
183 let mut cache = self.cache.write().await;
184 cache.insert(key, value);
185 }
186
187 pub async fn get_all_cached_values(&self) -> HashMap<String, InputValue> {
189 let cache = self.cache.read().await;
190 cache.clone()
191 }
192
193 pub async fn set_cached_value(&self, key: String, value: InputValue) {
195 self.cache_value(key, value).await;
196 }
197
198 pub async fn remove_cached_value(&self, key: &str) -> Option<InputValue> {
200 let mut cache = self.cache.write().await;
201 cache.remove(key)
202 }
203
204 pub async fn clear_all_cache(&self) {
206 self.cache.write().await.clear();
207 }
208
209 pub fn create_request_from_mcp_input(
211 &self,
212 mcp_input: &crate::mcp_clients::model::MCPServerInput,
213 default: Option<InputValue>,
214 ) -> InputRequest {
215 match mcp_input {
216 crate::mcp_clients::model::MCPServerInput::PromptString(input) => InputRequest {
217 id: input.id.clone(),
218 input_type: InputType::String {
219 password: input.password,
220 min_length: None,
221 max_length: None,
222 },
223 title: input.description.clone(),
224 description: input.description.clone(),
225 default,
226 required: true,
227 validation: None,
228 },
229 crate::mcp_clients::model::MCPServerInput::PickString(input) => InputRequest {
230 id: input.id.clone(),
231 input_type: InputType::PickString {
232 options: input.options.clone(),
233 multiple: false,
234 },
235 title: input.description.clone(),
236 description: input.description.clone(),
237 default,
238 required: true,
239 validation: None,
240 },
241 crate::mcp_clients::model::MCPServerInput::Command(input) => InputRequest {
242 id: input.id.clone(),
243 input_type: InputType::Command {
244 command: input.command.clone(),
245 args: input
246 .args
247 .as_ref()
248 .map(|m| {
249 let mut sorted_pairs: Vec<_> = m.iter().collect();
250 sorted_pairs.sort_by_key(|(k, _)| *k);
251 sorted_pairs.into_iter().map(|(_, v)| v.clone()).collect()
252 })
253 .unwrap_or_default(),
254 },
255 title: input.description.clone(),
256 description: input.description.clone(),
257 default,
258 required: true,
259 validation: None,
260 },
261 }
262 }
263
264 pub async fn handle_mcp_inputs(
266 &self,
267 inputs: &[crate::mcp_clients::model::MCPServerInput],
268 context: InputContext,
269 ) -> InputResult<HashMap<String, InputValue>> {
270 let mut results = HashMap::new();
271 let mut requests = Vec::new();
272
273 for input in inputs {
275 let request = self.create_request_from_mcp_input(input, None);
276 requests.push(request);
277 }
278
279 let responses = self.get_inputs(requests, context).await?;
281
282 for response in responses {
284 results.insert(response.id, response.value);
285 }
286
287 Ok(results)
288 }
289}
290
291impl Default for InputHandler {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::mcp_clients::model::*;
301
302 #[tokio::test]
303 async fn test_input_handler_creation() {
304 let handler = InputHandler::new();
305 assert!(handler.enable_cache);
306 }
307
308 #[tokio::test]
309 async fn test_cache_key_generation() {
310 let handler = InputHandler::new();
311 let context = InputContext::new()
312 .with_server_name("test_server".to_string())
313 .with_tool_name("test_tool".to_string());
314
315 let key = handler.build_cache_key("test_input", &context);
316 assert_eq!(key, "test_input:test_server:test_tool");
317 }
318
319 #[tokio::test]
320 async fn test_cache_operations() {
321 let handler = InputHandler::new();
322 let _context = InputContext::new();
323
324 let key = "test_key";
326 let value = InputValue::String("test_value".to_string());
327
328 handler.cache_value(key.to_string(), value.clone()).await;
329 let cached = handler.get_cached_value(key).await;
330
331 assert_eq!(cached, Some(value));
332 }
333
334 #[tokio::test]
335 async fn test_create_request_from_mcp_input() {
336 let handler = InputHandler::new();
337
338 let mcp_input = MCPServerInput::PromptString(PromptStringInput {
339 id: "test_input".to_string(),
340 description: "Test input".to_string(),
341 default: Some("default".to_string()),
342 password: Some(false),
343 });
344
345 let request = handler.create_request_from_mcp_input(&mcp_input, None);
346
347 assert_eq!(request.id, "test_input");
348 assert_eq!(request.title, "Test input");
349 assert_eq!(request.description, "Test input");
350 assert!(request.required);
351
352 match request.input_type {
353 InputType::String { password, .. } => {
354 assert_eq!(password, Some(false));
355 }
356 _ => panic!("Expected string input type"),
357 }
358 }
359}