Skip to main content

stoat_core/
config.rs

1//! Configuration types for stoat.
2//!
3//! These types represent the TOML config file that drives stoat's behavior.
4//! All fields that are not required have sensible defaults.
5//!
6//! URL fields are parsed and validated at deserialization time using
7//! [`url::Url`]. The `listen` address is parsed as a [`SocketAddr`].
8
9use std::collections::HashMap;
10use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
12use serde::Deserialize;
13use url::Url;
14
15/// Default listen address: localhost with automatic port assignment.
16const DEFAULT_LISTEN: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
17
18/// Default token file path (tilde-expanded at runtime by the I/O layer).
19const DEFAULT_TOKEN_FILE: &str = "~/.config/stoat/tokens.json";
20
21/// Top-level stoat configuration, deserialized from a TOML file.
22#[derive(Debug, Deserialize, PartialEq, Eq)]
23pub struct Config {
24    /// Address and port to listen on. Use port 0 for automatic assignment.
25    #[serde(default, deserialize_with = "deserialize_optional_socket_addr")]
26    listen: Option<SocketAddr>,
27
28    /// Path to the token storage file. Tilde (`~`) is expanded at runtime.
29    token_file: Option<String>,
30
31    /// Upstream API to proxy requests to.
32    pub upstream: Upstream,
33
34    /// OAuth PKCE configuration.
35    pub oauth: OAuth,
36
37    /// Request transformations applied to every proxied request.
38    pub translation: Option<Translation>,
39}
40
41impl Config {
42    /// Deserialize a [`Config`] from a TOML string.
43    ///
44    /// URL fields are validated during deserialization — invalid URLs will
45    /// produce an error. The `listen` address is validated as a
46    /// [`SocketAddr`].
47    ///
48    /// # Errors
49    ///
50    /// Returns a [`toml::de::Error`] if the input is not valid TOML, does
51    /// not match the expected schema, or contains invalid URLs or addresses.
52    pub fn from_toml(s: &str) -> Result<Self, toml::de::Error> {
53        toml::from_str(s)
54    }
55
56    /// The listen address, falling back to the default if not configured.
57    #[must_use]
58    pub fn listen_address(&self) -> SocketAddr {
59        self.listen.unwrap_or(DEFAULT_LISTEN)
60    }
61
62    /// The token file path, falling back to the default if not configured.
63    #[must_use]
64    pub fn token_file_path(&self) -> &str {
65        self.token_file.as_deref().unwrap_or(DEFAULT_TOKEN_FILE)
66    }
67}
68
69/// Upstream API target.
70#[derive(Debug, Deserialize, PartialEq, Eq)]
71pub struct Upstream {
72    /// Base URL of the upstream API.
73    pub base_url: Url,
74}
75
76/// Body format for token endpoint requests.
77///
78/// Controls whether the token exchange and refresh POST requests send
79/// `application/x-www-form-urlencoded` (the OAuth 2.0 RFC 6749 default)
80/// or `application/json` bodies.
81#[derive(Debug, Clone, Copy, Default, Deserialize, PartialEq, Eq)]
82#[serde(rename_all = "lowercase")]
83pub enum TokenFormat {
84    /// Form-encoded body (`application/x-www-form-urlencoded`). This is the
85    /// default per RFC 6749.
86    #[default]
87    Form,
88    /// JSON body (`application/json`). Required by some OAuth providers.
89    Json,
90}
91
92/// OAuth PKCE configuration.
93#[derive(Debug, Deserialize, PartialEq, Eq)]
94pub struct OAuth {
95    /// OAuth authorization endpoint.
96    pub authorize_url: Url,
97
98    /// OAuth token exchange and refresh endpoint.
99    pub token_url: Url,
100
101    /// OAuth client identifier.
102    pub client_id: String,
103
104    /// OAuth scopes to request.
105    pub scopes: Vec<String>,
106
107    /// Enable PKCE (S256). Defaults to `true` when not specified.
108    pkce: Option<bool>,
109
110    /// Redirect URI for the OAuth flow.
111    pub redirect_uri: Url,
112
113    /// Body format for the token endpoint. Defaults to `Form` when not
114    /// specified.
115    token_format: Option<TokenFormat>,
116}
117
118impl OAuth {
119    /// Whether PKCE is enabled, defaulting to `true`.
120    #[must_use]
121    pub fn pkce_enabled(&self) -> bool {
122        self.pkce.unwrap_or(true)
123    }
124
125    /// The token endpoint body format, defaulting to [`TokenFormat::Form`].
126    #[must_use]
127    pub fn token_format(&self) -> TokenFormat {
128        self.token_format.unwrap_or_default()
129    }
130}
131
132/// Request transformations applied to every proxied request.
133#[derive(Debug, Deserialize, PartialEq, Eq)]
134pub struct Translation {
135    /// Headers to remove from the incoming request before forwarding.
136    pub strip_headers: Option<Vec<String>>,
137
138    /// Headers to set on the outgoing request. Values support the
139    /// `{access_token}` template variable.
140    pub set_headers: Option<HashMap<String, String>>,
141
142    /// Query parameters to append to every outgoing request URL.
143    pub query_params: Option<HashMap<String, String>>,
144}
145
146/// Deserialize an optional `SocketAddr` from a TOML string value.
147fn deserialize_optional_socket_addr<'de, D>(deserializer: D) -> Result<Option<SocketAddr>, D::Error>
148where
149    D: serde::Deserializer<'de>,
150{
151    let value: Option<String> = Option::deserialize(deserializer)?;
152    value
153        .map(|s| s.parse().map_err(serde::de::Error::custom))
154        .transpose()
155}
156
157#[cfg(test)]
158mod tests {
159    use std::net::SocketAddr;
160
161    use url::Url;
162
163    use super::*;
164
165    /// The full example config from `docs/src/project/configuration.md`.
166    const FULL_CONFIG: &str = r#"
167listen = "127.0.0.1:8080"
168token_file = "~/.config/stoat/tokens.json"
169
170[upstream]
171base_url = "https://api.example.com"
172
173[oauth]
174authorize_url = "https://example.com/oauth/authorize"
175token_url = "https://example.com/oauth/token"
176client_id = "your-client-id"
177scopes = ["scope1", "scope2"]
178pkce = true
179redirect_uri = "https://example.com/oauth/callback"
180
181[translation]
182strip_headers = ["x-api-key"]
183
184[translation.query_params]
185beta = "true"
186
187[translation.set_headers]
188Authorization = "Bearer {access_token}"
189"#;
190
191    /// Only the required fields — everything optional is omitted.
192    const MINIMAL_CONFIG: &str = r#"
193[upstream]
194base_url = "https://api.example.com"
195
196[oauth]
197authorize_url = "https://example.com/oauth/authorize"
198token_url = "https://example.com/oauth/token"
199client_id = "your-client-id"
200scopes = ["scope1"]
201redirect_uri = "https://example.com/oauth/callback"
202"#;
203
204    #[test]
205    fn deserialize_full_config() {
206        let config = Config::from_toml(FULL_CONFIG).unwrap();
207
208        assert_eq!(
209            config.listen_address(),
210            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
211        );
212        assert_eq!(config.token_file_path(), "~/.config/stoat/tokens.json");
213        assert_eq!(
214            config.upstream.base_url,
215            Url::parse("https://api.example.com").unwrap(),
216        );
217        assert_eq!(
218            config.oauth.authorize_url,
219            Url::parse("https://example.com/oauth/authorize").unwrap(),
220        );
221        assert_eq!(
222            config.oauth.token_url,
223            Url::parse("https://example.com/oauth/token").unwrap(),
224        );
225        assert_eq!(config.oauth.client_id, "your-client-id");
226        assert_eq!(config.oauth.scopes, vec!["scope1", "scope2"]);
227        assert!(config.oauth.pkce_enabled());
228        assert_eq!(
229            config.oauth.redirect_uri,
230            Url::parse("https://example.com/oauth/callback").unwrap(),
231        );
232
233        let translation = config.translation.unwrap();
234        assert_eq!(
235            translation.strip_headers.unwrap(),
236            vec!["x-api-key".to_owned()]
237        );
238
239        let set_headers = translation.set_headers.unwrap();
240        assert_eq!(
241            set_headers.get("Authorization").unwrap(),
242            "Bearer {access_token}"
243        );
244
245        let query_params = translation.query_params.unwrap();
246        assert_eq!(query_params.get("beta").unwrap(), "true");
247    }
248
249    #[test]
250    fn deserialize_minimal_config() {
251        let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
252
253        assert_eq!(
254            config.upstream.base_url,
255            Url::parse("https://api.example.com").unwrap(),
256        );
257        assert_eq!(config.oauth.client_id, "your-client-id");
258        assert_eq!(config.oauth.scopes, vec!["scope1"]);
259        assert!(config.translation.is_none());
260    }
261
262    #[test]
263    fn default_listen_address() {
264        let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
265        assert_eq!(
266            config.listen_address(),
267            SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
268        );
269    }
270
271    #[test]
272    fn custom_listen_address() {
273        let toml = format!("listen = \"0.0.0.0:9999\"\n{MINIMAL_CONFIG}");
274        let config = Config::from_toml(&toml).unwrap();
275        assert_eq!(
276            config.listen_address(),
277            SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9999),
278        );
279    }
280
281    #[test]
282    fn default_token_file() {
283        let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
284        assert_eq!(config.token_file_path(), "~/.config/stoat/tokens.json");
285    }
286
287    #[test]
288    fn custom_token_file() {
289        let toml = format!("token_file = \"/tmp/tokens.json\"\n{MINIMAL_CONFIG}");
290        let config = Config::from_toml(&toml).unwrap();
291        assert_eq!(config.token_file_path(), "/tmp/tokens.json");
292    }
293
294    #[test]
295    fn pkce_defaults_to_true() {
296        let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
297        assert!(config.oauth.pkce_enabled());
298    }
299
300    #[test]
301    fn pkce_explicit_false() {
302        let toml = MINIMAL_CONFIG.replace(
303            "redirect_uri = \"https://example.com/oauth/callback\"",
304            "redirect_uri = \"https://example.com/oauth/callback\"\npkce = false",
305        );
306        let config = Config::from_toml(&toml).unwrap();
307        assert!(!config.oauth.pkce_enabled());
308    }
309
310    #[test]
311    fn missing_upstream_is_error() {
312        let toml = r#"
313[oauth]
314authorize_url = "https://example.com/oauth/authorize"
315token_url = "https://example.com/oauth/token"
316client_id = "your-client-id"
317scopes = ["scope1"]
318redirect_uri = "https://example.com/oauth/callback"
319"#;
320        let err = Config::from_toml(toml).unwrap_err();
321        let msg = err.to_string();
322        assert!(
323            msg.contains("upstream"),
324            "error should mention upstream: {msg}"
325        );
326    }
327
328    #[test]
329    fn missing_oauth_is_error() {
330        let toml = r#"
331[upstream]
332base_url = "https://api.example.com"
333"#;
334        let err = Config::from_toml(toml).unwrap_err();
335        let msg = err.to_string();
336        assert!(msg.contains("oauth"), "error should mention oauth: {msg}");
337    }
338
339    #[test]
340    fn missing_oauth_client_id_is_error() {
341        let toml = r#"
342[upstream]
343base_url = "https://api.example.com"
344
345[oauth]
346authorize_url = "https://example.com/oauth/authorize"
347token_url = "https://example.com/oauth/token"
348scopes = ["scope1"]
349redirect_uri = "https://example.com/oauth/callback"
350"#;
351        let err = Config::from_toml(toml).unwrap_err();
352        let msg = err.to_string();
353        assert!(
354            msg.contains("client_id"),
355            "error should mention client_id: {msg}"
356        );
357    }
358
359    #[test]
360    fn missing_oauth_scopes_is_error() {
361        let toml = r#"
362[upstream]
363base_url = "https://api.example.com"
364
365[oauth]
366authorize_url = "https://example.com/oauth/authorize"
367token_url = "https://example.com/oauth/token"
368client_id = "your-client-id"
369redirect_uri = "https://example.com/oauth/callback"
370"#;
371        let err = Config::from_toml(toml).unwrap_err();
372        let msg = err.to_string();
373        assert!(msg.contains("scopes"), "error should mention scopes: {msg}");
374    }
375
376    #[test]
377    fn empty_scopes_is_valid() {
378        let toml = MINIMAL_CONFIG.replace("scopes = [\"scope1\"]", "scopes = []");
379        let config = Config::from_toml(&toml).unwrap();
380        assert!(config.oauth.scopes.is_empty());
381    }
382
383    #[test]
384    fn translation_all_optional_fields() {
385        let toml = format!("{MINIMAL_CONFIG}\n[translation]\n");
386        let config = Config::from_toml(&toml).unwrap();
387        let translation = config.translation.unwrap();
388        assert!(translation.strip_headers.is_none());
389        assert!(translation.set_headers.is_none());
390        assert!(translation.query_params.is_none());
391    }
392
393    #[test]
394    fn translation_strip_headers_only() {
395        let toml = format!(
396            "{MINIMAL_CONFIG}\n[translation]\nstrip_headers = [\"x-api-key\", \"x-custom\"]\n"
397        );
398        let config = Config::from_toml(&toml).unwrap();
399        let translation = config.translation.unwrap();
400        assert_eq!(
401            translation.strip_headers.unwrap(),
402            vec!["x-api-key".to_owned(), "x-custom".to_owned()]
403        );
404        assert!(translation.set_headers.is_none());
405        assert!(translation.query_params.is_none());
406    }
407
408    #[test]
409    fn translation_set_headers_only() {
410        let toml = format!(
411            "{MINIMAL_CONFIG}\n[translation.set_headers]\nAuthorization = \"Bearer {{access_token}}\"\n"
412        );
413        let config = Config::from_toml(&toml).unwrap();
414        let translation = config.translation.unwrap();
415        assert!(translation.strip_headers.is_none());
416        let set_headers = translation.set_headers.unwrap();
417        assert_eq!(
418            set_headers.get("Authorization").unwrap(),
419            "Bearer {access_token}"
420        );
421    }
422
423    #[test]
424    fn translation_query_params_only() {
425        let toml = format!("{MINIMAL_CONFIG}\n[translation.query_params]\nbeta = \"true\"\n");
426        let config = Config::from_toml(&toml).unwrap();
427        let translation = config.translation.unwrap();
428        assert!(translation.strip_headers.is_none());
429        assert!(translation.set_headers.is_none());
430        let query_params = translation.query_params.unwrap();
431        assert_eq!(query_params.get("beta").unwrap(), "true");
432    }
433
434    #[test]
435    fn invalid_upstream_url_is_error() {
436        let toml = r#"
437[upstream]
438base_url = "not a valid url"
439
440[oauth]
441authorize_url = "https://example.com/oauth/authorize"
442token_url = "https://example.com/oauth/token"
443client_id = "your-client-id"
444scopes = ["scope1"]
445redirect_uri = "https://example.com/oauth/callback"
446"#;
447        assert!(Config::from_toml(toml).is_err());
448    }
449
450    #[test]
451    fn invalid_oauth_url_is_error() {
452        let toml = r#"
453[upstream]
454base_url = "https://api.example.com"
455
456[oauth]
457authorize_url = "not a url"
458token_url = "https://example.com/oauth/token"
459client_id = "your-client-id"
460scopes = ["scope1"]
461redirect_uri = "https://example.com/oauth/callback"
462"#;
463        assert!(Config::from_toml(toml).is_err());
464    }
465
466    #[test]
467    fn empty_toml_is_error() {
468        assert!(Config::from_toml("").is_err());
469    }
470
471    #[test]
472    fn extra_fields_are_ignored() {
473        // Forward compatibility: unknown top-level fields should not cause errors.
474        let toml = format!("{MINIMAL_CONFIG}\nunknown_field = \"value\"\n");
475        // TOML serde by default rejects unknown fields unless deny_unknown_fields
476        // is disabled. Check which behavior we have.
477        let result = Config::from_toml(&toml);
478        // This is acceptable either way (error or ignore), but we document the behavior.
479        // If it errors, that's fine — strict parsing catches typos.
480        // If it succeeds, that's also fine — forward compatibility.
481        drop(result);
482    }
483
484    #[test]
485    fn invalid_listen_address_is_error() {
486        let toml = format!("listen = \"not-an-address\"\n{MINIMAL_CONFIG}");
487        assert!(Config::from_toml(&toml).is_err());
488    }
489
490    #[test]
491    fn token_format_defaults_to_form() {
492        let config = Config::from_toml(MINIMAL_CONFIG).unwrap();
493        assert_eq!(config.oauth.token_format(), TokenFormat::Form);
494    }
495
496    #[test]
497    fn token_format_explicit_form() {
498        let toml = MINIMAL_CONFIG.replace(
499            "redirect_uri = \"https://example.com/oauth/callback\"",
500            "redirect_uri = \"https://example.com/oauth/callback\"\ntoken_format = \"form\"",
501        );
502        let config = Config::from_toml(&toml).unwrap();
503        assert_eq!(config.oauth.token_format(), TokenFormat::Form);
504    }
505
506    #[test]
507    fn token_format_explicit_json() {
508        let toml = MINIMAL_CONFIG.replace(
509            "redirect_uri = \"https://example.com/oauth/callback\"",
510            "redirect_uri = \"https://example.com/oauth/callback\"\ntoken_format = \"json\"",
511        );
512        let config = Config::from_toml(&toml).unwrap();
513        assert_eq!(config.oauth.token_format(), TokenFormat::Json);
514    }
515}