1use super::model::*;
11use async_trait::async_trait;
12use std::env;
13use std::io::{self, Write};
14use std::process::Command;
15use std::time::Duration;
16use tokio::time::timeout;
17use tracing::{debug, warn};
18
19#[async_trait]
21pub trait InputProvider: Send + Sync {
22 async fn get_input(
24 &self,
25 request: &InputRequest,
26 context: &InputContext,
27 ) -> InputResult<InputResponse>;
28}
29
30pub struct CliInputProvider {
32 timeout: Duration,
34}
35
36impl CliInputProvider {
37 pub fn new() -> Self {
39 Self {
40 timeout: Duration::from_secs(300), }
42 }
43
44 pub fn with_timeout(mut self, timeout: Duration) -> Self {
46 self.timeout = timeout;
47 self
48 }
49
50 async fn read_string(&self, prompt: &str, password: bool) -> InputResult<String> {
52 let future = async {
53 print!("{}", prompt);
54 io::stdout().flush().map_err(InputError::IoError)?;
55
56 if password {
57 let mut input = String::new();
60 io::stdin()
61 .read_line(&mut input)
62 .map_err(InputError::IoError)?;
63 Ok(input.trim_end().to_string())
64 } else {
65 let mut input = String::new();
66 io::stdin()
67 .read_line(&mut input)
68 .map_err(InputError::IoError)?;
69 Ok(input.trim_end().to_string())
70 }
71 };
72
73 timeout(self.timeout, future)
74 .await
75 .map_err(|_| InputError::Timeout)?
76 }
77
78 async fn read_pick(&self, prompt: &str, options: &[String]) -> InputResult<String> {
80 println!("{}", prompt);
81 for (i, option) in options.iter().enumerate() {
82 println!(" {}) {}", i + 1, option);
83 }
84
85 loop {
86 let input = self
87 .read_string("请输入选项编号 (Please enter option number): ", false)
88 .await?;
89
90 match input.parse::<usize>() {
91 Ok(n) if n >= 1 && n <= options.len() => {
92 return Ok(options[n - 1].clone());
93 }
94 _ => {
95 println!("无效选项,请重新输入 (Invalid option, please try again)");
96 }
97 }
98 }
99 }
100
101 async fn read_number(&self, prompt: &str) -> InputResult<i64> {
103 loop {
104 let input = self.read_string(prompt, false).await?;
105
106 match input.parse::<i64>() {
107 Ok(n) => return Ok(n),
108 _ => {
109 println!("无效数字,请重新输入 (Invalid number, please try again)");
110 }
111 }
112 }
113 }
114
115 async fn read_bool(
117 &self,
118 prompt: &str,
119 true_label: Option<&str>,
120 false_label: Option<&str>,
121 ) -> InputResult<bool> {
122 let true_label = true_label.unwrap_or("是/yes");
123 let false_label = false_label.unwrap_or("否/no");
124
125 loop {
126 let input = self
127 .read_string(
128 &format!("{} ({}/{}): ", prompt, true_label, false_label),
129 false,
130 )
131 .await?;
132 let input = input.to_lowercase();
133
134 if input == "y" || input == "yes" || input == "是" {
135 return Ok(true);
136 } else if input == "n" || input == "no" || input == "否" {
137 return Ok(false);
138 } else {
139 println!("无效选项,请重新输入 (Invalid option, please try again)");
140 }
141 }
142 }
143
144 fn validate_input(&self, value: &str, validation: &Option<ValidationRule>) -> InputResult<()> {
146 if let Some(rule) = validation {
147 match rule {
148 ValidationRule::Regex { pattern, message } => {
149 let regex = regex::Regex::new(pattern).map_err(|e| {
150 InputError::ValidationFailed(format!("Invalid regex pattern: {}", e))
151 })?;
152
153 if !regex.is_match(value) {
154 let msg = message
155 .as_deref()
156 .unwrap_or("输入格式不正确 (Input format is incorrect)");
157 return Err(InputError::ValidationFailed(msg.to_string()));
158 }
159 }
160 ValidationRule::Custom { .. } => {
161 warn!("Custom validation not implemented for CLI provider");
163 }
164 }
165 }
166 Ok(())
167 }
168}
169
170#[async_trait]
171impl InputProvider for CliInputProvider {
172 async fn get_input(
173 &self,
174 request: &InputRequest,
175 _context: &InputContext,
176 ) -> InputResult<InputResponse> {
177 let prompt = format!("{}: {}", request.title, request.description);
178
179 let value = match &request.input_type {
180 InputType::String {
181 password,
182 min_length,
183 max_length,
184 } => {
185 let input = self.read_string(&prompt, password.unwrap_or(false)).await?;
186
187 if let Some(min) = min_length {
189 if input.len() < *min {
190 return Err(InputError::ValidationFailed(format!(
191 "输入长度不能少于{}个字符 (Minimum length is {})",
192 min, min
193 )));
194 }
195 }
196 if let Some(max) = max_length {
197 if input.len() > *max {
198 return Err(InputError::ValidationFailed(format!(
199 "输入长度不能超过{}个字符 (Maximum length is {})",
200 max, max
201 )));
202 }
203 }
204
205 self.validate_input(&input, &request.validation)?;
207
208 InputValue::String(input)
209 }
210 InputType::PickString { options, .. } => {
211 let selected = self.read_pick(&prompt, options).await?;
212 InputValue::String(selected)
213 }
214 InputType::Number { min, max } => {
215 let num = self.read_number(&prompt).await?;
216
217 if let Some(min_val) = min {
219 if num < *min_val {
220 return Err(InputError::ValidationFailed(format!(
221 "数值不能小于{} (Minimum value is {})",
222 min_val, min_val
223 )));
224 }
225 }
226 if let Some(max_val) = max {
227 if num > *max_val {
228 return Err(InputError::ValidationFailed(format!(
229 "数值不能大于{} (Maximum value is {})",
230 max_val, max_val
231 )));
232 }
233 }
234
235 InputValue::Number(num)
236 }
237 InputType::Bool {
238 true_label,
239 false_label,
240 } => {
241 let bool_val = self
242 .read_bool(&prompt, true_label.as_deref(), false_label.as_deref())
243 .await?;
244 InputValue::Bool(bool_val)
245 }
246 InputType::FilePath { must_exist, filter } => {
247 let path = self.read_string(&prompt, false).await?;
248
249 if *must_exist && !std::path::Path::new(&path).exists() {
251 return Err(InputError::ValidationFailed(
252 "文件不存在 (File does not exist)".to_string(),
253 ));
254 }
255
256 if let Some(filter) = filter {
258 if !path.ends_with(filter) {
259 return Err(InputError::ValidationFailed(format!(
260 "文件类型不匹配,期望: {} (File type mismatch, expected: {})",
261 filter, filter
262 )));
263 }
264 }
265
266 InputValue::String(path)
267 }
268 InputType::Command { command, args } => {
269 debug!("Executing command: {} {:?}", command, args);
270 let output = if cfg!(target_os = "windows") {
271 let mut cmd = Command::new("cmd");
273 cmd.arg("/C");
274 cmd.arg(command);
275 for arg in args {
276 cmd.arg(arg);
277 }
278 cmd.output()
279 } else {
280 let mut cmd = Command::new("sh");
282 cmd.arg("-c");
283 let shell_cmd = if args.is_empty() {
285 command.clone()
286 } else {
287 format!("{} {}", command, args.join(" "))
288 };
289 cmd.arg(&shell_cmd);
290 cmd.output()
291 }
292 .map_err(|e| InputError::Other(format!("Command execution failed: {}", e)))?;
293
294 if !output.status.success() {
295 let stderr = String::from_utf8_lossy(&output.stderr);
296 return Err(InputError::Other(format!(
297 "Command failed with exit code {}: {}",
298 output.status.code().unwrap_or(-1),
299 stderr
300 )));
301 }
302
303 let result = String::from_utf8_lossy(&output.stdout).trim().to_string();
304 InputValue::String(result)
305 }
306 };
307
308 Ok(InputResponse {
309 id: request.id.clone(),
310 value,
311 cancelled: false,
312 })
313 }
314}
315
316impl Default for CliInputProvider {
317 fn default() -> Self {
318 Self::new()
319 }
320}
321
322pub struct EnvironmentInputProvider {
324 prefix: String,
326}
327
328impl EnvironmentInputProvider {
329 pub fn new() -> Self {
331 Self {
332 prefix: "A2C_SMCP_".to_string(),
333 }
334 }
335
336 pub fn with_prefix(mut self, prefix: String) -> Self {
338 self.prefix = prefix;
339 self
340 }
341
342 fn build_env_name(&self, id: &str, context: &InputContext) -> String {
344 let mut name = format!("{}{}", self.prefix, id.to_uppercase());
345
346 if let Some(server) = &context.server_name {
347 name = format!("{}_{}", name, server.to_uppercase());
348 }
349
350 if let Some(tool) = &context.tool_name {
351 name = format!("{}_{}", name, tool.to_uppercase());
352 }
353
354 name
355 }
356}
357
358#[async_trait]
359impl InputProvider for EnvironmentInputProvider {
360 async fn get_input(
361 &self,
362 request: &InputRequest,
363 context: &InputContext,
364 ) -> InputResult<InputResponse> {
365 let env_name = self.build_env_name(&request.id, context);
366
367 debug!("Looking for environment variable: {}", env_name);
368
369 match env::var(&env_name) {
370 Ok(value) => {
371 let converted_value = match &request.input_type {
373 InputType::String { .. } => InputValue::String(value),
374 InputType::PickString { .. } => InputValue::String(value),
375 InputType::FilePath { .. } => InputValue::String(value),
376 InputType::Command { .. } => InputValue::String(value),
377 InputType::Number { .. } => {
378 value.parse::<i64>().map(InputValue::Number).map_err(|_| {
379 InputError::ValidationFailed(format!(
380 "Invalid number in environment variable: {}",
381 env_name
382 ))
383 })?
384 }
385 InputType::Bool { .. } => {
386 let lower = value.to_lowercase();
387 if lower == "true" || lower == "1" || lower == "yes" || lower == "是" {
388 InputValue::Bool(true)
389 } else if lower == "false" || lower == "0" || lower == "no" || lower == "否"
390 {
391 InputValue::Bool(false)
392 } else {
393 return Err(InputError::ValidationFailed(format!(
394 "Invalid boolean value in environment variable: {}",
395 env_name
396 )));
397 }
398 }
399 };
400
401 Ok(InputResponse {
402 id: request.id.clone(),
403 value: converted_value,
404 cancelled: false,
405 })
406 }
407 Err(env::VarError::NotPresent) => {
408 if let Some(default) = &request.default {
411 Ok(InputResponse {
412 id: request.id.clone(),
413 value: default.clone(),
414 cancelled: false,
415 })
416 } else if request.required {
417 Err(InputError::ValidationFailed(format!(
418 "Required environment variable not found: {}",
419 env_name
420 )))
421 } else {
422 Err(InputError::Cancelled)
423 }
424 }
425 Err(e) => Err(InputError::Other(format!(
426 "Environment variable error: {}",
427 e
428 ))),
429 }
430 }
431}
432
433impl Default for EnvironmentInputProvider {
434 fn default() -> Self {
435 Self::new()
436 }
437}
438
439pub struct CompositeInputProvider {
441 providers: Vec<Box<dyn InputProvider>>,
443}
444
445impl CompositeInputProvider {
446 pub fn new() -> Self {
448 Self {
449 providers: Vec::new(),
450 }
451 }
452
453 pub fn add_provider(mut self, provider: Box<dyn InputProvider>) -> Self {
455 self.providers.push(provider);
456 self
457 }
458}
459
460#[async_trait]
461impl InputProvider for CompositeInputProvider {
462 async fn get_input(
463 &self,
464 request: &InputRequest,
465 context: &InputContext,
466 ) -> InputResult<InputResponse> {
467 for provider in &self.providers {
469 match provider.get_input(request, context).await {
470 Ok(response) => return Ok(response),
471 Err(InputError::Cancelled) => {
472 continue;
475 }
476 Err(e) => {
477 return Err(e);
479 }
480 }
481 }
482
483 Err(InputError::Cancelled)
485 }
486}
487
488impl Default for CompositeInputProvider {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[tokio::test]
499 async fn test_cli_provider_creation() {
500 let provider = CliInputProvider::new();
501 assert_eq!(provider.timeout.as_secs(), 300);
502 }
503
504 #[tokio::test]
505 async fn test_environment_provider_creation() {
506 let provider = EnvironmentInputProvider::new();
507 assert_eq!(provider.prefix, "A2C_SMCP_");
508 }
509
510 #[tokio::test]
511 async fn test_environment_provider_custom_prefix() {
512 let provider = EnvironmentInputProvider::new().with_prefix("CUSTOM_".to_string());
513 assert_eq!(provider.prefix, "CUSTOM_");
514 }
515}