Skip to main content

vtcode_config/core/
custom_provider.rs

1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5fn default_auth_timeout_ms() -> u64 {
6    5_000
7}
8
9fn default_auth_refresh_interval_ms() -> u64 {
10    300_000
11}
12
13/// Command-backed bearer token configuration for a custom provider.
14#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
16pub struct CustomProviderCommandAuthConfig {
17    /// Command to execute. Bare names are resolved via `PATH`.
18    pub command: String,
19
20    /// Optional command arguments.
21    #[serde(default)]
22    pub args: Vec<String>,
23
24    /// Optional working directory for the token command.
25    #[serde(default)]
26    pub cwd: Option<PathBuf>,
27
28    /// Maximum time to wait for the command to complete successfully.
29    #[serde(default = "default_auth_timeout_ms")]
30    pub timeout_ms: u64,
31
32    /// Maximum age for the cached token before rerunning the command.
33    #[serde(default = "default_auth_refresh_interval_ms")]
34    pub refresh_interval_ms: u64,
35}
36
37impl CustomProviderCommandAuthConfig {
38    fn validate(&self, provider_name: &str) -> Result<(), String> {
39        if self.command.trim().is_empty() {
40            return Err(format!(
41                "custom_providers[{provider_name}]: `auth.command` must not be empty"
42            ));
43        }
44
45        if self.timeout_ms == 0 {
46            return Err(format!(
47                "custom_providers[{provider_name}]: `auth.timeout_ms` must be greater than 0"
48            ));
49        }
50
51        if self.refresh_interval_ms == 0 {
52            return Err(format!(
53                "custom_providers[{provider_name}]: `auth.refresh_interval_ms` must be greater than 0"
54            ));
55        }
56
57        Ok(())
58    }
59}
60
61/// Configuration for a user-defined OpenAI-compatible provider endpoint.
62///
63/// Allows users to define multiple named custom endpoints (e.g., corporate
64/// proxies) with distinct display names, so they can toggle between them
65/// and clearly see which endpoint is active.
66#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
67#[derive(Debug, Clone, Deserialize, Serialize)]
68pub struct CustomProviderConfig {
69    /// Stable provider key used for routing and persistence (e.g., "mycorp").
70    /// Must be lowercase alphanumeric with optional hyphens/underscores.
71    pub name: String,
72
73    /// Human-friendly label shown in the TUI header, footer, and model picker
74    /// (e.g., "MyCorporateName").
75    pub display_name: String,
76
77    /// Base URL of the OpenAI-compatible API endpoint
78    /// (e.g., `<https://llm.corp.example/v1>`).
79    pub base_url: String,
80
81    /// Environment variable name that holds the API key for this endpoint
82    /// (e.g., "MYCORP_API_KEY").
83    #[serde(default)]
84    pub api_key_env: String,
85
86    /// Optional command-backed bearer token configuration.
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub auth: Option<CustomProviderCommandAuthConfig>,
89
90    /// Default model to use with this endpoint (e.g., "gpt-5-mini").
91    #[serde(default)]
92    pub model: String,
93}
94
95impl CustomProviderConfig {
96    /// Resolve the API key environment variable used for this provider.
97    ///
98    /// Falls back to a derived `NAME_API_KEY`-style variable when the config
99    /// does not set `api_key_env`.
100    pub fn resolved_api_key_env(&self) -> String {
101        if !self.api_key_env.trim().is_empty() {
102            return self.api_key_env.clone();
103        }
104
105        let mut key = String::new();
106        for ch in self.name.chars() {
107            if ch.is_ascii_alphanumeric() {
108                key.push(ch.to_ascii_uppercase());
109            } else if !key.ends_with('_') {
110                key.push('_');
111            }
112        }
113        if !key.ends_with("_API_KEY") {
114            if !key.ends_with('_') {
115                key.push('_');
116            }
117            key.push_str("API_KEY");
118        }
119        key
120    }
121
122    pub fn uses_command_auth(&self) -> bool {
123        self.auth.is_some()
124    }
125
126    /// Validate that required fields are present and the name doesn't collide
127    /// with built-in provider keys.
128    pub fn validate(&self) -> Result<(), String> {
129        if self.name.trim().is_empty() {
130            return Err("custom_providers: `name` must not be empty".to_string());
131        }
132
133        if !is_valid_provider_name(&self.name) {
134            return Err(format!(
135                "custom_providers[{}]: `name` must use lowercase letters, digits, hyphens, or underscores",
136                self.name
137            ));
138        }
139
140        if self.display_name.trim().is_empty() {
141            return Err(format!(
142                "custom_providers[{}]: `display_name` must not be empty",
143                self.name
144            ));
145        }
146
147        if self.base_url.trim().is_empty() {
148            return Err(format!(
149                "custom_providers[{}]: `base_url` must not be empty",
150                self.name
151            ));
152        }
153
154        if let Some(auth) = &self.auth {
155            auth.validate(&self.name)?;
156            if !self.api_key_env.trim().is_empty() {
157                return Err(format!(
158                    "custom_providers[{}]: `auth` cannot be combined with `api_key_env`",
159                    self.name
160                ));
161            }
162        }
163
164        let reserved = [
165            "openai",
166            "anthropic",
167            "gemini",
168            "copilot",
169            "deepseek",
170            "openrouter",
171            "ollama",
172            "lmstudio",
173            "moonshot",
174            "zai",
175            "minimax",
176            "huggingface",
177            "openresponses",
178        ];
179        let lower = self.name.to_lowercase();
180        if reserved.contains(&lower.as_str()) {
181            return Err(format!(
182                "custom_providers[{}]: name collides with built-in provider",
183                self.name
184            ));
185        }
186
187        Ok(())
188    }
189}
190
191fn is_valid_provider_name(name: &str) -> bool {
192    let bytes = name.as_bytes();
193    let Some(first) = bytes.first() else {
194        return false;
195    };
196    let Some(last) = bytes.last() else {
197        return false;
198    };
199
200    let is_valid_char = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_');
201    let is_alphanumeric = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9');
202
203    is_alphanumeric(*first) && is_alphanumeric(*last) && bytes.iter().copied().all(is_valid_char)
204}
205
206#[cfg(test)]
207mod tests {
208    use std::path::PathBuf;
209
210    use super::{
211        CustomProviderCommandAuthConfig, CustomProviderConfig, default_auth_refresh_interval_ms,
212        default_auth_timeout_ms,
213    };
214
215    #[test]
216    fn validate_accepts_lowercase_provider_name() {
217        let config = CustomProviderConfig {
218            name: "mycorp".to_string(),
219            display_name: "MyCorp".to_string(),
220            base_url: "https://llm.example/v1".to_string(),
221            api_key_env: String::new(),
222            auth: None,
223            model: "gpt-5-mini".to_string(),
224        };
225
226        assert!(config.validate().is_ok());
227        assert_eq!(config.resolved_api_key_env(), "MYCORP_API_KEY");
228    }
229
230    #[test]
231    fn validate_rejects_invalid_provider_name() {
232        let config = CustomProviderConfig {
233            name: "My Corp".to_string(),
234            display_name: "My Corp".to_string(),
235            base_url: "https://llm.example/v1".to_string(),
236            api_key_env: String::new(),
237            auth: None,
238            model: "gpt-5-mini".to_string(),
239        };
240
241        let err = config.validate().expect_err("invalid name should fail");
242        assert!(err.contains("must use lowercase letters, digits, hyphens, or underscores"));
243    }
244
245    #[test]
246    fn validate_rejects_auth_and_api_key_env_together() {
247        let config = CustomProviderConfig {
248            name: "mycorp".to_string(),
249            display_name: "MyCorp".to_string(),
250            base_url: "https://llm.example/v1".to_string(),
251            api_key_env: "MYCORP_API_KEY".to_string(),
252            auth: Some(CustomProviderCommandAuthConfig {
253                command: "print-token".to_string(),
254                args: Vec::new(),
255                cwd: None,
256                timeout_ms: default_auth_timeout_ms(),
257                refresh_interval_ms: default_auth_refresh_interval_ms(),
258            }),
259            model: "gpt-5-mini".to_string(),
260        };
261
262        let err = config.validate().expect_err("conflicting auth should fail");
263        assert!(err.contains("`auth` cannot be combined with `api_key_env`"));
264    }
265
266    #[test]
267    fn validate_accepts_command_auth_without_static_env_key() {
268        let config = CustomProviderConfig {
269            name: "mycorp".to_string(),
270            display_name: "MyCorp".to_string(),
271            base_url: "https://llm.example/v1".to_string(),
272            api_key_env: String::new(),
273            auth: Some(CustomProviderCommandAuthConfig {
274                command: "print-token".to_string(),
275                args: vec!["--json".to_string()],
276                cwd: Some(PathBuf::from("/tmp")),
277                timeout_ms: 1_000,
278                refresh_interval_ms: 60_000,
279            }),
280            model: "gpt-5-mini".to_string(),
281        };
282
283        assert!(config.validate().is_ok());
284        assert!(config.uses_command_auth());
285    }
286}