Skip to main content

vtcode_config/core/
custom_provider.rs

1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5fn default_auth_timeout_ms() -> u64 {
6    5_000
7}
8
9fn default_auth_refresh_interval_ms() -> u64 {
10    300_000
11}
12
13/// Command-backed bearer token configuration for a custom provider.
14#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
16pub struct CustomProviderCommandAuthConfig {
17    /// Command to execute. Bare names are resolved via `PATH`.
18    pub command: String,
19
20    /// Optional command arguments.
21    #[serde(default)]
22    pub args: Vec<String>,
23
24    /// Optional working directory for the token command.
25    #[serde(default)]
26    pub cwd: Option<PathBuf>,
27
28    /// Maximum time to wait for the command to complete successfully.
29    #[serde(default = "default_auth_timeout_ms")]
30    pub timeout_ms: u64,
31
32    /// Maximum age for the cached token before rerunning the command.
33    #[serde(default = "default_auth_refresh_interval_ms")]
34    pub refresh_interval_ms: u64,
35}
36
37impl CustomProviderCommandAuthConfig {
38    fn validate(&self, provider_name: &str) -> Result<(), String> {
39        if self.command.trim().is_empty() {
40            return Err(format!(
41                "custom_providers[{provider_name}]: `auth.command` must not be empty"
42            ));
43        }
44
45        if self.timeout_ms == 0 {
46            return Err(format!(
47                "custom_providers[{provider_name}]: `auth.timeout_ms` must be greater than 0"
48            ));
49        }
50
51        if self.refresh_interval_ms == 0 {
52            return Err(format!(
53                "custom_providers[{provider_name}]: `auth.refresh_interval_ms` must be greater than 0"
54            ));
55        }
56
57        Ok(())
58    }
59}
60
61/// Configuration for a user-defined OpenAI-compatible provider endpoint.
62///
63/// Allows users to define multiple named custom endpoints (e.g., corporate
64/// proxies) with distinct display names, so they can toggle between them
65/// and clearly see which endpoint is active.
66#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
67#[derive(Debug, Clone, Default, Deserialize, Serialize)]
68pub struct CustomProviderConfig {
69    /// Stable provider key used for routing and persistence (e.g., "mycorp").
70    /// Must be lowercase alphanumeric with optional hyphens/underscores.
71    pub name: String,
72
73    /// Human-friendly label shown in the TUI header, footer, and model picker
74    /// (e.g., "MyCorporateName").
75    pub display_name: String,
76
77    /// Base URL of the OpenAI-compatible API endpoint
78    /// (e.g., `<https://llm.corp.example/v1>`).
79    pub base_url: String,
80
81    /// Environment variable name that holds the API key for this endpoint
82    /// (e.g., "MYCORP_API_KEY").
83    #[serde(default)]
84    pub api_key_env: String,
85
86    /// Optional command-backed bearer token configuration.
87    #[serde(default, skip_serializing_if = "Option::is_none")]
88    pub auth: Option<CustomProviderCommandAuthConfig>,
89
90    /// Default model to use with this endpoint (e.g., "gpt-5-mini").
91    ///
92    /// When [`models`](Self::models) is empty, this single model is what the
93    /// `/model` picker offers for this provider. When [`models`](Self::models)
94    /// is non-empty, this field is used as the default selection but the
95    /// picker lists every entry in [`models`](Self::models).
96    #[serde(default)]
97    pub model: String,
98
99    /// Optional list of additional model identifiers offered by the provider.
100    ///
101    /// Useful for OpenAI-compatible aggregators such as Atlas Cloud that
102    /// expose many models behind a single endpoint. When set, the `/model`
103    /// picker shows one entry per model. When empty, the picker falls back to
104    /// the single [`model`](Self::model) field.
105    #[serde(default)]
106    pub models: Vec<String>,
107}
108
109impl CustomProviderConfig {
110    /// Resolve the API key environment variable used for this provider.
111    ///
112    /// Falls back to a derived `NAME_API_KEY`-style variable when the config
113    /// does not set `api_key_env`.
114    pub fn resolved_api_key_env(&self) -> String {
115        if !self.api_key_env.trim().is_empty() {
116            return self.api_key_env.clone();
117        }
118
119        let mut key = String::new();
120        for ch in self.name.chars() {
121            if ch.is_ascii_alphanumeric() {
122                key.push(ch.to_ascii_uppercase());
123            } else if !key.ends_with('_') {
124                key.push('_');
125            }
126        }
127        if !key.ends_with("_API_KEY") {
128            if !key.ends_with('_') {
129                key.push('_');
130            }
131            key.push_str("API_KEY");
132        }
133        key
134    }
135
136    pub fn uses_command_auth(&self) -> bool {
137        self.auth.is_some()
138    }
139
140    /// Return the list of models the `/model` picker should offer for this
141    /// provider.
142    ///
143    /// If `models` is non-empty, every entry is returned (trimmed). Otherwise
144    /// the single `model` field is returned as a one-element list. An empty
145    /// `model` field with no `models` list yields an empty result.
146    pub fn effective_models(&self) -> Vec<String> {
147        if !self.models.is_empty() {
148            return self
149                .models
150                .iter()
151                .map(|m| m.trim().to_string())
152                .filter(|m| !m.is_empty())
153                .collect();
154        }
155        let trimmed = self.model.trim();
156        if trimmed.is_empty() {
157            Vec::new()
158        } else {
159            vec![trimmed.to_string()]
160        }
161    }
162
163    /// Validate that required fields are present and the name doesn't collide
164    /// with built-in provider keys.
165    pub fn validate(&self) -> Result<(), String> {
166        if self.name.trim().is_empty() {
167            return Err("custom_providers: `name` must not be empty".to_string());
168        }
169
170        if !is_valid_provider_name(&self.name) {
171            return Err(format!(
172                "custom_providers[{}]: `name` must use lowercase letters, digits, hyphens, or underscores",
173                self.name
174            ));
175        }
176
177        if self.display_name.trim().is_empty() {
178            return Err(format!(
179                "custom_providers[{}]: `display_name` must not be empty",
180                self.name
181            ));
182        }
183
184        if self.base_url.trim().is_empty() {
185            return Err(format!(
186                "custom_providers[{}]: `base_url` must not be empty",
187                self.name
188            ));
189        }
190
191        if let Some(auth) = &self.auth {
192            auth.validate(&self.name)?;
193            if !self.api_key_env.trim().is_empty() {
194                return Err(format!(
195                    "custom_providers[{}]: `auth` cannot be combined with `api_key_env`",
196                    self.name
197                ));
198            }
199        }
200
201        if self.models.iter().any(|m| m.trim().is_empty()) {
202            return Err(format!(
203                "custom_providers[{}]: `models` entries must not be empty",
204                self.name
205            ));
206        }
207
208        let reserved = [
209            "openai",
210            "anthropic",
211            "gemini",
212            "copilot",
213            "deepseek",
214            "openrouter",
215            "ollama",
216            "lmstudio",
217            "moonshot",
218            "zai",
219            "minimax",
220            "huggingface",
221            "openresponses",
222        ];
223        let lower = self.name.to_lowercase();
224        if reserved.contains(&lower.as_str()) {
225            return Err(format!(
226                "custom_providers[{}]: name collides with built-in provider",
227                self.name
228            ));
229        }
230
231        Ok(())
232    }
233}
234
235fn is_valid_provider_name(name: &str) -> bool {
236    let bytes = name.as_bytes();
237    let Some(first) = bytes.first() else {
238        return false;
239    };
240    let Some(last) = bytes.last() else {
241        return false;
242    };
243
244    let is_valid_char = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_');
245    let is_alphanumeric = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9');
246
247    is_alphanumeric(*first) && is_alphanumeric(*last) && bytes.iter().copied().all(is_valid_char)
248}
249
250#[cfg(test)]
251mod tests {
252    use std::path::PathBuf;
253
254    use super::{
255        CustomProviderCommandAuthConfig, CustomProviderConfig, default_auth_refresh_interval_ms,
256        default_auth_timeout_ms,
257    };
258
259    #[test]
260    fn validate_accepts_lowercase_provider_name() {
261        let config = CustomProviderConfig {
262            name: "mycorp".to_string(),
263            display_name: "MyCorp".to_string(),
264            base_url: "https://llm.example/v1".to_string(),
265            api_key_env: String::new(),
266            auth: None,
267            model: "gpt-5-mini".to_string(),
268            models: Vec::new(),
269        };
270
271        assert!(config.validate().is_ok());
272        assert_eq!(config.resolved_api_key_env(), "MYCORP_API_KEY");
273    }
274
275    #[test]
276    fn validate_rejects_invalid_provider_name() {
277        let config = CustomProviderConfig {
278            name: "My Corp".to_string(),
279            display_name: "My Corp".to_string(),
280            base_url: "https://llm.example/v1".to_string(),
281            api_key_env: String::new(),
282            auth: None,
283            model: "gpt-5-mini".to_string(),
284            models: Vec::new(),
285        };
286
287        let err = config.validate().expect_err("invalid name should fail");
288        assert!(err.contains("must use lowercase letters, digits, hyphens, or underscores"));
289    }
290
291    #[test]
292    fn validate_rejects_auth_and_api_key_env_together() {
293        let config = CustomProviderConfig {
294            name: "mycorp".to_string(),
295            display_name: "MyCorp".to_string(),
296            base_url: "https://llm.example/v1".to_string(),
297            api_key_env: "MYCORP_API_KEY".to_string(),
298            auth: Some(CustomProviderCommandAuthConfig {
299                command: "print-token".to_string(),
300                args: Vec::new(),
301                cwd: None,
302                timeout_ms: default_auth_timeout_ms(),
303                refresh_interval_ms: default_auth_refresh_interval_ms(),
304            }),
305            model: "gpt-5-mini".to_string(),
306            models: Vec::new(),
307        };
308
309        let err = config.validate().expect_err("conflicting auth should fail");
310        assert!(err.contains("`auth` cannot be combined with `api_key_env`"));
311    }
312
313    #[test]
314    fn validate_accepts_command_auth_without_static_env_key() {
315        let config = CustomProviderConfig {
316            name: "mycorp".to_string(),
317            display_name: "MyCorp".to_string(),
318            base_url: "https://llm.example/v1".to_string(),
319            api_key_env: String::new(),
320            auth: Some(CustomProviderCommandAuthConfig {
321                command: "print-token".to_string(),
322                args: vec!["--json".to_string()],
323                cwd: Some(PathBuf::from("/tmp")),
324                timeout_ms: 1_000,
325                refresh_interval_ms: 60_000,
326            }),
327            model: "gpt-5-mini".to_string(),
328            models: Vec::new(),
329        };
330
331        assert!(config.validate().is_ok());
332        assert!(config.uses_command_auth());
333    }
334
335    #[test]
336    fn validate_rejects_empty_model_entry_in_models_list() {
337        let config = CustomProviderConfig {
338            name: "mycorp".to_string(),
339            display_name: "MyCorp".to_string(),
340            base_url: "https://llm.example/v1".to_string(),
341            api_key_env: "MYCORP_API_KEY".to_string(),
342            auth: None,
343            model: "gpt-5-mini".to_string(),
344            models: vec!["valid-model".to_string(), "   ".to_string()],
345        };
346
347        let err = config
348            .validate()
349            .expect_err("blank models entry should fail");
350        assert!(err.contains("`models` entries must not be empty"));
351    }
352
353    #[test]
354    fn effective_models_uses_models_list_when_present() {
355        let config = CustomProviderConfig {
356            name: "atlascloud".to_string(),
357            display_name: "Atlas Cloud".to_string(),
358            base_url: "https://api.atlascloud.ai/v1".to_string(),
359            api_key_env: "ATLASCLOUD_API_KEY".to_string(),
360            auth: None,
361            model: "deepseek-ai/deepseek-v4-flash".to_string(),
362            models: vec![
363                "deepseek-ai/deepseek-v4-flash".to_string(),
364                "deepseek-ai/deepseek-v4-pro".to_string(),
365                "deepseek-ai/DeepSeek-V3-0324".to_string(),
366                "qwen/qwen3.6-35b-a3b".to_string(),
367                "moonshotai/kimi-k2.6".to_string(),
368            ],
369        };
370
371        assert_eq!(
372            config.effective_models(),
373            vec![
374                "deepseek-ai/deepseek-v4-flash".to_string(),
375                "deepseek-ai/deepseek-v4-pro".to_string(),
376                "deepseek-ai/DeepSeek-V3-0324".to_string(),
377                "qwen/qwen3.6-35b-a3b".to_string(),
378                "moonshotai/kimi-k2.6".to_string(),
379            ]
380        );
381    }
382
383    #[test]
384    fn effective_models_falls_back_to_single_model_field() {
385        let config = CustomProviderConfig {
386            model: "gpt-5-mini".to_string(),
387            ..CustomProviderConfig::default()
388        };
389
390        assert_eq!(config.effective_models(), vec!["gpt-5-mini".to_string()]);
391    }
392}