1use std::path::PathBuf;
2
3use serde::{Deserialize, Serialize};
4
5fn default_auth_timeout_ms() -> u64 {
6 5_000
7}
8
9fn default_auth_refresh_interval_ms() -> u64 {
10 300_000
11}
12
13#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
16pub struct CustomProviderCommandAuthConfig {
17 pub command: String,
19
20 #[serde(default)]
22 pub args: Vec<String>,
23
24 #[serde(default)]
26 pub cwd: Option<PathBuf>,
27
28 #[serde(default = "default_auth_timeout_ms")]
30 pub timeout_ms: u64,
31
32 #[serde(default = "default_auth_refresh_interval_ms")]
34 pub refresh_interval_ms: u64,
35}
36
37impl CustomProviderCommandAuthConfig {
38 fn validate(&self, provider_name: &str) -> Result<(), String> {
39 if self.command.trim().is_empty() {
40 return Err(format!(
41 "custom_providers[{provider_name}]: `auth.command` must not be empty"
42 ));
43 }
44
45 if self.timeout_ms == 0 {
46 return Err(format!(
47 "custom_providers[{provider_name}]: `auth.timeout_ms` must be greater than 0"
48 ));
49 }
50
51 if self.refresh_interval_ms == 0 {
52 return Err(format!(
53 "custom_providers[{provider_name}]: `auth.refresh_interval_ms` must be greater than 0"
54 ));
55 }
56
57 Ok(())
58 }
59}
60
61#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
67#[derive(Debug, Clone, Deserialize, Serialize)]
68pub struct CustomProviderConfig {
69 pub name: String,
72
73 pub display_name: String,
76
77 pub base_url: String,
80
81 #[serde(default)]
84 pub api_key_env: String,
85
86 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub auth: Option<CustomProviderCommandAuthConfig>,
89
90 #[serde(default)]
92 pub model: String,
93}
94
95impl CustomProviderConfig {
96 pub fn resolved_api_key_env(&self) -> String {
101 if !self.api_key_env.trim().is_empty() {
102 return self.api_key_env.clone();
103 }
104
105 let mut key = String::new();
106 for ch in self.name.chars() {
107 if ch.is_ascii_alphanumeric() {
108 key.push(ch.to_ascii_uppercase());
109 } else if !key.ends_with('_') {
110 key.push('_');
111 }
112 }
113 if !key.ends_with("_API_KEY") {
114 if !key.ends_with('_') {
115 key.push('_');
116 }
117 key.push_str("API_KEY");
118 }
119 key
120 }
121
122 pub fn uses_command_auth(&self) -> bool {
123 self.auth.is_some()
124 }
125
126 pub fn validate(&self) -> Result<(), String> {
129 if self.name.trim().is_empty() {
130 return Err("custom_providers: `name` must not be empty".to_string());
131 }
132
133 if !is_valid_provider_name(&self.name) {
134 return Err(format!(
135 "custom_providers[{}]: `name` must use lowercase letters, digits, hyphens, or underscores",
136 self.name
137 ));
138 }
139
140 if self.display_name.trim().is_empty() {
141 return Err(format!(
142 "custom_providers[{}]: `display_name` must not be empty",
143 self.name
144 ));
145 }
146
147 if self.base_url.trim().is_empty() {
148 return Err(format!(
149 "custom_providers[{}]: `base_url` must not be empty",
150 self.name
151 ));
152 }
153
154 if let Some(auth) = &self.auth {
155 auth.validate(&self.name)?;
156 if !self.api_key_env.trim().is_empty() {
157 return Err(format!(
158 "custom_providers[{}]: `auth` cannot be combined with `api_key_env`",
159 self.name
160 ));
161 }
162 }
163
164 let reserved = [
165 "openai",
166 "anthropic",
167 "gemini",
168 "copilot",
169 "deepseek",
170 "openrouter",
171 "ollama",
172 "lmstudio",
173 "moonshot",
174 "zai",
175 "minimax",
176 "huggingface",
177 "openresponses",
178 ];
179 let lower = self.name.to_lowercase();
180 if reserved.contains(&lower.as_str()) {
181 return Err(format!(
182 "custom_providers[{}]: name collides with built-in provider",
183 self.name
184 ));
185 }
186
187 Ok(())
188 }
189}
190
191fn is_valid_provider_name(name: &str) -> bool {
192 let bytes = name.as_bytes();
193 let Some(first) = bytes.first() else {
194 return false;
195 };
196 let Some(last) = bytes.last() else {
197 return false;
198 };
199
200 let is_valid_char = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_');
201 let is_alphanumeric = |ch: u8| matches!(ch, b'a'..=b'z' | b'0'..=b'9');
202
203 is_alphanumeric(*first) && is_alphanumeric(*last) && bytes.iter().copied().all(is_valid_char)
204}
205
206#[cfg(test)]
207mod tests {
208 use std::path::PathBuf;
209
210 use super::{
211 CustomProviderCommandAuthConfig, CustomProviderConfig, default_auth_refresh_interval_ms,
212 default_auth_timeout_ms,
213 };
214
215 #[test]
216 fn validate_accepts_lowercase_provider_name() {
217 let config = CustomProviderConfig {
218 name: "mycorp".to_string(),
219 display_name: "MyCorp".to_string(),
220 base_url: "https://llm.example/v1".to_string(),
221 api_key_env: String::new(),
222 auth: None,
223 model: "gpt-5-mini".to_string(),
224 };
225
226 assert!(config.validate().is_ok());
227 assert_eq!(config.resolved_api_key_env(), "MYCORP_API_KEY");
228 }
229
230 #[test]
231 fn validate_rejects_invalid_provider_name() {
232 let config = CustomProviderConfig {
233 name: "My Corp".to_string(),
234 display_name: "My Corp".to_string(),
235 base_url: "https://llm.example/v1".to_string(),
236 api_key_env: String::new(),
237 auth: None,
238 model: "gpt-5-mini".to_string(),
239 };
240
241 let err = config.validate().expect_err("invalid name should fail");
242 assert!(err.contains("must use lowercase letters, digits, hyphens, or underscores"));
243 }
244
245 #[test]
246 fn validate_rejects_auth_and_api_key_env_together() {
247 let config = CustomProviderConfig {
248 name: "mycorp".to_string(),
249 display_name: "MyCorp".to_string(),
250 base_url: "https://llm.example/v1".to_string(),
251 api_key_env: "MYCORP_API_KEY".to_string(),
252 auth: Some(CustomProviderCommandAuthConfig {
253 command: "print-token".to_string(),
254 args: Vec::new(),
255 cwd: None,
256 timeout_ms: default_auth_timeout_ms(),
257 refresh_interval_ms: default_auth_refresh_interval_ms(),
258 }),
259 model: "gpt-5-mini".to_string(),
260 };
261
262 let err = config.validate().expect_err("conflicting auth should fail");
263 assert!(err.contains("`auth` cannot be combined with `api_key_env`"));
264 }
265
266 #[test]
267 fn validate_accepts_command_auth_without_static_env_key() {
268 let config = CustomProviderConfig {
269 name: "mycorp".to_string(),
270 display_name: "MyCorp".to_string(),
271 base_url: "https://llm.example/v1".to_string(),
272 api_key_env: String::new(),
273 auth: Some(CustomProviderCommandAuthConfig {
274 command: "print-token".to_string(),
275 args: vec!["--json".to_string()],
276 cwd: Some(PathBuf::from("/tmp")),
277 timeout_ms: 1_000,
278 refresh_interval_ms: 60_000,
279 }),
280 model: "gpt-5-mini".to_string(),
281 };
282
283 assert!(config.validate().is_ok());
284 assert!(config.uses_command_auth());
285 }
286}