Skip to main content

smcp_computer/inputs/
providers.rs

1/**
2* 文件名: providers
3* 作者: JQQ
4* 创建日期: 2025/12/15
5* 最后修改日期: 2025/12/15
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, async-trait
8* 描述: 输入提供者实现,支持CLI、环境变量等多种输入方式
9*/
10use super::model::*;
11use async_trait::async_trait;
12use std::env;
13use std::io::{self, Write};
14use std::process::Command;
15use std::time::Duration;
16use tokio::time::timeout;
17use tracing::{debug, warn};
18
19/// 输入提供者trait / Input provider trait
20#[async_trait]
21pub trait InputProvider: Send + Sync {
22    /// 获取输入 / Get input
23    async fn get_input(
24        &self,
25        request: &InputRequest,
26        context: &InputContext,
27    ) -> InputResult<InputResponse>;
28}
29
30/// CLI输入提供者 / CLI input provider
31pub struct CliInputProvider {
32    /// 超时时间 / Timeout duration
33    timeout: Duration,
34}
35
36impl CliInputProvider {
37    /// 创建新的CLI输入提供者 / Create new CLI input provider
38    pub fn new() -> Self {
39        Self {
40            timeout: Duration::from_secs(300), // 5分钟默认超时 / 5 minutes default timeout
41        }
42    }
43
44    /// 设置超时时间 / Set timeout duration
45    pub fn with_timeout(mut self, timeout: Duration) -> Self {
46        self.timeout = timeout;
47        self
48    }
49
50    /// 从标准输入读取字符串 / Read string from stdin
51    async fn read_string(&self, prompt: &str, password: bool) -> InputResult<String> {
52        let future = async {
53            print!("{}", prompt);
54            io::stdout().flush().map_err(InputError::IoError)?;
55
56            if password {
57                // 使用rpassword库读取密码 / Use rpassword library to read password
58                // 这里简化处理,实际应该使用rpassword / Simplified here, should use rpassword in practice
59                let mut input = String::new();
60                io::stdin()
61                    .read_line(&mut input)
62                    .map_err(InputError::IoError)?;
63                Ok(input.trim_end().to_string())
64            } else {
65                let mut input = String::new();
66                io::stdin()
67                    .read_line(&mut input)
68                    .map_err(InputError::IoError)?;
69                Ok(input.trim_end().to_string())
70            }
71        };
72
73        timeout(self.timeout, future)
74            .await
75            .map_err(|_| InputError::Timeout)?
76    }
77
78    /// 从标准输入读取选择 / Read pick from stdin
79    async fn read_pick(&self, prompt: &str, options: &[String]) -> InputResult<String> {
80        println!("{}", prompt);
81        for (i, option) in options.iter().enumerate() {
82            println!("  {}) {}", i + 1, option);
83        }
84
85        loop {
86            let input = self
87                .read_string("请输入选项编号 (Please enter option number): ", false)
88                .await?;
89
90            match input.parse::<usize>() {
91                Ok(n) if n >= 1 && n <= options.len() => {
92                    return Ok(options[n - 1].clone());
93                }
94                _ => {
95                    println!("无效选项,请重新输入 (Invalid option, please try again)");
96                }
97            }
98        }
99    }
100
101    /// 从标准输入读取数字 / Read number from stdin
102    async fn read_number(&self, prompt: &str) -> InputResult<i64> {
103        loop {
104            let input = self.read_string(prompt, false).await?;
105
106            match input.parse::<i64>() {
107                Ok(n) => return Ok(n),
108                _ => {
109                    println!("无效数字,请重新输入 (Invalid number, please try again)");
110                }
111            }
112        }
113    }
114
115    /// 从标准输入读取布尔值 / Read boolean from stdin
116    async fn read_bool(
117        &self,
118        prompt: &str,
119        true_label: Option<&str>,
120        false_label: Option<&str>,
121    ) -> InputResult<bool> {
122        let true_label = true_label.unwrap_or("是/yes");
123        let false_label = false_label.unwrap_or("否/no");
124
125        loop {
126            let input = self
127                .read_string(
128                    &format!("{} ({}/{}): ", prompt, true_label, false_label),
129                    false,
130                )
131                .await?;
132            let input = input.to_lowercase();
133
134            if input == "y" || input == "yes" || input == "是" {
135                return Ok(true);
136            } else if input == "n" || input == "no" || input == "否" {
137                return Ok(false);
138            } else {
139                println!("无效选项,请重新输入 (Invalid option, please try again)");
140            }
141        }
142    }
143
144    /// 验证输入 / Validate input
145    fn validate_input(&self, value: &str, validation: &Option<ValidationRule>) -> InputResult<()> {
146        if let Some(rule) = validation {
147            match rule {
148                ValidationRule::Regex { pattern, message } => {
149                    let regex = regex::Regex::new(pattern).map_err(|e| {
150                        InputError::ValidationFailed(format!("Invalid regex pattern: {}", e))
151                    })?;
152
153                    if !regex.is_match(value) {
154                        let msg = message
155                            .as_deref()
156                            .unwrap_or("输入格式不正确 (Input format is incorrect)");
157                        return Err(InputError::ValidationFailed(msg.to_string()));
158                    }
159                }
160                ValidationRule::Custom { .. } => {
161                    // 自定义验证需要在更高层实现 / Custom validation needs to be implemented at higher level
162                    warn!("Custom validation not implemented for CLI provider");
163                }
164            }
165        }
166        Ok(())
167    }
168}
169
170#[async_trait]
171impl InputProvider for CliInputProvider {
172    async fn get_input(
173        &self,
174        request: &InputRequest,
175        _context: &InputContext,
176    ) -> InputResult<InputResponse> {
177        let prompt = format!("{}: {}", request.title, request.description);
178
179        let value = match &request.input_type {
180            InputType::String {
181                password,
182                min_length,
183                max_length,
184            } => {
185                let input = self.read_string(&prompt, password.unwrap_or(false)).await?;
186
187                // 验证长度 / Validate length
188                if let Some(min) = min_length {
189                    if input.len() < *min {
190                        return Err(InputError::ValidationFailed(format!(
191                            "输入长度不能少于{}个字符 (Minimum length is {})",
192                            min, min
193                        )));
194                    }
195                }
196                if let Some(max) = max_length {
197                    if input.len() > *max {
198                        return Err(InputError::ValidationFailed(format!(
199                            "输入长度不能超过{}个字符 (Maximum length is {})",
200                            max, max
201                        )));
202                    }
203                }
204
205                // 验证格式 / Validate format
206                self.validate_input(&input, &request.validation)?;
207
208                InputValue::String(input)
209            }
210            InputType::PickString { options, .. } => {
211                let selected = self.read_pick(&prompt, options).await?;
212                InputValue::String(selected)
213            }
214            InputType::Number { min, max } => {
215                let num = self.read_number(&prompt).await?;
216
217                // 验证范围 / Validate range
218                if let Some(min_val) = min {
219                    if num < *min_val {
220                        return Err(InputError::ValidationFailed(format!(
221                            "数值不能小于{} (Minimum value is {})",
222                            min_val, min_val
223                        )));
224                    }
225                }
226                if let Some(max_val) = max {
227                    if num > *max_val {
228                        return Err(InputError::ValidationFailed(format!(
229                            "数值不能大于{} (Maximum value is {})",
230                            max_val, max_val
231                        )));
232                    }
233                }
234
235                InputValue::Number(num)
236            }
237            InputType::Bool {
238                true_label,
239                false_label,
240            } => {
241                let bool_val = self
242                    .read_bool(&prompt, true_label.as_deref(), false_label.as_deref())
243                    .await?;
244                InputValue::Bool(bool_val)
245            }
246            InputType::FilePath { must_exist, filter } => {
247                let path = self.read_string(&prompt, false).await?;
248
249                // 检查文件是否存在 / Check if file exists
250                if *must_exist && !std::path::Path::new(&path).exists() {
251                    return Err(InputError::ValidationFailed(
252                        "文件不存在 (File does not exist)".to_string(),
253                    ));
254                }
255
256                // 检查文件类型 / Check file type
257                if let Some(filter) = filter {
258                    if !path.ends_with(filter) {
259                        return Err(InputError::ValidationFailed(format!(
260                            "文件类型不匹配,期望: {} (File type mismatch, expected: {})",
261                            filter, filter
262                        )));
263                    }
264                }
265
266                InputValue::String(path)
267            }
268            InputType::Command { command, args } => {
269                debug!("Executing command: {} {:?}", command, args);
270                let output = if cfg!(target_os = "windows") {
271                    // Windows: Use cmd /C for shell mode
272                    let mut cmd = Command::new("cmd");
273                    cmd.arg("/C");
274                    cmd.arg(command);
275                    for arg in args {
276                        cmd.arg(arg);
277                    }
278                    cmd.output()
279                } else {
280                    // Unix: Use sh -c for shell mode
281                    let mut cmd = Command::new("sh");
282                    cmd.arg("-c");
283                    // Combine command and args into a single shell string
284                    let shell_cmd = if args.is_empty() {
285                        command.clone()
286                    } else {
287                        format!("{} {}", command, args.join(" "))
288                    };
289                    cmd.arg(&shell_cmd);
290                    cmd.output()
291                }
292                .map_err(|e| InputError::Other(format!("Command execution failed: {}", e)))?;
293
294                if !output.status.success() {
295                    let stderr = String::from_utf8_lossy(&output.stderr);
296                    return Err(InputError::Other(format!(
297                        "Command failed with exit code {}: {}",
298                        output.status.code().unwrap_or(-1),
299                        stderr
300                    )));
301                }
302
303                let result = String::from_utf8_lossy(&output.stdout).trim().to_string();
304                InputValue::String(result)
305            }
306        };
307
308        Ok(InputResponse {
309            id: request.id.clone(),
310            value,
311            cancelled: false,
312        })
313    }
314}
315
316impl Default for CliInputProvider {
317    fn default() -> Self {
318        Self::new()
319    }
320}
321
322/// 环境变量输入提供者 / Environment variable input provider
323pub struct EnvironmentInputProvider {
324    /// 前缀 / Prefix
325    prefix: String,
326}
327
328impl EnvironmentInputProvider {
329    /// 创建新的环境变量输入提供者 / Create new environment input provider
330    pub fn new() -> Self {
331        Self {
332            prefix: "A2C_SMCP_".to_string(),
333        }
334    }
335
336    /// 设置前缀 / Set prefix
337    pub fn with_prefix(mut self, prefix: String) -> Self {
338        self.prefix = prefix;
339        self
340    }
341
342    /// 构建环境变量名 / Build environment variable name
343    fn build_env_name(&self, id: &str, context: &InputContext) -> String {
344        let mut name = format!("{}{}", self.prefix, id.to_uppercase());
345
346        if let Some(server) = &context.server_name {
347            name = format!("{}_{}", name, server.to_uppercase());
348        }
349
350        if let Some(tool) = &context.tool_name {
351            name = format!("{}_{}", name, tool.to_uppercase());
352        }
353
354        name
355    }
356}
357
358#[async_trait]
359impl InputProvider for EnvironmentInputProvider {
360    async fn get_input(
361        &self,
362        request: &InputRequest,
363        context: &InputContext,
364    ) -> InputResult<InputResponse> {
365        let env_name = self.build_env_name(&request.id, context);
366
367        debug!("Looking for environment variable: {}", env_name);
368
369        match env::var(&env_name) {
370            Ok(value) => {
371                // 根据输入类型转换值 / Convert value based on input type
372                let converted_value = match &request.input_type {
373                    InputType::String { .. } => InputValue::String(value),
374                    InputType::PickString { .. } => InputValue::String(value),
375                    InputType::FilePath { .. } => InputValue::String(value),
376                    InputType::Command { .. } => InputValue::String(value),
377                    InputType::Number { .. } => {
378                        value.parse::<i64>().map(InputValue::Number).map_err(|_| {
379                            InputError::ValidationFailed(format!(
380                                "Invalid number in environment variable: {}",
381                                env_name
382                            ))
383                        })?
384                    }
385                    InputType::Bool { .. } => {
386                        let lower = value.to_lowercase();
387                        if lower == "true" || lower == "1" || lower == "yes" || lower == "是" {
388                            InputValue::Bool(true)
389                        } else if lower == "false" || lower == "0" || lower == "no" || lower == "否"
390                        {
391                            InputValue::Bool(false)
392                        } else {
393                            return Err(InputError::ValidationFailed(format!(
394                                "Invalid boolean value in environment variable: {}",
395                                env_name
396                            )));
397                        }
398                    }
399                };
400
401                Ok(InputResponse {
402                    id: request.id.clone(),
403                    value: converted_value,
404                    cancelled: false,
405                })
406            }
407            Err(env::VarError::NotPresent) => {
408                // 如果环境变量不存在,返回默认值或错误
409                // If environment variable doesn't exist, return default value or error
410                if let Some(default) = &request.default {
411                    Ok(InputResponse {
412                        id: request.id.clone(),
413                        value: default.clone(),
414                        cancelled: false,
415                    })
416                } else if request.required {
417                    Err(InputError::ValidationFailed(format!(
418                        "Required environment variable not found: {}",
419                        env_name
420                    )))
421                } else {
422                    Err(InputError::Cancelled)
423                }
424            }
425            Err(e) => Err(InputError::Other(format!(
426                "Environment variable error: {}",
427                e
428            ))),
429        }
430    }
431}
432
433impl Default for EnvironmentInputProvider {
434    fn default() -> Self {
435        Self::new()
436    }
437}
438
439/// 组合输入提供者 / Composite input provider
440pub struct CompositeInputProvider {
441    /// 提供者列表 / Provider list
442    providers: Vec<Box<dyn InputProvider>>,
443}
444
445impl CompositeInputProvider {
446    /// 创建新的组合输入提供者 / Create new composite input provider
447    pub fn new() -> Self {
448        Self {
449            providers: Vec::new(),
450        }
451    }
452
453    /// 添加提供者 / Add provider
454    pub fn add_provider(mut self, provider: Box<dyn InputProvider>) -> Self {
455        self.providers.push(provider);
456        self
457    }
458}
459
460#[async_trait]
461impl InputProvider for CompositeInputProvider {
462    async fn get_input(
463        &self,
464        request: &InputRequest,
465        context: &InputContext,
466    ) -> InputResult<InputResponse> {
467        // 依次尝试每个提供者 / Try each provider in order
468        for provider in &self.providers {
469            match provider.get_input(request, context).await {
470                Ok(response) => return Ok(response),
471                Err(InputError::Cancelled) => {
472                    // 取消错误继续尝试下一个提供者
473                    // Continue trying next provider for cancelled error
474                    continue;
475                }
476                Err(e) => {
477                    // 其他错误直接返回 / Return other errors directly
478                    return Err(e);
479                }
480            }
481        }
482
483        // 所有提供者都失败 / All providers failed
484        Err(InputError::Cancelled)
485    }
486}
487
488impl Default for CompositeInputProvider {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497
498    #[tokio::test]
499    async fn test_cli_provider_creation() {
500        let provider = CliInputProvider::new();
501        assert_eq!(provider.timeout.as_secs(), 300);
502    }
503
504    #[tokio::test]
505    async fn test_environment_provider_creation() {
506        let provider = EnvironmentInputProvider::new();
507        assert_eq!(provider.prefix, "A2C_SMCP_");
508    }
509
510    #[tokio::test]
511    async fn test_environment_provider_custom_prefix() {
512        let provider = EnvironmentInputProvider::new().with_prefix("CUSTOM_".to_string());
513        assert_eq!(provider.prefix, "CUSTOM_");
514    }
515}