1pub mod request_state;
4
5mod builder;
6mod origins;
7mod plugins;
8mod secrets;
9
10use crate::auth::trusted_origins::{matches_origin_pattern, OriginMatchSettings};
11use crate::cookies::{AuthCookie, AuthCookies};
12use crate::db::{AuthSchema, DbAdapter, DbSchema};
13use crate::env::{env_var, logger::Logger};
14use crate::error::RustAuthError;
15use crate::options::{
16 BackgroundTaskFuture, BackgroundTaskRunner, DynamicRateLimitPathRule, HybridRateLimitOptions,
17 MissingIpPolicy, RateLimitPathRule, RateLimitStorageOption, RateLimitStore, RustAuthOptions,
18 SecondaryStorage,
19};
20use crate::plugin::{AuthPlugin, PluginErrorCode};
21use crate::rate_limit::GovernorMemoryRateLimitStore;
22use crate::session::{DbSessionStore, SessionStore};
23use crate::user::DbUserStore;
24use crate::verification::{DbVerificationStore, VerificationStore};
25use http::Request;
26#[cfg(feature = "oauth")]
27use rustauth_oauth::oauth2::SocialOAuthProvider;
28use std::collections::BTreeMap;
29use std::fmt;
30use std::future::Future;
31use std::pin::Pin;
32use std::sync::Arc;
33
34use std::time::Duration as StdDuration;
35use time::Duration;
36
37pub use builder::{
38 create_auth_context, create_auth_context_with_adapter, create_auth_context_with_environment,
39 create_auth_context_with_environment_and_adapter,
40};
41pub use secrets::SecretMaterial;
42
43use origins::push_trusted_origin;
44
45pub type ContextTelemetryFuture = Pin<Box<dyn Future<Output = ()> + Send + 'static>>;
46pub type ContextTelemetryPublisher =
47 Arc<dyn Fn(ContextTelemetryEvent) -> ContextTelemetryFuture + Send + Sync>;
48
49#[derive(Clone, Debug, PartialEq)]
50pub struct ContextTelemetryEvent {
51 pub event_type: String,
52 pub anonymous_id: Option<String>,
53 pub payload: serde_json::Value,
54}
55
56pub(super) fn noop_telemetry_publisher() -> ContextTelemetryPublisher {
57 Arc::new(|_| Box::pin(async move {}))
58}
59
60#[derive(Clone)]
61pub struct AuthContext {
62 pub app_name: String,
63 pub base_url: String,
64 pub base_path: String,
65 pub options: RustAuthOptions,
66 pub auth_cookies: AuthCookies,
67 pub session_config: SessionConfig,
68 pub secret: String,
69 pub secret_config: SecretMaterial,
70 pub password: PasswordContext,
71 pub rate_limit: RateLimitContext,
72 pub trusted_origins: Vec<String>,
73 pub disabled_paths: Vec<String>,
74 pub plugins: Vec<AuthPlugin>,
75 pub adapter: Option<Arc<dyn DbAdapter>>,
76 pub secondary_storage: Option<Arc<dyn SecondaryStorage>>,
77 pub background_tasks: Option<Arc<dyn BackgroundTaskRunner>>,
78 #[cfg(feature = "oauth")]
79 pub social_providers: BTreeMap<String, Arc<dyn SocialOAuthProvider>>,
80 pub db_schema: DbSchema,
81 pub plugin_error_codes: BTreeMap<String, PluginErrorCode>,
82 pub plugin_database_hooks: Vec<crate::plugin::PluginDatabaseHook>,
83 pub plugin_migrations: Vec<crate::plugin::PluginMigration>,
84 pub telemetry_publisher: ContextTelemetryPublisher,
85 pub logger: Logger,
86}
87
88#[derive(Clone, Default, PartialEq, Eq)]
90pub struct AuthEnvironment {
91 pub rustauth_secret: Option<String>,
92 pub rustauth_secrets: Option<String>,
93 pub rustauth_trusted_origins: Option<String>,
94}
95
96impl fmt::Debug for AuthEnvironment {
97 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
98 formatter
99 .debug_struct("AuthEnvironment")
100 .field(
101 "rustauth_secret",
102 &self.rustauth_secret.as_ref().map(|_| "<redacted>"),
103 )
104 .field(
105 "rustauth_secrets",
106 &self.rustauth_secrets.as_ref().map(|_| "<redacted>"),
107 )
108 .field("rustauth_trusted_origins", &self.rustauth_trusted_origins)
109 .finish()
110 }
111}
112
113impl AuthEnvironment {
114 pub fn from_process() -> Self {
115 Self {
116 rustauth_secret: env_var("SECRET"),
117 rustauth_secrets: env_var("SECRETS"),
118 rustauth_trusted_origins: env_var("TRUSTED_ORIGINS"),
119 }
120 }
121}
122
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct SessionConfig {
125 pub update_age: Duration,
126 pub expires_in: Duration,
127 pub fresh_age: Duration,
128 pub cookie_refresh_cache: bool,
129}
130
131#[derive(Clone)]
132pub struct PasswordContext {
133 pub config: PasswordPolicy,
134 pub hash: fn(&str) -> Result<String, RustAuthError>,
135 pub verify: fn(&str, &str) -> Result<bool, RustAuthError>,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct PasswordPolicy {
140 pub min_password_length: usize,
141 pub max_password_length: usize,
142}
143
144#[derive(Clone)]
145pub struct RateLimitContext {
146 pub enabled: bool,
147 pub window: Duration,
148 pub max: u64,
149 pub storage: RateLimitStorageOption,
150 pub custom_rules: Vec<RateLimitPathRule>,
151 pub dynamic_rules: Vec<DynamicRateLimitPathRule>,
152 pub plugin_rules: Vec<crate::plugin::PluginRateLimitRule>,
153 pub custom_store: Option<Arc<dyn RateLimitStore>>,
154 pub hybrid: HybridRateLimitOptions,
155 pub memory_cleanup_interval: Option<StdDuration>,
156 pub memory_store: Arc<GovernorMemoryRateLimitStore>,
157 pub missing_ip_policy: MissingIpPolicy,
158}
159
160impl AuthContext {
161 pub fn db_schema(&self) -> &DbSchema {
162 &self.db_schema
163 }
164
165 pub fn schema(&self) -> AuthSchema<'_> {
167 AuthSchema::new(&self.db_schema)
168 }
169
170 pub fn adapter(&self) -> Option<Arc<dyn DbAdapter>> {
171 self.adapter.clone()
172 }
173
174 pub fn require_adapter(&self) -> Result<Arc<dyn DbAdapter>, RustAuthError> {
176 self.adapter
177 .clone()
178 .ok_or_else(|| RustAuthError::InvalidConfig("database adapter is required".to_owned()))
179 }
180
181 pub fn adapter_ref(&self) -> Result<&dyn DbAdapter, RustAuthError> {
183 self.adapter
184 .as_deref()
185 .ok_or_else(|| RustAuthError::InvalidConfig("database adapter is required".to_owned()))
186 }
187
188 pub fn users(&self) -> Result<DbUserStore<'_>, RustAuthError> {
190 DbUserStore::from_context(self)
191 }
192
193 pub fn sessions(&self) -> Result<SessionStore<'_>, RustAuthError> {
195 SessionStore::new(self)
196 }
197
198 pub fn verifications(&self) -> Result<VerificationStore<'_>, RustAuthError> {
203 VerificationStore::new(self)
204 }
205
206 pub fn session_store(&self) -> Result<DbSessionStore<'_>, RustAuthError> {
208 DbSessionStore::from_context(self)
209 }
210
211 pub fn verification_store(&self) -> Result<DbVerificationStore<'_>, RustAuthError> {
216 DbVerificationStore::from_context(self)
217 }
218
219 pub fn create_auth_cookie(
224 &self,
225 name: &str,
226 max_age: Option<u64>,
227 ) -> Result<AuthCookie, RustAuthError> {
228 crate::cookies::create_auth_cookie(&self.options, name, max_age)
229 }
230
231 pub fn secondary_storage(&self) -> Option<Arc<dyn SecondaryStorage>> {
232 self.secondary_storage.clone()
233 }
234
235 pub fn run_background_task(&self, task: BackgroundTaskFuture) -> bool {
236 let Some(runner) = &self.background_tasks else {
237 return false;
238 };
239 runner.spawn(task);
240 true
241 }
242
243 pub async fn publish_telemetry(&self, event: ContextTelemetryEvent) {
244 (self.telemetry_publisher)(event).await;
245 }
246
247 #[cfg(feature = "oauth")]
248 pub fn social_provider(&self, id: &str) -> Option<Arc<dyn SocialOAuthProvider>> {
249 self.social_providers.get(id).cloned()
250 }
251
252 pub fn has_plugin(&self, id: &str) -> bool {
253 self.plugins.iter().any(|plugin| plugin.id == id)
254 }
255
256 pub fn plugin(&self, id: &str) -> Option<&AuthPlugin> {
257 self.plugins.iter().find(|plugin| plugin.id == id)
258 }
259
260 pub fn is_trusted_origin(&self, url: &str, settings: Option<OriginMatchSettings>) -> bool {
261 self.trusted_origins
262 .iter()
263 .any(|origin| matches_origin_pattern(url, origin, settings))
264 }
265
266 pub fn trusted_origins_for_request(
267 &self,
268 request: Option<&Request<Vec<u8>>>,
269 ) -> Result<Vec<String>, RustAuthError> {
270 let mut origins = self.trusted_origins.clone();
271 if let Some(provider) = self.options.trusted_origins.provider() {
272 for origin in provider.trusted_origins(request)? {
273 push_trusted_origin(&mut origins, origin);
274 }
275 }
276 Ok(origins)
277 }
278
279 pub fn trusted_providers_for_request(
280 &self,
281 request: Option<&Request<Vec<u8>>>,
282 ) -> Result<Vec<String>, RustAuthError> {
283 let linking = &self.options.account.account_linking;
284 let mut providers = linking.trusted_providers.clone();
285 if let Some(provider) = &linking.trusted_providers_provider {
286 for trusted in provider.trusted_providers()? {
287 if !providers.iter().any(|existing| existing == &trusted) {
288 providers.push(trusted);
289 }
290 }
291 }
292 if let Some(provider) = &linking.trusted_providers_request_provider {
293 for trusted in provider.trusted_providers_for_request(request)? {
294 if !providers.iter().any(|existing| existing == &trusted) {
295 providers.push(trusted);
296 }
297 }
298 }
299 Ok(providers)
300 }
301
302 pub fn is_trusted_origin_for_request(
303 &self,
304 url: &str,
305 settings: Option<OriginMatchSettings>,
306 request: Option<&Request<Vec<u8>>>,
307 ) -> Result<bool, RustAuthError> {
308 Ok(self
309 .trusted_origins_for_request(request)?
310 .iter()
311 .any(|origin| matches_origin_pattern(url, origin, settings)))
312 }
313}