Skip to main content

vtcode_core/mcp/
enhanced_config.rs

1//! Enhanced MCP Configuration
2//!
3//! This module provides enhanced configuration options for MCP with
4//! improved validation, security features, and better error handling.
5
6use serde::{Deserialize, Serialize};
7
8// Import canonical types from vtcode-config instead of defining locally
9pub use vtcode_config::mcp::{McpRateLimitConfig, McpValidationConfig};
10
11use tracing::{debug, warn};
12
13/// Enhanced security configuration for MCP
14#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize)]
16pub struct EnhancedMcpSecurityConfig {
17    /// Enable authentication for MCP server
18    #[serde(default = "default_auth_enabled")]
19    pub auth_enabled: bool,
20
21    /// API key environment variable name
22    #[serde(default)]
23    pub api_key_env: Option<String>,
24
25    /// Rate limiting configuration
26    #[serde(default)]
27    pub rate_limit: McpRateLimitConfig,
28
29    /// Tool call validation configuration
30    #[serde(default)]
31    pub validation: McpValidationConfig,
32}
33
34impl Default for EnhancedMcpSecurityConfig {
35    fn default() -> Self {
36        Self {
37            auth_enabled: default_auth_enabled(),
38            api_key_env: None,
39            rate_limit: McpRateLimitConfig::default(),
40            validation: McpValidationConfig::default(),
41        }
42    }
43}
44
45// Default functions
46fn default_auth_enabled() -> bool {
47    false
48}
49
50/// Enhanced MCP client configuration with validation
51#[derive(Debug, Clone)]
52pub struct ValidatedMcpClientConfig {
53    /// Original configuration
54    pub original: crate::config::mcp::McpClientConfig,
55    /// Enhanced security configuration
56    pub security: EnhancedMcpSecurityConfig,
57}
58
59impl ValidatedMcpClientConfig {
60    /// Create a new validated configuration from the original
61    pub fn new(original: crate::config::mcp::McpClientConfig) -> Self {
62        let security = EnhancedMcpSecurityConfig::default();
63        Self { original, security }
64    }
65
66    /// Validate the configuration and return any issues found
67    pub fn validate(&self) -> Vec<ValidationError> {
68        let mut errors = Vec::new();
69
70        // Validate server configuration if enabled
71        if self.original.server.enabled {
72            // Validate port range
73            if self.original.server.port == 0 {
74                errors.push(ValidationError::InvalidPort(
75                    self.original.server.port.into(),
76                ));
77            }
78
79            // Validate bind address
80            if self.original.server.bind_address.is_empty() {
81                errors.push(ValidationError::EmptyBindAddress);
82            }
83
84            // Validate security settings if auth is enabled
85            if self.security.auth_enabled && self.security.api_key_env.is_none() {
86                errors.push(ValidationError::MissingApiKeyEnv);
87            }
88        }
89
90        // Validate timeouts
91        if let Some(startup_timeout) = self.original.startup_timeout_seconds
92            && startup_timeout > 300
93        {
94            // Max 5 minutes
95            errors.push(ValidationError::InvalidStartupTimeout(startup_timeout));
96        }
97
98        if let Some(tool_timeout) = self.original.tool_timeout_seconds
99            && tool_timeout > 3600
100        {
101            // Max 1 hour
102            errors.push(ValidationError::InvalidToolTimeout(tool_timeout));
103        }
104
105        // Validate provider configurations
106        for provider in &self.original.providers {
107            if provider.name.is_empty() {
108                errors.push(ValidationError::EmptyProviderName);
109            }
110
111            // Validate max_concurrent_requests
112            if provider.max_concurrent_requests == 0 {
113                errors.push(ValidationError::InvalidMaxConcurrentRequests(
114                    provider.name.clone(),
115                    provider.max_concurrent_requests,
116                ));
117            }
118        }
119
120        errors
121    }
122
123    /// Check if the configuration is valid
124    pub fn is_valid(&self) -> bool {
125        self.validate().is_empty()
126    }
127
128    /// Log any validation warnings
129    pub fn log_warnings(&self) {
130        let errors = self.validate();
131        if !errors.is_empty() {
132            warn!("MCP configuration validation issues found:");
133            for error in errors {
134                warn!("  - {}", error);
135            }
136        } else {
137            debug!("MCP configuration validation passed");
138        }
139    }
140}
141
142/// Validation error types
143#[derive(Debug, Clone, thiserror::Error)]
144pub enum ValidationError {
145    #[error("Invalid server port: {0}")]
146    InvalidPort(u64),
147    #[error("Server bind address cannot be empty")]
148    EmptyBindAddress,
149    #[error("API key environment variable must be set when auth is enabled")]
150    MissingApiKeyEnv,
151    #[error("Startup timeout cannot exceed 300 seconds: {0}")]
152    InvalidStartupTimeout(u64),
153    #[error("Tool timeout cannot exceed 3600 seconds: {0}")]
154    InvalidToolTimeout(u64),
155    #[error("MCP provider name cannot be empty")]
156    EmptyProviderName,
157    #[error("Max concurrent requests must be greater than 0 for provider '{0}': {1}")]
158    InvalidMaxConcurrentRequests(String, usize),
159}
160
161/// Enhanced tool configuration
162#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
163#[derive(Debug, Clone, Deserialize, Serialize)]
164pub struct EnhancedMcpToolConfig {
165    /// Name of the tool to expose
166    pub name: String,
167    /// Whether the tool is enabled
168    #[serde(default = "default_tool_enabled")]
169    pub enabled: bool,
170    /// Optional description override
171    pub description: Option<String>,
172    /// Rate limiting for this specific tool
173    #[serde(default)]
174    pub rate_limit: Option<McpRateLimitConfig>,
175    /// Validation rules specific to this tool
176    #[serde(default)]
177    pub validation: Option<McpValidationConfig>,
178}
179
180fn default_tool_enabled() -> bool {
181    true
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::config::mcp::{
188        McpClientConfig, McpProviderConfig, McpServerConfig, McpStdioServerConfig,
189        McpTransportConfig,
190    };
191    use hashbrown::HashMap;
192
193    fn create_test_config() -> McpClientConfig {
194        McpClientConfig {
195            enabled: true,
196            ui: Default::default(),
197            providers: vec![McpProviderConfig {
198                name: "test_provider".to_owned(),
199                transport: McpTransportConfig::Stdio(McpStdioServerConfig {
200                    command: "test_command".to_owned(),
201                    args: vec![],
202                    working_directory: None,
203                }),
204                env: HashMap::new(),
205                enabled: true,
206                max_concurrent_requests: 5,
207                startup_timeout_ms: None,
208            }],
209            server: McpServerConfig {
210                enabled: true,
211                bind_address: "127.0.0.1".to_owned(),
212                port: 3000,
213                transport: crate::config::mcp::McpServerTransport::Sse,
214                name: "test_server".to_owned(),
215                version: "1.0.0".to_owned(),
216                exposed_tools: vec![],
217            },
218            allowlist: Default::default(),
219            requirements: Default::default(),
220            max_concurrent_connections: 10,
221            request_timeout_seconds: 30,
222            retry_attempts: 3,
223            startup_timeout_seconds: Some(60),
224            tool_timeout_seconds: Some(300),
225            experimental_use_rmcp_client: false,
226            connection_pooling_enabled: true,
227            tool_cache_capacity: 128,
228            connection_timeout_seconds: 30,
229            security: Default::default(),
230            lifecycle: Default::default(),
231        }
232    }
233
234    #[test]
235    fn test_validated_config_creation() {
236        let original = create_test_config();
237        let validated = ValidatedMcpClientConfig::new(original);
238        assert!(validated.is_valid());
239    }
240
241    #[test]
242    fn test_invalid_port_validation() {
243        let mut original = create_test_config();
244        original.server.port = 65535; // Max valid port
245        let validated = ValidatedMcpClientConfig::new(original);
246        assert!(validated.is_valid());
247    }
248
249    #[test]
250    fn test_empty_bind_address_validation() {
251        let mut original = create_test_config();
252        original.server.bind_address = String::new(); // Empty bind address
253        let validated = ValidatedMcpClientConfig::new(original);
254        assert!(!validated.is_valid());
255    }
256
257    #[test]
258    fn test_timeout_validation() {
259        let mut original = create_test_config();
260        original.startup_timeout_seconds = Some(400); // Too long
261        let validated = ValidatedMcpClientConfig::new(original);
262        assert!(!validated.is_valid());
263    }
264
265    #[test]
266    fn test_zero_concurrent_requests_validation() {
267        let mut original = create_test_config();
268        original.providers[0].max_concurrent_requests = 0; // Invalid
269        let validated = ValidatedMcpClientConfig::new(original);
270        assert!(!validated.is_valid());
271    }
272}