1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4#[cfg(feature = "oauth")]
5use rustauth_oauth::oauth2::SocialOAuthProvider;
6
7use time::Duration;
8
9use crate::background::tokio::TokioBackgroundTaskRunner;
10use crate::cookies::get_cookies;
11use crate::crypto::password::{hash_password, verify_password};
12use crate::crypto::{build_secret_config, parse_secrets_env};
13use crate::db::RateLimitStorage as DbRateLimitStorage;
14use crate::db::{auth_schema, AuthSchemaOptions, DbAdapter, DbField, HookedAdapter};
15use crate::env::is_production_posture;
16use crate::env::logger::create_logger;
17use crate::error::RustAuthError;
18use crate::options::hooks::{plugin_after_hooks, plugin_before_hooks};
19use crate::options::RateLimitStore;
20use crate::options::{
21 plugin_database_hooks_from_init, BackgroundTaskRunner, ModelSchemaOptions,
22 RateLimitStorageOption, RustAuthOptions, SessionAdditionalField, UserAdditionalField,
23};
24use crate::plugin::AuthPlugin;
25use crate::rate_limit::{GovernorMemoryRateLimitStore, LegacyRateLimitStorageAdapter};
26
27use super::origins::resolve_trusted_origins;
28use super::plugins::initialize_plugins;
29use super::secrets::{resolve_legacy_secret, validate_secret, DEFAULT_SECRET};
30use super::{
31 noop_telemetry_publisher, AuthContext, AuthEnvironment, PasswordContext, PasswordPolicy,
32 RateLimitContext, SecretMaterial, SessionConfig,
33};
34
35pub fn create_auth_context(options: RustAuthOptions) -> Result<AuthContext, RustAuthError> {
36 create_auth_context_with_environment_and_adapter(options, AuthEnvironment::from_process(), None)
37}
38
39pub fn create_auth_context_with_adapter(
40 options: RustAuthOptions,
41 adapter: Arc<dyn DbAdapter>,
42) -> Result<AuthContext, RustAuthError> {
43 create_auth_context_with_environment_and_adapter(
44 options,
45 AuthEnvironment::from_process(),
46 Some(adapter),
47 )
48}
49
50pub fn create_auth_context_with_environment(
51 options: RustAuthOptions,
52 environment: AuthEnvironment,
53) -> Result<AuthContext, RustAuthError> {
54 create_auth_context_with_environment_and_adapter(options, environment, None)
55}
56
57pub fn create_auth_context_with_environment_and_adapter(
58 options: RustAuthOptions,
59 environment: AuthEnvironment,
60 adapter: Option<Arc<dyn DbAdapter>>,
61) -> Result<AuthContext, RustAuthError> {
62 let logger = create_logger(options.logger.clone());
63 let production_posture = is_production_posture(&options);
64 let env_secrets = parse_secrets_env(environment.rustauth_secrets.as_deref())?;
65 let secrets = if options.secrets.is_empty() {
66 env_secrets.unwrap_or_default()
67 } else {
68 options.secrets.clone()
69 };
70 let legacy_secret = resolve_legacy_secret(&options, &environment);
71
72 let (secret, secret_config) = if secrets.is_empty() {
73 let secret = legacy_secret.unwrap_or_else(|| DEFAULT_SECRET.to_owned());
74 validate_secret(&secret, &options)?;
75 (secret.clone(), SecretMaterial::Single(secret))
76 } else {
77 let config = build_secret_config(&secrets, legacy_secret.as_deref().unwrap_or(""))?;
78 let current = config
79 .keys
80 .get(&config.current_version)
81 .cloned()
82 .ok_or_else(|| {
83 RustAuthError::InvalidSecretConfig(format!(
84 "secret version {} not found in keys",
85 config.current_version
86 ))
87 })?;
88 (current, SecretMaterial::Rotating(config))
89 };
90
91 let base_path = options
92 .base_path
93 .clone()
94 .unwrap_or_else(|| "/api/auth".to_owned());
95 let base_url = options.base_url.clone().unwrap_or_default();
96 let trusted_origins = resolve_trusted_origins(&base_url, &options, &environment);
97 let auth_cookies = get_cookies(&options)?;
98 #[cfg(feature = "oauth")]
99 let social_providers = resolve_social_providers(&options)?;
100 let session_config = SessionConfig {
101 update_age: options.session.update_age.unwrap_or(Duration::hours(24)),
102 expires_in: options.session.expires_in.unwrap_or(Duration::days(7)),
103 fresh_age: options.session.fresh_age.unwrap_or(Duration::days(1)),
104 cookie_refresh_cache: options.session.cookie_cache.refresh_cache,
105 };
106 let password = PasswordContext {
107 config: PasswordPolicy {
108 min_password_length: options.password.min_password_length,
109 max_password_length: options.password.max_password_length,
110 },
111 hash: options.password.hash_password.unwrap_or(hash_password),
112 verify: options.password.verify_password.unwrap_or(verify_password),
113 };
114 validate_rate_limit_storage(&options)?;
115 let rate_limit = RateLimitContext {
116 enabled: options.rate_limit.enabled.unwrap_or(production_posture),
117 window: options.rate_limit.window,
118 max: options.rate_limit.max,
119 storage: options.rate_limit.storage,
120 custom_rules: options.rate_limit.custom_rules.clone(),
121 dynamic_rules: options.rate_limit.dynamic_rules.clone(),
122 plugin_rules: Vec::new(),
123 custom_store: options.rate_limit.custom_store.clone().or_else(|| {
124 options.rate_limit.custom_storage.clone().map(|storage| {
125 Arc::new(LegacyRateLimitStorageAdapter::new(storage)) as Arc<dyn RateLimitStore>
126 })
127 }),
128 hybrid: options.rate_limit.hybrid.clone(),
129 memory_cleanup_interval: options.rate_limit.memory_cleanup_interval,
130 memory_store: Arc::new(GovernorMemoryRateLimitStore::with_cleanup_interval(
131 options.rate_limit.memory_cleanup_interval,
132 )),
133 missing_ip_policy: options.rate_limit.missing_ip_policy,
134 };
135
136 let schema_options = schema_options_from_auth_options(&options);
137 let app_name = options
138 .app_name
139 .clone()
140 .unwrap_or_else(|| "RustAuth".to_owned());
141 let mut context =
142 AuthContext {
143 app_name,
144 base_url,
145 base_path,
146 options: options.clone(),
147 auth_cookies,
148 session_config,
149 secret,
150 secret_config,
151 password,
152 rate_limit,
153 trusted_origins,
154 disabled_paths: options.disabled_paths,
155 plugins: options.plugins,
156 adapter,
157 secondary_storage: options.secondary_storage.clone(),
158 background_tasks: options.advanced.background_tasks.clone().or_else(|| {
159 Some(Arc::new(TokioBackgroundTaskRunner) as Arc<dyn BackgroundTaskRunner>)
160 }),
161 #[cfg(feature = "oauth")]
162 social_providers,
163 db_schema: auth_schema(schema_options),
164 plugin_error_codes: BTreeMap::new(),
165 plugin_database_hooks: {
166 let mut hooks = plugin_database_hooks_from_init(&options.init_database_hooks);
167 hooks.extend(options.database_hooks.clone());
168 hooks
169 },
170 plugin_migrations: Vec::new(),
171 telemetry_publisher: noop_telemetry_publisher(),
172 logger,
173 };
174 apply_global_hooks(&mut context);
175 initialize_plugins(&mut context)?;
176 if !context.plugin_database_hooks.is_empty() {
177 if let Some(adapter) = context.adapter.clone() {
178 context.adapter = Some(Arc::new(HookedAdapter::with_logger(
179 adapter,
180 context.plugin_database_hooks.clone(),
181 context.logger.clone(),
182 )));
183 }
184 }
185 Ok(context)
186}
187
188fn apply_global_hooks(context: &mut AuthContext) {
189 let before = plugin_before_hooks(&context.options.hooks);
190 let after = plugin_after_hooks(&context.options.hooks);
191 if before.is_empty() && after.is_empty() {
192 return;
193 }
194 let mut plugin = AuthPlugin::new("__rustauth_global__");
195 plugin.hooks.before = before;
196 plugin.hooks.after = after;
197 context.plugins.insert(0, plugin);
198}
199
200#[cfg(feature = "oauth")]
201fn resolve_social_providers(
202 options: &RustAuthOptions,
203) -> Result<BTreeMap<String, Arc<dyn SocialOAuthProvider>>, RustAuthError> {
204 let mut providers = BTreeMap::new();
205 for provider in &options.social_providers {
206 insert_social_provider(&mut providers, provider.clone())?;
207 }
208 Ok(providers)
209}
210
211#[cfg(feature = "oauth")]
212pub(super) fn insert_social_provider(
213 providers: &mut BTreeMap<String, Arc<dyn SocialOAuthProvider>>,
214 provider: Arc<dyn SocialOAuthProvider>,
215) -> Result<(), RustAuthError> {
216 let id = provider.id().to_owned();
217 if id.trim().is_empty() {
218 return Err(RustAuthError::InvalidConfig(
219 "social provider id cannot be empty".to_owned(),
220 ));
221 }
222 if providers.insert(id.clone(), provider).is_some() {
223 return Err(RustAuthError::InvalidConfig(format!(
224 "duplicate social provider `{id}`"
225 )));
226 }
227 Ok(())
228}
229
230fn validate_rate_limit_storage(options: &RustAuthOptions) -> Result<(), RustAuthError> {
231 if options.rate_limit.custom_store.is_some() || options.rate_limit.custom_storage.is_some() {
232 return Ok(());
233 }
234 if matches!(
235 options.rate_limit.storage,
236 RateLimitStorageOption::Database | RateLimitStorageOption::SecondaryStorage
237 ) {
238 return Err(RustAuthError::InvalidConfig(
239 "rate_limit.custom_store or rate_limit.custom_storage is required when using database or secondary-storage rate limiting without a concrete adapter".to_owned(),
240 ));
241 }
242 Ok(())
243}
244
245fn apply_model_schema(table: &mut crate::db::TableOptions, schema: &ModelSchemaOptions) {
246 table.name = schema.model_name.clone();
247 table.field_names = schema
248 .field_names
249 .iter()
250 .map(|(key, value)| (key.clone(), value.clone()))
251 .collect();
252}
253
254fn schema_options_from_auth_options(options: &RustAuthOptions) -> AuthSchemaOptions {
255 let mut schema_options = AuthSchemaOptions {
256 has_secondary_storage: options.secondary_storage.is_some(),
257 store_session_in_database: options.session.store_session_in_database,
258 rate_limit_storage: match options.rate_limit.storage {
259 RateLimitStorageOption::Memory => DbRateLimitStorage::Memory,
260 RateLimitStorageOption::Database => DbRateLimitStorage::Database,
261 RateLimitStorageOption::SecondaryStorage => DbRateLimitStorage::SecondaryStorage,
262 },
263 ..AuthSchemaOptions::default()
264 };
265 apply_model_schema(&mut schema_options.user, &options.user.schema);
266 apply_model_schema(&mut schema_options.session, &options.session.schema);
267 apply_model_schema(&mut schema_options.account, &options.account.schema);
268 apply_model_schema(
269 &mut schema_options.verification,
270 &options.verification.schema,
271 );
272 apply_model_schema(&mut schema_options.rate_limit, &options.rate_limit.schema);
273 for (name, field) in &options.user.additional_fields {
274 schema_options
275 .user
276 .additional_fields
277 .insert(name.clone(), user_additional_field_to_db_field(name, field));
278 }
279 for (name, field) in &options.session.additional_fields {
280 schema_options.session.additional_fields.insert(
281 name.clone(),
282 session_additional_field_to_db_field(name, field),
283 );
284 }
285 schema_options
286}
287
288pub(super) fn user_additional_field_to_db_field(
289 logical_name: &str,
290 field: &UserAdditionalField,
291) -> DbField {
292 additional_field_to_db_field(
293 logical_name,
294 field.db_name.as_deref(),
295 field.field_type.clone(),
296 field.required,
297 field.input,
298 field.returned,
299 )
300}
301
302pub(super) fn session_additional_field_to_db_field(
303 logical_name: &str,
304 field: &SessionAdditionalField,
305) -> DbField {
306 additional_field_to_db_field(
307 logical_name,
308 field.db_name.as_deref(),
309 field.field_type.clone(),
310 field.required,
311 field.input,
312 field.returned,
313 )
314}
315
316fn additional_field_to_db_field(
317 logical_name: &str,
318 db_name: Option<&str>,
319 field_type: crate::db::DbFieldType,
320 required: bool,
321 input: bool,
322 returned: bool,
323) -> DbField {
324 let mut field = DbField::new(db_name.unwrap_or(logical_name), field_type);
325 if !required {
326 field = field.optional();
327 }
328 if !input {
329 field = field.generated();
330 }
331 if !returned {
332 field = field.hidden();
333 }
334 field
335}