1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5fn default_auth_timeout_ms() -> u64 {
6 5_000
7}
8
9fn default_auth_refresh_interval_ms() -> u64 {
10 300_000
11}
12
13#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
16pub struct CustomProviderCommandAuthConfig {
17 pub command: String,
19
20 #[serde(default)]
22 pub args: Vec<String>,
23
24 #[serde(default)]
26 pub cwd: Option<PathBuf>,
27
28 #[serde(default = "default_auth_timeout_ms")]
30 pub timeout_ms: u64,
31
32 #[serde(default = "default_auth_refresh_interval_ms")]
34 pub refresh_interval_ms: u64,
35}
36
37impl CustomProviderCommandAuthConfig {
38 fn validate(&self, provider_name: &str) -> Result<(), String> {
39 if self.command.trim().is_empty() {
40 return Err(format!(
41 "custom_providers[{provider_name}]: `auth.command` must not be empty"
42 ));
43 }
44
45 if self.timeout_ms == 0 {
46 return Err(format!(
47 "custom_providers[{provider_name}]: `auth.timeout_ms` must be greater than 0"
48 ));
49 }
50
51 if self.refresh_interval_ms == 0 {
52 return Err(format!(
53 "custom_providers[{provider_name}]: `auth.refresh_interval_ms` must be greater than 0"
54 ));
55 }
56
57 Ok(())
58 }
59}
60
61#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
67#[derive(Debug, Clone, Default, Deserialize, Serialize)]
68pub struct CustomProviderConfig {
69 pub name: String,
72
73 pub display_name: String,
76
77 pub base_url: String,
80
81 #[serde(default)]
84 pub api_key_env: String,
85
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub auth: Option<CustomProviderCommandAuthConfig>,
89
90 #[serde(default)]
97 pub model: String,
98
99 #[serde(default)]
106 pub models: Vec<String>,
107}
108
109impl CustomProviderConfig {
110 pub fn resolved_api_key_env(&self) -> String {
115 if !self.api_key_env.trim().is_empty() {
116 return self.api_key_env.clone();
117 }
118
119 let mut key = String::new();
120 for ch in self.name.chars() {
121 if ch.is_ascii_alphanumeric() {
122 key.push(ch.to_ascii_uppercase());
123 } else if !key.ends_with('_') {
124 key.push('_');
125 }
126 }
127 if !key.ends_with("_API_KEY") {
128 if !key.ends_with('_') {
129 key.push('_');
130 }
131 key.push_str("API_KEY");
132 }
133 key
134 }
135
136 pub fn uses_command_auth(&self) -> bool {
137 self.auth.is_some()
138 }
139
140 pub fn effective_models(&self) -> Vec<String> {
147 if !self.models.is_empty() {
148 return self
149 .models
150 .iter()
151 .map(|m| m.trim().to_string())
152 .filter(|m| !m.is_empty())
153 .collect();
154 }
155 let trimmed = self.model.trim();
156 if trimmed.is_empty() {
157 Vec::new()
158 } else {
159 vec![trimmed.to_string()]
160 }
161 }
162
163 pub fn validate(&self) -> Result<(), String> {
166 if self.name.trim().is_empty() {
167 return Err("custom_providers: `name` must not be empty".to_string());
168 }
169
170 if !is_valid_provider_name(&self.name) {
171 return Err(format!(
172 "custom_providers[{}]: `name` must use lowercase letters, digits, hyphens, or underscores",
173 self.name
174 ));
175 }
176
177 if self.display_name.trim().is_empty() {
178 return Err(format!(
179 "custom_providers[{}]: `display_name` must not be empty",
180 self.name
181 ));
182 }
183
184 if self.base_url.trim().is_empty() {
185 return Err(format!(
186 "custom_providers[{}]: `base_url` must not be empty",
187 self.name
188 ));
189 }
190
191 if let Some(auth) = &self.auth {
192 auth.validate(&self.name)?;
193 if !self.api_key_env.trim().is_empty() {
194 return Err(format!(
195 "custom_providers[{}]: `auth` cannot be combined with `api_key_env`",
196 self.name
197 ));
198 }
199 }
200
201 if self.models.iter().any(|m| m.trim().is_empty()) {
202 return Err(format!(
203 "custom_providers[{}]: `models` entries must not be empty",
204 self.name
205 ));
206 }
207
208 let reserved = [
209 "openai",
210 "anthropic",
211 "gemini",
212 "copilot",
213 "deepseek",
214 "openrouter",
215 "ollama",
216 "lmstudio",
217 "llamacpp",
218 "moonshot",
219 "zai",
220 "minimax",
221 "huggingface",
222 "openresponses",
223 ];
224 let lower = self.name.to_lowercase();
225 if reserved.contains(&lower.as_str()) {
226 return Err(format!(
227 "custom_providers[{}]: name collides with built-in provider",
228 self.name
229 ));
230 }
231
232 Ok(())
233 }
234}
235
236fn is_valid_provider_name(name: &str) -> bool {
237 let bytes = name.as_bytes();
238 let Some(first) = bytes.first() else {
239 return false;
240 };
241 let Some(last) = bytes.last() else {
242 return false;
243 };
244
245 let is_valid_char = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_');
246 let is_alphanumeric = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9');
247
248 is_alphanumeric(*first) && is_alphanumeric(*last) && bytes.iter().copied().all(is_valid_char)
249}
250
251#[cfg(test)]
252mod tests {
253 use std::path::PathBuf;
254
255 use super::{
256 CustomProviderCommandAuthConfig, CustomProviderConfig, default_auth_refresh_interval_ms,
257 default_auth_timeout_ms,
258 };
259
260 #[test]
261 fn validate_accepts_lowercase_provider_name() {
262 let config = CustomProviderConfig {
263 name: "mycorp".to_string(),
264 display_name: "MyCorp".to_string(),
265 base_url: "https://llm.example/v1".to_string(),
266 api_key_env: String::new(),
267 auth: None,
268 model: "gpt-5-mini".to_string(),
269 models: Vec::new(),
270 };
271
272 assert!(config.validate().is_ok());
273 assert_eq!(config.resolved_api_key_env(), "MYCORP_API_KEY");
274 }
275
276 #[test]
277 fn validate_rejects_invalid_provider_name() {
278 let config = CustomProviderConfig {
279 name: "My Corp".to_string(),
280 display_name: "My Corp".to_string(),
281 base_url: "https://llm.example/v1".to_string(),
282 api_key_env: String::new(),
283 auth: None,
284 model: "gpt-5-mini".to_string(),
285 models: Vec::new(),
286 };
287
288 let err = config.validate().expect_err("invalid name should fail");
289 assert!(err.contains("must use lowercase letters, digits, hyphens, or underscores"));
290 }
291
292 #[test]
293 fn validate_rejects_auth_and_api_key_env_together() {
294 let config = CustomProviderConfig {
295 name: "mycorp".to_string(),
296 display_name: "MyCorp".to_string(),
297 base_url: "https://llm.example/v1".to_string(),
298 api_key_env: "MYCORP_API_KEY".to_string(),
299 auth: Some(CustomProviderCommandAuthConfig {
300 command: "print-token".to_string(),
301 args: Vec::new(),
302 cwd: None,
303 timeout_ms: default_auth_timeout_ms(),
304 refresh_interval_ms: default_auth_refresh_interval_ms(),
305 }),
306 model: "gpt-5-mini".to_string(),
307 models: Vec::new(),
308 };
309
310 let err = config.validate().expect_err("conflicting auth should fail");
311 assert!(err.contains("`auth` cannot be combined with `api_key_env`"));
312 }
313
314 #[test]
315 fn validate_accepts_command_auth_without_static_env_key() {
316 let config = CustomProviderConfig {
317 name: "mycorp".to_string(),
318 display_name: "MyCorp".to_string(),
319 base_url: "https://llm.example/v1".to_string(),
320 api_key_env: String::new(),
321 auth: Some(CustomProviderCommandAuthConfig {
322 command: "print-token".to_string(),
323 args: vec!["--json".to_string()],
324 cwd: Some(PathBuf::from("/tmp")),
325 timeout_ms: 1_000,
326 refresh_interval_ms: 60_000,
327 }),
328 model: "gpt-5-mini".to_string(),
329 models: Vec::new(),
330 };
331
332 assert!(config.validate().is_ok());
333 assert!(config.uses_command_auth());
334 }
335
336 #[test]
337 fn validate_rejects_empty_model_entry_in_models_list() {
338 let config = CustomProviderConfig {
339 name: "mycorp".to_string(),
340 display_name: "MyCorp".to_string(),
341 base_url: "https://llm.example/v1".to_string(),
342 api_key_env: "MYCORP_API_KEY".to_string(),
343 auth: None,
344 model: "gpt-5-mini".to_string(),
345 models: vec!["valid-model".to_string(), " ".to_string()],
346 };
347
348 let err = config
349 .validate()
350 .expect_err("blank models entry should fail");
351 assert!(err.contains("`models` entries must not be empty"));
352 }
353
354 #[test]
355 fn effective_models_uses_models_list_when_present() {
356 let config = CustomProviderConfig {
357 name: "atlascloud".to_string(),
358 display_name: "Atlas Cloud".to_string(),
359 base_url: "https://api.atlascloud.ai/v1".to_string(),
360 api_key_env: "ATLASCLOUD_API_KEY".to_string(),
361 auth: None,
362 model: "deepseek-ai/deepseek-v4-flash".to_string(),
363 models: vec![
364 "deepseek-ai/deepseek-v4-flash".to_string(),
365 "deepseek-ai/deepseek-v4-pro".to_string(),
366 "deepseek-ai/DeepSeek-V3-0324".to_string(),
367 "qwen/qwen3.6-35b-a3b".to_string(),
368 "moonshotai/kimi-k2.6".to_string(),
369 ],
370 };
371
372 assert_eq!(
373 config.effective_models(),
374 vec![
375 "deepseek-ai/deepseek-v4-flash".to_string(),
376 "deepseek-ai/deepseek-v4-pro".to_string(),
377 "deepseek-ai/DeepSeek-V3-0324".to_string(),
378 "qwen/qwen3.6-35b-a3b".to_string(),
379 "moonshotai/kimi-k2.6".to_string(),
380 ]
381 );
382 }
383
384 #[test]
385 fn effective_models_falls_back_to_single_model_field() {
386 let config = CustomProviderConfig {
387 model: "gpt-5-mini".to_string(),
388 ..CustomProviderConfig::default()
389 };
390
391 assert_eq!(config.effective_models(), vec!["gpt-5-mini".to_string()]);
392 }
393}