1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5fn default_auth_timeout_ms() -> u64 {
6 5_000
7}
8
9fn default_auth_refresh_interval_ms() -> u64 {
10 300_000
11}
12
13#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
16pub struct CustomProviderCommandAuthConfig {
17 pub command: String,
19
20 #[serde(default)]
22 pub args: Vec<String>,
23
24 #[serde(default)]
26 pub cwd: Option<PathBuf>,
27
28 #[serde(default = "default_auth_timeout_ms")]
30 pub timeout_ms: u64,
31
32 #[serde(default = "default_auth_refresh_interval_ms")]
34 pub refresh_interval_ms: u64,
35}
36
37impl CustomProviderCommandAuthConfig {
38 fn validate(&self, provider_name: &str) -> Result<(), String> {
39 if self.command.trim().is_empty() {
40 return Err(format!(
41 "custom_providers[{provider_name}]: `auth.command` must not be empty"
42 ));
43 }
44
45 if self.timeout_ms == 0 {
46 return Err(format!(
47 "custom_providers[{provider_name}]: `auth.timeout_ms` must be greater than 0"
48 ));
49 }
50
51 if self.refresh_interval_ms == 0 {
52 return Err(format!(
53 "custom_providers[{provider_name}]: `auth.refresh_interval_ms` must be greater than 0"
54 ));
55 }
56
57 Ok(())
58 }
59}
60
61#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
67#[derive(Debug, Clone, Default, Deserialize, Serialize)]
68pub struct CustomProviderConfig {
69 pub name: String,
72
73 pub display_name: String,
76
77 pub base_url: String,
80
81 #[serde(default)]
84 pub api_key_env: String,
85
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub auth: Option<CustomProviderCommandAuthConfig>,
89
90 #[serde(default)]
97 pub model: String,
98
99 #[serde(default)]
106 pub models: Vec<String>,
107}
108
109impl CustomProviderConfig {
110 pub fn resolved_api_key_env(&self) -> String {
115 if !self.api_key_env.trim().is_empty() {
116 return self.api_key_env.clone();
117 }
118
119 let mut key = String::new();
120 for ch in self.name.chars() {
121 if ch.is_ascii_alphanumeric() {
122 key.push(ch.to_ascii_uppercase());
123 } else if !key.ends_with('_') {
124 key.push('_');
125 }
126 }
127 if !key.ends_with("_API_KEY") {
128 if !key.ends_with('_') {
129 key.push('_');
130 }
131 key.push_str("API_KEY");
132 }
133 key
134 }
135
136 pub fn uses_command_auth(&self) -> bool {
137 self.auth.is_some()
138 }
139
140 pub fn effective_models(&self) -> Vec<String> {
147 if !self.models.is_empty() {
148 return self
149 .models
150 .iter()
151 .map(|m| m.trim().to_string())
152 .filter(|m| !m.is_empty())
153 .collect();
154 }
155 let trimmed = self.model.trim();
156 if trimmed.is_empty() {
157 Vec::new()
158 } else {
159 vec![trimmed.to_string()]
160 }
161 }
162
163 pub fn validate(&self) -> Result<(), String> {
166 if self.name.trim().is_empty() {
167 return Err("custom_providers: `name` must not be empty".to_string());
168 }
169
170 if !is_valid_provider_name(&self.name) {
171 return Err(format!(
172 "custom_providers[{}]: `name` must use lowercase letters, digits, hyphens, or underscores",
173 self.name
174 ));
175 }
176
177 if self.display_name.trim().is_empty() {
178 return Err(format!(
179 "custom_providers[{}]: `display_name` must not be empty",
180 self.name
181 ));
182 }
183
184 if self.base_url.trim().is_empty() {
185 return Err(format!(
186 "custom_providers[{}]: `base_url` must not be empty",
187 self.name
188 ));
189 }
190
191 if let Some(auth) = &self.auth {
192 auth.validate(&self.name)?;
193 if !self.api_key_env.trim().is_empty() {
194 return Err(format!(
195 "custom_providers[{}]: `auth` cannot be combined with `api_key_env`",
196 self.name
197 ));
198 }
199 }
200
201 if self.models.iter().any(|m| m.trim().is_empty()) {
202 return Err(format!(
203 "custom_providers[{}]: `models` entries must not be empty",
204 self.name
205 ));
206 }
207
208 let reserved = [
209 "openai",
210 "anthropic",
211 "gemini",
212 "copilot",
213 "deepseek",
214 "openrouter",
215 "ollama",
216 "lmstudio",
217 "moonshot",
218 "zai",
219 "minimax",
220 "huggingface",
221 "openresponses",
222 ];
223 let lower = self.name.to_lowercase();
224 if reserved.contains(&lower.as_str()) {
225 return Err(format!(
226 "custom_providers[{}]: name collides with built-in provider",
227 self.name
228 ));
229 }
230
231 Ok(())
232 }
233}
234
235fn is_valid_provider_name(name: &str) -> bool {
236 let bytes = name.as_bytes();
237 let Some(first) = bytes.first() else {
238 return false;
239 };
240 let Some(last) = bytes.last() else {
241 return false;
242 };
243
244 let is_valid_char = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_');
245 let is_alphanumeric = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9');
246
247 is_alphanumeric(*first) && is_alphanumeric(*last) && bytes.iter().copied().all(is_valid_char)
248}
249
250#[cfg(test)]
251mod tests {
252 use std::path::PathBuf;
253
254 use super::{
255 CustomProviderCommandAuthConfig, CustomProviderConfig, default_auth_refresh_interval_ms,
256 default_auth_timeout_ms,
257 };
258
259 #[test]
260 fn validate_accepts_lowercase_provider_name() {
261 let config = CustomProviderConfig {
262 name: "mycorp".to_string(),
263 display_name: "MyCorp".to_string(),
264 base_url: "https://llm.example/v1".to_string(),
265 api_key_env: String::new(),
266 auth: None,
267 model: "gpt-5-mini".to_string(),
268 models: Vec::new(),
269 };
270
271 assert!(config.validate().is_ok());
272 assert_eq!(config.resolved_api_key_env(), "MYCORP_API_KEY");
273 }
274
275 #[test]
276 fn validate_rejects_invalid_provider_name() {
277 let config = CustomProviderConfig {
278 name: "My Corp".to_string(),
279 display_name: "My Corp".to_string(),
280 base_url: "https://llm.example/v1".to_string(),
281 api_key_env: String::new(),
282 auth: None,
283 model: "gpt-5-mini".to_string(),
284 models: Vec::new(),
285 };
286
287 let err = config.validate().expect_err("invalid name should fail");
288 assert!(err.contains("must use lowercase letters, digits, hyphens, or underscores"));
289 }
290
291 #[test]
292 fn validate_rejects_auth_and_api_key_env_together() {
293 let config = CustomProviderConfig {
294 name: "mycorp".to_string(),
295 display_name: "MyCorp".to_string(),
296 base_url: "https://llm.example/v1".to_string(),
297 api_key_env: "MYCORP_API_KEY".to_string(),
298 auth: Some(CustomProviderCommandAuthConfig {
299 command: "print-token".to_string(),
300 args: Vec::new(),
301 cwd: None,
302 timeout_ms: default_auth_timeout_ms(),
303 refresh_interval_ms: default_auth_refresh_interval_ms(),
304 }),
305 model: "gpt-5-mini".to_string(),
306 models: Vec::new(),
307 };
308
309 let err = config.validate().expect_err("conflicting auth should fail");
310 assert!(err.contains("`auth` cannot be combined with `api_key_env`"));
311 }
312
313 #[test]
314 fn validate_accepts_command_auth_without_static_env_key() {
315 let config = CustomProviderConfig {
316 name: "mycorp".to_string(),
317 display_name: "MyCorp".to_string(),
318 base_url: "https://llm.example/v1".to_string(),
319 api_key_env: String::new(),
320 auth: Some(CustomProviderCommandAuthConfig {
321 command: "print-token".to_string(),
322 args: vec!["--json".to_string()],
323 cwd: Some(PathBuf::from("/tmp")),
324 timeout_ms: 1_000,
325 refresh_interval_ms: 60_000,
326 }),
327 model: "gpt-5-mini".to_string(),
328 models: Vec::new(),
329 };
330
331 assert!(config.validate().is_ok());
332 assert!(config.uses_command_auth());
333 }
334
335 #[test]
336 fn validate_rejects_empty_model_entry_in_models_list() {
337 let config = CustomProviderConfig {
338 name: "mycorp".to_string(),
339 display_name: "MyCorp".to_string(),
340 base_url: "https://llm.example/v1".to_string(),
341 api_key_env: "MYCORP_API_KEY".to_string(),
342 auth: None,
343 model: "gpt-5-mini".to_string(),
344 models: vec!["valid-model".to_string(), " ".to_string()],
345 };
346
347 let err = config
348 .validate()
349 .expect_err("blank models entry should fail");
350 assert!(err.contains("`models` entries must not be empty"));
351 }
352
353 #[test]
354 fn effective_models_uses_models_list_when_present() {
355 let config = CustomProviderConfig {
356 name: "atlascloud".to_string(),
357 display_name: "Atlas Cloud".to_string(),
358 base_url: "https://api.atlascloud.ai/v1".to_string(),
359 api_key_env: "ATLASCLOUD_API_KEY".to_string(),
360 auth: None,
361 model: "deepseek-ai/deepseek-v4-flash".to_string(),
362 models: vec![
363 "deepseek-ai/deepseek-v4-flash".to_string(),
364 "deepseek-ai/deepseek-v4-pro".to_string(),
365 "deepseek-ai/DeepSeek-V3-0324".to_string(),
366 "qwen/qwen3.6-35b-a3b".to_string(),
367 "moonshotai/kimi-k2.6".to_string(),
368 ],
369 };
370
371 assert_eq!(
372 config.effective_models(),
373 vec![
374 "deepseek-ai/deepseek-v4-flash".to_string(),
375 "deepseek-ai/deepseek-v4-pro".to_string(),
376 "deepseek-ai/DeepSeek-V3-0324".to_string(),
377 "qwen/qwen3.6-35b-a3b".to_string(),
378 "moonshotai/kimi-k2.6".to_string(),
379 ]
380 );
381 }
382
383 #[test]
384 fn effective_models_falls_back_to_single_model_field() {
385 let config = CustomProviderConfig {
386 model: "gpt-5-mini".to_string(),
387 ..CustomProviderConfig::default()
388 };
389
390 assert_eq!(config.effective_models(), vec!["gpt-5-mini".to_string()]);
391 }
392}