1use std::collections::HashMap;
10use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
12use serde::Deserialize;
13use url::Url;
14
15const DEFAULT_LISTEN: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
17
18const DEFAULT_TOKEN_FILE: &str = "~/.config/stoat/tokens.json";
20
21#[derive(Debug, Deserialize, PartialEq, Eq)]
23pub struct Config {
24 #[serde(default, deserialize_with = "deserialize_optional_socket_addr")]
26 listen: Option<SocketAddr>,
27
28 token_file: Option<String>,
30
31 pub upstream: Upstream,
33
34 pub oauth: OAuth,
36
37 pub translation: Option<Translation>,
39}
40
41impl Config {
42 pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
53 toml::from_str(s)
54 }
55
56 #[must_use]
58 pub fn listen_address(&self) -> SocketAddr {
59 self.listen.unwrap_or(DEFAULT_LISTEN)
60 }
61
62 #[must_use]
64 pub fn token_file_path(&self) -> &str {
65 self.token_file.as_deref().unwrap_or(DEFAULT_TOKEN_FILE)
66 }
67}
68
69#[derive(Debug, Deserialize, PartialEq, Eq)]
71pub struct Upstream {
72 pub base_url: Url,
74}
75
76#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "lowercase")]
83pub enum TokenFormat {
84 #[default]
87 Form,
88 Json,
90}
91
92#[derive(Debug, Deserialize, PartialEq, Eq)]
94pub struct OAuth {
95 pub authorize_url: Url,
97
98 pub token_url: Url,
100
101 pub client_id: String,
103
104 pub scopes: Vec<String>,
106
107 pkce: Option<bool>,
109
110 pub redirect_uri: Url,
112
113 token_format: Option<TokenFormat>,
116}
117
118impl OAuth {
119 #[must_use]
121 pub fn pkce_enabled(&self) -> bool {
122 self.pkce.unwrap_or(true)
123 }
124
125 #[must_use]
127 pub fn token_format(&self) -> TokenFormat {
128 self.token_format.unwrap_or_default()
129 }
130}
131
132#[derive(Debug, Deserialize, PartialEq, Eq)]
134pub struct Translation {
135 pub strip_headers: Option<Vec<String>>,
137
138 pub set_headers: Option<HashMap<String, String>>,
141
142 pub query_params: Option<HashMap<String, String>>,
144}
145
146fn deserialize_optional_socket_addr<'de, D>(deserializer: D) -> Result<Option<SocketAddr>, D::Error>
148where
149 D: serde::Deserializer<'de>,
150{
151 let value: Option<String> = Option::deserialize(deserializer)?;
152 value
153 .map(|s| s.parse().map_err(serde::de::Error::custom))
154 .transpose()
155}
156
157#[cfg(test)]
158mod tests {
159 use std::net::SocketAddr;
160
161 use url::Url;
162
163 use super::*;
164
165 const FULL_CONFIG: &str = r#"
167listen = "127.0.0.1:8080"
168token_file = "~/.config/stoat/tokens.json"
169
170[upstream]
171base_url = "https://api.example.com"
172
173[oauth]
174authorize_url = "https://example.com/oauth/authorize"
175token_url = "https://example.com/oauth/token"
176client_id = "your-client-id"
177scopes = ["scope1", "scope2"]
178pkce = true
179redirect_uri = "https://example.com/oauth/callback"
180
181[translation]
182strip_headers = ["x-api-key"]
183
184[translation.query_params]
185beta = "true"
186
187[translation.set_headers]
188Authorization = "Bearer {access_token}"
189"#;
190
191 const MINIMAL_CONFIG: &str = r#"
193[upstream]
194base_url = "https://api.example.com"
195
196[oauth]
197authorize_url = "https://example.com/oauth/authorize"
198token_url = "https://example.com/oauth/token"
199client_id = "your-client-id"
200scopes = ["scope1"]
201redirect_uri = "https://example.com/oauth/callback"
202"#;
203
204 #[test]
205 fn deserialize_full_config() {
206 let config = Config::from_toml(FULL_CONFIG).unwrap();
207
208 assert_eq!(
209 config.listen_address(),
210 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
211 );
212 assert_eq!(config.token_file_path(), "~/.config/stoat/tokens.json");
213 assert_eq!(
214 config.upstream.base_url,
215 Url::parse("https://api.example.com").unwrap(),
216 );
217 assert_eq!(
218 config.oauth.authorize_url,
219 Url::parse("https://example.com/oauth/authorize").unwrap(),
220 );
221 assert_eq!(
222 config.oauth.token_url,
223 Url::parse("https://example.com/oauth/token").unwrap(),
224 );
225 assert_eq!(config.oauth.client_id, "your-client-id");
226 assert_eq!(config.oauth.scopes, vec!["scope1", "scope2"]);
227 assert!(config.oauth.pkce_enabled());
228 assert_eq!(
229 config.oauth.redirect_uri,
230 Url::parse("https://example.com/oauth/callback").unwrap(),
231 );
232
233 let translation = config.translation.unwrap();
234 assert_eq!(
235 translation.strip_headers.unwrap(),
236 vec!["x-api-key".to_owned()]
237 );
238
239 let set_headers = translation.set_headers.unwrap();
240 assert_eq!(
241 set_headers.get("Authorization").unwrap(),
242 "Bearer {access_token}"
243 );
244
245 let query_params = translation.query_params.unwrap();
246 assert_eq!(query_params.get("beta").unwrap(), "true");
247 }
248
249 #[test]
250 fn deserialize_minimal_config() {
251 let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
252
253 assert_eq!(
254 config.upstream.base_url,
255 Url::parse("https://api.example.com").unwrap(),
256 );
257 assert_eq!(config.oauth.client_id, "your-client-id");
258 assert_eq!(config.oauth.scopes, vec!["scope1"]);
259 assert!(config.translation.is_none());
260 }
261
262 #[test]
263 fn default_listen_address() {
264 let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
265 assert_eq!(
266 config.listen_address(),
267 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
268 );
269 }
270
271 #[test]
272 fn custom_listen_address() {
273 let toml = format!("listen = \"0.0.0.0:9999\"\n{MINIMAL_CONFIG}");
274 let config = Config::from_toml(&toml).unwrap();
275 assert_eq!(
276 config.listen_address(),
277 SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9999),
278 );
279 }
280
281 #[test]
282 fn default_token_file() {
283 let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
284 assert_eq!(config.token_file_path(), "~/.config/stoat/tokens.json");
285 }
286
287 #[test]
288 fn custom_token_file() {
289 let toml = format!("token_file = \"/tmp/tokens.json\"\n{MINIMAL_CONFIG}");
290 let config = Config::from_toml(&toml).unwrap();
291 assert_eq!(config.token_file_path(), "/tmp/tokens.json");
292 }
293
294 #[test]
295 fn pkce_defaults_to_true() {
296 let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
297 assert!(config.oauth.pkce_enabled());
298 }
299
300 #[test]
301 fn pkce_explicit_false() {
302 let toml = MINIMAL_CONFIG.replace(
303 "redirect_uri = \"https://example.com/oauth/callback\"",
304 "redirect_uri = \"https://example.com/oauth/callback\"\npkce = false",
305 );
306 let config = Config::from_toml(&toml).unwrap();
307 assert!(!config.oauth.pkce_enabled());
308 }
309
310 #[test]
311 fn missing_upstream_is_error() {
312 let toml = r#"
313[oauth]
314authorize_url = "https://example.com/oauth/authorize"
315token_url = "https://example.com/oauth/token"
316client_id = "your-client-id"
317scopes = ["scope1"]
318redirect_uri = "https://example.com/oauth/callback"
319"#;
320 let err = Config::from_toml(toml).unwrap_err();
321 let msg = err.to_string();
322 assert!(
323 msg.contains("upstream"),
324 "error should mention upstream: {msg}"
325 );
326 }
327
328 #[test]
329 fn missing_oauth_is_error() {
330 let toml = r#"
331[upstream]
332base_url = "https://api.example.com"
333"#;
334 let err = Config::from_toml(toml).unwrap_err();
335 let msg = err.to_string();
336 assert!(msg.contains("oauth"), "error should mention oauth: {msg}");
337 }
338
339 #[test]
340 fn missing_oauth_client_id_is_error() {
341 let toml = r#"
342[upstream]
343base_url = "https://api.example.com"
344
345[oauth]
346authorize_url = "https://example.com/oauth/authorize"
347token_url = "https://example.com/oauth/token"
348scopes = ["scope1"]
349redirect_uri = "https://example.com/oauth/callback"
350"#;
351 let err = Config::from_toml(toml).unwrap_err();
352 let msg = err.to_string();
353 assert!(
354 msg.contains("client_id"),
355 "error should mention client_id: {msg}"
356 );
357 }
358
359 #[test]
360 fn missing_oauth_scopes_is_error() {
361 let toml = r#"
362[upstream]
363base_url = "https://api.example.com"
364
365[oauth]
366authorize_url = "https://example.com/oauth/authorize"
367token_url = "https://example.com/oauth/token"
368client_id = "your-client-id"
369redirect_uri = "https://example.com/oauth/callback"
370"#;
371 let err = Config::from_toml(toml).unwrap_err();
372 let msg = err.to_string();
373 assert!(msg.contains("scopes"), "error should mention scopes: {msg}");
374 }
375
376 #[test]
377 fn empty_scopes_is_valid() {
378 let toml = MINIMAL_CONFIG.replace("scopes = [\"scope1\"]", "scopes = []");
379 let config = Config::from_toml(&toml).unwrap();
380 assert!(config.oauth.scopes.is_empty());
381 }
382
383 #[test]
384 fn translation_all_optional_fields() {
385 let toml = format!("{MINIMAL_CONFIG}\n[translation]\n");
386 let config = Config::from_toml(&toml).unwrap();
387 let translation = config.translation.unwrap();
388 assert!(translation.strip_headers.is_none());
389 assert!(translation.set_headers.is_none());
390 assert!(translation.query_params.is_none());
391 }
392
393 #[test]
394 fn translation_strip_headers_only() {
395 let toml = format!(
396 "{MINIMAL_CONFIG}\n[translation]\nstrip_headers = [\"x-api-key\", \"x-custom\"]\n"
397 );
398 let config = Config::from_toml(&toml).unwrap();
399 let translation = config.translation.unwrap();
400 assert_eq!(
401 translation.strip_headers.unwrap(),
402 vec!["x-api-key".to_owned(), "x-custom".to_owned()]
403 );
404 assert!(translation.set_headers.is_none());
405 assert!(translation.query_params.is_none());
406 }
407
408 #[test]
409 fn translation_set_headers_only() {
410 let toml = format!(
411 "{MINIMAL_CONFIG}\n[translation.set_headers]\nAuthorization = \"Bearer {{access_token}}\"\n"
412 );
413 let config = Config::from_toml(&toml).unwrap();
414 let translation = config.translation.unwrap();
415 assert!(translation.strip_headers.is_none());
416 let set_headers = translation.set_headers.unwrap();
417 assert_eq!(
418 set_headers.get("Authorization").unwrap(),
419 "Bearer {access_token}"
420 );
421 }
422
423 #[test]
424 fn translation_query_params_only() {
425 let toml = format!("{MINIMAL_CONFIG}\n[translation.query_params]\nbeta = \"true\"\n");
426 let config = Config::from_toml(&toml).unwrap();
427 let translation = config.translation.unwrap();
428 assert!(translation.strip_headers.is_none());
429 assert!(translation.set_headers.is_none());
430 let query_params = translation.query_params.unwrap();
431 assert_eq!(query_params.get("beta").unwrap(), "true");
432 }
433
434 #[test]
435 fn invalid_upstream_url_is_error() {
436 let toml = r#"
437[upstream]
438base_url = "not a valid url"
439
440[oauth]
441authorize_url = "https://example.com/oauth/authorize"
442token_url = "https://example.com/oauth/token"
443client_id = "your-client-id"
444scopes = ["scope1"]
445redirect_uri = "https://example.com/oauth/callback"
446"#;
447 assert!(Config::from_toml(toml).is_err());
448 }
449
450 #[test]
451 fn invalid_oauth_url_is_error() {
452 let toml = r#"
453[upstream]
454base_url = "https://api.example.com"
455
456[oauth]
457authorize_url = "not a url"
458token_url = "https://example.com/oauth/token"
459client_id = "your-client-id"
460scopes = ["scope1"]
461redirect_uri = "https://example.com/oauth/callback"
462"#;
463 assert!(Config::from_toml(toml).is_err());
464 }
465
466 #[test]
467 fn empty_toml_is_error() {
468 assert!(Config::from_toml("").is_err());
469 }
470
471 #[test]
472 fn extra_fields_are_ignored() {
473 let toml = format!("{MINIMAL_CONFIG}\nunknown_field = \"value\"\n");
475 let result = Config::from_toml(&toml);
478 drop(result);
482 }
483
484 #[test]
485 fn invalid_listen_address_is_error() {
486 let toml = format!("listen = \"not-an-address\"\n{MINIMAL_CONFIG}");
487 assert!(Config::from_toml(&toml).is_err());
488 }
489
490 #[test]
491 fn token_format_defaults_to_form() {
492 let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
493 assert_eq!(config.oauth.token_format(), TokenFormat::Form);
494 }
495
496 #[test]
497 fn token_format_explicit_form() {
498 let toml = MINIMAL_CONFIG.replace(
499 "redirect_uri = \"https://example.com/oauth/callback\"",
500 "redirect_uri = \"https://example.com/oauth/callback\"\ntoken_format = \"form\"",
501 );
502 let config = Config::from_toml(&toml).unwrap();
503 assert_eq!(config.oauth.token_format(), TokenFormat::Form);
504 }
505
506 #[test]
507 fn token_format_explicit_json() {
508 let toml = MINIMAL_CONFIG.replace(
509 "redirect_uri = \"https://example.com/oauth/callback\"",
510 "redirect_uri = \"https://example.com/oauth/callback\"\ntoken_format = \"json\"",
511 );
512 let config = Config::from_toml(&toml).unwrap();
513 assert_eq!(config.oauth.token_format(), TokenFormat::Json);
514 }
515}