1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
13#[serde(untagged)]
14pub enum CookieDomainPolicy {
15 Explicit(String),
17 #[default]
19 Auto,
20}
21
22impl CookieDomainPolicy {
23 pub fn domain(&self) -> Option<&str> {
25 match self {
26 Self::Explicit(d) if !d.is_empty() => Some(d.as_str()),
27 _ => None,
28 }
29 }
30}
31
32impl From<Option<String>> for CookieDomainPolicy {
36 fn from(opt: Option<String>) -> Self {
37 match opt {
38 Some(s) if !s.is_empty() => Self::Explicit(s),
39 _ => Self::Auto,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct YAuthConfig {
46 pub base_url: String,
47 pub session_cookie_name: String,
48 #[serde(with = "duration_secs")]
49 pub session_ttl: Duration,
50 #[serde(default)]
59 pub cookie_domain: CookieDomainPolicy,
60 pub secure_cookies: bool,
61 pub trusted_origins: Vec<String>,
62 pub smtp: Option<SmtpConfig>,
63 #[serde(default)]
65 pub auto_admin_first_user: bool,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub remember_me_ttl: Option<DurationSecs>,
71 #[serde(default)]
74 pub session_binding: SessionBindingConfig,
75 #[serde(default = "default_true")]
79 pub allow_signups: bool,
80 #[serde(default = "default_schema")]
84 pub db_schema: String,
85}
86
87fn default_true() -> bool {
88 true
89}
90
91fn default_schema() -> String {
92 "public".into()
93}
94
95impl Default for YAuthConfig {
96 fn default() -> Self {
97 Self {
98 base_url: "http://localhost:3000".into(),
99 session_cookie_name: "session".into(),
100 session_ttl: Duration::from_secs(7 * 24 * 3600),
101 cookie_domain: CookieDomainPolicy::Auto,
102 secure_cookies: false,
103 trusted_origins: vec!["http://localhost:3000".into()],
104 smtp: None,
105 auto_admin_first_user: false,
106 remember_me_ttl: None,
107 session_binding: SessionBindingConfig::default(),
108 allow_signups: true,
109 db_schema: "public".into(),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Copy)]
116pub struct DurationSecs(pub Duration);
117
118impl DurationSecs {
119 pub fn as_duration(&self) -> Duration {
120 self.0
121 }
122}
123
124impl Serialize for DurationSecs {
125 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
126 serializer.serialize_u64(self.0.as_secs())
127 }
128}
129
130impl<'de> Deserialize<'de> for DurationSecs {
131 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
132 let secs = u64::deserialize(deserializer)?;
133 Ok(DurationSecs(Duration::from_secs(secs)))
134 }
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct SessionBindingConfig {
142 pub bind_ip: bool,
143 pub bind_user_agent: bool,
144 #[serde(default = "default_binding_action")]
145 pub ip_mismatch_action: BindingAction,
146 #[serde(default = "default_binding_action")]
147 pub ua_mismatch_action: BindingAction,
148}
149
150fn default_binding_action() -> BindingAction {
151 BindingAction::Warn
152}
153
154impl Default for SessionBindingConfig {
155 fn default() -> Self {
156 Self {
157 bind_ip: false,
158 bind_user_agent: false,
159 ip_mismatch_action: BindingAction::Warn,
160 ua_mismatch_action: BindingAction::Warn,
161 }
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub enum BindingAction {
167 Warn,
168 Invalidate,
169}
170
171mod duration_secs {
172 use serde::{self, Deserialize, Deserializer, Serializer};
173 use std::time::Duration;
174
175 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
176 where
177 S: Serializer,
178 {
179 serializer.serialize_u64(duration.as_secs())
180 }
181
182 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
183 where
184 D: Deserializer<'de>,
185 {
186 let secs = u64::deserialize(deserializer)?;
187 Ok(Duration::from_secs(secs))
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SmtpConfig {
193 pub host: String,
194 pub port: u16,
195 pub from: String,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct RateLimitConfig {
206 pub max_requests: u32,
207 pub window_secs: u64,
208}
209
210impl Default for RateLimitConfig {
211 fn default() -> Self {
212 Self {
213 max_requests: 10,
214 window_secs: 60,
215 }
216 }
217}
218
219#[cfg(feature = "email-password")]
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct EmailPasswordConfig {
222 pub min_password_length: usize,
223 pub require_email_verification: bool,
224 pub hibp_check: bool,
225 #[serde(default)]
226 pub password_policy: PasswordPolicyConfig,
227 #[serde(default = "default_rate_limit")]
231 pub rate_limit: Option<RateLimitConfig>,
232}
233
234#[cfg(feature = "email-password")]
235fn default_rate_limit() -> Option<RateLimitConfig> {
236 Some(RateLimitConfig::default())
237}
238
239#[cfg(feature = "email-password")]
240impl Default for EmailPasswordConfig {
241 fn default() -> Self {
242 Self {
243 min_password_length: 8,
244 require_email_verification: true,
245 hibp_check: true,
246 password_policy: PasswordPolicyConfig::default(),
247 rate_limit: Some(RateLimitConfig::default()),
248 }
249 }
250}
251
252#[cfg(feature = "email-password")]
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct PasswordPolicyConfig {
257 #[serde(default = "default_min_password_length")]
259 pub min_length: usize,
260 pub max_length: usize,
261 pub require_uppercase: bool,
262 pub require_lowercase: bool,
263 pub require_digit: bool,
264 pub require_special: bool,
265 pub disallow_common_passwords: bool,
267 pub password_history_count: u32,
269}
270
271#[cfg(feature = "email-password")]
272fn default_min_password_length() -> usize {
273 8
274}
275
276#[cfg(feature = "email-password")]
277impl Default for PasswordPolicyConfig {
278 fn default() -> Self {
279 Self {
280 min_length: 8,
281 max_length: 128,
282 require_uppercase: false,
283 require_lowercase: false,
284 require_digit: false,
285 require_special: false,
286 disallow_common_passwords: true,
287 password_history_count: 0,
288 }
289 }
290}
291
292#[cfg(feature = "passkey")]
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct PasskeyConfig {
295 pub rp_id: String,
296 pub rp_origin: String,
297 pub rp_name: String,
298}
299
300#[cfg(feature = "mfa")]
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct MfaConfig {
303 pub issuer: String,
304 pub backup_code_count: usize,
305}
306
307#[cfg(feature = "mfa")]
308impl Default for MfaConfig {
309 fn default() -> Self {
310 Self {
311 issuer: "YAuth".into(),
312 backup_code_count: 10,
313 }
314 }
315}
316
317#[cfg(feature = "oauth")]
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct OAuthConfig {
320 pub providers: Vec<OAuthProviderConfig>,
321}
322
323#[cfg(feature = "oauth")]
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct OAuthProviderConfig {
326 pub name: String,
327 pub client_id: String,
328 pub client_secret: String,
329 pub auth_url: String,
330 pub token_url: String,
331 pub userinfo_url: String,
332 pub scopes: Vec<String>,
333 #[serde(default)]
336 pub emails_url: Option<String>,
337}
338
339#[cfg(feature = "magic-link")]
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct MagicLinkConfig {
342 #[serde(with = "duration_secs")]
343 pub link_ttl: Duration,
344 pub allow_signup: bool,
345 pub default_role: Option<String>,
346}
347
348#[cfg(feature = "magic-link")]
349impl Default for MagicLinkConfig {
350 fn default() -> Self {
351 Self {
352 link_ttl: Duration::from_secs(5 * 60),
353 allow_signup: true,
354 default_role: None,
355 }
356 }
357}
358
359#[cfg(feature = "oauth2-server")]
360#[derive(Debug, Clone, Serialize, Deserialize)]
361pub struct OAuth2ServerConfig {
362 pub issuer: String,
364 #[serde(with = "duration_secs")]
366 pub authorization_code_ttl: Duration,
367 #[serde(default)]
369 pub scopes_supported: Vec<String>,
370 #[serde(default)]
372 pub allow_dynamic_registration: bool,
373 #[serde(default = "default_device_code_ttl", with = "duration_secs")]
375 pub device_code_ttl: Duration,
376 #[serde(default = "default_device_poll_interval")]
378 pub device_poll_interval: u32,
379 #[serde(default)]
381 pub device_verification_uri: Option<String>,
382 #[serde(default)]
385 pub consent_ui_url: Option<String>,
386}
387
388#[cfg(feature = "oauth2-server")]
389fn default_device_code_ttl() -> Duration {
390 Duration::from_secs(600)
391}
392
393#[cfg(feature = "oauth2-server")]
394fn default_device_poll_interval() -> u32 {
395 5
396}
397
398#[cfg(feature = "oauth2-server")]
399impl Default for OAuth2ServerConfig {
400 fn default() -> Self {
401 Self {
402 issuer: "http://localhost:3000".into(),
403 authorization_code_ttl: Duration::from_secs(60),
404 scopes_supported: vec![],
405 allow_dynamic_registration: true,
406 device_code_ttl: Duration::from_secs(600),
407 device_poll_interval: 5,
408 device_verification_uri: None,
409 consent_ui_url: None,
410 }
411 }
412}
413
414#[cfg(feature = "bearer")]
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct BearerConfig {
417 pub jwt_secret: String,
418 #[serde(with = "duration_secs")]
419 pub access_token_ttl: Duration,
420 #[serde(with = "duration_secs")]
421 pub refresh_token_ttl: Duration,
422 #[serde(default)]
424 pub audience: Option<String>,
425}
426
427#[cfg(feature = "account-lockout")]
431#[derive(Debug, Clone, Serialize, Deserialize)]
432pub struct AccountLockoutConfig {
433 pub max_failed_attempts: u32,
435 #[serde(with = "duration_secs")]
437 pub lockout_duration: Duration,
438 pub exponential_backoff: bool,
440 #[serde(with = "duration_secs")]
442 pub max_lockout_duration: Duration,
443 #[serde(with = "duration_secs")]
445 pub attempt_window: Duration,
446 pub auto_unlock: bool,
448}
449
450#[cfg(feature = "account-lockout")]
451impl Default for AccountLockoutConfig {
452 fn default() -> Self {
453 Self {
454 max_failed_attempts: 5,
455 lockout_duration: Duration::from_secs(300),
456 exponential_backoff: true,
457 max_lockout_duration: Duration::from_secs(86400),
458 attempt_window: Duration::from_secs(900),
459 auto_unlock: true,
460 }
461 }
462}
463
464#[cfg(feature = "webhooks")]
468#[derive(Debug, Clone, Serialize, Deserialize)]
469pub struct WebhookConfig {
470 pub max_retries: u32,
471 #[serde(with = "duration_secs")]
472 pub retry_delay: Duration,
473 #[serde(with = "duration_secs")]
474 pub timeout: Duration,
475 pub max_webhooks: usize,
476}
477
478#[cfg(feature = "webhooks")]
479impl Default for WebhookConfig {
480 fn default() -> Self {
481 Self {
482 max_retries: 3,
483 retry_delay: Duration::from_secs(30),
484 timeout: Duration::from_secs(10),
485 max_webhooks: 10,
486 }
487 }
488}
489
490#[cfg(feature = "oidc")]
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct OidcConfig {
497 pub issuer: String,
498 #[serde(with = "duration_secs")]
499 pub id_token_ttl: Duration,
500 #[serde(default = "default_oidc_claims")]
501 pub claims_supported: Vec<String>,
502}
503
504#[cfg(feature = "oidc")]
505fn default_oidc_claims() -> Vec<String> {
506 vec![
507 "sub".into(),
508 "email".into(),
509 "email_verified".into(),
510 "name".into(),
511 ]
512}
513
514#[cfg(feature = "oidc")]
515impl Default for OidcConfig {
516 fn default() -> Self {
517 Self {
518 issuer: "http://localhost:3000".into(),
519 id_token_ttl: Duration::from_secs(3600),
520 claims_supported: default_oidc_claims(),
521 }
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn default_config_has_sane_values() {
531 let config = YAuthConfig::default();
532 assert_eq!(config.session_cookie_name, "session");
533 assert_eq!(config.session_ttl, Duration::from_secs(604800));
534 assert!(!config.secure_cookies);
535 assert!(config.smtp.is_none());
536 assert!(!config.auto_admin_first_user);
537 assert!(config.allow_signups);
538 }
539
540 #[test]
541 fn config_serialization_roundtrip() {
542 let config = YAuthConfig::default();
543 let json = serde_json::to_string(&config).unwrap();
544 let parsed: YAuthConfig = serde_json::from_str(&json).unwrap();
545 assert_eq!(parsed.base_url, config.base_url);
546 assert_eq!(parsed.session_ttl, config.session_ttl);
547 }
548
549 #[test]
550 fn allow_signups_defaults_to_true_when_missing() {
551 let json = r#"{"base_url":"http://localhost:3000","session_cookie_name":"session","session_ttl":604800,"cookie_domain":null,"secure_cookies":false,"trusted_origins":["http://localhost:3000"],"smtp":null,"auto_admin_first_user":false,"session_binding":{"bind_ip":false,"bind_user_agent":false,"ip_mismatch_action":"Warn","ua_mismatch_action":"Warn"}}"#;
552 let config: YAuthConfig = serde_json::from_str(json).unwrap();
553 assert!(config.allow_signups);
554 }
555
556 #[test]
557 fn allow_signups_can_be_set_to_false() {
558 let config = YAuthConfig {
559 allow_signups: false,
560 ..Default::default()
561 };
562 let json = serde_json::to_string(&config).unwrap();
563 let parsed: YAuthConfig = serde_json::from_str(&json).unwrap();
564 assert!(!parsed.allow_signups);
565 }
566
567 #[test]
568 fn duration_serde_as_seconds() {
569 let config = YAuthConfig {
570 session_ttl: Duration::from_secs(3600),
571 ..Default::default()
572 };
573 let json = serde_json::to_string(&config).unwrap();
574 assert!(json.contains("3600"));
575 let parsed: YAuthConfig = serde_json::from_str(&json).unwrap();
576 assert_eq!(parsed.session_ttl, Duration::from_secs(3600));
577 }
578
579 #[cfg(feature = "email-password")]
580 #[test]
581 fn email_password_config_defaults() {
582 let config = EmailPasswordConfig::default();
583 assert_eq!(config.min_password_length, 8);
584 assert!(config.require_email_verification);
585 assert!(config.hibp_check);
586 }
587
588 #[cfg(feature = "mfa")]
589 #[test]
590 fn mfa_config_defaults() {
591 let config = MfaConfig::default();
592 assert_eq!(config.issuer, "YAuth");
593 assert_eq!(config.backup_code_count, 10);
594 }
595
596 #[test]
599 fn session_binding_config_defaults() {
600 let config = SessionBindingConfig::default();
601 assert!(!config.bind_ip);
602 assert!(!config.bind_user_agent);
603 assert_eq!(config.ip_mismatch_action, BindingAction::Warn);
604 assert_eq!(config.ua_mismatch_action, BindingAction::Warn);
605 }
606
607 #[test]
608 fn session_binding_config_serialization_roundtrip() {
609 let config = SessionBindingConfig {
610 bind_ip: true,
611 bind_user_agent: true,
612 ip_mismatch_action: BindingAction::Invalidate,
613 ua_mismatch_action: BindingAction::Warn,
614 };
615 let json = serde_json::to_string(&config).unwrap();
616 let parsed: SessionBindingConfig = serde_json::from_str(&json).unwrap();
617 assert!(parsed.bind_ip);
618 assert!(parsed.bind_user_agent);
619 assert_eq!(parsed.ip_mismatch_action, BindingAction::Invalidate);
620 assert_eq!(parsed.ua_mismatch_action, BindingAction::Warn);
621 }
622
623 #[test]
624 fn binding_action_equality() {
625 assert_ne!(BindingAction::Warn, BindingAction::Invalidate);
626 assert_eq!(BindingAction::Warn, BindingAction::Warn);
627 assert_eq!(BindingAction::Invalidate, BindingAction::Invalidate);
628 }
629
630 #[test]
633 fn duration_secs_serialization() {
634 let ds = DurationSecs(Duration::from_secs(2592000));
635 let json = serde_json::to_string(&ds).unwrap();
636 assert_eq!(json, "2592000");
637 }
638
639 #[test]
640 fn duration_secs_deserialization() {
641 let ds: DurationSecs = serde_json::from_str("2592000").unwrap();
642 assert_eq!(ds.0, Duration::from_secs(2592000));
643 }
644
645 #[test]
646 fn duration_secs_as_duration() {
647 let ds = DurationSecs(Duration::from_secs(42));
648 assert_eq!(ds.as_duration(), Duration::from_secs(42));
649 }
650
651 #[test]
652 fn yauth_config_with_remember_me_ttl_roundtrip() {
653 let config = YAuthConfig {
654 remember_me_ttl: Some(DurationSecs(Duration::from_secs(2592000))),
655 ..Default::default()
656 };
657 let json = serde_json::to_string(&config).unwrap();
658 assert!(json.contains("2592000"));
659 let parsed: YAuthConfig = serde_json::from_str(&json).unwrap();
660 let ttl = parsed
661 .remember_me_ttl
662 .expect("remember_me_ttl should be Some");
663 assert_eq!(ttl.0, Duration::from_secs(2592000));
664 }
665
666 #[test]
667 fn yauth_config_remember_me_ttl_none_omitted() {
668 let config = YAuthConfig {
669 remember_me_ttl: None,
670 ..Default::default()
671 };
672 let json = serde_json::to_string(&config).unwrap();
673 assert!(!json.contains("remember_me_ttl"));
674 }
675
676 #[cfg(feature = "email-password")]
679 #[test]
680 fn password_policy_config_defaults() {
681 let config = PasswordPolicyConfig::default();
682 assert_eq!(config.min_length, 8);
683 assert_eq!(config.max_length, 128);
684 assert!(!config.require_uppercase);
685 assert!(!config.require_lowercase);
686 assert!(!config.require_digit);
687 assert!(!config.require_special);
688 assert!(config.disallow_common_passwords);
689 assert_eq!(config.password_history_count, 0);
690 }
691
692 #[cfg(feature = "account-lockout")]
695 #[test]
696 fn account_lockout_config_defaults() {
697 let config = AccountLockoutConfig::default();
698 assert_eq!(config.max_failed_attempts, 5);
699 assert_eq!(config.lockout_duration, Duration::from_secs(300));
700 assert!(config.exponential_backoff);
701 assert_eq!(config.max_lockout_duration, Duration::from_secs(86400));
702 assert_eq!(config.attempt_window, Duration::from_secs(900));
703 assert!(config.auto_unlock);
704 }
705
706 #[cfg(feature = "account-lockout")]
707 #[test]
708 fn account_lockout_config_serialization_roundtrip() {
709 let config = AccountLockoutConfig::default();
710 let json = serde_json::to_string(&config).unwrap();
711 let parsed: AccountLockoutConfig = serde_json::from_str(&json).unwrap();
712 assert_eq!(parsed.max_failed_attempts, config.max_failed_attempts);
713 assert_eq!(parsed.lockout_duration, config.lockout_duration);
714 assert_eq!(parsed.exponential_backoff, config.exponential_backoff);
715 assert_eq!(parsed.max_lockout_duration, config.max_lockout_duration);
716 assert_eq!(parsed.attempt_window, config.attempt_window);
717 assert_eq!(parsed.auto_unlock, config.auto_unlock);
718 }
719
720 #[cfg(feature = "webhooks")]
723 #[test]
724 fn webhook_config_defaults() {
725 let config = WebhookConfig::default();
726 assert_eq!(config.max_retries, 3);
727 assert_eq!(config.retry_delay, Duration::from_secs(30));
728 assert_eq!(config.timeout, Duration::from_secs(10));
729 assert_eq!(config.max_webhooks, 10);
730 }
731
732 #[cfg(feature = "oidc")]
735 #[test]
736 fn oidc_config_defaults() {
737 let config = OidcConfig::default();
738 assert_eq!(config.issuer, "http://localhost:3000");
739 assert_eq!(config.id_token_ttl, Duration::from_secs(3600));
740 assert!(config.claims_supported.contains(&"sub".to_string()));
741 assert!(config.claims_supported.contains(&"email".to_string()));
742 assert!(
743 config
744 .claims_supported
745 .contains(&"email_verified".to_string())
746 );
747 assert!(config.claims_supported.contains(&"name".to_string()));
748 assert_eq!(config.claims_supported.len(), 4);
749 }
750}